import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from sklearn.svm import SVC
from tqdm.auto import tqdm
from sklearn.utils import shuffle


class ConfidenceMIA:
    def __init__(self, target_model, device, verbose=True):
        self.target_model = target_model
        self.device = device
        self.target_model.to(self.device)
        self.model = SVC(C=3, gamma="auto", kernel="rbf")
        self.verbose = verbose
        self.__fit = False

    def __calculate_confidence(self, loader):
        probs = []
        targets = []
        self.target_model.eval()
        with torch.no_grad():
            for data, target in tqdm(loader, total=len(loader)):
                data, target = data.to(self.device), target.to(self.device)
                out = self.target_model(data)
                probs.append(torch.softmax(out.logits, dim=1).data)
                targets.append(target)
        probs = torch.cat(probs)
        targets = torch.cat(targets)
        confidence = torch.gather(probs, 1, targets[:, None])
        return confidence

    def fit(
        self,
        train_examples_loader,
        non_train_examples_loader,
        eval_examples_loader=None,
    ):
        cnf_train_examples = self.__calculate_confidence(train_examples_loader)
        cnf_non_train_examples = self.__calculate_confidence(non_train_examples_loader)
        cnf_eval_examples = self.__calculate_confidence(eval_examples_loader)

        X_train = (
            torch.cat([cnf_train_examples, cnf_non_train_examples], dim=0).cpu().numpy()
        )
        Y_train = np.concatenate(
            [
                np.ones(cnf_train_examples.shape[0]),
                np.zeros(cnf_non_train_examples.shape[0]),
            ]
        )

        X_train, Y_train = shuffle(X_train, Y_train)
        self.model.fit(X_train, Y_train)

        self.__fit = True
        if self.verbose:
            print("Attack Model Trained")

        if eval_examples_loader is not None:
            X_test = cnf_eval_examples.cpu().numpy()
            acc = self.model.predict(X_test).mean()
            if self.verbose:
                print("Eval Accuracy: ", acc)
            return acc
        return None

    def calculate(self, forget_examples_loader):
        if self.__fit:
            cnf_forget_examples = (
                self.__calculate_confidence(forget_examples_loader).cpu().numpy()
            )
            score = 1 - self.model.predict(cnf_forget_examples).mean()
            return score
        raise Exception("fit membership inference model first before calculating score")


class JSDivergence:
    def __init__(self, retrained_model, unlearned_model, device):
        self.retrained_model = retrained_model
        self.unlearned_model = unlearned_model
        self.device = device
        self.retrained_model.to(self.device)
        self.unlearned_model.to(self.device)
        self.metric = JensenShannonDivergence()

    def __calculate_logits(self, model, loader):
        probs = []
        model.eval()
        with torch.no_grad():
            for data, target in tqdm(loader, total=len(loader)):
                data, target = data.to(self.device), target.to(self.device)
                out = model(data)
                probs.append(torch.softmax(out.logits, dim=1).data)
        probs = torch.cat(probs)
        return probs

    def calculate(self, test_loader):
        probs_un = self.__calculate_logits(self.unlearned_model, test_loader)
        probs_re = self.__calculate_logits(self.retrained_model, test_loader)
        return self.metric(probs_re, probs_un).item()


class JensenShannonDivergence(nn.Module):
    def __init__(self, reduction="batchmean"):
        super().__init__()
        self.reduction = reduction

    def forward(self, p, q):
        epsilon = 1e-20
        p = p.clamp(min=epsilon)
        q = q.clamp(min=epsilon)
        p = p / p.sum(dim=-1, keepdim=True)
        q = q / q.sum(dim=-1, keepdim=True)
        m = 0.5 * (p + q)
        kl_pm = F.kl_div(p.log(), m, reduction=self.reduction)
        kl_qm = F.kl_div(q.log(), m, reduction=self.reduction)
        jsd = 0.5 * (kl_pm + kl_qm)
        # return jsd
        return jsd / torch.log(torch.tensor(2.0, device=jsd.device))


