import torch
import sys
import os
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
import numpy as np
import torch.nn.functional as F
from scipy.stats import norm
from sklearn.metrics import roc_curve, auc, accuracy_score
import math
import matplotlib.pyplot as plt

from torch.distributions import Normal
from sympy.stats.rv import probability
from torch.utils.data import DataLoader, TensorDataset, random_split, Subset
from tqdm import tqdm
from types import SimpleNamespace
import random

from MIA.MIA import MIA
from util.OptimParser import parse_optim
from util.SeqParser import parse_seq
from util.DataParser import parse_data
from util.DataParser import parse_data_mia, parse_shadow_data
from util.DeviceParser import parse_device
# from util.MIAParser import parse_mia
from util.ModelParser import parse_model
from util.ParamParser import *
from util.Eval import AverageCalculator
from util.MetricsCalculation import accuracy, precision_p, precision_n, recall_p, recall_n, f1_p, f1_n, auc, tpr_fpr

from sklearn.metrics import roc_curve, auc, accuracy_score


def update_args_with_defaults(args):
    """Set different default parameters according to different MIA settings"""
    defaults = {
        # "hidden_size": 50,
        "epochs": 100,
        "optim":
                {"name": "sgd",
                "lr": 0.1,
                "momentum": 0.9,
                "weight_decay": 5.0e-4},
        "lr_schedule":{
                "name": "jump",
                "min_jump_pt": 100,
                "jump_freq": 50,
                "start_v": 0.1,
                "power": 0.1},
        "dataset": "cifar10",
        "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
        "shadow_model_type": 'resnet',
        "n_shadow": 5,
        "normalize": False,
        "samples": None,
        "load_model": False
    }
    for key, value in defaults.items():
        if not hasattr(args, key):  # If the parameter does not exist, dynamically add it
            setattr(args, key, value)
    return args


def get_accuracy(model, data_loader, device='cuda'):
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for data in data_loader:
            images, labels = data
            outputs = model(images.to(device))
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels.to(device)).sum().item()
    return correct / total * 100


