from scripts_ import *
import argparse
import utils
import absl.app as app

import configs

from typing import Type, Dict, Any, Optional

parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=6)
args = parser.parse_args()


SEED = args.seed

CONTINUE_N_ITER: Optional[int] = None  # for debugging only
CONTINUE = (CONTINUE_N_ITER is not None)

Train.seed(SEED)
SCRIPTS = {
    'gru': (TrainGRU, TestGRU),
    'cdl': (TrainCDL, TestCDL),
    'cdlattn': (TrainCDL, TestCDL),
    'mlp': (TrainMLP, TestMLP),
    'ooc': (TrainOOC, TestOOC),
    'full': (TrainFull, TestFull),
    'ticsa': (TrainTICSA, TestTICSA),
    'gnn': (TrainGNN, TestGNN)
}


RUN_ID = f"seed-{SEED}"


def get_trainer(approach: str, config: configs.Config) -> Train:
    if 'ooc' == approach:
        train = TrainOOC(
            config.env_id, RUN_ID, config.trainargs, config.ooc_args,
            p_edge=config.p_edge, causal_threshold=config.cmi_thres,
            n_iter_stable=config.ooc_n_iter_stable,
            env_options=config.env_option,
            _continue=CONTINUE)
    elif 'full' == approach:
        train = TrainFull(
            config.env_id, RUN_ID, config.trainargs, config.ooc_args,
            env_options=config.env_option, _continue=CONTINUE)
    elif 'mlp' == approach:
        train = TrainMLP(
            config.env_id, RUN_ID, config.trainargs, config.mlp_args,
            env_options=config.env_option, _continue=CONTINUE)
    elif approach == 'cdl' or approach == 'cdlattn':
        train = TrainCDL(
            config.env_id, RUN_ID, config.trainargs, config.cdl_args,
            causal_threshold=config.cmi_thres,
            env_options=config.env_option, _continue=CONTINUE)
    elif 'gru' == approach:
        train = TrainGRU(
            config.env_id, RUN_ID, config.trainargs, config.gru_args,
            causal_threshold=config.fcit_thres, n_job_fcit=config.njob_fcit,
            env_options=config.env_option, _continue=CONTINUE)
    elif 'ticsa' == approach:
        train = TrainTICSA(
            config.env_id, RUN_ID, config.trainargs, config.tisca_args,
            norm_penalty=1.0,
            env_options=config.env_option, _continue=CONTINUE)
    elif 'gnn' == approach:
        train = TrainGNN(config.env_id, RUN_ID, config.trainargs,
                         config.gnn_args, env_options=config.env_option,
                         _continue = CONTINUE)
    else:
        assert False
    return train


def run_test(approach: str, path: str, config: configs.Config, 
             label = 'test', env_options = {}):
    test_type: Type[Test] = SCRIPTS[approach][1]
    test_type(path, config.testargs, env_options, label).main()


def experiment_1(_):
    # -------< Selct the environment here! >------
    # CONFIG = configs.config_block(2)
    CONFIG = configs.config_block(5)
    # CONFIG = configs.config_block(10)
    # CONFIG = configs.config_mouse('444')
    # CONFIG = configs.CONFIG_CMS
    # CONFIG = configs.CONFIG_DZB
    
    if CONTINUE:
        assert CONTINUE_N_ITER is not None
        CONFIG.trainargs.n_iter = CONTINUE_N_ITER

    CONFIG.cdl_args.kernel = 'attn' # annotate this line if use pooling CDL.
    # CONFIG.cdl_args.kernel = 'max' # annotate this line if use attention CDL.

    # ----- list the concerned approaches here -----
    for approach in [
        'ooc',  # OOCDM
        # 'full',  # OOFULL
        # 'cdl',  # CDL
        # 'mlp',  # MLP
        # 'gru',  # FCIT+GRU
        # 'ticsa',  # TICSA
        # 'gnn',  # GNN
    ]:
        trainer = get_trainer(approach, CONFIG)
        trainer.main()  # annotate this line if you only want to evaluate the already-trained models.
        path = str(trainer.path)
        del trainer
        run_test(approach, path, CONFIG, 'test', {})
        if CONFIG.env_id in ('block', 'mouse'):
            run_test(approach, path, CONFIG, 'test-new',
                     dict(ood=True))
    

def experiment_2(_):
    for approach in [
        "ooc",
        "full",
    ]:
        CONFIG = configs.config_mouse('seen')
        if CONTINUE:
            assert CONTINUE_N_ITER is not None
            CONFIG.trainargs.n_iter = CONTINUE_N_ITER

        trainer = get_trainer(approach, CONFIG)
        trainer.main()
        path = str(trainer.path)
        del trainer
        run_test(approach, path, CONFIG, 'test', {})
        run_test(approach, path, CONFIG, 'test-unseen',
                 dict(task='unseen'))
        
        CONFIG.env_option['task'] = 'unseen'
        trainer = get_trainer(approach, CONFIG)
        trainer.main()
        path = str(trainer.path)
        del trainer
        run_test(approach, path, CONFIG, 'test', {})


if __name__ == "__main__":
    # ------ Select the experiment here ------
    app.run(experiment_1, ['_'])
    # app.run(experiment_2, ['_'])
