from counterfactual.import_essentials import *
from counterfactual.train import train
from counterfactual.net import BaselineModel
from counterfactual.evaluate import load_model, model_cf_gen, evaluate
from counterfactual.baseline import CCHVAE
from counterfactual.utils import load_json


if __name__ == "__main__":
    adult_config = json.load(open("counterfactual/configs/adult.json"))
    student_config = json.load(open("counterfactual/configs/student.json"))
    home_config = json.load(open("counterfactual/configs/home.json"))
    breast_config = load_json("counterfactual/configs/extra/breast_cancer.json")
    student_performance_config = load_json("counterfactual/configs/extra/student_performance.json")
    titanic_config = load_json("counterfactual/configs/extra/titanic.json")

    t_config = json.load(open("counterfactual/configs/trainer.json"))

    t_config['max_epochs'] = 5
    
    ###################################################################
    # adult
    ###################################################################
    model = load_model('saved_weights/adult/baseline/epoch=55-step=10695.ckpt', 56)
    adult_config['batch_size'] = 32
    cf_result = train(
        CCHVAE(adult_config, model=model),
        t_config,
        logger=pl_loggers.TestTubeLogger(Path('log/'), name="adult/vae")
    )

    result = model_cf_gen(cf_result['module'], is_parallel=False)
    result["cat_idx"] = len(model.continous_cols)
    evaluate(result, dataset_name="adult", cf_name="C-CHVAE")

    ###################################################################
    # home
    ###################################################################
    model = load_model('saved_weights/home/baseline/epoch=92-step=5765.ckpt', 93)
    home_config['batch_size'] = 32
    cf_result = train(
        CCHVAE(home_config, model=model),
        t_config,
        logger=pl_loggers.TestTubeLogger(Path('log/'), name="home/vae")
    )

    result = model_cf_gen(cf_result['module'], is_parallel=False)
    result["cat_idx"] = len(model.continous_cols)
    evaluate(result, dataset_name="home", cf_name="C-CHVAE")

    ###################################################################
    # student
    ###################################################################
    model = load_model('saved_weights/student/baseline/epoch=98-step=18908.ckpt', 99)
    student_config['batch_size'] = 32
    cf_result = train(
        CCHVAE(student_config, model=model),
        t_config,
        logger=pl_loggers.TestTubeLogger(Path('log/'), name="student/vae")
    )

    result = model_cf_gen(cf_result['module'], is_parallel=False)
    result["cat_idx"] = len(model.continous_cols)
    evaluate(result, dataset_name="student", cf_name="C-CHVAE")

    ###################################################################
    # breast
    ###################################################################
    model = load_model('saved_weights/extra/breast/baseline/epoch=383-step=1535.ckpt', 384)
    breast_config['batch_size'] = 32
    cf_result = train(
        CCHVAE(breast_config, model=model),
        t_config,
        logger=pl_loggers.TestTubeLogger(Path('log/'), name="extra/heart/vae")
    )

    result = model_cf_gen(cf_result['module'], is_parallel=False)
    result["cat_idx"] = len(model.continous_cols)
    evaluate(result, dataset_name="breast", cf_name="C-CHVAE")

    ###################################################################
    # student performance
    ###################################################################
    model = load_model('saved_weights/extra/student_performance/baseline/epoch=287-step=1151.ckpt', 288)
    student_performance_config['batch_size'] = 32
    cf_result = train(
        CCHVAE(student_performance_config, model=model),
        t_config,
        logger=pl_loggers.TestTubeLogger(Path('log/'), name="extra/student/vae")
    )

    result = model_cf_gen(cf_result['module'], is_parallel=False)
    result["cat_idx"] = len(model.continous_cols)
    # evaluate(result, dataset_name="student_performance", cf_name="C-CHVAE")

    ###################################################################
    # titanic
    ###################################################################
    model = load_model('saved_weights/extra/titanic/baseline/epoch=63-step=383.ckpt', 64)
    titanic_config['batch_size'] = 32
    cf_result = train(
        CCHVAE(titanic_config, model=model),
        t_config,
        logger=pl_loggers.TestTubeLogger(Path('log/'), name="extra/titanic/vae")
    )

    result = model_cf_gen(cf_result['module'], is_parallel=False)
    result["cat_idx"] = len(model.continous_cols)
    evaluate(result, dataset_name="titanic", cf_name="C-CHVAE")
