import numpy as np
from tqdm import tqdm
import torch
from torch.nn import functional as F
from utils.infonce import get_infonce
from utils.mia2 import get_membership_attack_prob2
from utils.mia import get_membership_attack_prob

def evaluate_summary(model, retrain_model, result_model, statistics, loaders, args):
    torch.cuda.empty_cache()
    # evaluate_model(model, loaders, "Original Model", args)
    torch.cuda.empty_cache()
    # evaluate_model(retrain_model, loaders, "Retrain Model", args)
    torch.cuda.empty_cache()
    evaluate_model(result_model, loaders, "Unlearned Model", args)
    torch.cuda.empty_cache()
    calculate_activation_dist_with_divergence(model,retrain_model, result_model, loaders, args)
    torch.cuda.empty_cache()
    if args.test_mode == 'class': evaluate_mi_class(args)
    elif args.test_mode == 'sub_class': raise NotImplementedError
    elif args.test_mode == 'sample': evaluate_mi_sample(args)

    if statistics is not None: print(f"Total FLOPS: {flops_easy_view(statistics.total_flops)}, Elapsed time: {statistics.elapsed_time:.2f} sec")

@torch.no_grad()
def evaluate_model(model, loaders, model_name, args):
    print(f"Evaluating {model_name}...")
    model.eval()
    acc_train = 0
    acc_test = 0
    acc_train_forget = eval(model, loaders['train_forget_test_loader'], args)
    acc_train_remain = eval(model, loaders['train_remain_test_loader'], args)
    acc_test_forget = eval(model, loaders['test_forget_loader'], args)
    acc_test_remain = eval(model, loaders['test_remain_loader'], args)
    try:
        length = len(loaders['test_forget_set']) + len(loaders['test_remain_set'])
        print(len(loaders['test_forget_set']), len(loaders['test_remain_set']))
        
    except: 
        length = len(loaders['test_remain_set'])
        print(len(loaders['test_remain_set']))
    
    print("LENGTH", length)
    train_remain_test_set = torch.utils.data.Subset(loaders['train_remain_test_set'], torch.randperm(len(loaders['train_remain_test_set']))[:length])
    loaders['train_remain_test_loader'] = torch.utils.data.DataLoader(train_remain_test_set, batch_size=100, shuffle=False, num_workers=4)

    mia = get_membership_attack_prob(
        loaders['train_remain_test_loader'], loaders['train_forget_test_loader'], loaders['test_loader'], model,
    )
    mia2 = get_membership_attack_prob2(loaders, model)
    
    args.logger.info(f"{model_name} Train: {acc_train:.2f}, Train_forget: {acc_train_forget:.2f}, Train_remain: {acc_train_remain:.2f}")
    args.logger.info(f"{model_name} Test: {acc_test:.2f}, Test_forget: {acc_test_forget:.2f}, Test_remain: {acc_test_remain:.2f}")
    args.logger.info(f"{model_name} MIA: (confidence) {mia['confidence'] * 100} (entropy) {mia['entropy'] * 100}")
    args.logger.info(f"{model_name} MIA2: (confidence) {100 - mia2['confidence'] * 100} (entropy) {100 - mia2['entropy'] * 100}")


