import json
import numpy as np
import pprint

def create_config(ds_src, ds_tgt, class_name, 
    maxiters=1000, tol=1e-6, lr=1, eps=1): 

    return {
        "base_dir" : "", 
        "sources" : [
            f"data_nlp/sources/{ds_src}/encoding_*.npy"
        ], 
        "targets": [
            f"data_nlp/targets/{ds_tgt}/encoding_{class_name}.npy"
        ], 
        "target_dist": [1], 
        "solver": "simplex", 
        "solver_args": {
            "maxiters": maxiters, 
            "tol": tol, 
            "eps": eps, 
            "lr": lr
        }, 
        "normalize_source": ds_src == 'stl', 
        "normalize_target": ds_tgt == 'stl', 
        "stat_kwargs": {
            "alphas": [0.1]
        }
    }
# number of classes in target
nclasses = {
    "sst": 2, 
    "emoji": 20, 
    "emotion": 4, 
    "yelp": 5, 
    "dydae": 7
}

if __name__ == "__main__": 
    SOURCES = TARGETS = ('sst', 'emoji', 'emotion', 'yelp', 'dydae')
    EPSILONS = [10]
    # EPSILONS = [1, 0.1, 0.01, 0.001, 0.0001]

    for ds_src in SOURCES: 
        for ds_tgt in TARGETS: 
            for tgt in range(nclasses[ds_tgt]): 
                print(f'{ds_src}_{ds_tgt}_class_encoding_{tgt}')
                for eps in EPSILONS: 
                    d = create_config(ds_src, ds_tgt, tgt, eps=eps)
                    with open(f'configs/{ds_src}_{ds_tgt}_class_encoding_{tgt}_eps_{eps}.json', 'w') as f: 
                        json.dump(d, f, indent=4, sort_keys=True)