def train_model_with_loader(model, train_loader, training_epochs=100, lr=0.01, device='cuda'):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    for epochy in range(training_epochs):
        for batch in train_loader:
            data, target = batch[0].to(device), batch[1].to(device)
            output = model(data)
            loss = criterion(output, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    return model


def calculate_accuracy(model, dataloader, device='cpu'):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            outputs = model(x)
            _, predicted = torch.max(outputs.data, 1)
            total += y.size(0)
            correct += (predicted == y).sum().item()
    return 100 * correct / total


# Function to train a model
def train_model_with_raw_tensors(model, train_data, train_labels, epochs=100, lr=0.01, bs=128 * 2, device='cuda'):
    dataset = torch.utils.data.TensorDataset(train_data, train_labels)
    train_loader = torch.utils.data.DataLoader(dataset, batch_size=bs, shuffle=True)
    criterion = nn.CrossEntropyLoss()
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    for epoch in range(epochs):
        for batch in train_loader:
            img, label = batch
            img, label = img.to(device), label.to(device)
            optimizer.zero_grad()
            outputs = model(img)
            loss = criterion(outputs, label)
            loss.backward()
            optimizer.step()
    return model


def train_model_with_loader(model, train_loader, training_epochs=100, lr=0.01, device='cuda'):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    for epochy in range(training_epochs):
        for batch in train_loader:
            data, target = batch[0].to(device), batch[1].to(device)
            output = model(data)
            loss = criterion(output, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    return model


def get_prob(model, data, target, device):
    model.eval()
    lab = target[0].item()
    with torch.no_grad():
        output = model(data.to(device))
        softmax_output = nn.Softmax(dim=1)(output)
        prob = softmax_output[0][lab].detach().cpu().item()
    return prob


def online_shadow_zone(
        shadow_data,
        shadow_labels,
        target_data,
        target_labels,
        num_shadow_models=5,
        shaodw_epochs=30,
        shadow_lr=1e-2,
        random_sample_number=100,
        device='cuda',
        train_shadow_models=True,
        get_shadow_models=None
):
    rand_indices = torch.randperm(shadow_data.size(0))
    random_data = shadow_data[rand_indices[:random_sample_number]]
    random_labels = shadow_labels[rand_indices[:random_sample_number]]

    in_probs = [[] for _ in range(num_shadow_models)]
    out_probs = [[] for _ in range(num_shadow_models)]
    rand_probs_per_target = []

    # Online RMIA must have train_shadow_models=True
    if train_shadow_models:
        for omeaga in range(target_data.size(0)):
            rand_probs = [[] for _ in range(num_shadow_models)]
            for ind in range(num_shadow_models):

                # split the shadow_data and shadow_labs into two parts (random)
                indices = torch.randperm(shadow_data.size(0))
                shadow_data1 = shadow_data[indices[:len(indices) // 2]]
                shadow_labs1 = shadow_labels[indices[:len(indices) // 2]]
                shadow_data2 = shadow_data[indices[len(indices) // 2:]]
                shadow_labs2 = shadow_labels[indices[len(indices) // 2:]]

                # Include the target example in the dataset
                model_in = ShadowModel().to(device)
                train_data_in = torch.cat((shadow_data1, target_data[omeaga:omeaga + 1]), 0)
                train_labels_in = torch.cat((shadow_labs1, target_labels[omeaga:omeaga + 1]), 0)
                model_in = train_model_with_raw_tensors(model_in, train_data_in, train_labels_in, shaodw_epochs, shadow_lr)
                prob_in = get_prob(model_in, target_data[omeaga:omeaga + 1], target_labels[omeaga:omeaga + 1], device)
                in_probs[ind].append(prob_in)

                # Exclude the target example from the dataset
                model_out = ShadowModel().to(device)
                train_data_out = shadow_data2
                train_labels_out = shadow_labs2
                model_out = train_model_with_raw_tensors(model_out, train_data_out, train_labels_out, shaodw_epochs,
                                                     shadow_lr)
                prob_out = get_prob(model_out, target_data[omeaga:omeaga + 1], target_labels[omeaga:omeaga + 1], device)
                out_probs[ind].append(prob_out)

                # get probs on random data
                for i in range(random_sample_number):
                    prob_random = get_prob(model_out, random_data[i:i + 1], random_labels[i:i + 1], device)
                    rand_probs[ind].append(prob_random)

            rand_probs = np.array(rand_probs)  # (shadow_models-out, random_samples)
            # take mean along shadow_models(out)
            rand_probs_over_shadow_out = np.mean(rand_probs, axis=0)  # (random_samples)
            rand_probs_per_target.append(rand_probs_over_shadow_out)

        in_probs = np.array(in_probs)  # (shadow_models-in, target_samples)
        out_probs = np.array(out_probs)  # (shadow_models-out, target_samples)
        rand_probs_per_target = np.array(rand_probs_per_target)  # (target_samples, random_samples)

        # take mean along shadow_models(in)
        in_probs = np.mean(in_probs, axis=0)  # (target_samples)
        out_probs = np.mean(out_probs, axis=0)  # (target_samples)

        prob_target = 1 / 2 * (in_probs + out_probs)  # (target_samples)

    else:
        raise ValueError("For online RMIA, train_shadow_models must be True.")

    return random_data, random_labels, prob_target, rand_probs_per_target


def offline_shadow_zone(
        shadow_data,
        shadow_labels,
        target_data,
        target_labels,
        num_shadow_models=5,
        shadow_epochs=30,
        shadow_lr=1e-2,
        random_sample_number=500,    # check here
        device='cuda',
        a=0.6,                       # varies on different dataset
        train_shadow_models=True,
        shadow_model_list=None,
        rand_probs_per_target=None,
        random_data=None,
        random_labels=None
):
    # rand_indices = torch.randperm(shadow_data.size(0))
    # random_data = shadow_data[rand_indices[:random_sample_number]]
    # random_labels = shadow_labels[rand_indices[:random_sample_number]]

    # in_probs = [[] for _ in range(num_shadow_models)]
    out_probs = [[] for _ in range(num_shadow_models)]
    # rand_probs_per_target = []

    # Train / load shadow models
    if train_shadow_models:
        for omeaga in range(target_data.size(0)):
            # rand_probs = [[] for _ in range(num_shadow_models)]
            for ind in range(num_shadow_models):

                # split the shadow_data and shadow_labs into two parts (random)
                indices = torch.randperm(shadow_data.size(0))
                # shadow_data1 = shadow_data[indices[:len(indices) // 2]]
                # shadow_labs1 = shadow_labels[indices[:len(indices) // 2]]
                shadow_data2 = shadow_data[indices[len(indices) // 2:]]
                shadow_labs2 = shadow_labels[indices[len(indices) // 2:]]

                # Exclude the target example from the dataset
                model_out = ShadowModel().to(device)
                train_data_out = shadow_data2
                train_labels_out = shadow_labs2
                model_out = train_model_with_raw_tensors(model_out, train_data_out, train_labels_out, shadow_epochs,
                                                         shadow_lr)
                prob_out = get_prob(model_out, target_data[omeaga:omeaga + 1], target_labels[omeaga:omeaga + 1], device)
                out_probs[ind].append(prob_out)

                # get probs on random data
                for i in range(random_sample_number):
                    prob_random = get_prob(model_out, random_data[i:i + 1], random_labels[i:i + 1], device)
                    rand_probs[ind].append(prob_random)

            rand_probs = np.array(rand_probs)  # (shadow_models-out, random_samples)

            # take mean along shadow_models(out)
            rand_probs_over_shadow_out = np.mean(rand_probs, axis=0)  # (random_samples)
            rand_probs_per_target.append(rand_probs_over_shadow_out)

        # in_probs = np.array(in_probs)  # (shadow_models-in, target_samples)
        out_probs = np.array(out_probs)  # (shadow_models-out, target_samples)
        rand_probs_per_target = np.array(rand_probs_per_target)  # (target_samples, random_samples)

        # take mean along shadow_models(in)
        # in_probs = np.mean(in_probs, axis=0)  # (target_samples)
        out_probs = np.mean(out_probs, axis=0)  # (target_samples)

        # a=0.3 for CIFAR-10 and CINIC-10. a=0.6 for CIFAR-100. a=1 for ImageNet. a=0.2 for Purchase-100
        prob_target = 1 / 2 * ((1+a) * out_probs + (1-a))  # (target_samples)

    else:
        # if get_shadow_models is None:
        #     raise ValueError("When train_shadow_models=False, get_shadow_models must be provided.")
        shadow_models = shadow_model_list
        if len(shadow_models) != int(num_shadow_models):
            raise ValueError(f"Expected {num_shadow_models} shadow models, got {len(shadow_models)}.")

        for omeaga in range(target_data.size(0)):
            rand_probs = [[] for _ in range(num_shadow_models)]
            for ind, model_out in enumerate(shadow_models):

                if (omeaga + 1) % 200 == 0:
                    print(f"Target data {omeaga + 1}, [Shadow-Zone] Using Shadow Model {ind + 1}/{num_shadow_models}")

                model_out.eval()

                prob_out = get_prob(model_out, target_data[omeaga:omeaga + 1], target_labels[omeaga:omeaga + 1], device)
                out_probs[ind].append(prob_out)

                # get probs on random data
                # for i in range(random_sample_number):
                #     prob_random = get_prob(model_out, random_data[i:i + 1], random_labels[i:i + 1], device)
                #     rand_probs[ind].append(prob_random)

            # rand_probs = np.array(rand_probs)  # (shadow_models-out, random_samples)

            # take mean along shadow_models(out)
            # rand_probs_over_shadow_out = np.mean(rand_probs, axis=0)  # (random_samples)
            # rand_probs_per_target.append(rand_probs_over_shadow_out)

        out_probs = np.array(out_probs)  # (shadow_models-out, target_samples)
        # rand_probs_per_target = np.array(rand_probs_per_target)  # (target_samples, random_samples)

        # take mean along shadow_models(out)
        out_probs = np.mean(out_probs, axis=0)  # (target_samples)

        # a=0.3 for CIFAR-10 and CINIC-10. a=0.6 for CIFAR-100. a=1 for ImageNet. a=0.2 for Purchase-100.
        prob_target = 1 / 2 * ((1 + a) * out_probs + (1 - a))  # (target_samples)


    return random_data, random_labels, prob_target, rand_probs_per_target


def attack_zone(target_model,
                target_data,
                target_labels,
                random_data,
                random_labels,
                prob_target,
                rand_probs_per_target,               # (random_samples)
                gamma,                             # check here
                lr_rand_per_target,
                device='cuda'):
    scores = []

    for i in range(target_data.size(0)):
        target_prob = prob_target[i]
        target_prob_given_target_model = get_prob(target_model, target_data[i:i + 1], target_labels[i:i + 1], device)
        lr_target = target_prob_given_target_model / (target_prob + 1e-15)
        C = 0
        for j in range(random_data.size(0)):
            # rand_prob = rand_probs_per_target[j]
            # rand_prob_given_target_model = get_prob(target_model, random_data[j:j + 1], random_labels[j:j + 1], device)
            # lr_rand = rand_prob_given_target_model / (rand_prob + 1e-15)
            lr_rand = lr_rand_per_target[j]
            C += 1 if lr_target / lr_rand > gamma else 0

        scores.append(C / random_data.size(0))
        # print('scores dim', len(scores))

    # Create ground truth array
    # TODO: fix this hard code
    batch_size = target_data.size(0)
    gt_result_per_batch = np.concatenate([np.ones(batch_size // 2, dtype=int), np.zeros(batch_size // 2, dtype=int)])


    return scores, gt_result_per_batch    # list, array


def find_best_threshold(gt_result, Scores):
    # Calculate AUC
    fpr, tpr, thresholds = roc_curve(gt_result, Scores)
    auc_value = auc(fpr, tpr)

    best_acc = 0.0
    best_beta = 0.0
    best_result = None
    best_conf = (0, 0, 0, 0)  # tp, fn, tn, fp

    # go through thresholds
    for beta in thresholds:
        binary_result = (Scores >= beta).astype(int)

        tp = np.sum((binary_result == 1) & (gt_result == 1))
        fn = np.sum((binary_result == 0) & (gt_result == 1))
        tn = np.sum((binary_result == 0) & (gt_result == 0))
        fp = np.sum((binary_result == 1) & (gt_result == 0))

        # acc = (tp + tn) / (tp + fn + tn + fp + 1e-12)
        # acc = (tp + tn) / len(gt_result)
        acc = accuracy_score(gt_result, binary_result)

        if acc > best_acc:
            best_acc = acc
            best_beta = beta
            best_result = binary_result
            best_conf = (tp, fn, tn, fp)

    # result at the best threshold
    tp, fn, tn, fp = best_conf

    # output for Scores
    # member_pred = Scores[:len(gt_result) // 2]
    # nonmember_pred = Scores[len(gt_result) // 2:]

    # output for binary result
    member_pred = best_result[gt_result == 1]  # positive sample
    nonmember_pred = best_result[gt_result == 0]  # negative sample

    # Change from numpy array to tensor, for further use in accuracy()
    member_pred = torch.tensor(member_pred, dtype=torch.float)
    nonmember_pred = torch.tensor(nonmember_pred, dtype=torch.float)

    def get_tpr_at_fpr(target_fpr):
        idx = np.where(fpr <= target_fpr)[0]
        return tpr[idx[-1]] if len(idx) > 0 else 0.0

    tpr01fpr = get_tpr_at_fpr(0.001)
    tpr001fpr = get_tpr_at_fpr(0.0001)

    return {
        "auc": auc_value,
        "best_accuracy": best_acc,
        "predict": best_result,
        "member_pred": member_pred,
        "nonmember_pred": nonmember_pred,
        "tpr01fpr": tpr01fpr,
        "tpr001fpr": tpr001fpr,
        "tp": tp,
        "fn": fn,
        "tn": tn,
        "fp": fp,
        "best_beta": best_beta  # return best threshold
    }


class RMIA:
    def __init__(self,
                 num_shadow_models=5,
                 shadow_epochs=30,
                 shadow_lr=1e-2,
                 random_sample_number=1000,
                 gamma=2.0,
                 beta=0.5,
                 device='cuda',
                 mia_mode="eval",**_):

        self.num_shadow_models = num_shadow_models
        self.shadow_epochs = shadow_epochs
        self.shadow_lr = shadow_lr
        self.random_sample_number = random_sample_number
        self.gamma = gamma
        # self.beta = beta
        self.device = device

        self.mia_mode = mia_mode
        self.args = None
        self.shadow_test_data_loader = None
        self.shadow_train_data_loader = None
        self.target_test_data_loader = None
        self.target_train_data_loader = None
        self.shadow_member_data = None
        self.shadow_member_labels = None
        self.shadow_nonmember_data = None
        self.shadow_nonmember_labels = None
        self.target_model = None
        self.Scores = None
        self.best_threshold_for_attack = None
        self.a = None
        self.shadow_model_list = None

        # Store random data output (1+3)
        self.random_data = None
        self.random_labels = None
        self.rand_probs_per_target = None
        self.lr_rand_per_target = None

        # random data is the training data for the out model, so it must be 2. No matter attack or evaluate
        # # Store random data output (2+4)
        # self.random_data_24 = None
        # self.random_labels_24 = None
        # self.rand_probs_per_target_24 = None
        # self.lr_rand_per_target_24 = None

        self.batch_count = None

        self.gt_result = None
        self.gt_result_24 = None




    def fit(self, target_model, fit_data_loaders, **kwargs):

        self.shadow_test_data_loader = fit_data_loaders["shadow_nonmember"]    # 4
        self.shadow_train_data_loader = fit_data_loaders["shadow_member"]    # 2

        self.target_test_data_loader = fit_data_loaders["nonmember_train"]    # 3
        self.target_train_data_loader = fit_data_loaders["member_train"]    # 1

        # Update args
        self.args = update_args_with_defaults(SimpleNamespace(**kwargs))
        print("args", self.args)

        self.Scores = []  # Empty list store RMIA output_result

        self.batch_count = 0

        self.gt_result = np.array([], dtype=int)
        self.gt_result_24 = np.array([], dtype=int)

        device = self.device

        # load target model and shadow models
        self.target_model =target_model
        self.shadow_model_list = self._get_shadow_models()

        # a=0.3 for CIFAR-10 and CINIC-10. a=0.6 for CIFAR-100. a=1 for ImageNet. a=0.2 for Purchase-100.
        if self.args.dataset == "cifar10" or "cinic10":
            self.a = 0.3
        elif self.args.dataset == "cifar100":
            self.a = 0.6
        elif self.args.dataset == "ImageNet":
            self.a = 1
        elif self.args.dataset == "purchase100":
            self.a = 0.2
        else:
            print("[RMIA] Dataset not supported!")

        def get_shadow_data(loaders):
            if isinstance(loaders, list):
                print("[RMIA] Detected list of shadow_member_loaders, will merge as shadow_data.")
                shadow_data, shadow_labels = [], []
                for loader in loaders:
                    for batch in loader:
                        x, y, index = batch
                        shadow_data.append(x)
                        shadow_labels.append(y)
                shadow_data = torch.cat(shadow_data)
                shadow_labels = torch.cat(shadow_labels)
            else:
                shadow_data, shadow_labels = [], []
                for batch in loaders:
                    x, y, index = batch
                    shadow_data.append(x)
                    shadow_labels.append(y)
                shadow_data = torch.cat(shadow_data)
                shadow_labels = torch.cat(shadow_labels)


            print(f"[RMIA] Shadow images: {shadow_data.shape}")

            model2load = kwargs.get('model2load', None)
            use_load = model2load is not None and str(model2load).strip() != ''
            print(f"[RMIA] Using model2load: {model2load}, use_load: {use_load}")

            return shadow_data, shadow_labels

        # eval random data selection: shadow 4[0:4]
        self.shadow_member_data, self.shadow_member_labels = get_shadow_data(self.shadow_train_data_loader)    # 2[0:4]
        self.shadow_nonmember_data, self.shadow_nonmember_labels = get_shadow_data(self.shadow_test_data_loader)    # 4[0:4]

        # Random data selection (4)
        rand_indices = torch.randperm(self.shadow_nonmember_data.size(0))  # 4
        random_data = self.shadow_nonmember_data[rand_indices[:self.random_sample_number]]  # 4
        random_labels = self.shadow_nonmember_labels[rand_indices[:self.random_sample_number]]  # 4
        self.random_data = random_data
        self.random_labels = random_labels

        # get rand_probs_per_target

        # in_probs = [[] for _ in range(num_shadow_models)]
        # out_probs = [[] for _ in range(self.num_shadow_models)]
        # rand_probs_per_target = []
        shadow_models = self.shadow_model_list
        rand_probs = [[] for _ in range(self.num_shadow_models)]

        for ind, model_out in enumerate(shadow_models):

            model_out.eval()

            # get probs on random data
            for i in range(self.random_sample_number):
                prob_random = get_prob(model_out, random_data[i:i + 1], random_labels[i:i + 1], device)
                rand_probs[ind].append(prob_random)

        rand_probs = np.array(rand_probs)  # (shadow_models-out, random_samples)

        # take mean along shadow_models(out)
        rand_probs_over_shadow_out = np.mean(rand_probs, axis=0)  # (random_samples)
        rand_probs_per_target = np.array(rand_probs_over_shadow_out)  # (random_samples)

        self.rand_probs_per_target = rand_probs_per_target

        # get lr_rand
        lr_rand_per_target = []
        for j in range(self.random_data.size(0)):
            rand_prob = self.rand_probs_per_target[j]
            rand_prob_given_target_model = get_prob(self.target_model, self.random_data[j:j + 1], self.random_labels[j:j + 1], self.device)
            lr_rand = rand_prob_given_target_model / (rand_prob + 1e-15)
            lr_rand_per_target.append(lr_rand)
        self.lr_rand_per_target = np.array(lr_rand_per_target)


        # For attack mode, use 2 and 4 to calculate the best threshold in advance
        if self.mia_mode == "attack":
            Scores_24 = []
            # gt_result = np.array([1] * len(self.shadow_train_data_loader) + [0] * len(self.shadow_test_data_loader))  # 2 + 4

            print(f"[RMIA] shadow_train_data_loader / shadow_test_data_loader length: {len(self.shadow_train_data_loader)}, {len(self.shadow_test_data_loader)}")
            # shadow_loader: has 5 loaders, just use the first one.

            # threshold 2[0] + 4[0]; shadow 2[1:4]. But 2[0] and 2[1:4] still have much overlap
            # shadow_train_data_loader = self.shadow_train_data_loader[0]    # 2[0]
            # shadow_test_data_loader = self.shadow_test_data_loader[0]    # 4[0]

            # threshold 1 + 4[0]; shadow 2[0:4]
            shadow_train_data_loader = self.target_train_data_loader    # 1
            shadow_test_data_loader = self.shadow_test_data_loader[0]    # 4[0]

            for idx, (member_data, nonmember_data) in enumerate(zip(shadow_train_data_loader, shadow_test_data_loader)):

                if len(member_data) == 3:
                    member_data_batch, member_label_batch, member_idx_batch = member_data
                    nonmember_data_batch, nonmember_label_batch, nonmember_idx_batch = nonmember_data
                elif len(member_data) == 4:
                    member_data_batch, member_label_batch, true_member_label_batch, member_idx_batch = member_data
                    nonmember_data_batch, nonmember_label_batch, true_nonmember_label_batch, nonmember_idx_batch = nonmember_data

                else:
                    raise ValueError('The input data is not valid!')

                data_batch = torch.cat((member_data_batch, nonmember_data_batch))
                label_batch = torch.cat((member_label_batch, nonmember_label_batch))
                idx_batch = torch.cat((member_idx_batch, nonmember_idx_batch))
                data_batch = data_batch.cuda(device)
                label_batch = label_batch.cuda(device)
                idx_batch = idx_batch.cuda(device)

                use_load = True
                random_data, random_labels, prob_target, rand_probs_per_target = offline_shadow_zone(shadow_data=self.shadow_nonmember_data,    # 2
                                                                                                   shadow_labels=self.shadow_nonmember_labels,
                                                                                                   target_data=data_batch,                      # 2[0] + 4[0]
                                                                                                   target_labels=label_batch,
                                                                                                   num_shadow_models=self.num_shadow_models,
                                                                                                   shadow_epochs=self.shadow_epochs,
                                                                                                   shadow_lr=self.shadow_lr,
                                                                                                   random_sample_number=self.random_sample_number,
                                                                                                   device=self.device,
                                                                                                   a = self.a,
                                                                                                   train_shadow_models=not use_load,
                                                                                                   shadow_model_list=self.shadow_model_list if use_load else None,
                                                                                                   rand_probs_per_target=self.rand_probs_per_target,
                                                                                                   random_data=self.random_data,
                                                                                                   random_labels=self.random_labels
                                                                                                   )

                scores_24, gt_result_per_batch_24 = attack_zone(target_model=self.target_model,
                                        target_data=data_batch,
                                        target_labels=label_batch,
                                        random_data=self.random_data,  # random data alignment
                                        random_labels=self.random_labels,
                                        prob_target=prob_target,
                                        rand_probs_per_target=rand_probs_per_target,
                                        gamma=self.gamma,
                                        lr_rand_per_target=self.lr_rand_per_target,
                                        device=self.device)
                Scores_24.extend(scores_24)
                self.gt_result_24 = np.concatenate([self.gt_result_24, gt_result_per_batch_24])

            aggregation_result = find_best_threshold(self.gt_result_24, np.array(Scores_24))
            self.best_threshold_for_attack = aggregation_result["best_beta"]

        else:
            self.best_threshold_for_attack = None



    def infer(self, target_model, target_data, target_labels):

        batch_count = self.batch_count

        Scores = self.Scores
        # === Shadow Zone ===
        use_load = True
        random_data, random_labels, prob_target, rand_probs_per_target = offline_shadow_zone(shadow_data=self.shadow_member_data,    # 2
                                                                                           shadow_labels=self.shadow_member_labels,
                                                                                           target_data=target_data,                  # 1 + 3
                                                                                           target_labels=target_labels,
                                                                                           num_shadow_models=self.num_shadow_models,
                                                                                           shadow_epochs=self.shadow_epochs,
                                                                                           shadow_lr=self.shadow_lr,
                                                                                           random_sample_number=self.random_sample_number,
                                                                                           device=self.device,
                                                                                           a = self.a,
                                                                                           train_shadow_models=not use_load,
                                                                                           shadow_model_list=self.shadow_model_list if use_load else None,
                                                                                           rand_probs_per_target=self.rand_probs_per_target,
                                                                                           random_data=self.random_data,
                                                                                           random_labels=self.random_labels
                                                                                           )

        print(f"[RMIA] Random data selection at batch count [{batch_count}] done! shadow prob_target shape: {prob_target.shape}")  # (target_samples)


        # === Attack Zone ===
        scores, gt_result_per_batch = attack_zone(target_model,
                             target_data,
                             target_labels,
                             random_data=self.random_data,  # random data alignment
                             random_labels=self.random_labels,
                             prob_target=prob_target,
                             rand_probs_per_target=rand_probs_per_target,
                             gamma=self.gamma,
                             lr_rand_per_target=self.lr_rand_per_target,
                             device=self.device)

        # store gt_result per batch
        self.gt_result = np.concatenate([self.gt_result, gt_result_per_batch])

        # for attack mode
        if self.mia_mode == "attack":
            binary_result = (scores > self.best_threshold_for_attack)    # output batch binary scores, will be used

        # for evaluate mode
        else:
            binary_result = scores    # output batch scores, intermediate output without further use

        # store scores per batch
        Scores.extend(scores)    # list without lists in it. append get list of list
        self.Scores = Scores   # store total scores

        self.batch_count = batch_count + 1

        # Change from numpy array to tensor, for further use
        binary_result = torch.tensor(binary_result, dtype=torch.float)

        return binary_result,


    def output(self):

        # len(loader) return num_batches. len(loader.dataset) return num_samples, but may include useless data besides samples like labels & idx
        gt_result = self.gt_result    # 1 + 3
        Scores = self.Scores

        # change to array
        Scores = np.array(Scores)

        print(f"[RMIA] Scores shape: {Scores.shape}, gt_result_shape: {gt_result.shape}")

        aggregation_result = find_best_threshold(gt_result, Scores)

        return aggregation_result



    def _get_shadow_models(self):
        """
        Get shadow model: If self.args.load_madel exists and a saved model exists, load it;
             Otherwise, save after training.
        """
        shadow_model_list = []
        #  file suffix based on load_ epoch (e.g. use. ckpt for 200 rounds, otherwise use _ {epoch}. ckpt)
        #  TODO: fix this hard code here
        if self.args.load_epoch == 200:
            suffix = '.ckpt'
        else:
            suffix = f'_{self.args.load_epoch}.ckpt'

        for i in range(self.args.n_shadow):
            # Save path of shadow model: args.model2load/{i}/model_flename
            model_dir = os.path.join(self.args.model2load, str(i))
            # model_filename = self.args.shadow_model_type + suffix
            model_filename = "resnet" + suffix
            model_path = os.path.join(model_dir, model_filename)

            # print(f"[RMIA] load_model: {self.args.load_model}, model_path: {model_path}")

            if self.args.load_model and os.path.exists(model_path):
                print(f"Load the saved shadow model:{model_path}")
                # Initialize the model and load the previously saved state dictionary
                model = parse_model(self.args.dataset, arch=self.args.shadow_model_type, normalize=self.args.normalize)
                model.to(self.args.device)
                state_dict = torch.load(model_path, weights_only=True)

                from opacus.validators import ModuleValidator
                if "dpsgd" in model_path:
                    errors = ModuleValidator.validate(model, strict=False)
                    if errors:
                        model = ModuleValidator.fix(model)

                    new_state_dict = {}
                    for k, v in state_dict.items():
                        name = k
                        if name.startswith('_module.'):
                            name = name[8:]
                        new_state_dict[name] = v
                    state_dict = new_state_dict

                model.load_state_dict(state_dict)
            else:
                print(f"Model file not found, start training shadow model, index:{i}")
                # Train Shadow Model
                model = train_model_with_loader(self.args, self.shadow_train_loaders[i], self.args.shadow_model_type)
                # save
                os.makedirs(model_dir, exist_ok=True)
                torch.save(model.state_dict(), model_path)
            shadow_model_list.append(model)

        return shadow_model_list


# model architecture, no need when use_load is True
class CNN(nn.Module):
    def __init__(self, channel, num_classes):
        super(CNN, self).__init__()

        # Convolutional layers
        self.conv1 = nn.Conv2d(channel, 32, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.bn4 = nn.BatchNorm2d(256)
        # Pooling layer
        self.pool = nn.MaxPool2d(2, 2)

        # Dropout layer
        self.dropout = nn.Dropout(0.5)

        # Calculating the size of the input for the first fully connected layer
        fc1_input_size = 1024

        # Fully connected layers
        self.fc1 = nn.Linear(fc1_input_size, 1024)
        self.fc2 = nn.Linear(1024, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = self.pool(F.relu(self.bn3(self.conv3(x))))
        x = self.pool(F.relu(self.bn4(self.conv4(x))))

        # Flattening the output
        x = x.view(x.size(0), -1)

        x = self.dropout(F.relu(self.fc1(x)))
        x = self.fc2(x)
        return x

    def feature(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = self.pool(F.relu(self.bn3(self.conv3(x))))
        x = self.pool(F.relu(self.bn4(self.conv4(x))))

        # Flattening the output
        x = x.view(x.size(0), -1)

        return x



class ShadowModel(nn.Module):
    def __init__(self, channel=3, num_classes=10):
        super(ShadowModel, self).__init__()

        # Convolutional layers
        self.conv1 = nn.Conv2d(channel, 32, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.bn4 = nn.BatchNorm2d(256)
        # Pooling layer
        self.pool = nn.MaxPool2d(2, 2)

        # Dropout layer
        self.dropout = nn.Dropout(0.5)

        # Calculating the size of the input for the first fully connected layer
        fc1_input_size = 1024

        # Fully connected layers
        self.fc1 = nn.Linear(fc1_input_size, 1024)
        self.fc2 = nn.Linear(1024, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = self.pool(F.relu(self.bn3(self.conv3(x))))
        x = self.pool(F.relu(self.bn4(self.conv4(x))))

        # Flattening the output
        x = x.view(x.size(0), -1)

        x = self.dropout(F.relu(self.fc1(x)))
        x = self.fc2(x)
        return x

    def feature(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = self.pool(F.relu(self.bn3(self.conv3(x))))
        x = self.pool(F.relu(self.bn4(self.conv4(x))))

        # Flattening the output
        x = x.view(x.size(0), -1)

        return x
