from counterfactual.import_essentials import *
from counterfactual.utils import *
from counterfactual.train import *
from counterfactual.training_module import *
from counterfactual.net import *
from counterfactual.evaluate import *
from counterfactual.baseline import *
from counterfactual.methods.uncertain import UncertainCF


if __name__ == "__main__" and not in_jupyter():
    dummy_config = json.load(open("counterfactual/configs/dummy.json"))
    adult_config = json.load(open("counterfactual/configs/adult.json"))
    student_config = json.load(open("counterfactual/configs/student.json"))
    home_config = json.load(open("counterfactual/configs/home.json"))
    breast_config = load_json("counterfactual/configs/extra/breast_cancer.json")
    student_performance_config = load_json("counterfactual/configs/extra/student_performance.json")
    titanic_config = load_json("counterfactual/configs/extra/titanic.json")
    credit_card_config = load_json("counterfactual/configs/extra/credit_card.json")
    german_credit_config = load_json("counterfactual/configs/extra/german_credit.json")

    t_config = json.load(open("counterfactual/configs/trainer.json"))

    configs = [
        {
            "data_name": "adult", 
            "c_net_path": "saved_weights/adult/c_net/epoch=126-step=24256.ckpt", 
            "c_net_epoch": 127,
            "baseline_path": "saved_weights/adult/baseline/epoch=55-step=10695.ckpt",
            "baseline_epoch": 56,
            "config": adult_config
        },
        {
            "data_name": "student", 
            "c_net_path": "saved_weights/student/c_net/epoch=115-step=22155.ckpt", 
            "c_net_epoch": 116,
            "baseline_path": "saved_weights/student/baseline/epoch=98-step=18908.ckpt",
            "baseline_epoch": 99,
            "config": student_config
        },
        {
            "data_name": "home", 
            "c_net_path": "saved_weights/home/c_net/epoch=564-step=35029.ckpt", 
            "c_net_epoch": 565,
            "baseline_path": "saved_weights/home/baseline/epoch=92-step=5765.ckpt",
            "baseline_epoch": 93,
            "config": home_config
        },
        {
            "data_name": "breast", 
            "c_net_path": "saved_weights/extra/breast/c_net/epoch=440-step=1763.ckpt", 
            "c_net_epoch": 441,
            "baseline_path": "saved_weights/extra/breast/baseline/epoch=383-step=1535.ckpt",
            "baseline_epoch": 384,
            "config": breast_config
        },
        {
            "data_name": "student_performance", 
            "c_net_path": "saved_weights/extra/student_performance/c_net/epoch=451-step=1807.ckpt", 
            "c_net_epoch": 452,
            "baseline_path": "saved_weights/extra/student_performance/baseline/epoch=287-step=1151.ckpt",
            "baseline_epoch": 288,
            "config": student_performance_config
        },
        {
            "data_name": "titanic", 
            "c_net_path": "saved_weights/extra/titanic/c_net/epoch=61-step=371.ckpt", 
            "c_net_epoch": 62,
            "baseline_path": "saved_weights/extra/titanic/baseline/epoch=63-step=383.ckpt",
            "baseline_epoch": 64,
            "config": titanic_config
        },
        {
            "data_name": "credit", 
            "c_net_path": "saved_weights/extra/credit/c_net/epoch=90-step=2001.ckpt", 
            "c_net_epoch": 91,
            "baseline_path": "saved_weights/extra/credit/baseline/epoch=361-step=7963.ckpt",
            "baseline_epoch": 362,
            "config": credit_card_config
        },
        {
            "data_name": "german", 
            "c_net_path": "saved_weights/extra/german/c_net/epoch=21-step=131.ckpt", 
            "c_net_epoch": 22,
            "baseline_path": "saved_weights/extra/german/baseline/epoch=19-step=119.ckpt",
            "baseline_epoch": 20,
            "config": german_credit_config
        },
    ]
 
    # seeds = [0, 21, 113, 13]
    seeds = [None]

    current_time = time.time()

    for seed in seeds:
        if seed is not None:
            seed_everything(seed=seed)
        for config in configs:
            # print("dealing ", config['data_name'])
            # model = load_model(config['c_net_path'], config['c_net_epoch'], module=CounterfactualModel2Optimizers)
            # result = model_cf_gen(model)
            # result["cat_idx"] = len(model.continous_cols)
            # evaluate(result, dataset_name=config['data_name'], cf_name="CounterfactualNet", seed=seed)

            # load baseline model
            model = load_model(config['baseline_path'], config['baseline_epoch'])

            # CounteRGAN
            t_config['max_epochs'] = 50
            t_config['max_steps'] = 2000
            cf_result_1 = train(
                CounteRGANTrainingModule(config['config'], model=model, target_class=1.),
                t_config,
                logger=pl_loggers.TestTubeLogger(Path('log/'), name=f"{config['data_name']}/cfgan")
            )
            cf_result_0 = train(
                CounteRGANTrainingModule(config['config'], model=model, target_class=0.),
                t_config,
                logger=pl_loggers.TestTubeLogger(Path('log/'), name=f"{config['data_name']}/cfgan")
            )
            cf_method = CounteRGAN(cf_result_0['module'], cf_result_1['module'])
            result = model_cf_gen(cf_method)
            result["cat_idx"] = len(model.continous_cols)
            evaluate(result, dataset_name=config['data_name'], cf_name="CounteRGAN", seed=seed)

            # UncertainCF
            # cf_method = UncertainCF(config['config'], model)
            # result = model_cf_gen(cf_method)
            # result["cat_idx"] = len(model.continous_cols)
            # evaluate(result, dataset_name=config['data_name'], cf_name="UncertainCF", seed=seed)

            # VAE-CF
            # config['config']['validity_reg'] = 0.2
            # t_config['max_epochs'] = 5
            # cf_result = train(
            #     VAE_CF(config['config'], model=model),
            #     t_config,
            #     logger=pl_loggers.TestTubeLogger(Path('log/'), name=f"{config['data_name']}/vae")
            # )

            # result = model_cf_gen(cf_result['module'])
            # result["cat_idx"] = len(model.continous_cols)
            # evaluate(result, dataset_name=config['data_name'], cf_name="VAE-CF", seed=seed)

            # C-CHVAE
            # config['config']['batch_size'] = 32
            # t_config['max_epochs'] = 5
            # cf_result = train(
            #     CCHVAE(config['config'], model=model),
            #     t_config,
            #     logger=pl_loggers.TestTubeLogger(Path('log/'), name=f"{config['data_name']}/ccvae")
            # )

            # result = model_cf_gen(cf_result['module'], is_parallel=False)
            # result["cat_idx"] = len(model.continous_cols)
            # evaluate(result, dataset_name=config['data_name'], cf_name="C-CHVAE", seed=seed)

            # valiniaCF
            # result = cf_gen_parallel(CFExplainer=VanillaCF, cf_params={'model': model})
            # result["cat_idx"] = len(model.continous_cols)
            # evaluate(result, dataset_name=config['data_name'], cf_name="VanillaCF", seed=seed)

            # diverseCF
            # result = cf_gen_parallel(CFExplainer=DiverseCF, cf_params={'model': model})
            # result["cat_idx"] = len(model.continous_cols)
            # evaluate(result, dataset_name=config['data_name'], cf_name="DiverseCF", seed=seed)

            # train AE first
            # t_config['max_epochs'] = 10
            # t_config['gpus'] = 0
            # result = train(AE(config['config']), t_config)
            # ae = result['module']

            # # protoCF
            # result = cf_gen_parallel(CFExplainer=ProtoCF, cf_params={
            #     'model': model, 'ae': ae, 'train_loader': DataLoader(model.train_dataset, batch_size=128, shuffle=True)
            # })
            # result["cat_idx"] = len(model.continous_cols)
            # evaluate(result, dataset_name=config['data_name'], cf_name="ProtoCF", seed=seed)

    print("total time: ", time.time() - current_time)