@torch.no_grad()
def calculate_activation_dist_with_divergence(orig_model, gold_model, unlearn_model, loaders, args=None):
    #Activation distance from "Can Bad Teaching Induce Forgetting? Unlearning in Deep Networks Using an Incompetent Teacher" AAAI-23
    #code from: https://github.com/vikram2000b/bad-teaching-unlearning/blob/main/metrics.py
    '''
    calculate average l2 distance of prediction probabilities(logits) in forget set(original ver.)
    Here we calculate on both forget train and test set
    '''
    orig_model.eval()
    gold_model.eval()
    unlearn_model.eval()

    fg_train_loader = loaders['train_forget_test_loader']
    rt_train_loader = loaders['train_remain_test_loader']
    fg_test_loader = loaders['test_forget_loader']
    rt_test_loader = loaders['test_remain_loader']

    # if args.test_mode == 'sub_class':
    #     rt_train_loader = loaders['train_adjacent_test_loader']

    loader_pack = [fg_train_loader, fg_test_loader, rt_train_loader, rt_test_loader]

    dist_orig_result, dist_unlearn_result = [], []
    div_orig_result, div_unlearn_result= [], []

    for loader_idx, loader in enumerate(loader_pack):
        #original, unlearn
        distances_orig = [] 
        distances_unlearn = []

        orig_prob = []
        gold_prob = []
        unlearn_prob = []
        if loader is None:
            dist_unlearn_result.append((0, 0))
            dist_orig_result.append((0, 0))
            div_unlearn_result.append((0, 0))
            div_orig_result.append((0, 0))
            continue

        for batch_idx, (inputs, targets) in enumerate(loader):
            
            inputs = inputs.to(args.device)
            orig_outputs, orig_embeddings = orig_model(inputs, get_embeddings=True)
            gold_outputs, gold_embeddings = gold_model(inputs, get_embeddings=True)
            unlearn_outputs, unlearn_embeddings = unlearn_model(inputs, get_embeddings=True)

            diff_orig = torch.sqrt(torch.sum(torch.square(F.softmax(gold_outputs, dim = 1) - F.softmax(orig_outputs, dim = 1)), axis = 1))
            diff_orig = diff_orig.detach().cpu()
            distances_orig.append(diff_orig)
    
            diff_unlearn= torch.sqrt(torch.sum(torch.square(F.softmax(gold_outputs, dim = 1) - F.softmax(unlearn_outputs, dim = 1)), axis = 1))
            diff_unlearn = diff_unlearn.detach().cpu()
            distances_unlearn.append(diff_unlearn)

            orig_prob.append(F.softmax(orig_outputs + 1e-7, dim=1))
            gold_prob.append(F.softmax(gold_outputs + 1e-7, dim=1))
            unlearn_prob.append(F.softmax(unlearn_outputs + 1e-7, dim=1))

        orig_prob = torch.cat(orig_prob, axis = 0).cpu()
        gold_prob = torch.cat(gold_prob, axis = 0).cpu()
        unlearn_prob = torch.cat(unlearn_prob, axis = 0).cpu()
        js_div_un, js_std_un= JSDiv(gold_prob, unlearn_prob)
        js_div_orig, js_std_orig = JSDiv(gold_prob, orig_prob)
    
        distances_unlearn = torch.cat(distances_unlearn, axis = 0)
        distances_unlearn_std = ((distances_unlearn - distances_unlearn.mean())**2).mean()
        distances_orig = torch.cat(distances_orig, axis = 0)
        distances_orig_std = ((distances_unlearn - distances_unlearn.mean())**2).mean()

        dist_unlearn_result.append((distances_unlearn.mean(),distances_unlearn_std))
        dist_orig_result.append((distances_orig.mean(), distances_orig_std))
        div_unlearn_result.append((js_div_un, js_std_un))
        div_orig_result.append((js_div_orig, js_std_orig))


    for idx, (dist, div) in enumerate([(dist_orig_result, div_orig_result), (dist_unlearn_result, div_unlearn_result)]):
        idx = "orig   " if idx == 0 else 'unlearn'
        args.logger.info(f"Activate Distance :[Gold, {idx}] | Train_forget: {dist[0][0]:.3f} Std: {dist[0][1]:.3f} | Test_forget: {dist[1][0]:.3f} Std: {dist[1][1]:.3f} | Train_Remain: {dist[2][0]:.3f} Std: {dist[2][1]:.3f} | Test_Remain: {dist[3][0]:.3f} Std: {dist[3][1]:.3f}")
        args.logger.info(f"JS divergence     :[Gold, {idx}] | Train_forget: {div[0][0]:.3f} Std: {div[0][1]:.3f} | Test_forget: {div[1][0]:.3f} Std: {div[1][1]:.3f} | Train_Remain: {div[2][0]:.3f} Std: {div[2][1]:.3f} | Test_Remain: {div[3][0]:.3f} Std: {div[3][1]:.3f}")

