# -*- coding: utf-8 -*-
import hydra
from omegaconf import OmegaConf
import trainer
import MAP
import os
import tracemalloc
import policy_initial
from omegaconf import DictConfig
import gen_circle
import warnings

warnings.filterwarnings("ignore", category=UserWarning)

tracemalloc.start()


@hydra.main(config_path="configuration/.", config_name="config", version_base="1.1")
def modify_and_run(cfg: DictConfig):
    #gen_circle.gen()
    os.chdir(hydra.utils.get_original_cwd())
    OmegaConf.set_struct(cfg, False)  

    for cfg.sample_type in [9]:
        for cfg.d in [0]:

            #short episode
            if cfg.d == 1:
                cfg.ep_len = 5
                cfg.sk_num = 16
                cfg.index_dim = 4
                cfg.cont_lr= 1e-4
                cfg.agent_lr= 1e-4
                cfg.cb.knn = 40
                cfg.cic.knn = 40
                cfg.train_iter = 2000
            


            #long episode
            elif cfg.d == 0:
                cfg.ep_len = 10
                cfg.sk_num = 32
                cfg.index_dim = 5
                cfg.cont_lr= 3e-5
                cfg.agent_lr= 3e-5
                cfg.cb.knn = 200
                cfg.cic.knn = 200
                cfg.train_iter = 10000


            else:
                print("error")
                cfg.d = None

            policy_initial.policy_initialization(cfg)

            for j in ["CB","cic","diayn","dads","large_cb" ]:
            
                cfg.algorithm = j
                env = MAP.Env(cfg.sample_type, 0)
            

                mult_sk_rl = trainer.Train(cfg, env)
                mult_sk_rl.train(cfg.train_iter)


if __name__=="__main__":
    modify_and_run()
    