import torch
import torch.nn as nn
from MIA.MIA import MIA
from torchmetrics import ROC
from torchmetrics.functional.classification import binary_auroc
from torchmetrics.classification import BinaryROC, BinaryAccuracy
from torchmetrics.functional.classification import binary_accuracy
from sklearn import metrics
import numpy as np
import copy


class LogLoss(nn.Module):
    def __init__(self, small_value=1e-8):
        super(LogLoss, self).__init__()
        self.small_value = small_value

    def forward(self, input, target):

        indices = target.long()

        selected_probabilities = input.gather(1, indices.unsqueeze(1)).squeeze()

        selected_probabilities = torch.clamp(selected_probabilities, min=self.small_value)

        loss = -torch.log(selected_probabilities)
        return loss


class Merlin(MIA):
    def __init__(self, name, threshold, metric, shadow_model_path, mia_mode="attack", norm_type="l2", epsilon=1,
                 sigma=0.1, num_runs=100, load_epoch=200, total_epoch=200, device='cuda', **_):
        super().__init__(name, threshold, metric, mia_mode)
        self.epsilon = epsilon
        self.sigma = sigma
        self.num_runs = num_runs
        self.threshold = None
        self.multi_label = True
        self.num_class = None
        if mia_mode == "attack":
            self.attack = True
        else:
            self.attack = False
        self.device = device
        
        if load_epoch == total_epoch:
            shadow_model_path = f'{shadow_model_path}.ckpt'
        else:
            shadow_model_path = f'{shadow_model_path}_{load_epoch}.ckpt'

        self.shadow_model_path = shadow_model_path

        self.infer_original_label = None
        self.infer_score = None
        self.infer_mn_label = None
        self.count = 1

    def fit(self, model, fit_data_loaders, **kwargs):
        if self.attack:
            # get the first shadow dataset
            member_data_loader = fit_data_loaders["shadow_member"][0]
            nonmember_data_loader = fit_data_loaders["shadow_nonmember"][0]

            # load shadow model
            self.shadow_model = copy.deepcopy(model).to(self.device)
            state_dict = torch.load(self.shadow_model_path, map_location=self.device)
            if "dpsgd" in self.shadow_model_path:
                new_state_dict = {}
                for k, v in state_dict.items():
                    name = k
                    if name.startswith('_module.'):
                        name = name[8:]
                    new_state_dict[name] = v
                state_dict = new_state_dict
            self.shadow_model.load_state_dict(state_dict)

            member_ratios, member_labels = self._compute_increase_ratios(self.shadow_model, member_data_loader, self.device)
            nonmember_ratios, non_member_labels = self._compute_increase_ratios(self.shadow_model, nonmember_data_loader, self.device)

            all_ratios = torch.cat([member_ratios, nonmember_ratios])
            all_labels = torch.cat([member_labels, non_member_labels])


            labels = torch.cat(
                [torch.ones_like(member_ratios), torch.zeros_like(nonmember_ratios)])  # 1=member, 0=nonmember

            # TPR - FPR maximum -> threshold
            thresholds = torch.linspace(0, 1, steps=100)
            best_thresh = 0
            best_youden = -1

            if self.multi_label:
                self.num_class = int(all_labels.max().item()) + 1
                self.label_threshold = [0]* self.num_class
                for c in range(self.num_class):
                    index = torch.where(all_labels == c)
                    label_ratio = all_ratios[index]
                    specific_label = labels[index]
                    roc = ROC(task="binary")
                    fpr, tpr, thresholds = roc(label_ratio, specific_label.int())

                    youden = tpr - fpr
                    best_idx = youden.argmax()
                    best_thresh = thresholds[best_idx].item()
                    self.threshold = best_thresh
                    self.label_threshold[c] = best_thresh
            else:
                roc = ROC(task="binary")
                fpr, tpr, thresholds = roc(all_ratios, labels.int())

                youden = tpr - fpr
                best_idx = youden.argmax()
                best_thresh = thresholds[best_idx].item()

                self.threshold = best_thresh
        else:
            self.threshold = None


    def _compute_increase_ratios(self, model, data_loader, device):
        model.eval()

        # original loss: log form or cross entropy?

        # loss_fn = nn.CrossEntropyLoss(reduction='none')
        loss_fn = LogLoss()
        increase_ratios = []
        data_labels = []

        for batch in data_loader:
            data, labels = batch[0].to(device), batch[1].to(device)
            batch_size = data.size(0)

            with torch.no_grad():
                outputs = model(data)
                original_losses = loss_fn(outputs, labels)

            
            repeated_data = data.unsqueeze(1).repeat(1, self.num_runs, 1, 1, 1)
            repeated_data = repeated_data.view(-1, *data.shape[1:])
            repeated_labels = labels.unsqueeze(1).repeat(1, self.num_runs).view(-1)

            noise = torch.normal(0,self.sigma,repeated_data.shape).to(self.device)
            perturbed_data = repeated_data + noise

            with torch.no_grad():
                outputs_perturbed = model(perturbed_data)
                perturbed_losses = loss_fn(outputs_perturbed, repeated_labels)
                perturbed_losses = perturbed_losses.view(batch_size, self.num_runs)     # TODO: check this

            delta_losses = perturbed_losses - original_losses.unsqueeze(1)
            ratio = (delta_losses > 0).float().mean(dim=1)  # shape: (B,)
            increase_ratios.append(ratio)
            data_labels.append(labels)

        return torch.cat(increase_ratios), torch.cat(data_labels)  # shape: (N,)

    def infer(self, model, data, labels=None):
        """
        Args:
            data: Tensor of shape (B, C, H, W)
            label: Tensor of shape (B,)
        Returns:
            membership_vector: Tensor of shape (B,) with values 0 or 1
        """
        model.eval()
        device = next(model.parameters()).device
        data, labels = data.to(device), labels.to(device)
        batch_size = data.size(0)
        m_nm_size = batch_size // 2

        # loss_fn = nn.CrossEntropyLoss(reduction='none')
        loss_fn = LogLoss()

        with torch.no_grad():
            outputs = model(data)
            original_losses = loss_fn(outputs, labels)  # shape: (B,)

        repeated_data = data.unsqueeze(1).repeat(1, self.num_runs, 1, 1, 1)
        repeated_data = repeated_data.view(-1, *data.shape[1:])  # shape: (B * num_runs, ...)
        repeated_labels = labels.unsqueeze(1).repeat(1, self.num_runs).view(-1)

        noise = self.epsilon * self.sigma * torch.randn_like(repeated_data)
        perturbed_data = repeated_data + noise

        with torch.no_grad():
            outputs_perturbed = model(perturbed_data)
            perturbed_losses = loss_fn(outputs_perturbed, repeated_labels)
            perturbed_losses = perturbed_losses.view(batch_size, self.num_runs)

        delta_losses = perturbed_losses - original_losses.unsqueeze(1)  # shape: (B, num_runs)
        loss_increase_ratio = (delta_losses > 0).float().mean(dim=1)  # shape: (B,)

        # save scores and labels:
        if self.infer_score is None:
            self.infer_score = loss_increase_ratio
            self.infer_original_label = labels
        else:
            self.infer_score = torch.cat([torch.atleast_1d(self.infer_score), torch.atleast_1d(loss_increase_ratio)])
            self.infer_original_label = torch.cat(
                [torch.atleast_1d(self.infer_original_label), torch.atleast_1d(labels)])

        m_nm_labels = torch.tensor([torch.tensor(1)] * m_nm_size + [torch.tensor(0)] * m_nm_size).to(self.device)
        if self.infer_mn_label is None:
            m_nm_total = m_nm_labels
        else:
            m_nm_total = torch.cat([self.infer_mn_label, m_nm_labels])
        self.infer_mn_label = m_nm_total

        # infer mode
        self.count += 1

        if self.attack:
            if self.multi_label:
                label_thresholds = torch.tensor([self.label_threshold[int(l)] for l in labels]).to(self.device)
                membership_vector = torch.where(loss_increase_ratio >= label_thresholds, 1, 0)
            else:  
                membership_vector = (loss_increase_ratio >= self.threshold).long()  # shape: (B,)
            return membership_vector, None  # 0: non-member, 1: member
        else:
            return None, None

    def output(self):
        unique_labels = torch.arange(torch.min(self.infer_original_label), torch.max(self.infer_original_label))

        idx_list = []
        best_accuracy = []
        result = []
        auc_list = []
        fpr_tpr_001 = []
        fpr_tpr_0001 = []
        result = self.infer_mn_label.clone()
        
        self.label_threshold = []

        for exact_label in unique_labels:
            idx = torch.stack(torch.where(self.infer_original_label == exact_label))
            idx_list.append(idx)
            y_score = self.infer_score[idx]
            y_true = self.infer_mn_label[idx]
            auc = binary_auroc(y_score, y_true, thresholds=None)
            auc_list.append(auc)
            fpr, tpr, thresholds = metrics.roc_curve(y_score=np.array(y_score.cpu()).ravel(),
                                                     y_true=np.array(y_true.cpu()).ravel())
            fpr_tpr_001.append(np.max(tpr[fpr <= 0.001]))
            fpr_tpr_0001.append(np.max(tpr[fpr <= 0.0001]))

            acc = []
            for th in thresholds:
                pred = (y_score >= th).int()
                acc.append(binary_accuracy(pred, y_true, threshold=0.0))

            acc = torch.stack(acc)
            best_idx = acc.argmax()
            self.label_threshold.append(thresholds[best_idx])
            best_accuracy.append(acc[best_idx])
            result[
                torch.where((self.infer_original_label == exact_label) & (self.infer_score > thresholds[best_idx]))] = 1
            result[torch.where(
                (self.infer_original_label == exact_label) & (self.infer_score <= thresholds[best_idx]))] = 0

        member_pred = result[torch.where(self.infer_mn_label == 1)]
        nonmember_pred = result[torch.where(self.infer_mn_label == 0)]

        tp = torch.sum(member_pred)
        fn = member_pred.shape[0] - tp
        tn = torch.sum(nonmember_pred)
        fp = nonmember_pred.shape[0] - tn

        return {
            "auc": torch.mean(torch.tensor(auc_list)),
            "best_accuracy": (tp + fp) / (tp + fn + tn + fp),
            "predict": result,
            "member_pred": member_pred,
            "nonmember_pred": nonmember_pred,
            "tpr01fpr": np.mean(fpr_tpr_001),
            "tpr001fpr": np.mean(fpr_tpr_0001),
            "tp": tp,
            "fn": fn,
            "tn": tn,
            "fp": fp,
        }

