import torch
import inspect
from util.CalculateRiskScore import calculate_risk_score
from util.MetricsCalculation import entropy, confidence, modified_entropy, grad_norm
from util.ModelParser import parse_model
import copy
from MIA.MIA import MIA
from torchmetrics.functional.classification import binary_auroc
from torchmetrics.classification import BinaryROC, BinaryAccuracy
from torchmetrics.functional.classification import binary_accuracy
from sklearn import metrics 
import numpy as np
import os



metric_map = {
    'entropy': entropy,
    'confidence': confidence,
    'modified_entropy': modified_entropy,
}

def minmax_scale_tensor(x: torch.Tensor) -> torch.Tensor:
    min_val = torch.min(x)
    max_val = torch.max(x)
    return (x - min_val) / (max_val - min_val + 1e-12)


class MetricMIA(MIA):
    def __init__(self, name, metric, shadow_model_path, mia_mode="attack", num_classes=10, threshold=0.5, load_epoch=200, total_epoch=200, device='cuda', **kwargs):
        super().__init__(name, threshold, metric, mia_mode)
        self.score_type = metric
        print(f"score:{self.score_type}")
        
        if isinstance(metric, str):
            print(f"Using metric: {metric}")
            assert metric in metric_map, f"Unknown metric: {metric}"
            self.metric = metric_map[metric]


        self.shadow_train_metrics = None
        self.shadow_test_metrics = None
        self.shadow_train_labels = None
        self.shadow_test_labels = None

        if load_epoch == total_epoch:
            shadow_model_path = f'{shadow_model_path}.ckpt'
        else:
            shadow_model_path = f'{shadow_model_path}_{load_epoch}.ckpt'

        self.shadow_model_path = shadow_model_path
        
        self.device = device
        if mia_mode == "attack":
            self.attack = True
        else:
            self.attack = False
        self.num_classes = num_classes 

        self.infer_original_label = None
        self.infer_score = None
        self.infer_mn_label = None
        self.count = 1
        self.thresholds = []


        

    @torch.no_grad()
    def _model_prediction(self, model, dataloader):
        model.eval()
        return_outputs = torch.tensor([], device=self.device)
        return_labels = torch.tensor([], device=self.device)

        for (inputs, labels, *_) in dataloader:
            outputs = model(inputs.to(self.device)) 
            probs = torch.softmax(outputs, dim=1)
            return_outputs = torch.cat((return_outputs, torch.tensor(probs, device=self.device)), dim=0)
            return_labels = torch.cat((return_labels, labels.to(self.device)), dim=0)
        return (return_outputs, return_labels)
    
    def _model_infer_prediction(self, model, inputs):
        model.eval()
        outputs = model(inputs.to(self.device)) 
        probs = torch.softmax(outputs, dim=1)
        return torch.tensor(probs, device=self.device)

    def _model_performance(self, model, train_dataloader, test_dataloader):
        train_outputs = self._model_prediction(model, train_dataloader)
        test_outputs = self._model_prediction(model, test_dataloader)

        return train_outputs, test_outputs
    
    def _calculate_threshold(self, train_metrics, test_metrics):
        value_list = torch.cat([train_metrics, test_metrics], dim=0)
        thre, max_acc = 0, 0
        for value in value_list:
            tr_ratio = torch.sum(train_metrics >= value) / len(train_metrics)
            te_ratio = torch.sum(test_metrics < value) / len(test_metrics)
            acc = 0.5 * (tr_ratio + te_ratio)
            if acc > max_acc:
                thre, max_acc = value, acc
        return thre
    
    def _threshold(self):
        thresholds = torch.zeros(self.num_classes, device=self.device)
        for num in range(self.num_classes):
            train_index = torch.where(self.s_tr_labels == num)[0]
            train_metric = self.s_tr_metric[train_index]
            
            test_index = torch.where(self.s_te_labels == num)[0]
            test_metric = self.s_te_metric[test_index] 
            thresholds[num] = self._calculate_threshold(train_metric, test_metric)
        self.thresholds = thresholds

    def fit(self, model, fit_data_loaders, **kwargs):
        
        if self.attack:
            # get the first shadow dataset
            shadow_train_data = fit_data_loaders["shadow_member"][0]
            shadow_test_data =  fit_data_loaders["shadow_nonmember"][0]

            # load shadow model
            self.shadow_model = copy.deepcopy(model).to(self.device)
            state_dict = torch.load(self.shadow_model_path, map_location=self.device)
            if "dpsgd" in self.shadow_model_path:
                new_state_dict = {}
                for k, v in state_dict.items():
                    name = k
                    if name.startswith('_module.'):
                        name = name[8:]
                    new_state_dict[name] = v
                state_dict = new_state_dict
            self.shadow_model.load_state_dict(state_dict)
            # calculate metrics
            s_train_performance, s_test_performance = self._model_performance(self.shadow_model, shadow_train_data, shadow_test_data) 

            self.s_tr_outputs, self.s_tr_labels = s_train_performance
            self.s_te_outputs, self.s_te_labels = s_test_performance

            # calculate metrics for shadow train and shadow test
            self.s_tr_metric = self.metric(self.s_tr_outputs, self.s_tr_labels)
            self.s_te_metric = self.metric(self.s_te_outputs, self.s_te_labels)

            # shadow train and target train
            self._threshold()
        else:
            self.threshold = None


    def infer(self, model, data_batch, labels=None):

        model.eval()
        batch_size = len(data_batch)
        m_nm_size = batch_size // 2
        
        infer_performance = self._model_infer_prediction(model, data_batch)
        infer_results = []

        infer_outputs = infer_performance
        infer_metric = self.metric(infer_outputs, labels)


        # save scores and labels:
        if self.infer_score is None:
            self.infer_score = infer_metric
            self.infer_original_label = labels
        else:
            self.infer_score = torch.cat([torch.atleast_1d(self.infer_score),torch.atleast_1d(infer_metric)])
            self.infer_original_label = torch.cat([torch.atleast_1d(self.infer_original_label),torch.atleast_1d(labels)])
        
        m_nm_labels = torch.tensor([torch.tensor(1)] * m_nm_size + [torch.tensor(0)] * m_nm_size).to(self.device)
        if self.infer_mn_label is None:
            m_nm_total = m_nm_labels
        else:
            m_nm_total = torch.cat([self.infer_mn_label, m_nm_labels])
        self.infer_mn_label = m_nm_total

        # infer mode
        self.count += 1

        if self.attack:
            for label, metric in zip(labels, infer_metric):
                if metric >= self.thresholds[label]:
                    infer_results.append(1)
                else:
                    infer_results.append(0)

            return torch.tensor(infer_results, device=self.device), None
        else:
            return None, None
    
    def output(self):
        unique_labels = torch.arange(torch.min(self.infer_original_label),torch.max(self.infer_original_label) )
        
        idx_list = []
        best_accuracy = []
        result = []
        auc_list = []
        fpr_tpr_001 = []
        fpr_tpr_0001 = []
        result = self.infer_mn_label.clone()

        for exact_label in unique_labels:
            idx = torch.stack(torch.where(self.infer_original_label == exact_label))
            idx_list.append(idx)
            y_score = self.infer_score[idx]
            y_true = self.infer_mn_label[idx]
            auc = binary_auroc(y_score, y_true, thresholds=None)
            auc_list.append(auc)
            fpr, tpr, thresholds = metrics.roc_curve(y_score=np.array(y_score.cpu()).ravel(),y_true=np.array(y_true.cpu()).ravel())
            fpr_tpr_001.append(np.max(tpr[fpr <= 0.001]))
            fpr_tpr_0001.append(np.max(tpr[fpr <= 0.0001]))

            acc = []
            for th in thresholds:
                pred = (y_score >= th).int()
                acc.append(binary_accuracy(pred, y_true, threshold=0.0))  

            acc = torch.stack(acc)
            best_idx = acc.argmax()
            self.thresholds.append(thresholds[best_idx])
            best_accuracy.append(acc[best_idx])
            result[torch.where((self.infer_original_label == exact_label) & (self.infer_score > thresholds[best_idx]))]= 1 
            result[torch.where((self.infer_original_label == exact_label) & (self.infer_score <= thresholds[best_idx]))] = 0 
            
        member_pred = result[torch.where(self.infer_mn_label == 1)]
        nonmember_pred = result[torch.where(self.infer_mn_label == 0)]

        tp = torch.sum(member_pred)
        fn = member_pred.shape[0] - tp
        tn = torch.sum(nonmember_pred)
        fp = nonmember_pred.shape[0]  - tn

        return {
            "auc": torch.mean(torch.tensor(auc_list)),
            "best_accuracy": (tp+fp)/(tp+fn+tn+fp),
            "predict": result,
            "member_pred": member_pred,
            "nonmember_pred": nonmember_pred,
            "tpr01fpr": np.mean(fpr_tpr_001),
            "tpr001fpr": np.mean(fpr_tpr_0001),
            "tp": tp,
            "fn": fn,
            "tn": tn,
            "fp": fp,
        }
        

        


