import json
import numpy as np
import pprint

def get_classes(ds): 
    if ds in ('pets'): 
        return [
            "cat", 
            "dog"
        ] 
    elif ds in ('cifar10', '3db', 'in'):
        return [
            "airplane", 
            "automobile", 
            "bird", 
            "cat", 
            "deer", 
            "dog", 
            "frog", 
            "horse", 
            "ship", 
            "truck"
        ]
    elif ds in ('stl'): 
        return [
            "airplane", 
            "automobile", 
            "bird", 
            "cat", 
            "deer", 
            "dog", 
            "horse", 
            # "monkey",
            "ship", 
            "truck"
        ]

def get_sources(ds, class_name):
    base_dir = f"data/sources/{ds}"
    if ds == 'in': 
        return  [
            f"{base_dir}/{class_name}/*.npy"
        ] 
    elif ds == 'cifar10': 
        return [
            f"{base_dir}/{class_name}/cluster_*.npy"
        ]
    elif ds == 'pets': 
        return [
            f"{base_dir}/{class_name}/{class_name}-*.npy"
        ]
    elif ds == '3db': 
        return [
            f"{base_dir}/{class_name}/h*_z*/env*/model*.npy"
        ]
    elif ds == 'stl': 
        return [
            f"{base_dir}/cluster_*.npy"
        ]
    else:
        raise ValueError('Unknown source')

def get_targets(ds, class_name): 
    if ds in ('cifar10', 'pets', '3db', 'stl'): 
        return  [
            f"data/targets/{ds}/{class_name}/train.npy"
        ]
    else:
        raise ValueError('Unknown target')

def create_config(ds_src, ds_tgt, class_name, 
    maxiters=1000, tol=1e-6, lr=1, eps=1): 
    sources = get_sources(ds_src, class_name)
    targets = get_targets(ds_tgt, class_name)
    
    assert len(targets) == 1 

    return {
        "base_dir" : "", 
        "sources" : sources, 
        "targets": targets, 
        "target_dist": [1], 
        "solver": "simplex", 
        "solver_args": {
            "maxiters": maxiters, 
            "tol": tol, 
            "eps": eps, 
            "lr": lr
        }, 
        "normalize_source": ds_src == 'stl', 
        "normalize_target": ds_tgt == 'stl'
    }


if __name__ == "__main__": 
    SOURCES = ('in', 'cifar10', 'pets', 'stl', '3db')
    TARGETS = ('cifar10', 'pets', 'stl', '3db')
    EPSILONS = [1, 0.1, 0.01, 0.001, 0.0001]

    for ds_src in SOURCES: 
        for ds_tgt in TARGETS: 
            tgt_classes = get_classes(ds_tgt)
            src_classes = get_classes(ds_src)

            classes = [c for c in tgt_classes if c in src_classes]
            for tgt in classes: 
                print(f'{ds_src}_{ds_tgt}_class_{tgt}')
                for eps in EPSILONS: 
                    d = create_config(ds_src, ds_tgt, tgt, eps=eps)
                    with open(f'configs/{ds_src}_{ds_tgt}_class_{tgt}_eps_{eps}.json', 'w') as f: 
                        json.dump(d, f, indent=4, sort_keys=True)
