import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from scipy.stats import norm
from MIA.MIA import MIA
from util.ModelParser import parse_model
import os
from torch.distributions import Normal
import sklearn.metrics as metrics
import numpy as np
from torch.utils.data import Subset, ConcatDataset
from util.MetricsCalculation import get_metric
import torch.nn.functional as F

# [13]Membership inference attacks from first principles
# https://github.com/tensorflow/privacy


class LiRA(MIA):
    def __init__(self, name,  threshold, metric, model_class=None, mia_mode="attack",load_epoch=None,total_epoch=200, **kwargs):
        super().__init__(name, threshold, metric, mia_mode)
        
        self.shadow_models = []  # Store shadow models
        self.member_scores = []  # Training data confidence
        self.non_member_scores = []  # Non-training data confidence
        self.model_class = model_class  # Allow passing a model class if necessary
        self.shadow_mu = None
        self.shadow_sigma = None
        self.fit_config = {}
        self.threshold = 0.5
        self.pred_reslut = None
        self.true_result = None
        if mia_mode == "attack":
            self.attack = True
        else:
            self.attack = False

        self.infer_score = torch.tensor([])
        self.infer_label = None
        self.count = 1

        self.auc = None,
        self.tpr = None,
        self.tpr= None,
        self.fpr_tpr_0001 = None,
        self.fpr_tpr_0001 = None,

    


    def fit(self, model, fit_data_loaders , epochs=10, num_shadow_models=5, **kwargs):
        

        shadow_test_data_loader = fit_data_loaders["shadow_nonmember"]
        shadow_train_data_loader = fit_data_loaders["shadow_member"]
        
        target_train_data_loader = fit_data_loaders["member_train"]
        target_test_data_loader = fit_data_loaders["nonmember_train"]


        # if attack: use shadow_*_data_loader, else use target_*_data_loader

        # TODO: This part need to be discussed in the paper

        self.device = next(model.parameters()).device  
        self.infer_score = self.infer_score.to(self.device)
            
        self.shadow_test_loader = self.out_data = shadow_test_data_loader
        self.shadow_train_loaders = shadow_train_data_loader
        self.num_shadow_models = num_shadow_models


        for key, value in kwargs.items():
            self.fit_config[key] = value

        # get shadow models
        self.shadow_model_list = self._get_shadow_models()

        if self.attack:

            dataloader, total_labels = self._concate_dataloader(
                member_loader=shadow_train_data_loader[-1],
                non_member_loader=shadow_test_data_loader[-1],
            )

            # all_data = []
            # with torch.no_grad():
            #     for batch in data:
            #         all_data.append(batch[0]) 
            #     full_tensor = torch.cat(all_data, dim=0)
            #     data = full_tensor


        # with torch.no_grad():
        #     for batch in data_loader:
        #         data, y_true, _ = batch
        #         y_true = torch.tensor(y_true)
        #         outputs = model(data.to(self.device))
        #         res = self._stability_scores(outputs, y_true)
        #         logscore = torch.cat([torch.atleast_1d(logscore),torch.atleast_1d(res)])
        # return logscore

            with torch.no_grad():
                total_lira_score = None
                for batch in  dataloader:
                    data, label, _ = batch
                    y_true = label.to(self.device)

                    outputs = model(data.to(self.device))
                    res = self._stability_scores(outputs, y_true)
                    
                    shadow_res = None
                    for i in range((self.num_shadow_models -1)):
                        shadow_model = self.shadow_model_list[i]
                        shadow_model.eval()
                        outputs = shadow_model(data.to(self.device))
                        compute_score = self._stability_scores(outputs, y_true)
                        if shadow_res is None:
                            shadow_res = compute_score.unsqueeze(-1)
                        else:
                            shadow_res = torch.cat([shadow_res, compute_score.unsqueeze(-1)],dim=1)
                    

                    def lira_function(res):
                        # summation = torch.cat([ self.shadow_out.repeat(res.shape[0],1), shadow_res],dim=1 )
                        summation = shadow_res
                        mu_out = summation.mean(dim=1)
                        std_out = summation.std(dim=1)
                        dist_out = Normal(loc=mu_out, scale=std_out+ 1e-30)
                        return -dist_out.log_prob(res)

                    lira_score = lira_function(res) 
                    if total_lira_score is None:
                        total_lira_score = lira_score
                    else:
                        total_lira_score = torch.cat([total_lira_score, lira_score])
            
                self.threshold, *_ = self.best_threshold(score=total_lira_score,m_nm_label=total_labels)

        

    def _get_shadow_models(self):
        """
        Get shadow model: If self.args.load_madel exists and a saved model exists, load it;
             Otherwise, save after training.
        """
        load_epoch = self.fit_config.get("load_epoch", 200)
        total_epoch = self.fit_config.get("total_epoch", 200)
        n_shdow = self.num_shadow_models
        shadow_model_list = []
        model2load = self.fit_config.get("model2load", None)
        dataset = self.fit_config.get("dataset", None)
        shadow_model_type = self.fit_config.get("shadow_model_type", "resnet")
        shadow_model_name = "resnet"
        normalize = self.fit_config.get("normalize", True)

        #  file suffix based on load_ epoch (e.g. use. ckpt for 200 rounds, otherwise use _ {epoch}. ckpt)
        #  TODO: fix this hard code here
        if load_epoch == total_epoch:
            suffix = '.ckpt'
        else:
            suffix = f'_{load_epoch}.ckpt'
        
        for i in range(n_shdow):
            # Save path of shadow model: args.model2load/{i}/model_flename 
            model_dir = os.path.join(model2load, str(i))
            model_filename = shadow_model_name + suffix
            model_path = os.path.join(model_dir, model_filename)
            
            if  os.path.exists(model_path):
                print(f"Load the saved shadow model:{model_path}")
                # Initialize the model and load the previously saved state dictionary
                model = parse_model(dataset, arch=shadow_model_type, normalize=normalize)
                model.to(self.device)
                
                state_dict = torch.load(model_path, weights_only=True)
                from opacus.validators import ModuleValidator
                if "dpsgd" in model_path:
                    errors = ModuleValidator.validate(model, strict=False)
                    if errors:
                        model = ModuleValidator.fix(model)

                    new_state_dict = {}
                    for k, v in state_dict.items():
                        name = k
                        if name.startswith('_module.'):
                            name = name[8:]
                        new_state_dict[name] = v
                    state_dict = new_state_dict
                model.load_state_dict(state_dict)
            else:
                print(f"Model file not found, start training shadow model, index:{i}")
                # Train Shadow Model
                NotImplementedError("Training shadow model is not implemented yet. Please implement the training logic.")
            shadow_model_list.append(model)
        
        return shadow_model_list

    def best_threshold(self, m_nm_label, score):
        # get auc score
        if isinstance(m_nm_label, torch.Tensor):
            m_nm_label = m_nm_label.cpu()

        if isinstance(score, torch.Tensor):
            score = score.cpu()

        fpr, tpr, thresholds = metrics.roc_curve(m_nm_label, score)
        accs = []
        for th in thresholds:
            y_pred = (score > th)
            acc = metrics.accuracy_score(m_nm_label, y_pred)
            accs.append(acc)

        best_idx =  np.argmax(accs)
        auc = metrics.roc_auc_score(y_true=m_nm_label,
                              y_score=score)

        print(f"best evaluation dataset acc:{accs[best_idx]} with auc:{auc}")
        best_threshold = float(thresholds[best_idx])
        return best_threshold, auc, fpr, tpr

    
    def _compute_logscore(self, model, data_loader):
        model.eval()
        correct = torch.tensor([]).to(self.device)
        # get full data result
        probs = torch.tensor([]).to(self.device)
        labels = torch.tensor([]).to(self.device)
        logscore= torch.tensor([]).to(self.device)

        with torch.no_grad():
            for batch in data_loader:
                data, y_true, _ = batch
                y_true = torch.tensor(y_true)
                outputs = model(data.to(self.device))
                res = self._stability_scores(outputs, y_true)
                logscore = torch.cat([torch.atleast_1d(logscore),torch.atleast_1d(res)])
        return logscore




    def _stability_scores(self, opredictions, labels):
        predictions = opredictions - torch.max(opredictions, dim=-1, keepdims=True)[0]
        predictions = torch.exp(predictions)
        predictions = predictions / torch.sum(predictions, dim=-1, keepdims=True)

        COUNT = predictions.size(0)

        y_wrong = torch.sum(predictions, dim=-1)
        idx = torch.arange(COUNT)
        y_true=predictions[idx, labels[:COUNT]]
        predictions[idx, labels[:COUNT]] = 0.0
        y_wrong = torch.sum(predictions, dim=-1) 
        log_score = torch.log(y_true + 1e-45) - torch.log(y_wrong + 1e-45)
        return log_score


    def _compute_scores(self, model, data_loader):
        model.eval()
        logscore= torch.tensor([]).to(self.device)
    
        with torch.no_grad():
            for batch in data_loader:
                data, y_true, _ = batch
                y_true = y_true.to(self.device)
                outputs = model(data.to(self.device))
                res = self._stability_scores(outputs, y_true)
                logscore = torch.cat([torch.atleast_1d(logscore),torch.atleast_1d(res)])

        return logscore, logscore


    def _concate_dataloader(self, member_loader, non_member_loader):
      
        # balanced model 

        member_dataset_len = len(member_loader.dataset)
        non_member_dataset_len = len(non_member_loader.dataset)

        total_len = min(member_dataset_len, non_member_dataset_len)


        member_idx = torch.randperm(member_dataset_len).tolist()[:total_len]
        non_member_idx = torch.randperm(non_member_dataset_len).tolist()[:total_len]

        validate_dataset = torch.utils.data.ConcatDataset([
            Subset(member_loader.dataset, member_idx),
            Subset(non_member_loader.dataset, non_member_idx)  
])
        valdate_dataloader = torch.utils.data.DataLoader(
            validate_dataset,
            batch_size=256,
            shuffle=False,
            num_workers=4,
        )

        validation_label = torch.tensor([torch.tensor(1)] * total_len + [torch.tensor(0)] * total_len)
        return valdate_dataloader, validation_label



    def infer(self, model, data_batch, label_batch):
        
        model.eval()
        batch_size = len(data_batch)
        m_nm_size = batch_size // 2
        
        with torch.no_grad():
            y_true = label_batch.to(self.device)
            # TODO: main error 
        
            outputs = model(data_batch.to(self.device))
            res = self._stability_scores(outputs, y_true)
            
            shadow_res = None
            for i in range((self.num_shadow_models -1)):
                shadow_model = self.shadow_model_list[i]
                shadow_model.eval()
                outputs = shadow_model(data_batch.to(self.device))
                compute_score = self._stability_scores(outputs, y_true)
                if shadow_res is None:
                    shadow_res = compute_score.unsqueeze(-1)
                else:
                    shadow_res = torch.cat([shadow_res, compute_score.unsqueeze(-1)],dim=1)
            


        def lira_function(res):
            # summation = torch.cat([ self.shadow_out.repeat(res.shape[0],1), shadow_res],dim=1 )
            summation = shadow_res
            mu_out = summation.mean(dim=1)
            std_out = summation.std(dim=1)
            dist_out = Normal(loc=mu_out, scale=std_out+ 1e-30)
            return -dist_out.log_prob(res)

        lira_score = lira_function(res) 

        # save result
        self.infer_score = torch.cat([torch.atleast_1d(self.infer_score),torch.atleast_1d(lira_score)])
        m_nm_labels = torch.tensor([torch.tensor(1)] * m_nm_size + [torch.tensor(0)] * m_nm_size)

        if self.infer_label is None:
            m_nm_total = m_nm_labels
        else:
            m_nm_total = torch.cat([self.infer_label, m_nm_labels])

        self.infer_label = m_nm_total

        # get accuracy:
        
        if self.attack:
            y_pred = lira_score >= self.threshold
            auc = metrics.roc_auc_score(y_true=m_nm_labels.cpu(), y_score=-lira_score.cpu())
            print(f"batch acc: {metrics.accuracy_score(m_nm_labels.cpu(), y_pred.cpu())} with auc {auc} ")

        else:
            threshold, auc, fpr, tpr = self.best_threshold(m_nm_label=m_nm_total,
                            score=self.infer_score)
            self.auc = auc
            self.threshold = threshold

            self.tpr = tpr
            self.fpr = fpr

            y_pred = lira_score < self.threshold
            
        
        self.count += 1
        return y_pred, lira_score

    def output(self):
        # result_member_idx = torch.where(self.infer_label == 1)
        # result_non_member_idx = torch.where(self.infer_label == 0)

        # result_member = self.infer_score[result_member_idx] < self.threshold
        # result_non_member = self.infer_score[result_non_member_idx] < self.threshold

        threshold, auc, fpr, tpr = self.best_threshold(m_nm_label=self.infer_label,
                            score=self.infer_score)
        
        self.fpr_tpr_001 = np.max(tpr[fpr <= 0.001])
        self.fpr_tpr_01 = np.max(tpr[fpr <= 0.01])
        self.fpr_tpr_0001 = np.max(tpr[fpr <= 0.0001])

        print(f"fpr_tpr_0001:{self.fpr_tpr_0001} \n fpr_tpr_001:{self.fpr_tpr_001} \n fpr_tpr_01:{self.fpr_tpr_01}")

        result = self.infer_label.clone()
        result[torch.where(self.infer_score > threshold)]= 1 
        result[torch.where(self.infer_score <= threshold)]= 0 
        

        member_pred = result[torch.where(self.infer_label == 1)]
        nonmember_pred = result[torch.where(self.infer_label == 0)]


        tp = torch.sum(member_pred)
        fn = member_pred.shape[0] - tp
        tn = torch.sum(nonmember_pred)
        fp = nonmember_pred.shape[0]  - tn

        return {
            "auc": auc,
            "best_accuracy": (tp+fp)/(tp+fn+tn+fp),
            "predict": result,
            "member_pred": member_pred,
            "nonmember_pred": nonmember_pred,
            "tpr01fpr": self.fpr_tpr_001 ,
            "tpr001fpr": self.fpr_tpr_0001  ,
            "tp": tp,
            "fn": fn,
            "tn": tn,
            "fp": fp,
        }