Source code for ada.utils.streamlit_configs

import os
import streamlit as st
import glob
import json
import numpy as np
import ada.utils.experimentation as xp
from copy import deepcopy


[docs]def configure_network(default_file, on_sidebar=True): network_params = xp.load_json_dict(default_file) stmod = st.sidebar if on_sidebar else st stmod.header("Network configuration") stmod.subheader("Learning parameters") stmod.markdown("$\\lambda$ controls the weight of the critic") adapt_lambda = stmod.checkbox( "Use adaptive lambda", value=network_params["train_params"]["adapt_lambda"] ) lambda_init = stmod.number_input( "Final (max) lambda", value=float(network_params["train_params"]["lambda_init"]), step=1e-3, format="%.1f", ) adapt_lr = stmod.checkbox( "Use adaptive learning rate", value=network_params["train_params"]["adapt_lr"] ) init_lr = stmod.number_input( "Initial learning rate", value=float(network_params["train_params"]["init_lr"]), step=1e-4, format="%.4f", ) nb_init_epochs = stmod.number_input( "Number of warmup epochs", min_value=0, value=network_params["train_params"]["nb_init_epochs"], ) nb_adapt_epochs = stmod.number_input( "Number of adaptation epochs", min_value=1, value=network_params["train_params"]["nb_adapt_epochs"], ) stmod.subheader("Architecture details") stmod.markdown( "Choose the width of the Feature extractor hidden layer. The size of the " "final layer width is used for both Task Classifier and Critic. " "The Critic can have more than one layer, all of same width." ) hidden_size = stmod.text_input( "Feature hidden layers", value=network_params["archi_params"]["feature"]["hidden_size"], ) hidden_size = json.loads(hidden_size) critic_layers = stmod.text_input( "Critic hidden layers", value=network_params["archi_params"]["critic"]["hidden_size"], ) critic_layers = json.loads(critic_layers) network_params["train_params"]["adapt_lambda"] = adapt_lambda network_params["train_params"]["lambda_init"] = lambda_init network_params["train_params"]["adapt_lr"] = adapt_lr network_params["train_params"]["init_lr"] = init_lr network_params["train_params"]["nb_init_epochs"] = nb_init_epochs network_params["train_params"]["nb_adapt_epochs"] = nb_adapt_epochs network_params["archi_params"]["feature"]["hidden_size"] = hidden_size network_params["archi_params"]["critic"]["hidden_size"] = critic_layers return network_params
[docs]def configure_dataset(default_dir, on_sidebar=True): stmod = st.sidebar if on_sidebar else st stmod.header("Dataset") json_files = glob.glob(f"{default_dir}/*.json") all_params_files = [(f, xp.load_json_dict(f)) for f in json_files] toy_files = [ f for (f, p) in all_params_files if p.get("dataset_group", "none") == "toy" ] dataset = stmod.selectbox("Dataset", toy_files, index=0) default_params = xp.load_json_dict(dataset) if on_sidebar: return default_params toy_params = deepcopy(default_params) # centers position default_centers = np.array([[-0.5, 0.0], [0.5, 0]]) param_centers = default_params["cluster"].get("centers", default_centers.tolist()) new_centers_st = stmod.text_input( "Position of class centers (source)", param_centers ) new_centers = json.loads(new_centers_st) n_clusters = len(new_centers) stmod.markdown(f"{n_clusters} classes.") toy_params["cluster"]["centers"] = new_centers toy_params["cluster"]["n_clusters"] = n_clusters # centers radii radius0 = default_params["cluster"]["radius"] same_radius = stmod.checkbox( "Use same variance everywhere (class/dimension)", value=isinstance(radius0, float), ) if same_radius: if not isinstance(radius0, float): radius0 = np.array(radius0).flatten()[0] radius = stmod.number_input( "Class variance", step=10 ** (np.floor(np.log10(radius0))), value=radius0, format="%.4f", ) toy_params["cluster"]["radius"] = radius else: if isinstance(radius0, float): radii = (np.ones_like(new_centers) * radius0).tolist() else: radii = radius0 new_radius_st = stmod.text_input( "Variance of each class along each dimension", radii ) new_radius = json.loads(new_radius_st) shape_variance = np.array(new_radius).shape shape_clusters = np.array(new_centers).shape if shape_variance == shape_clusters: stmod.markdown( ":heavy_check_mark: Shape of variance values matches the shape of clusters." ) else: stmod.markdown( ":warning: Warning: Shape of variances doesn't match the shape of clusters." ) toy_params["cluster"]["radius"] = new_radius # class balance proba_classes = default_params["cluster"]["proba_classes"] if n_clusters == 2: new_proba_classes = stmod.number_input( "Probability of class 1", step=10 ** (np.floor(np.log10(proba_classes))), value=proba_classes, format="%.4f", ) else: new_proba_classes_st = stmod.text_input( "Weight or probability of each class (will be normalized to sum to 1)", proba_classes, ) new_proba_classes = json.loads(new_proba_classes_st) nb_probas = len(new_proba_classes) if nb_probas == n_clusters: stmod.markdown( ":heavy_check_mark: class probas values matches the number of clusters." ) else: stmod.markdown( ":warning: Warning: class probas values don't match the number of clusters." ) toy_params["cluster"]["proba_classes"] = new_proba_classes # target shift default_cond_shift = default_params["shift"]["data_shift"] if n_clusters == 2: cond_shift = stmod.checkbox( "Class-conditional shift", value="cond" in default_cond_shift ) else: cond_shift = False if cond_shift: rotation0 = default_params["shift"]["re"] if isinstance(rotation0, float): default_r0 = rotation0 default_r1 = rotation0 else: default_r0 = default_params["shift"]["re"][0] default_r1 = default_params["shift"]["re"][1] re0 = stmod.slider( "Rotation class 0", min_value=-np.pi, max_value=np.pi, value=default_r0, ) re1 = stmod.slider( "Rotation class 1", min_value=-np.pi, max_value=np.pi, value=default_r1, ) transl0 = default_params["shift"]["te"] if isinstance(transl0, float): default_t0 = transl0 default_t1 = transl0 else: default_t0 = default_params["shift"]["te"][0] default_t1 = default_params["shift"]["te"][1] te0 = stmod.slider( "Translation class 0", min_value=-3.0, max_value=3.0, value=default_t0, ) te1 = stmod.slider( "Translation class 1", min_value=-3.0, max_value=3.0, value=default_t1, ) toy_params["shift"]["re"] = [re0, re1] toy_params["shift"]["te"] = [te0, te1] else: re = stmod.slider( "Rotation", min_value=-np.pi, max_value=np.pi, value=default_params["shift"]["re"], ) te = stmod.slider( "Translation", min_value=-3.0, max_value=3.0, value=default_params["shift"]["te"], ) toy_params["shift"]["re"] = re toy_params["shift"]["te"] = te test_view_data(toy_params) # choose a new (unique) name for the dataset and save data_hash = xp.param_to_hash(toy_params) default_hash = xp.param_to_hash(default_params) default_name = toy_params["dataset_name"] if default_hash != data_hash: if toy_params["dataset_name"] == default_params["dataset_name"]: default_name = data_hash data_name = st.text_input("Choose a (unique) name for your dataset", default_name) data_name = data_name.replace(" ", "_") toy_params["dataset_name"] = data_name data_file = os.path.join(default_dir, f"{data_name}.json") if os.path.exists(data_file): st.text(f"Data set with this name exists! {data_file}") else: if st.button("Save dataset"): with open(data_file, "w") as fd: fd.write(json.dumps(toy_params)) default_params = deepcopy(toy_params) st.text(f"Configuration saved to {data_file}") return toy_params, data_hash
[docs]def test_view_data(data_params): from ada.utils.plotting import colored_scattered_plot2x2 from ada.datasets.toys import CausalBlobs, get_datashift_params target_shift = get_datashift_params(**data_params["shift"]) source_data = CausalBlobs( ".tmp_view", n_samples=data_params["n_samples"], transform=get_datashift_params("no_shift"), cluster_params=data_params["cluster"], ) target_data = CausalBlobs( ".tmp_view", n_samples=data_params["n_samples"], transform=target_shift, cluster_params=data_params["cluster"], ) X_s, y_s = source_data.data, source_data.targets X_t, y_t = target_data.data, target_data.targets fig, ax = colored_scattered_plot2x2(X_s, X_t, y_s, y_t, set_aspect_equal=True) fig.set_tight_layout(tight=None) st.write("View data") st.pyplot(fig) source_data.delete_from_disk() target_data.delete_from_disk()