import json
from experiment import Experiment
from interval_models import MDPSpec
from datetime import datetime
from config.base import base_config as cfg_standards
import sys

print(sys.argv[1], sys.argv[2], sys.argv[2].lower() == "true")

filenames = [(sys.argv[1], sys.argv[2].lower() == "true")]

GLOBAL_TITLE = "NeurIPS"

DRY_RUN = False

NUM_RUNS = 2 if DRY_RUN else 20
MULTITHREAD = True

all_seeds = list(range(100))

run_seeds = all_seeds[:NUM_RUNS] # TODO: FIX THIS IF ADDING RUNS WITH OTHER SEEDS

for filename, large in filenames:
    # for temp in [0.25, 0.5, 0.75]:
    temp = 1
    with open('data/input/cfgs/'+filename+'.json') as f:
        load_file = json.load(f)

    assert len(load_file) == 1, load_file

    for name in load_file:
        assert len(load_file[name]) == 1, load_file[name]
        for cfg in load_file[name]:
            cfg.update(cfg_standards)

            if large:
                cfg['name'] += '-large'
            if 'collision' in cfg['name'].lower():
                cfg.update({
                        "p_evals": [{"sl": 0.6}, {"sl": 0.65}, {"sl": 75}, {"sl": 0.8}], # for collision.prism
                        "p_bounds":{"sl":[0.6,0.8]}, # for collision.prism
                        "p_init": [{"sl": (0.6 + 0.8) / 2} for _ in range(NUM_RUNS)],
                })
            else:
                cfg.update({
                        "p_evals": [{"sl": 0.1}, {"sl": 0.2}, {"sl": 0.3}, {"sl": 0.4}],
                        "p_bounds":{"sl":[0.1,0.4]},
                        "p_init": [{"sl": (0.1 + 0.4) / 2} for _ in range(NUM_RUNS)],
                })
                

            if DRY_RUN:
                cfg['batch_dim'] = 4
                cfg['rounds'] = 2

            # for loss in ['cce', 'kld']:
            loss = 'cce'
            
            # for method in ['qbn', 'qrnn']:
            for method in ['qbn']:
                
                cfg['method'] = method
                
                for pi_setting in [{'policy' : 'qmdp'}, {'policy' : 'fib'}]:

                    cfg.update(pi_setting)

                    cfg['a_loss'] = loss
                    cfg['temperature'] = temp

                    max_k = (3 if cfg['quantization'].lower() == 'tern' else 2)**cfg['bottleneck_dim']

                    # Normal run
                    cfg['dynamic_uncertainty'] = True
                    spec = MDPSpec.Rminmax
                    cfg['specification'] = spec.value
                    exp = Experiment(f'{GLOBAL_TITLE}{"-DRY-RUN" if DRY_RUN else ""}/{method}/{"train_deterministic" if cfg["train_deterministic"] else "train_stochastic"}/maxk={max_k}/{loss}/{cfg["name"]}-{str(spec.name)}-{pi_setting["policy"]}', cfg, NUM_RUNS)
                    try:
                        exp.execute(MULTITHREAD, run_seeds)
                    except Exception as e:
                        print("Run failed: ", e)
                        if DRY_RUN:
                            raise e

                    # Baseline
                    cfg['dynamic_uncertainty'] = False
                    spec = MDPSpec.Rminmax
                    cfg['specification'] = spec.value
                    exp = Experiment(f'{GLOBAL_TITLE}{"-DRY-RUN" if DRY_RUN else ""}/{method}/{"train_deterministic" if cfg["train_deterministic"] else "train_stochastic"}/maxk={max_k}/{loss}/{cfg["name"]}-{str(spec.name)}-BASELINE-{pi_setting["policy"]}', cfg, NUM_RUNS)
                    try:
                        exp.execute(MULTITHREAD, run_seeds)
                    except Exception as e:
                        print("Run failed: ", e)
                        if DRY_RUN:
                            raise e
