from counterfactual.import_essentials import *
from counterfactual.utils import *
from counterfactual.train import *
from counterfactual.training_module import *
from counterfactual.net import *
from counterfactual.evaluate import *
from counterfactual.baseline import *


if __name__ == "__main__":
    adult_config = load_json("counterfactual/configs/adult.json")

    configs = [
        # {
        #     "data_name": "adult", 
        #     "config": load_json("counterfactual/configs/adult.json")

        # },
        # {
        #     "data_name": "home", 
        #     "config": load_json("counterfactual/configs/home.json")
        # },
        # {
        #     "data_name": "student", 
        #     "config": load_json("counterfactual/configs/student.json")
        # },
        {
            "data_name": "credit_card", 
            "config": load_json("counterfactual/configs/extra/credit_card.json")
        },
    ]

    t_config = json.load(open("counterfactual/configs/trainer.json"))
    # lrs = [1e-2, 3e-3, 1e-3, 3e-4, 1e-4]
    t_config['max_epochs'] = 100
    t_config['gpus'] = 0

    for config in configs:
        # config["config"]['lambda_2'] = 0.1
        # train(
        #     module=CounterfactualModel2Optimizers(config["config"]),
        #     t_configs=t_config,
        #     logger=pl_loggers.TestTubeLogger(Path('log/'), name=f"{config['data_name']}/cf_2opt_update")
        # )

        # training_res = train(
        #     module=CounterfactualModel2OptsNoPass(config["config"]),
        #     t_configs=t_config,
        #     logger=pl_loggers.TestTubeLogger(Path('log/'), name=f"{config['data_name']}/cf_2opt_nopass")
        # )
        # model = training_res['module']
        # result = model_cf_gen(model)
        # result["cat_idx"] = len(model.continous_cols)
        # evaluate(result, dataset_name=config['data_name'], cf_name="CounterfactualNet-NoPass")

        # training_res = train(
        #     module=CounterfactualModelSeparate(config["config"]),
        #     t_configs=t_config,
        #     logger=pl_loggers.TestTubeLogger(Path('log/'), name=f"{config['data_name']}/cf_2opt_separate")
        # )
        # model = training_res['module']
        # result = model_cf_gen(model)
        # result["cat_idx"] = len(model.continous_cols)
        # evaluate(result, dataset_name=config['data_name'], cf_name="CounterfactualNet-Separate")

        # training_res = train(
        #     module=CounterfactualModelPosthoc(config["config"]),
        #     t_configs=t_config,
        #     logger=pl_loggers.TestTubeLogger(Path('log/'), name=f"{config['data_name']}/cf_2opt_separate")
        # )
        # model = training_res['module']
        # result = model_cf_gen(model)
        # result["cat_idx"] = len(model.continous_cols)
        # evaluate(result, dataset_name=config['data_name'], cf_name="CounterfactualNet-Posthoc")

        config['loss_2'] = 'l1'
        training_res = train(
            module=CounterfactualModel2Optimizers(config["config"]),
            t_configs=t_config,
            logger=pl_loggers.TestTubeLogger(Path('log/'), name=f"{config['data_name']}/cf_2opt")
        )
        model = training_res['module']
        result = model_cf_gen(model)
        result["cat_idx"] = len(model.continous_cols)
        evaluate(result, dataset_name=config['data_name'], cf_name="CounterfactualNet-loss2=l1")


        # train(
        #     module=BaselineModel(config["config"]),
        #     t_configs=t_config,
        #     logger=pl_loggers.TestTubeLogger(Path('log/'), name=f"{config['data_name']}/baseline")
        # )