def JSDiv(p, q):
    m = (p+q)/2
    batch_wise_kl = 0.5*F.kl_div(torch.log(p + 1e-9), m, reduction='batchmean') + 0.5*F.kl_div(torch.log(q + 1e-9), m, reduction='batchmean')
    logit_wise_kl = (0.5*F.kl_div(torch.log(p + 1e-9), m, reduction='none') + 0.5*F.kl_div(torch.log(q + 1e-9), m, reduction='none')).sum(axis=1)
    std = ((logit_wise_kl - batch_wise_kl)**2).mean() 
    return batch_wise_kl, std

def eval(model, loader, args):
    if loader is None: return 0
    total = 0; correct = 0
    for (inputs, targets) in tqdm((loader)):
        inputs, targets = inputs.to(args.device), targets.to(args.device)
        outputs = model(inputs)
        predicted = torch.argmax(outputs, 1)
        total += targets.size(0)
        correct += (predicted == targets).sum().item()

    return 100. * correct / total

def eval_class(model, loader, args, class_idx):
    if loader is None: return 0
    total = 0; correct = 0
    for (inputs, targets) in tqdm(loader):
        inputs, targets = inputs.to(args.device), targets.to(args.device)
        # only evaluate the specific class
        mask = targets == class_idx
        inputs, targets = inputs[mask], targets[mask]
        if len(targets) == 0: continue
        
        outputs = model(inputs)
        predicted = torch.argmax(outputs, 1)
        total += targets.size(0)
        correct += (predicted == targets).sum().item()

    print("total: ", total, "correct: ", correct)
    return 100. * correct / total

