from counterfactual.import_essentials import *
from counterfactual.utils import *
from counterfactual.train import *
from counterfactual.training_module import *
from counterfactual.net import *
from counterfactual.evaluate import *
from counterfactual.baseline import *


def evaluate_robustness(configs, thresholds: list):
    def calculate_robustness(cf_module: DataModule, cf_result: dict):
        diffs, total_nums = cf_module.check_cont_robustness(cf_result['x'], cf_result['cf'], cf_result['cf_y'])
        return {
            "diffs": diffs.item(),
            "total_nums": total_nums,
            "robustness": 1.0 - diffs.item() / total_nums
        }
    
    r = {}
    for config in configs:
        # baseline = load_model(config['baseline_path'], config['baseline_iter'])
        c_net = load_model(config['c_net_path'], config['c_net_iter'], module=CounterfactualModel2Optimizers)
        r[config['name']] = {}
        for threshold in thresholds:
            # baseline.threshold = threshold
            c_net.threshold = threshold
            
            # vanilla_cf = torch.load(f"results/{config['name']}/VanillaCF_result.pt")
            # vanilla_robustness = calculate_robustness(baseline, vanilla_cf)
            
            # diverse_cf = torch.load(f"results/{config['name']}/DiverseCF_result.pt")
            # diverse_robustness = calculate_robustness(baseline, diverse_cf)
            
            # proto_cf = torch.load(f"results/{config['name']}/ProtoCF_result.pt")
            # proto_robustness = calculate_robustness(baseline, proto_cf)
            
            # vae_cf = torch.load(f"results/{config['name']}/VAE-CF_result.pt")
            # vae_robustness = calculate_robustness(baseline, vae_cf)
            
            c_net_cf = torch.load(f"results/{config['name']}/CounterfactualNet_result.pt")
            c_net_robustness = calculate_robustness(c_net, c_net_cf)
            
            r[config['name']][f"{threshold}"] = {
                # "vanilla_cf": vanilla_robustness,
                # "diverse_cf": diverse_robustness,
                # "proto_cf": proto_robustness,
                # "vae_cf": vae_robustness,
                "c_net_cf": c_net_robustness
            }
            
            csv_path = f"results/{config['name']}/robustness/threshold={threshold}.csv"
            pd.DataFrame.from_dict(r[config['name']][f"{threshold}"], orient="index",).to_csv(csv_path)
    return r


if __name__ == "__main__" and not in_jupyter():
    configs = [
        {
            "name": "adult",
            "baseline_path":
            "saved_weights/adult/baseline/epoch=55-step=10695.ckpt",
            "baseline_iter": 56,
            "c_net_path": "saved_weights/adult/c_net/epoch=126-step=24256.ckpt",
            "c_net_iter": 127
        },
        {
            "name": "home",
            "baseline_path":
            "saved_weights/home/baseline/epoch=92-step=5765.ckpt",
            "baseline_iter": 93,
            "c_net_path": "saved_weights/home/c_net/epoch=564-step=35029.ckpt",
            "c_net_iter": 565
        },
        {
            "name": "student",
            "baseline_path":
            "saved_weights/student/baseline/epoch=98-step=18908.ckpt",
            "baseline_iter": 99,
            "c_net_path": "saved_weights/student/c_net/epoch=115-step=22155.ckpt",
            "c_net_iter": 116
        },
        # {
        #     "name": "extra/student_performance",
        #     "baseline_path":
        #     "saved_weights/extra/student_performance/baseline/epoch=287-step=1151.ckpt",
        #     "baseline_iter": 288,
        #     "c_net_path": "saved_weights/extra/student_performance/c_net/epoch=462-step=1851.ckpt",
        #     "c_net_iter": 463
        # },
        # {
        #     "name": "extra/titanic",
        #     "baseline_path":
        #     "saved_weights/extra/titanic/baseline/epoch=63-step=383.ckpt",
        #     "baseline_iter": 64,
        #     "c_net_path": "saved_weights/extra/titanic/c_net/epoch=61-step=371.ckpt",
        #     "c_net_iter": 62
        # },
        # {
        #     "name": "extra/breast",
        #     "baseline_path":
        #     "saved_weights/extra/breast/baseline/epoch=383-step=1535.ckpt",
        #     "baseline_iter": 384,
        #     "c_net_path": "saved_weights/extra/breast/c_net/epoch=371-step=1487.ckpt",
        #     "c_net_iter": 372
        # },
    ]

    r = evaluate_robustness(configs, thresholds=[1.0, 2.0, 5.0, 10.0])