from counterfactual.import_essentials import *
from counterfactual.utils import *
from counterfactual.train import *
from counterfactual.training_module import *
from counterfactual.net import *
from counterfactual.evaluate import *
from counterfactual.baseline import *
from pytorch_lightning.metrics.functional.classification import accuracy


def l_inf_proj(delta: torch.Tensor, eps: float, cat_idx: Optional[int]=None):
    if cat_idx is None:
        return delta.clamp(-eps, eps)
    else:
        delta[:, :cat_idx] = delta[:, :cat_idx].clamp(-eps, eps)
        return delta

def fsgm(x, y, pred_fn, epsilon, cat_idx):
    alpha = epsilon * 1.25
    # init delta
    delta = torch.zeros_like(x).uniform_(-epsilon, epsilon)
    delta.requires_grad = True
    # loss and calculate delta.grad
    loss = F.binary_cross_entropy(pred_fn(x + delta), y)
    loss.backward()
    # fast-sign gradient descent
    scaled_g = delta.grad.detach().sign()
    delta.data = l_inf_proj(delta + alpha * scaled_g, eps=epsilon, cat_idx=cat_idx)
    delta.grad.zero_()
    return delta.detach()


def pgd(x, y, pred_fn, epsilon, cat_idx, max_steps=20):
    # init delta
    alpha = epsilon * 2.5 / max_steps

    delta = torch.zeros_like(x).uniform_(-epsilon, epsilon)
    delta.requires_grad = True

    for _ in range(max_steps):
        delta.requires_grad = True
        loss = F.binary_cross_entropy(pred_fn(x + delta), y)
        loss.backward()
        scaled_g = delta.grad.detach().sign()
        delta.data = l_inf_proj(delta + alpha * scaled_g, eps=epsilon, cat_idx=cat_idx)
        delta.grad.zero_()
    return delta.detach()  


def adversarial_robustness(val_dataset, cat_idx, pred_fn, epsilon, seeds=None):
    if seeds is None:
        seeds = [0, 7, 31, 1024, 31415926]
    robustness = []
    val_X, label = val_dataset[:]
    for seed in seeds:
        seed_everything(seed)
        # delta = fsgm(val_X, label, pred_fn, epsilon=epsilon, cat_idx=cat_idx)
        delta = pgd(val_X, label, pred_fn, epsilon=epsilon, cat_idx=cat_idx)
        _rob = accuracy(
            torch.round(pred_fn(val_X)),
            torch.round(pred_fn(val_X + delta)),
        )
        robustness.append(_rob)
    return robustness
    


configs = [
    {
        "name": "adult",
        "baseline_path": "log/adult/baseline/version_0/checkpoints/epoch=89-step=17189.ckpt",
        "cf_net_org_path": "saved_weights/adult/c_net/epoch=126-step=24256.ckpt",
        "cf_net_upt_path": "log/adult/cf_2opt_update/version_1/checkpoints/epoch=92-step=17762.ckpt"
    },
    # {
    #     "name": "home",
    #     "baseline_path": "log/home/baseline/version_0/checkpoints/epoch=97-step=6075.ckpt",
    #     "cf_net_org_path": "saved_weights/home/c_net/epoch=564-step=35029.ckpt",
    #     "cf_net_upt_path": "log/home/cf_2opt_update/version_1/checkpoints/epoch=95-step=5951.ckpt"
    # },
    # {
    #     "name": "student",
    #     "baseline_path": "log/student/baseline/version_0/checkpoints/epoch=85-step=16425.ckpt",
    #     "cf_net_org_path": "saved_weights/student/c_net/epoch=115-step=22155.ckpt",
    #     "cf_net_upt_path": "log/student/cf_2opt_update/version_1/checkpoints/epoch=85-step=16425.ckpt"
    # },
    # {
    #     "name": "home",
    #     "baseline_path":
    #         "saved_weights/home/baseline/epoch=92-step=5765.ckpt",
    #         "baseline_iter": 93,
    #         "c_net_path": "saved_weights/home/c_net/epoch=564-step=35029.ckpt",
    #         "c_net_iter": 565
    # },
    # {
    #     "name": "student",
    #     "baseline_path":
    #         "saved_weights/student/baseline/epoch=98-step=18908.ckpt",
    #         "baseline_iter": 99,
    #         "c_net_path": "saved_weights/student/c_net/epoch=115-step=22155.ckpt",
    #         "c_net_iter": 116
    # },
]


robust_results_list = []

for config in configs:
    # baseline_model = load_model(config['baseline_path'], module=BaselineModel)
    cf_net_org = load_model(config['cf_net_org_path'], module=CounterfactualModel2Optimizers)
    cf_net_upt = load_model(config['cf_net_upt_path'], module=CounterfactualModel2Optimizers)
    # predict function
    # pred_fn_baseline = lambda x: baseline_model(x)

    def pred_fn_cfnet_org(x):
        y_hat, _ = cf_net_org(x)
        return y_hat
    
    def pred_fn_cfnet_upt(x):
        y_hat, _ = cf_net_upt(x)
        return y_hat
    
    val_dataset = cf_net_org.val_dataset
    cat_idx = len(cf_net_org.continous_cols)
    robust_result = {}
    eps_list = np.arange(1, 21) * 0.01
    for eps in eps_list:
        # robustness_baseline = adversarial_robustness(
        #     val_dataset, cat_idx, pred_fn=pred_fn_baseline, epsilon=eps)

        robustness_cfnet_org = adversarial_robustness(
            val_dataset, cat_idx, pred_fn=pred_fn_cfnet_org, epsilon=eps)

        robustness_cfnet_upt = adversarial_robustness(
            val_dataset, cat_idx, pred_fn=pred_fn_cfnet_upt, epsilon=eps)

        robust_result.update({
            f'eps={eps}': {
                # 'baseline': robustness_baseline, 
                'cf_net_org': robustness_cfnet_org,
                'cf_net_upt': robustness_cfnet_upt
            }})
    
    result_df = pd.DataFrame.from_dict(robust_result, orient='columns')
    robust_results_list.append(result_df)

for i, result_df in enumerate(robust_results_list):
    print(configs[i]['name'], ":")
    print(tabulate(result_df, headers='keys', tablefmt='psql'))

