import json
from experiment import Experiment
from interval_models import MDPSpec
from datetime import datetime
from config.base import base_config as cfg_standards
import sys

filenames = [("collision", False), ("evade", True)]

GLOBAL_TITLE = "NeurIPS"

DRY_RUN = False

NUM_RUNS = 2 if DRY_RUN else 20
MULTITHREAD = True

all_seeds = list(range(100))

run_seeds = all_seeds[:NUM_RUNS] # TODO: FIX THIS IF ADDING RUNS WITH OTHER SEEDS

for filename, large in filenames:
    # for temp in [0.25, 0.5, 0.75]:
    temp = 1
    with open('data/input/cfgs/'+filename+'.json') as f:
        load_file = json.load(f)

    assert len(load_file) == 1, load_file

    for name in load_file:
        assert len(load_file[name]) == 1, load_file[name]
        for cfg in load_file[name]:
            cfg.update(cfg_standards)

            if large:
                cfg['name'] += '-large'
            if 'collision' in cfg['name'].lower():
                cfg.update({
                        "p_evals": [{"sl": 0.6}, {"sl": 0.65}, {"sl": 75}, {"sl": 0.8}], # for collision.prism
                        "p_bounds":{"sl":[0.6,0.8]}, # for collision.prism
                        "p_init": [{"sl": (0.6 + 0.8) / 2} for _ in range(NUM_RUNS)],
                })
            else:
                cfg.update({
                        "p_evals": [{"sl": 0.1}, {"sl": 0.2}, {"sl": 0.3}, {"sl": 0.4}],
                        "p_bounds":{"sl":[0.1,0.4]},
                        "p_init": [{"sl": (0.1 + 0.4) / 2} for _ in range(NUM_RUNS)],
                })
                

            if DRY_RUN:
                cfg['batch_dim'] = 4
                cfg['rounds'] = 2

            # for loss in ['cce', 'kld']:
            loss = 'cce'
            
            memories = [
                {"quantization" : "tern", "bottleneck_dim" : 1},
                {"quantization" : "sign", "bottleneck_dim" : 2},
                {"quantization" : "sign", "bottleneck_dim" : 3},
                {"quantization" : "sign", "bottleneck_dim" : 4}
            ]
            
            # for method in ['qbn', 'qrnn']:
            for method in ['qbn']:
                for memory_setting in memories:
                    
                    cfg.update(memory_setting)
                    
                    cfg['method'] = method
                    

                    for pi_setting in [{'policy' : 'qmdp'}]:
                    # for pi_setting in [{'policy' : 'qmdp'}, {'policy' : 'fib'}]:
                        # pi_setting = {'policy' : 'qumdp'}

                        cfg.update(pi_setting)

                        cfg['a_loss'] = loss
                        cfg['temperature'] = temp

                        max_k = (3 if cfg['quantization'].lower() == 'tern' else 2)**cfg['bottleneck_dim']

                        # Normal run
                        cfg['dynamic_uncertainty'] = True
                        spec = MDPSpec.Rminmax
                        cfg['specification'] = spec.value
                        exp = Experiment(f'{GLOBAL_TITLE}{"-DRY-RUN" if DRY_RUN else ""}/{method}/{"train_deterministic" if cfg["train_deterministic"] else "train_stochastic"}/maxk={max_k}/{loss}/{cfg["name"]}-{str(spec.name)}-{pi_setting["policy"]}', cfg, NUM_RUNS)
                        try:
                            exp.execute(MULTITHREAD, run_seeds)
                        except Exception as e:
                            print("Run failed: ", e)
                            if DRY_RUN:
                                raise e
                
                cfg.update({"quantization" : "tern", "bottleneck_dim" : 2})
                cfg['train_deterministic'] = False

                for temp in [-1, 0.5]:
                    cfg['temperature'] = temp
                
                    max_k = (3 if cfg['quantization'].lower() == 'tern' else 2)**cfg['bottleneck_dim']

                    # Normal run
                    cfg['dynamic_uncertainty'] = True
                    spec = MDPSpec.Rminmax
                    cfg['specification'] = spec.value
                    exp = Experiment(f'{GLOBAL_TITLE}{"-DRY-RUN" if DRY_RUN else ""}/{method}/{"train_deterministic" if cfg["train_deterministic"] else f"train_stochastic/tau={temp}"}/maxk={max_k}/{loss}/{cfg["name"]}-{str(spec.name)}-{pi_setting["policy"]}', cfg, NUM_RUNS)
                    try:
                        exp.execute(MULTITHREAD, run_seeds)
                    except Exception as e:
                        print("Run failed: ", e)
                        if DRY_RUN:
                            raise e
                    
