"""
Paper: ML-Leaks: Model and Data Independent Membership Inference Attacks and
    Defenses on Machine Learning Models
Links: https://arxiv.org/pdf/1806.01246
Code: https://github.com/AhmedSalem2/ML-Leaks
"""
from types import SimpleNamespace

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, random_split, Subset
import numpy as np
import sys
import os

from MIA.MIA import MIA
from util.OptimParser import parse_optim
from util.SeqParser import parse_seq
from util.ModelParser import parse_model

# Automatically add root directory to sys.path
if 'MIABench' in os.getcwd():  
    sys.path.insert(0, os.getcwd())
else:  
    sys.path.insert(0, os.path.join(os.getcwd(), 'MIABench'))

from MIA.MIA import MIA


def update_args_with_defaults(args):
    """Set different default parameters according to different MIA settings"""
    defaults = {
        # "hidden_size": 50,
        "epochs": 100,
        "optim":
            {"name": "sgd",
             "lr": 0.1,
             "momentum": 0.9,
             "weight_decay": 5.0e-4},
        "lr_schedule": {
            "name": "jump",
            "min_jump_pt": 100,
            "jump_freq": 50,
            "start_v": 0.1,
            "power": 0.1},
        "dataset": "cifar10", # Need to be the same as the shadow_data dataset
        "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
        "shadow_model_types": ['resnet18', 'resnet34'], #The various shadow model architectures of adversary1
        "adversary": 1, # 1 2 3
        "threshold": 0.1,  # The threshold range for advanced 3 is 0 to 1, where 0 tends to be judged as a non member
        "normalize": False,
        "samples": None,
        "load_model": False 
    }
    for key, value in defaults.items():
        if not hasattr(args, key): 
            setattr(args, key, value)
    return args