class ToWAccuracy:
    def __init__(self, unlearned_model, retrained_model, device):
        self.unlearned_model = unlearned_model
        self.retrained_model = retrained_model
        self.unlearned_model.to(device)
        self.retrained_model.to(device)
        self.device = device

    def __calculate_accuracy(self, loader):
        self.retrained_model.eval()
        self.unlearned_model.eval()
        un_correct = 0
        re_correct = 0
        with torch.no_grad():
            for data, target in tqdm(loader, total=len(loader)):
                data, target = data.to(self.device), target.to(self.device)
                out_un = self.unlearned_model(data)
                out_re = self.retrained_model(data)
                pred_un = out_un.logits.argmax(dim=1, keepdim=True)
                pred_re = out_re.logits.argmax(dim=1, keepdim=True)
                un_correct += pred_un.eq(target.view_as(pred_un)).sum().item()
                re_correct += pred_re.eq(target.view_as(pred_re)).sum().item()

        un_accuracy = un_correct / len(loader.dataset)
        re_accuracy = re_correct / len(loader.dataset)
        return un_accuracy, re_accuracy

    def calculate(self, forget_loader, retain_loader, test_loader):
        un_forget_acc, re_forget_acc = self.__calculate_accuracy(forget_loader)
        un_retain_acc, re_retain_acc = self.__calculate_accuracy(retain_loader)
        un_test_acc, re_test_acc = self.__calculate_accuracy(test_loader)
        tow_score = (
            (1 - abs(un_forget_acc - re_forget_acc))
            * (1 - abs(un_retain_acc - re_retain_acc))
            * (1 - abs(un_test_acc - re_test_acc))
        )
        return tow_score


class ToWMIA:
    def __init__(
        self,
        unlearned_model,
        retrained_model,
        mia_unlearned: ConfidenceMIA,
        mia_retrained: ConfidenceMIA,
        device,
    ):
        self.unlearned_model = unlearned_model
        self.retrained_model = retrained_model
        self.unlearned_model.to(device)
        self.retrained_model.to(device)
        self.device = device
        self.un_mia = mia_unlearned
        self.re_mia = mia_retrained

    def __calculate_accuracy(self, loader):
        self.retrained_model.eval()
        self.unlearned_model.eval()
        un_correct = 0
        re_correct = 0
        with torch.no_grad():
            for data, target in tqdm(loader, total=len(loader)):
                data, target = data.to(self.device), target.to(self.device)
                out_un = self.unlearned_model(data)
                out_re = self.retrained_model(data)
                pred_un = out_un.logits.argmax(dim=1, keepdim=True)
                pred_re = out_re.logits.argmax(dim=1, keepdim=True)
                un_correct += pred_un.eq(target.view_as(pred_un)).sum().item()
                re_correct += pred_re.eq(target.view_as(pred_re)).sum().item()

        un_accuracy = un_correct / len(loader.dataset)
        re_accuracy = re_correct / len(loader.dataset)
        return un_accuracy, re_accuracy

    def calculate(self, forget_loader, retain_loader, test_loader):
        un_forget_mia = self.un_mia.calculate(forget_loader)
        re_forget_mia = self.re_mia.calculate(forget_loader)
        un_retain_acc, re_retain_acc = self.__calculate_accuracy(retain_loader)
        un_test_acc, re_test_acc = self.__calculate_accuracy(test_loader)
        tow_score = (
            (1 - abs(un_forget_mia - re_forget_mia))
            * (1 - abs(un_retain_acc - re_retain_acc))
            * (1 - abs(un_test_acc - re_test_acc))
        )
        return tow_score


def l2_distance_between_models(model_a, model_b):
    model_a.to("cpu")
    model_b.to("cpu")
    distance = 0.0
    for param_a, param_b in zip(model_a.parameters(), model_b.parameters()):
        distance += torch.norm(param_a.data - param_b.data, p=2).item() ** 2
    return distance**0.5