def evaluate_mi_class(args):
        
    if args.model_name == "ResNet18":
        
        if args.data_name == "cifar10":
            # (1/10) log_2(1/10) + (9/10) log_2(9/10) = 0.4689955935892812
            if args.class_idx_unlearn == 1:
                layers = [4,5]
                orig_values, ori_auc, ori_mean = [0.4639, 0.4576], 0.4607, 0.4607
                gold_values, gold_auc, gold_mean = [0.4608, 0.3149], 0.3879, 0.3879
            # (2/10) log_2(2/10) + (8/10) log_2(8/10) = 0.7219280948873623
            elif args.class_idx_unlearn == 2:
                layers = [4,5]
                orig_values, ori_auc, ori_mean = [0.7172, 0.7066], 0.7119, 0.7119
                gold_values, gold_auc, gold_mean = [0.6512, 0.4520], 0.5516, 0.5516

        elif args.data_name == "cifar100":
            layers = [4,5]
            # (1/100) log_2(1/100) + (99/100) log_2(99/100) = 0.08079313589591118
            if args.class_idx_unlearn == 1:
                orig_values, ori_auc, ori_mean = [0.0741, 0.0666], 0.0704, 0.0704
                gold_values, gold_auc, gold_mean = [0.0718, 0.0408], 0.0563, 0.0563
            # (5 / 100) log_2(5 / 100) + (95 / 100) log_2(95 / 100) = 0.2863969571159562
            elif args.class_idx_unlearn == 5:
                orig_values, ori_auc, ori_mean = [0.2777, 0.2420], 0.2599, 0.2599
                gold_values, gold_auc, gold_mean = [0.2689, 0.1146], 0.1918, 0.1918
            # (20 / 100) log_2(20 / 100) + (80 / 100) log_2(80 / 100) = 0.7219280948873623
            elif args.class_idx_unlearn == 20:
                orig_values, ori_auc, ori_mean = [0.6996, 0.5779], 0.6388, 0.6388
                gold_values, gold_auc, gold_mean = [0.6831, 0.2401], 0.4616, 0.4616

    elif args.model_name == "ResNet50": 
        layers = [4, 5, 6]
        if args.data_name == "cifar10":
            orig_values, ori_auc, ori_mean = [0.4639, 0.4629, 0.4512], 0.4602, 0.4533
            gold_values, gold_auc, gold_mean = [0.4607, 0.4312, 0.3178], 0.4102, 0.4032
        elif args.data_name == "cifar100":
            orig_values, ori_auc, ori_mean = [0.0741, 0.0742, 0.0703], 0.0732, 0.0729
            gold_values, gold_auc, gold_mean = [0.0732, 0.0694, 0.0488], 0.0652, 0.0638
        elif args.data_name == "imagenet":
            orig_values, ori_auc, ori_mean = [0.1679], 0.1679, 0.1679
            gold_values, gold_auc, gold_mean = [0.1609], 0.1609, 0.1609

    elif args.model_name == "ViT":
        layers = [10, 11, 12]
        if args.data_name == "cifar10":
            orig_values, ori_auc, ori_mean = [0.4630, 0.4612, 0.4435], 0.4572, 0.4559
            gold_values, gold_auc, gold_mean = [0.4613, 0.4595, 0.4155], 0.4490, 0.4454
        elif args.data_name == "cifar100":
            orig_values, ori_auc, ori_mean = [0.0696, 0.0684, 0.0623], 0.0672, 0.0668
            gold_values, gold_auc, gold_mean = [0.0683, 0.0654, 0.0581], 0.0643, 0.0639
        elif args.data_name == "imagenet":
            orig_values, ori_auc, ori_mean = [0.2719, 0.2652, 0.2088], 0.2528, 0.2486
            gold_values, gold_auc, gold_mean = [0.2714, 0.2597, 0.2020], 0.2482, 0.2444

    args.train_mode = 'mi'
    answers = []
    for layer in layers:
        args.layer = layer
        # lr to 0.1 -> 1e-1, 0.01 -> 1e-2, 0.001 -> 1e-3
        args.ckpt = f"final_checkpoints/{args.data_name}/{args.test_mode}/{args.model_name}_{args.method}_{args.data_name}_{args.test_mode}_{args.class_idx}_{args.class_idx_unlearn}_{args.remain_epochs}_{args.forget_epochs}_{args.remain_batch_size}_{args.forget_batch_size}_{args.optimizer}_{to_scientific_notation(args.lr)}{args.rnd_seed}.pth"
        answer = get_infonce(args, pretrained_model=None)
        answers.append(answer)

    print(f"Original: {orig_values}, AUC: {ori_auc:.4f}, Mean: {ori_mean:.4f}")
    print(f"Gold: {gold_values}, AUC: {gold_auc:.4f}, Mean: {gold_mean:.4f}")

    if len(layers) > 1:
        unlearn_mean = sum(answers)/len(answers)
        unlearn_AUC = (2 * sum(answers) - answers[-1] - answers[0]) / (2 * len(answers) - 2)
        print(f"Mean: {unlearn_mean:.4f}, AUC: {unlearn_AUC:.4f}")
        print(f"Mean Ratio: {(unlearn_mean - gold_mean) / (ori_mean - gold_mean):.4f}, AUC Ratio: {(unlearn_AUC - gold_auc) / (ori_auc - gold_auc):.4f}")
    elif len(layers) == 1:
        print(f"Mean: {answers[0]:.4f}, AUC: {answers[0]:.4f}")
        print(f"Mean Ratio: {(answers[0] - gold_mean) / (ori_mean - gold_mean):.4f}, AUC Ratio: {(answers[0] - gold_auc) / (ori_auc - gold_auc):.4f}")