class MLLeaks(MIA):
    def __init__(self, name="MLLeaks", threshold=None, metric=None, mia_mode="attack",**_):
        """
        name: name of this method
        threshold: float, the threshold to identify member or non-member
        metric: metric function, you can obtain the metric number by metric(model, data)
        """

        super().__init__(name, threshold, metric, mia_mode)
        # self.name = name

        self.args = None
        # self.shadow_train_loader = None
        # self.shadow_test_loader = None
        self.shadow_train_loaders = []  # Store the "train loader" for each "shadow model"
        self.shadow_test_loaders = []
        self.shadow_model_list = []  # Trained shadow models
        self.attack_model = None
        self.topX = 3
        assert self.mia_mode == "attack", "MLLeaks only supports attack mode."

    def fit(self, model, fit_data_loaders, **kwargs):
        '''
        model: model under MIA
        shadow_data_generator: a generator to generate shadow data by shadow_data_generator.next()
        '''
        self.args = update_args_with_defaults(SimpleNamespace(**kwargs))
        print("args", self.args)
        if self.args.adversary == 3:
            pass
        else:
            if self.args.adversary == 2:
                pass
            
            shadow_train_loaders = fit_data_loaders["shadow_member"]
            shadow_test_loaders = fit_data_loaders["shadow_nonmember"]
            assert shadow_train_loaders is not None, "shadow_train_loaders should not be None in attack mode"
            assert shadow_test_loaders is not None, "shadow_test_loaders should not be None in attack mode"
            self.shadow_train_loaders = shadow_train_loaders
            self.shadow_test_loaders = shadow_test_loaders

            self.shadow_model_list = self._get_shadow_models()
            self.attack_model = self._train_attack_model()



    def infer(self, model, data, label):
        '''
        model: model under MIA
        data: batch of data, you can obtain the logit by model(data)
        target: true target of batch of data
        '''
        model = model.to(self.args.device)
        data = data.to(self.args.device)
        label = label.to(self.args.device)

        model.eval()
        if self.args.adversary == 3:
            #TODO: may consider the evaluation mode here
            
            # adversary 3 use the top t percentile of the probability of random data as the threshold

            # Generate a batch of random data with the same shape as 'data'
            # For image datasets, draw each pixel from a uniform distribution [0, 1)
            random_data = torch.tensor(generate_random_data(data), dtype=torch.float32,device=self.args.device, requires_grad=False)

            # Query the model with the random data to get logits
            random_logits = model(random_data).cpu().detach().numpy()

            # Convert logits to probabilities (assuming logits are in shape: [batch_size, num_classes])
            # Here we use the softmax function:
            exp_logits = np.exp(random_logits)
            random_probs = exp_logits / np.sum(exp_logits, axis=1, keepdims=True)

            # For each random data point, extract the top1 (maximum) probability
            top1_probs = np.max(random_probs, axis=1)
           
            # Determine the threshold as the t-th percentile of these top1 probabilities
            threshold = np.percentile(top1_probs, 100 * (1-self.args.threshold))
            #print("threshold",threshold)

            target_logits = model(data).cpu().detach().numpy()
            exp_logits2 = np.exp(target_logits)
            target_probs = exp_logits2 / np.sum(exp_logits2, axis=1, keepdims=True)
            top1_probs2 = np.max(target_probs, axis=1)

            preds = torch.tensor(np.where(top1_probs2 > threshold, 1, 0))
            return preds, None
        else:
            # adversary 1 and 2
            self.attack_model.eval()
            with torch.no_grad():
                outputs = model(data)
                probabilities = F.softmax(outputs, dim=1)
                # print("test probabilities", probabilities.shape)
                top_probabilities, _ = torch.topk(probabilities, self.topX, dim=1, sorted=True)
                # attack_inputs = torch.cat((top_probabilities, label.unsqueeze(1).float()), dim=1)
                outputs = self.attack_model(top_probabilities)
                preds = torch.argmax(outputs, dim=1)
                # preds = (outputs > 0.5).long()
            # res = {"preds": preds, "outputs": outputs}
            return preds, outputs


    def _get_shadow_models(self):
        """
        Get shadow model: If self.args.load_madel exists and a saved model exists, load it;
             Otherwise, save after training.
        """
        shadow_model_list = []
        #  file suffix based on load_ epoch (e.g. use. ckpt for 200 rounds, otherwise use _ {epoch}. ckpt)
        #  TODO: fix this hard code here
        if self.args.load_epoch == self.args.total_epoch:
            suffix = '.ckpt'
        else:
            suffix = f'_{self.args.load_epoch}.ckpt'
        
        for i in range(len(self.args.shadow_model_types)):
            # Save path of shadow model: args.model2load/{i}/model_flename 
            shadow_model_name = "resnet"
            model_dir = os.path.join(self.args.model2load, str(i))
            model_filename =shadow_model_name + suffix
            model_path = os.path.join(model_dir, model_filename)
            
            if self.args.load_model and os.path.exists(model_path):
                print(f"Load the saved shadow model:{model_path}")

                # Initialize the model and load the previously saved state dictionary
                model = parse_model(self.args.dataset, arch=self.args.shadow_model_types[i], normalize=self.args.normalize)
                model.to(self.args.device)
                state_dict = torch.load(model_path, weights_only=True)

                from opacus.validators import ModuleValidator
                if "dpsgd" in model_path:
                    errors = ModuleValidator.validate(model, strict=False)
                    if errors:
                        model = ModuleValidator.fix(model)

                    new_state_dict = {}
                    for k, v in state_dict.items():
                        name = k
                        if name.startswith('_module.'):
                            name = name[8:]
                        new_state_dict[name] = v
                    state_dict = new_state_dict

                model.load_state_dict(state_dict)
            else:
                print(f"Model file not found, start training shadow model, index:{i}")
                # Train Shadow Model
                model = train_model(self.args, self.shadow_train_loaders[i], shadow_model_name)
                # save
                os.makedirs(model_dir, exist_ok=True)
                torch.save(model.state_dict(), model_path)
            shadow_model_list.append(model)
        
        return shadow_model_list

    def _train_attack_model(self):
        """train attack model"""

        attack_train_loader = self._get_attack_model_train_loader()
        attack_model = train_model(self.args, attack_train_loader, None, is_attack=True)
        return attack_model

    def _get_attack_model_train_loader(self):
        """
        Create a training set for the attack model (using multiple shadow models and shadow datasets),
          taking only the top 3 logits and label as inputs for the attack model
        """
        batch_size = self.shadow_train_loaders[0].batch_size

        new_data = []  # data of attack model
        new_target = []  # target of attack model
        for i in range(len(self.shadow_model_list)):
            shadow_model = self.shadow_model_list[i]
            shadow_model.eval()  

            # Helper function to process a dataloader
            def process_dataloader(dataloader, target_label):
                for inputs, label, _ in dataloader:
                    inputs = inputs.to(self.args.device)
                    label = label.to(self.args.device)
                    with torch.no_grad():
                        outputs = shadow_model(inputs)
                        probabilities = F.softmax(outputs, dim=1)
                        # print("train probabilities",probabilities.shape)
                        # Select the top 3 probabilities without the need for true labels
                        top_probabilities, _ = torch.topk(probabilities, self.topX, dim=1, sorted=True)
                        #combined_data = torch.cat((top_probabilities, label.unsqueeze(1).float()), dim=1)  
                    new_data.append(top_probabilities)

                    #one_hot_labels = torch.tensor([[1, 0] if target_label == 0 else [0, 1]] * outputs.size(0))
                    new_target.extend([target_label]*outputs.size(0)) # Add corresponding labels for each sample

            # Infer the data of shadow_train_dataloader to get "attack model data"
            #  with "attack model target" 1
            process_dataloader(self.shadow_train_loaders[i], target_label=1)

            # Infer the data of shadow_test_dataloader to get "attack model data"
            #  with "attack model target" 0
            process_dataloader(self.shadow_test_loaders[i], target_label=0)

        # Concat all data and labels together
        new_data = torch.cat(new_data)  
        new_target = torch.tensor(new_target) 

        # create attack dataset and dataloader
        attack_train_dataset = TensorDataset(new_data, new_target)
        attack_train_loader = DataLoader(attack_train_dataset, batch_size=batch_size, shuffle=True)

        return attack_train_loader


