"""
Paper: Membership Inference Attacks Against Machine Learning Models
Link: https://arxiv.org/pdf/1610.05820
Code: https://github.com/csong27/membership-inference
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from sympy.stats.rv import probability
from torch.utils.data import DataLoader, TensorDataset, random_split, Subset
import numpy as np
import sys
import os
from types import SimpleNamespace
import random

from MIA.MIA import MIA
from util.OptimParser import parse_optim
from util.SeqParser import parse_seq
from util.ModelParser import parse_model
from util.DataParser import parse_data
from sklearn.metrics import accuracy_score

if 'MIABench' in os.getcwd():  
    sys.path.insert(0, os.getcwd())
else:  
    sys.path.insert(0, os.path.join(os.getcwd(), 'MIABench'))


def update_args_with_defaults(args):
    """Set different default parameters according to different MIA settings"""
    defaults = {
        # "hidden_size": 50,
        "epochs": 100,
        "optim":
                {"name": "sgd",
                "lr": 0.1,
                "momentum": 0.9,
                "weight_decay": 5.0e-4},
        "lr_schedule":{
                "name": "jump",
                "min_jump_pt": 100,
                "jump_freq": 50,
                "start_v": 0.1,
                "power": 0.1},
        "dataset": "cifar10",
        "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
        "shadow_model_type": 'resnet',
        "n_shadow": 5,
        "normalize": False,
        "samples": None,
        "load_model": False 
    }
    for key, value in defaults.items():
        if not hasattr(args, key):  # If the parameter does not exist, dynamically add it
            setattr(args, key, value)
    return args


class ShadowMIA(MIA):
    def __init__(self, name="ShadowMIA", threshold=None, metric=None, mia_mode="attack",**_):
        """
        name: name of this method
        threshold: float, the threshold to identify member or non-member
        metric: metric function, you can obtain the metric number by metric(model, data)
        """
        super().__init__(name, threshold, metric, mia_mode)
        self.args = None
        #self.shadow_train_loader = None
        #self.shadow_test_loader = None
        self.shadow_train_loaders = []  # Store the "train loader" for each "shadow model"
        self.shadow_test_loaders = []
        self.shadow_model_list = []  # Save multiple trained shadow models
        # self.attack_model = None
        self.attack_model_dict = {}  # Store "attack models" for each label
        assert self.mia_mode == "attack", "ShadowMIA only supports attack mode."

    def fit(self, model, fit_data_loaders, **kwargs):
        '''
        model: model under MIA
        shadow_data_generator: a generator to generate shadow data by shadow_data_generator.next()
        '''
        self.args = update_args_with_defaults(SimpleNamespace(**kwargs))
        print("args", self.args)
        shadow_train_loaders = fit_data_loaders["shadow_member"]
        shadow_test_loaders = fit_data_loaders["shadow_nonmember"]
        assert shadow_train_loaders is not None, "shadow_train_loaders should not be None in attack mode"
        assert shadow_test_loaders is not None, "shadow_test_loaders should not be None in attack mode"
        self.shadow_train_loaders = shadow_train_loaders
        self.shadow_test_loaders = shadow_test_loaders
        self.shadow_model_list = self._get_shadow_models()
        self.attack_model_dict = self._train_attack_model()
        # print("keys",self.attack_model_dict.keys())

    def infer(self, model, data, label):
        '''
        model: model under MIA
        data: batch of data, you can obtain the logit by model(data)
        label: true target of batch of data
        '''
        model = model.to(self.args.device)
        data = data.to(self.args.device)
        label = label.to(self.args.device)

        batch_size = data.shape[0]
        model.eval()

        with torch.no_grad():
            probabilities = F.softmax(model(data), dim=1)
            # sorted_probs, _ = torch.sort(probabilities, dim=1, descending=True)
            # attack_inputs = torch.cat((sorted_probs, label.unsqueeze(1).float()), dim=1)
            attack_inputs = probabilities
            outputs = []
            for i in range(len(label)):
                attack_model = self.attack_model_dict[int(label[i].item())]
                outputs.append(attack_model(attack_inputs[i].unsqueeze(0)))
            outputs = torch.cat(outputs, dim=0)
            preds = torch.argmax(outputs, dim=1)
        res = {"preds": preds, "outputs": outputs}
        mn_label = [1]*(batch_size//2) + [0]*(batch_size//2)
        acc = accuracy_score(mn_label, torch.argmax(outputs, dim=1).cpu())
        print(acc)
        return preds, outputs

    def _get_shadow_models(self):
        """
        Get shadow model: If self.args.load_madel exists and a saved model exists, load it;
             Otherwise, save after training.
        """
        shadow_model_list = []
        #  file suffix based on load_ epoch (e.g. use. ckpt for 200 rounds, otherwise use _ {epoch}. ckpt)
        #  TODO: fix this hard code here
        if self.args.load_epoch == self.args.total_epoch:
            suffix = '.ckpt'
        else:
            suffix = f'_{self.args.load_epoch}.ckpt'
        
        for i in range(self.args.n_shadow):
            # Save path of shadow model: args.model2load/{i}/model_flename 
            model_dir = os.path.join(self.args.model2load, str(i))
            model_filename = "resnet"+ suffix
            model_path = os.path.join(model_dir, model_filename)
            
            if self.args.load_model and os.path.exists(model_path):
                print(f"Load the saved shadow model:{model_path}")
                # Initialize the model and load the previously saved state dictionary
                model = parse_model(self.args.dataset, arch=self.args.shadow_model_type, normalize=self.args.normalize)
                model.to(self.args.device)
                state_dict = torch.load(model_path, weights_only=True)
                from opacus.validators import ModuleValidator
                if "dpsgd" in model_path:
                    errors = ModuleValidator.validate(model, strict=False)
                    if errors:
                        model = ModuleValidator.fix(model)

                    new_state_dict = {}
                    for k, v in state_dict.items():
                        name = k
                        if name.startswith('_module.'):
                            name = name[8:]
                        new_state_dict[name] = v
                    state_dict = new_state_dict
                model.load_state_dict(state_dict)
            else:
                print(f"Model file not found, start training shadow model, index:{i}")
                # Train Shadow Model
                model = train_model(self.args, self.shadow_train_loaders[i], self.args.shadow_model_type)
                # save
                os.makedirs(model_dir, exist_ok=True)
                torch.save(model.state_dict(), model_path)
            shadow_model_list.append(model)
        
        return shadow_model_list


    def _train_attack_model(self):
        """train attack model"""
        attack_train_loader_dict = self._get_attack_model_train_loader()
        attack_model_dict = {}
        # Train multiple attack models, each corresponding to a different label value.
        for label in attack_train_loader_dict.keys():
            attack_model = train_model(self.args, attack_train_loader_dict[label], None, is_attack=True)
            attack_model_dict[label] = attack_model
        return attack_model_dict

    def _get_attack_model_train_loader(self):
        """
        Create a training set for the attack model (using multiple shadow models and shadow datasets)
        """
        batch_size = self.shadow_train_loaders[0].batch_size
        new_data = []  # data for "attack model"
        new_target = []  # target of "attack model",means member/nonmember
        labels = []  # label of "shadow model"
        for i in range(self.args.n_shadow):
            shadow_model = self.shadow_model_list[i]
            shadow_model.eval()  

            # Helper function to process a dataloader
            def process_dataloader(dataloader, target_label):
                for inputs, label, _ in dataloader:
                    inputs = inputs.to(self.args.device)
                    label = label.to(self.args.device)
                    with torch.no_grad():
                        outputs = shadow_model(inputs)
                        probabilities = F.softmax(outputs, dim=1)
                        # sorted_probs, _ = torch.sort(probabilities, dim=1, descending=True)
                    # combined_data = torch.cat((sorted_probs, label.unsqueeze(1).float()), dim=1)  # 添加 label
                    combined_data = probabilities  # no sorting or label
                    new_data.append(combined_data)

                    # one_hot_labels = torch.tensor([[1, 0] if target_label == 0 else [0, 1]] * outputs.size(0))
                    new_target.extend([target_label]*outputs.size(0)) # Add corresponding labels for each sample
                    labels.extend(label.cpu().tolist())
            # Infer the data of shadow_train_dataloader to get "attack model data"
            #  with "attack model target" 1
            process_dataloader(self.shadow_train_loaders[i], target_label=1)

            # Infer the data of shadow_test_dataloader to get "attack model data"
            #  with "attack model target" 0
            process_dataloader(self.shadow_test_loaders[i], target_label=0)

        # Concat all data and labels together
        new_data = torch.cat(new_data)  
        new_target = torch.tensor(new_target)  
        labels = torch.tensor(labels)
        # print(new_data.shape, new_target.shape)


        # create attack dataset and dataloader
        attack_train_dataset = TensorDataset(new_data, new_target)

        # unique label 
        unique_labels = torch.unique(labels)

        # store  DataLoader for each label (original dataset label)
        attack_train_loader_dict = {}

        for label in unique_labels:
            # Find the index corresponding to the current label
            indices = (labels == label).nonzero(as_tuple=True)[0]
            subset_dataset = Subset(attack_train_dataset, indices.tolist())
            loader = DataLoader(subset_dataset, batch_size=batch_size, shuffle=True)
            attack_train_loader_dict[int(label.item())] = loader
        # attack_train_loader = DataLoader(attack_train_dataset, batch_size=batch_size, shuffle=True)

        return attack_train_loader_dict


# Training function
def train_model(args, train_loader, model_type, is_attack=False):
    # n_hidden=50, batch_size=100, epochs=100, learning_rate=0.01, model_type='cnn', l2_ratio=1e-7
    # n_classes=?
    inputs, _, *_ = next(iter(train_loader))
    input_shape = inputs.shape
    if is_attack:
        print("attack model")
        model = LinearModel(input_shape[-1], 2)
        # optim = {"name": "adam","lr": 0.01,}
        #optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
        #epochs = 30
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
        epochs = 100
        #print(next(iter(train_loader)))
    else:
        # model = get_model(model_type, input_shape, args.n_classes, hidden_size=args.hidden_size)
        #optim = args.optim
        model = parse_model(dataset = args.dataset, arch = model_type, normalize = args.normalize)
        optimizer = parse_optim(policy = args.optim, params = model.named_parameters())
        epochs = args.epochs

    criterion = nn.CrossEntropyLoss()  

    lr_func = parse_seq(**args.lr_schedule) if (hasattr(args, "lr_schedule") and not is_attack) else None

    # Training loop
    model.to(args.device)
    model.train()
    for epoch in range(epochs):
        total = 0
        correct = 0
        running_loss = 0.0
        for idx, (inputs, targets, *_) in enumerate(train_loader):

            if lr_func is not None:
                epoch_batch_idx = epoch
                epoch_batch_idx += idx / len(train_loader)
                lr_this_batch = lr_func(epoch_batch_idx)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_this_batch

            inputs, targets = inputs.to(args.device), targets.to(args.device)

            outputs = model(inputs)
            # print("criterion",outputs[0],targets.shape)
            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            # calculate accuracy
            _, predicted = torch.max(outputs, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
        
        train_accuracy = 100 * correct / total
        if epoch % 10 == 0 or epoch == epochs-1:
            print(f'Epoch {epoch}, Loss: {running_loss:.3f}, Train Accuracy: {train_accuracy:.2f}%')
        # test acc
    return model

class DNNModel(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(DNNModel, self).__init__()
        # Define layers
        self.fc1 = nn.Linear(input_dim, 512)  # First Dense layer
        self.dropout1 = nn.Dropout(0.2)      # First Dropout layer
        self.fc2 = nn.Linear(512, 256)       # Second Dense layer
        self.dropout2 = nn.Dropout(0.2)      # Second Dropout layer
        self.fc3 = nn.Linear(256, 128)       # Third Dense layer
        self.fc4 = nn.Linear(128, num_classes)  # Final output layer
        # self.softmax = nn.Softmax(dim=1)  

    def forward(self, x):
        # Forward pass
        x = F.relu(self.fc1(x))              # Activation for fc1
        x = self.dropout1(x)                 # Apply dropout after activation
        x = F.relu(self.fc2(x))              # Activation for fc2
        x = self.dropout2(x)                 # Apply dropout after activation
        x = F.relu(self.fc3(x))              # Activation for fc3
        x = self.fc4(x)                      # No activation here
        #x = torch.sigmoid(x)                 # Sigmoid activation for final output
        #x = self.softmax(x)
        return x

class NNModel(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(NNModel, self).__init__()
        # Define layers
        self.fc1 = nn.Linear(input_dim, 256)  # First Dense layer
        self.dropout1 = nn.Dropout(0.2)      # First Dropout layer

        self.fc2 = nn.Linear(256, 128)       # Third Dense layer
        self.fc3 = nn.Linear(128, num_classes)  # Final output layer
        # self.softmax = nn.Softmax(dim=1)  

    def forward(self, x):
        # Forward pass
        x = F.relu(self.fc1(x))              # Activation for fc1
        x = self.dropout1(x)                 # Apply dropout after activation
        x = F.relu(self.fc2(x))              # Activation for fc2
        x = self.fc3(x)                      # No activation here
        return x
    
class LinearModel(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(LinearModel, self).__init__()
        # Define layers
        self.fc1 = nn.Linear(input_dim, num_classes)  # First Dense layer
        # self.softmax = nn.Softmax(dim=1)  

    def forward(self, x):
        # Forward pass
        x = self.fc1(x)          
        return x