def evaluate_mi_sub_class(args):
    args.test_mode = 'sub_class'
    if 'baby' in args.sub_class_name: ori, gold = 0.05647, 0.1634
    elif 'lamp' in args.sub_class_name: ori, gold = 0.2589, 0.5603
    elif 'mushroom' in args.sub_class_name: ori, gold = 0.3980, 0.6102
    elif 'rocket' in args.sub_class_name: ori, gold = 0.3825, 0.6191
    elif 'sea' in args.sub_class_name: ori, gold = 0.1636, 0.2606
    
    args.train_mode = 'mi'
    args.sub_class_name = args.sub_class_name[0]
    
    args.layer = 5
    args.ckpt = f"final_checkpoints/{args.data_name}/{args.test_mode}/{args.model_name}_{args.method}_{args.sub_class_name}_{args.remain_epochs}_{args.forget_epochs}_{args.remain_batch_size}_{args.forget_batch_size}_{args.optimizer}_{to_scientific_notation(args.lr)}{args.rnd_seed}.pth"
    answer = get_infonce(args, pretrained_model=None)
    
    print(f"Original: {ori:.4f}, Gold: {gold:.4f}")
    print(f"Mean: {answer:.4f}, AUC: {answer:.4f}")
    print(f"Mean Ratio: {(answer - gold) / (ori - gold):.4f}", f"AUC Ratio: {(answer - gold) / (ori - gold):.4f}")
    
    
def evaluate_mi_sample(args):
    args.test_mode = 'sample'
    
    if args.data_name == "cifar10":
        layers = [4,5]
        orig_values, ori_auc, ori_mean = [2.7144, 1.8794], 2.2969, 2.2969
        gold_values, gold_auc, gold_mean = [2.306, 1.625], 1.9655, 1.9655
    elif args.data_name == "cifar100":
        layers = [4,5]
        orig_values, ori_auc, ori_mean = [1.9618, 1.2048], 1.5833, 1.5833
        gold_values, gold_auc, gold_mean = [1.669, 0.8519], 1.26, 1.26
        
    
    args.train_mode = 'mi'
    answers = []
    for layer in layers:
        args.layer = layer
        args.ckpt = f"final_checkpoints/{args.data_name}/{args.test_mode}/{args.model_name}_{args.method}_{args.data_name}_{args.test_mode}_{args.sample_unlearn_per_class}_{args.remain_epochs}_{args.forget_epochs}_{args.remain_batch_size}_{args.forget_batch_size}_{args.optimizer}_{to_scientific_notation(args.lr)}{args.rnd_seed}.pth"
        answer = get_infonce(args, pretrained_model=None)
        answers.append(answer)

    print(f"Original: {orig_values}, AUC: {ori_auc:.4f}, Mean: {ori_mean:.4f}")
    print(f"Gold: {gold_values}, AUC: {gold_auc:.4f}, Mean: {gold_mean:.4f}")

    if len(layers) > 1:
        unlearn_mean = sum(answers)/len(answers)
        unlearn_AUC = (2 * sum(answers) - answers[-1] - answers[0]) / (2 * len(answers) - 2)
        print(f"Mean: {unlearn_mean:.4f}, AUC: {unlearn_AUC:.4f}")
        print(f"Mean Ratio: {(unlearn_mean - gold_mean) / (ori_mean - gold_mean):.4f}, AUC Ratio: {(unlearn_AUC - gold_auc) / (ori_auc - gold_auc):.4f}")
    elif len(layers) == 1:
        print(f"Mean: {answers[0]:.4f}, AUC: {answers[0]:.4f}")
        print(f"Mean Ratio: {(answers[0] - gold_mean) / (ori_mean - gold_mean):.4f}, AUC Ratio: {(answers[0] - gold_auc) / (ori_auc - gold_auc):.4f}")

def flops_easy_view(total_flops):
    flops = ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z', 'Y']
    i = 0
    while total_flops >= 1000:
        total_flops /= 1000
        i += 1
    return f"{total_flops:.2f} {flops[i]}FLOPS"


def to_scientific_notation(num):
    scientific_str = f"{num:.1e}"
    base, exponent = scientific_str.split('e')
    if base.endswith('.0'):
        base = base[:-2]
    exponent = exponent.replace('0', '')

    return f"{base}e{exponent}"