def generate_random_data(data):
    """
    Generate random data based on the attributes of the input data:
        -If the data is of image type (assuming there are at least 3 dimensions and
            the last dimension is 1 or 3), then "the mean of each pixel"
                is generated from a uniform distribution [0,1).
        -If the data is in the form of a two-dimensional table:
            *For binary columns (with only 0 and 1 unique values),
                use coin toss (randomly generating 0 or 1).
            *Otherwise, it is considered continuous data and a uniform random number
                is generated based on the minimum and maximum values of the column.
        -For data of other shapes, it is processed as continuous data by default,
            that is, generated using a uniform distribution of [0,1).
    """
    if data.ndim >= 3 and (data.shape[-1] in [1, 3] or data.shape[-3] in [1, 3]):
        return np.random.uniform(0, 1, size=data.shape)
    
    if data.ndim == 2:
        data_numpy = data.to('cpu').numpy()
        n, m = data.shape
        random_data = np.empty(data.shape, dtype=float)

        
        for col in range(m):
            col_data = data_numpy[:, col]
            unique_vals = np.unique(col_data)
            
            if set(unique_vals).issubset({0, 1}):
                # binary
                random_data[:, col] = np.random.choice([0, 1], size=n)
            else:
                # continuous
                col_min = col_data.min()
                col_max = col_data.max()
                random_data[:, col] = np.random.uniform(col_min, col_max, size=n)
        return random_data

    # continuous data by default
    return np.random.uniform(0, 1, size=data.shape)


