import json
from experiment import Experiment
from interval_models import MDPSpec
from datetime import datetime
from config.base import base_config as cfg_standards
import sys

filenames = [('avoid', True), ('intercept', True), ("collision", False), ("evade", True)]

GLOBAL_TITLE = "NeurIPS"

DRY_RUN = False

NUM_RUNS = 2 if DRY_RUN else 20
MULTITHREAD = True

all_seeds = list(range(100))

run_seeds = all_seeds[:NUM_RUNS] # TODO: FIX THIS IF ADDING RUNS WITH OTHER SEEDS

for method in ['qrnn', 'kmeans']:
    for filename, large in filenames:
        temp = 1
        with open('data/input/cfgs/'+filename+'.json') as f:
            load_file = json.load(f)

        assert len(load_file) == 1, load_file

        for name in load_file:
            assert len(load_file[name]) == 1, load_file[name]
            for cfg in load_file[name]:
                cfg.update(cfg_standards)

                if large:
                    cfg['name'] += '-large'
                if 'collision' in cfg['name'].lower():
                    cfg.update({
                            "p_evals": [{"sl": 0.6}, {"sl": 0.65}, {"sl": 75}, {"sl": 0.8}], # for collision.prism
                            "p_bounds":{"sl":[0.6,0.8]}, # for collision.prism
                            "p_init": [{"sl": (0.6 + 0.8) / 2} for _ in range(NUM_RUNS)],
                    })
                else:
                    cfg.update({
                            "p_evals": [{"sl": 0.1}, {"sl": 0.2}, {"sl": 0.3}, {"sl": 0.4}],
                            "p_bounds":{"sl":[0.1,0.4]},
                            "p_init": [{"sl": (0.1 + 0.4) / 2} for _ in range(NUM_RUNS)],
                    })
                    

                if DRY_RUN:
                    cfg['batch_dim'] = 4
                    cfg['rounds'] = 2
                    
                if method == 'qrnn':
                    a = {
                    "weight_decay" : 0.004,
                    "clipnorm" : 0.1,
                    'clipvalue' : 0.1,
                    }
                    cfg.update(a)

                # for loss in ['cce', 'kld']:
                loss = 'cce'

                cfg['method'] = method
                

                for pi_setting in [{'policy' : 'qmdp'}]:
                # for pi_setting in [{'policy' : 'qmdp'}, {'policy' : 'fib'}]:

                    cfg.update(pi_setting)

                    cfg['a_loss'] = loss
                    cfg['temperature'] = temp

                    max_k = (3 if cfg['quantization'].lower() == 'tern' else 2)**cfg['bottleneck_dim']

                    # Normal run
                    cfg['dynamic_uncertainty'] = True
                    spec = MDPSpec.Rminmax
                    cfg['specification'] = spec.value
                    exp = Experiment(f'{GLOBAL_TITLE}{"-DRY-RUN" if DRY_RUN else ""}/{method}/{"train_deterministic" if cfg["train_deterministic"] else "train_stochastic"}/maxk={max_k}/{loss}/{cfg["name"]}-{str(spec.name)}-{pi_setting["policy"]}', cfg, NUM_RUNS)
                    try:
                        exp.execute(MULTITHREAD, run_seeds)
                    except Exception as e:
                        print("Run failed: ", e)
                        if DRY_RUN:
                            raise e