# Training function
def train_model(args, train_loader, model_type, is_attack=False):
    # n_hidden=50, batch_size=100, epochs=100, learning_rate=0.01, model_type='cnn', l2_ratio=1e-7
    # n_classes=?
    inputs, _, *_ = next(iter(train_loader))
    input_shape = inputs.shape
    if is_attack:
        print("attack model")
        #model = DNNModel(input_shape[-1], 2)
        #optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
        model = MLPModel(input_shape[-1], 2)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01) # l2_ratio=1e-6,
        epochs = 200
    else:
        # model = get_model(model_type, input_shape, args.n_classes, hidden_size=args.hidden_size)
        model = parse_model(dataset = args.dataset, arch = model_type, normalize = args.normalize)
        optimizer = parse_optim(policy = args.optim, params = model.named_parameters())
        epochs = args.epochs

    criterion = nn.CrossEntropyLoss() 

    lr_func = parse_seq(**args.lr_schedule) if (hasattr(args, "lr_schedule") and not is_attack) else None

    # Training loop
    model.to(args.device)
    model.train()
    for epoch in range(epochs):
        total = 0
        correct = 0
        running_loss = 0.0
        for idx, (inputs, targets, *_) in enumerate(train_loader):
            if lr_func is not None:
                epoch_batch_idx = epoch
                epoch_batch_idx += idx / len(train_loader)
                lr_this_batch = lr_func(epoch_batch_idx)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_this_batch

            inputs, targets = inputs.to(args.device), targets.to(args.device)

            outputs = model(inputs)
            # print("criterion",outputs[0],targets.shape)
            loss = criterion(outputs, targets)
            if is_attack:
                l2_reg = sum(torch.norm(param, 2) ** 2 for param in model.parameters())
                loss += 1e-6 * l2_reg

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            # calculate accuracy
            _, predicted = torch.max(outputs, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()

        train_accuracy = 100 * correct / total
        if epoch % 10 == 0 or epoch == epochs - 1:
            print(f'Epoch {epoch}, Loss: {running_loss:.3f}, Train Accuracy: {train_accuracy:.2f}%')
    return model

class DNNModel(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(DNNModel, self).__init__()
        # Define layers
        self.fc1 = nn.Linear(input_dim, 512)  # First Dense layer
        self.dropout1 = nn.Dropout(0.2)      # First Dropout layer
        self.fc2 = nn.Linear(512, 256)       # Second Dense layer
        self.dropout2 = nn.Dropout(0.2)      # Second Dropout layer
        self.fc3 = nn.Linear(256, 128)       # Third Dense layer
        self.fc4 = nn.Linear(128, num_classes)  # Final output layer
        # self.softmax = nn.Softmax(dim=1)  

    def forward(self, x):
        # Forward pass
        x = F.relu(self.fc1(x))              # Activation for fc1
        x = self.dropout1(x)                 # Apply dropout after activation
        x = F.relu(self.fc2(x))              # Activation for fc2
        x = self.dropout2(x)                 # Apply dropout after activation
        x = F.relu(self.fc3(x))              # Activation for fc3
        x = self.fc4(x)                      # No activation here
        #x = torch.sigmoid(x)                 # Sigmoid activation for final output
        #x = self.softmax(x)
        return x
    
class MLPModel(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(MLPModel, self).__init__()
        # Define layers

        self.fc1 = nn.Linear(input_dim, 64)       # Third Dense layer
        self.tanh = nn.Tanh()
        self.fc2 = nn.Linear(64, num_classes)  # Final output layer
        # self.softmax = nn.Softmax(dim=1)  

    def forward(self, x):
        # Forward pass
        x = self.tanh(self.fc1(x))
        x = self.fc2(x)
        return x

class LinearModel(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(LinearModel, self).__init__()
        # Define layers
        self.fc1 = nn.Linear(input_dim, num_classes)  # First Dense layer
        # self.softmax = nn.Softmax(dim=1)  

    def forward(self, x):
        # Forward pass
        x = self.fc1(x)          
        return x