"""
Paper: Practical Blind Membership Inference Attack via Differential Comparisons
Link: https://arxiv.org/pdf/2101.01341
Code: https://github.com/hyhmia/BlindMI/tree/master
"""
import logging
from types import SimpleNamespace

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.cluster import KMeans
from tqdm import tqdm

from sklearn.svm import OneClassSVM
from torch.utils.data import DataLoader, TensorDataset, random_split
import numpy as np
import sys
import os

# Automatically add root directory to sys.path
if 'MIABench' in os.getcwd():  
    sys.path.insert(0, os.getcwd())
else:  
    sys.path.insert(0, os.path.join(os.getcwd(), 'MIABench'))

from MIA.MIA import MIA


def update_args_with_defaults(args):
    """Set different default parameters according to different MIA settings"""
    defaults = {
        "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
        # "attack_model_type": 'dnn',
        "attack_method": "one_class", # 'one_class'  'diff_w'  'diff_single'  'diff_bi'
        # the method to generate the non-member data.
        'non_mem_generator': 'sobel' 
        # Image data: 'sobel' 'gaussian_noise' 'sp_noise' 'scharr' 'laplace' 
        # Table data: 'gaussian_noise_table'
        # Binary data（purchase100）: 'flip_binary_data'
    }
    for key, value in defaults.items():
        if not hasattr(args, key):  
            setattr(args, key, value)
    return args


class BlindMI(MIA):
    def __init__(self, name="BlindMI", threshold=None, metric=None, mia_mode="attack",**_):
        """
        name: name of this method
        threshold: float, the threshold to identify member or non-member
        metric: metric function, you can obtain the metric number by metric(model, data)
        """

        super().__init__(name, threshold, metric, mia_mode)
        self.args = None
        assert self.mia_mode == "attack", "BlindMI only supports attack mode."

    def fit(self, model=None, fit_data_loaders=None, **kwargs):
        '''
        model: model under MIA
        member_train_loader: a generator to generate target model training data by member_train_loader.next()
        nonmember_train_loader: a generator to generate target model testing data by nonmember_train_loader.next()
        '''
        self.args = update_args_with_defaults(SimpleNamespace(**kwargs))
        print("args", self.args)

    def infer(self, model, data, label=None):
        '''
        model: model under MIA
        data: batch of data, you can obtain the logit by model(data)
        label: true label of batch of data, only needed for diff_w
        '''
        model = model.to(self.args.device)
        model.eval()
        data = data.to('cpu')
        if self.args.attack_method == 'one_class':
            preds =  self._one_class(model,data)
        elif self.args.attack_method == 'diff_w':
            preds = self._diff_w(model, data, label)
        elif self.args.attack_method == 'diff_single':
            preds = self._diff_single(model, data)
        elif self.args.attack_method == 'diff_bi':
            preds = self._diff_bi(model, data)
        else:
            raise ValueError("argument attack_method should in ['one_class','diff_w','diff_single','diff_bi']")

        return torch.tensor(preds), None

    def _non_member_generator(self, data):
        if self.args.non_mem_generator == 'sobel':
            return sobel(data)
        elif self.args.non_mem_generator == 'sp_noise':
            return sp_noise(data)
        elif self.args.non_mem_generator == 'gaussian_noise':
            return gaussian_noise(data)
        elif self.args.non_mem_generator == 'gaussian_noise_table':
            return gaussian_noise_table(data)
        elif self.args.non_mem_generator == 'flip_binary_data':
            return flip_binary_data(data)
        elif self.args.non_mem_generator == 'scharr':
            return scharr(data)
        elif self.args.non_mem_generator == 'laplace':
            return laplace(data)
        else:
            raise ValueError("argument non_mem_generator should in ['sobel', 'gaussian_noise', 'sp_noise', 'scharr', 'laplace']")

    def _one_class(self, model, data):
        """One-class SVM version with generated non-member as training set
                     and predict whether the data has been trained or not."""
        # Generate non-member data by adding noise
        data_sobel = self._non_member_generator(data)
        data_sobel = torch.tensor(data_sobel).to(self.args.device).float()
        data = data.to(self.args.device)

        with torch.no_grad():
            output = F.softmax(model(data),dim=1).cpu().numpy()
            # Sort predictions and extract top-3 values
            mix = np.sort(output, axis=1)[:, ::-1][:, :3]

            # Predicting Non member Data with Target Model
            nonMem_pred = F.softmax(model(data_sobel),dim=1).cpu().numpy()
            nonMem = np.sort(nonMem_pred, axis=1)[:, ::-1][:, :3]
        # Step 3: Fit One-Class SVM with non-member data
        cls = OneClassSVM(nu=0.9, kernel='sigmoid', gamma='scale')
        cls.fit(nonMem)

        # Step 4: Predict membership probabilities for target data
        outputs = cls.predict(mix)
        # preds shape (n_samples,), outputs shape (n_samples,) 
        preds = np.where(outputs == 1, 1, 0)  # Convert to binary predictions (1 for member, 0 for non-member)
        return preds

    def _diff_w(self, model, data, target):
        """
        Attck the target with BLINDMI-DIFF-W, BLINDMI-DIFF with gernerated non-member.
        The non-member is generated by randomly chosen data and the number is 20 by default.
        (Using Sobel noise to generate non members)
        """
        
        # Randomly select 20 data points, add noise, to generate non-member data;
        #  Use "differential comparison" to separate the non-members in the data
        nonMem_index = np.random.randint(0, data.shape[0], size=20)
        data_sobel = self._non_member_generator(data[nonMem_index])
        data_sobel = torch.tensor(data_sobel, device=self.args.device).float()

        data = data.to(self.args.device)
        target = target.to(self.args.device)

        with torch.no_grad():
            # Step 1: Target predictions and mix creation
            y_pred = F.softmax(model(data), dim=1)
            # concat "target label probability" and "Top-2 probability"
            target_probs = y_pred.gather(1, target.unsqueeze(1))  # "target label probability"
            top_2_probs = torch.sort(y_pred, dim=1, descending=True)[0][:, :2] # "Top-2 probability"
            mix = torch.cat([target_probs, top_2_probs], dim=1)

            # Step 2: non-member data prediction
            nonMem_pred = F.softmax(model(data_sobel),dim=1)  
            # concat "target label probability" and "Top-2 probability"
            nonMem_target_probs = nonMem_pred.gather(1, target[nonMem_index].unsqueeze(1))  
            nonMem_top_2_probs = torch.sort(nonMem_pred, dim=1, descending=True)[0][:, :2]  
            nonMem = torch.cat([nonMem_target_probs, nonMem_top_2_probs], dim=1)  

        # Step 3: differential comparison
        # Shuffle and process mix batches
        # m_true = torch.zeros([mix.shape[0], 2])
        # m_true[:,1] = 1
        # dataset = torch.utils.data.TensorDataset(mix, torch.tensor(m_true, device=args.device))
        mix = mix.to(device='cpu')
        nonMem = nonMem.to(device='cpu')

        indices = torch.arange(mix.size(0))  # create indices
        # Include the original index in the dataset to ensure that
        #    the final output and input order are consistent
        dataset = torch.utils.data.TensorDataset(mix, indices) 
        data_loader = torch.utils.data.DataLoader(dataset, batch_size=20, shuffle=True)

        # Initialize predictions and mix
        m_pred = torch.zeros(mix.size(0))  # store the final prediction
        original_mix = torch.zeros_like(mix)  # store the original order of "mix"
        # m_true_shuffled = []
        for mix_batch, index_batch in data_loader:
            # Initialize predictions
            # After being moved to non member,  m_pred_batch and m_pred_epoch becomes 0
            m_pred_batch = torch.ones(mix_batch.size(0))
            m_pred_epoch = torch.ones(mix_batch.size(0))
            nonMemInMix = True

            while nonMemInMix:
                mix_epoch_new = mix_batch[m_pred_epoch.bool()]
                # Calculate the distance between "non-member dataset" and "current batch data"
                dis_ori = mmd_loss(nonMem, mix_epoch_new, weight=1)  
                nonMemInMix = False

                for index, item in enumerate(mix_batch):
                    if m_pred_batch[index] == 1:
                        # new member(mix)/nonmember dataset
                        nonMem_batch_new = torch.cat([nonMem, item.unsqueeze(0)], dim=0)
                        mix_batch_new = torch.cat([mix_batch[:index], mix_batch[index + 1:]], dim=0)

                        m_pred_without = torch.cat([m_pred_batch[:index], m_pred_batch[index + 1:]])
                        mix_batch_new = mix_batch_new[m_pred_without.bool()]
                        # After moving a specific data point:
                        #   Calculate the distance between "non-member dataset" and "current batch data"
                        dis_new = mmd_loss(nonMem_batch_new, mix_batch_new, weight=1)  
                        # If the distance increases, move this data to "non-member dataset"
                        if dis_new > dis_ori:  
                            nonMemInMix = True
                            m_pred_epoch[index] = 0
                m_pred_batch = m_pred_epoch.clone()

            # Write the current batch prediction to the corresponding original index position
            m_pred[index_batch] = m_pred_batch
            original_mix[index_batch] = mix_batch
            # m_true_shuffled.append(m_true_batch)

        # return torch.cat(m_pred, dim=0), torch.cat(mix_shuffled, dim=0), nonMem
        return m_pred

    def _diff_single(self, model, data):
        """
        Attck the target with BLINDMI-DIFF-W/O, BLINDMI-DIFF without gernerated non-member.
        Roughly choose the non-member by threshold method.
        
        Based on the maximum probability value;
          Separate the non-members in the data,
            using "differential comparison" to move them unidirectionally
        """
        data = data.to(self.args.device)
        with torch.no_grad():
            y_pred = F.softmax(model(data), dim=1).to(device='cpu').numpy()
        mix = np.sort(y_pred, axis=1)[:, ::-1][:, :3].copy()  # top3 probs
        # source code: 1000/(data.shape[0]) --> now: min(1/10, 1000/(data.shape[0]))
        # Using a percentage threshold to divide data into members or non members,
        non_Mem = torch.tensor(mix[threshold_Divide(mix, min(1/10, 1000/(data.shape[0]))) == 0], dtype=torch.float32) 

        mix = torch.tensor(mix, dtype=torch.float32)
        indices = torch.arange(mix.size(0))
        dataset = DataLoader(TensorDataset(mix, indices),
                             shuffle=True, batch_size=20
                             )  # source code: batch_size=len(data)/10 or 1000  now: 20

        m_pred = torch.zeros(mix.size(0))  # store the final prediction
        for mix_batch, index_batch in dataset:
            m_pred_batch = torch.ones(mix_batch.size(0))
            flag = True

            while flag:
                m_in_loop = m_pred_batch.clone()
                # Calculate the distance between "non-member dataset" and "current batch data"
                dis_ori = mmd_loss(non_Mem, mix_batch[m_in_loop.bool()], weight=1) 
                flag = False

                for index, item in enumerate(mix_batch):
                    if m_in_loop[index] == 1:
                        m_in_loop[index] = 0
                        mix_1 = mix_batch[m_in_loop.bool()]
                        mix_2 = torch.cat([non_Mem, item.unsqueeze(0)], dim=0)
                        # After moving a specific data point:
                        #   Calculate the distance between "non-member dataset" and "current batch data"
                        dis_new = mmd_loss(mix_2, mix_1, weight=1) 
                        m_in_loop[index] = 1
                        # If the distance increases, move this data to "non-member dataset"
                        if dis_new > dis_ori:
                            flag = True
                            m_pred_batch[index] = 0
            # Write the current batch prediction to the corresponding original index position
            m_pred[index_batch] = m_pred_batch
        return m_pred

    def _diff_bi(self, model, data):
        '''
        Attck the target with BLINDMI-DIFF-W/O, BLINDMI-DIFF without gernerated non-member.
        Roughly divide the data into member and non-member by threshold method.
        
        Based on the maximum probability value;
          Separate the members and non-members in the data,
            using "differential comparison" to move them bidirectionally
        '''
        data = data.to(self.args.device)
        with torch.no_grad():
            y_pred = F.softmax(model(data), dim=1).to(device='cpu').numpy()
        mix = np.sort(y_pred, axis=1)[:, ::-1][:, :3].copy()
        m_pred = threshold_Divide(mix, 0.5)  # Preliminary division, half members and half non members

        mix = torch.tensor(mix, dtype=torch.float32)
        m_pred = torch.tensor(m_pred, dtype=torch.float32)
        indices = torch.arange(mix.size(0))

        
        dataset = DataLoader(
            TensorDataset(mix, indices, m_pred),
            shuffle=True, batch_size=20
        ) # source code: batch_size=len(data)/10 or 1000  now: 20

        final_preds = torch.zeros(mix.size(0))  # store the final prediction

        for mix_batch, index_batch, m_pred_batch in dataset:
            m_pred_batch = m_pred_batch.clone().float()
            flag = True

            i = 0
            while flag and i < 100:
                #print("iteration",i)
                i += 1
                dis_ori = mmd_loss(mix_batch[m_pred_batch == 0], mix_batch[m_pred_batch == 1], weight=1)
                flag = False

                # for index, item in tqdm(enumerate(mix_batch)):
                for index, item in enumerate(mix_batch):
                    # Attempt to convert from non-member to member
                    if m_pred_batch[index] == 0:
                        m_pred_batch[index] = 1
                        dis_new = mmd_loss(mix_batch[m_pred_batch == 0], mix_batch[m_pred_batch == 1], weight=1)
                        if dis_new < dis_ori:
                            m_pred_batch[index] = 0
                        else:
                            flag = True
                            dis_ori = dis_new

                    # Attempt to convert from member to non-member
                    elif m_pred_batch[index] == 1:
                        m_pred_batch[index] = 0
                        dis_new = mmd_loss(mix_batch[m_pred_batch == 0], mix_batch[m_pred_batch == 1], weight=1)
                        if dis_new < dis_ori:
                            m_pred_batch[index] = 1
                        else:
                            flag = True
                            dis_ori = dis_new
            # update final predition
            final_preds[index_batch] = m_pred_batch
        return final_preds

"""
--------------------------------Utils------------------------------------
"""


'''
def KMeans_Divide(mix):
    """
    Rough partitioning of data using K-Means method
    : param mix: data
    : return: predicted label (0 or 1)
    """
    kmeans = KMeans(n_clusters=2).fit(mix)
    mix_1 = mix[kmeans.labels_.astype(bool)]
    mix_2 = mix[~kmeans.labels_.astype(bool)]
    if np.mean(torch.tensor(mix_1).max(axis=1).values) > np.mean(torch.tensor(mix_2).max(axis=1).values):
        m_pred = kmeans.labels_
    else:
        m_pred = np.where(kmeans.labels_ == 1, 0, 1)
    return m_pred
'''

import cv2 as cv
import random

def threshold_Divide(mix, ratio):
    """
    Use percentage threshold to divide data into members or non-members
    : param mix: data
    : param ratio: Select threshold by percentage
    : return: predicted label
    """
    threshold = np.percentile(mix.max(axis=1), ratio * 100)
    m_pred = (mix.max(axis=1) > threshold).astype(int)
    return m_pred

# Noise used for table data
def gaussian_noise_table(data, mean=0, var_scaling=0.001):
    """
    Add Gaussian noise to each column of the table data, using different variances for each column.
    
    Parameters:
        -Data: Input table data (numpy array), shape (rows, columns).
        -Mean: The mean of noise, default is 0.
        -Var_Scaling: The scaling factor of the noise variance, relative to the original variance of each column.
            
    return:
        -Noisy_data: Table data with added noise, consistent with the input shape.
    """
    # Calculate the original variance of each column
    #print("data",data.shape,type(data))
    data = data.numpy()
    column_variances = np.var(data, axis=0)
    # Calculate the standard deviation of noise based on variance 
    noise_std = np.sqrt(column_variances * var_scaling)
    
    # Generate Gaussian noise for each column 
    noise = np.random.normal(mean, noise_std, data.shape)
    noisy_data = data + noise
    # print("data,noise",data,noise)
    return noisy_data

# Noise for binary data
def flip_binary_data(data, flip_prob=0.5):
    """
    Randomly flip binary values in table data (0 to 1 or 1 to 0).
    
    Parameters:
        -Data: The input NumPy array or PyTorch Tensor has the shape of (rows, columns).
        -Flipuprob: The probability of each element being flipped
             (default value is 0.1, which means 10% of the data is flipped).
    return:
        -Flicked_data: A flipped NumPy array with the same shape as the original data.
    """
    data = data.numpy()
    flip_mask = np.random.rand(*data.shape) < flip_prob
    flipped_data = np.where(flip_mask, 1 - data, data)
    
    return flipped_data


# Noise for image data
def gaussian_noise(img_set, mean=0, var=0.001):
    ret = np.empty(img_set.shape)
    for m, image in enumerate(img_set):
        image = np.array(image/255, dtype=float)
        noise = np.random.normal(mean, var ** 0.5, image.shape)
        out = image + noise
        if out.min() < 0:
            low_clip = -1.
        else:
            low_clip = 0.
        out = np.clip(out, low_clip, 1.0)
        out = np.uint8(out*255)
        ret[m, :] = out
    return ret


def sobel(img_set):
    ret = np.empty(img_set.shape)
    for i, img in enumerate(img_set):
        grad_x = cv.Sobel(np.float32(img), cv.CV_32F, 1, 0)
        grad_y = cv.Sobel(np.float32(img), cv.CV_32F, 0, 1)
        gradx = cv.convertScaleAbs(grad_x)
        grady = cv.convertScaleAbs(grad_y)
        gradxy = cv.addWeighted(gradx, 0.5, grady, 0.5, 0)
        ret[i, :] = gradxy
    return ret


def sp_noise(img_set, prob=0.001):
    ret = np.empty(img_set.shape)
    for m, image in enumerate(img_set):
        out = np.zeros(image.shape, np.uint8)
        thres = 1 - prob
        for i in range(image.shape[0]):
            for j in range(image.shape[1]):
                rdn = random.random()
                if rdn < prob:
                    out[i][j] = 0
                elif rdn > thres:
                    out[i][j] = 255
                else:
                    out[i][j] = image[i][j]
        ret[m,:] = out
    return ret


def scharr(img_set):
    ret = np.empty(img_set.shape)
    for i, img in enumerate(img_set):
        grad_x = cv.Scharr(np.float32(img), cv.CV_32F, 1, 0)
        grad_y = cv.Scharr(np.float32(img), cv.CV_32F, 0, 1)
        gradx = cv.convertScaleAbs(grad_x)
        grady = cv.convertScaleAbs(grad_y)
        gradxy = cv.addWeighted(gradx, 0.5, grady, 0.5, 0)
        ret[i, :] = gradxy

    return ret


def laplace(img_set):
    ret = np.empty(img_set.shape)
    for i, img in enumerate(img_set):
        gray_lap = cv.Laplacian(np.float32(img), cv.CV_32F, ksize=3)
        dst = cv.convertScaleAbs(gray_lap)
        ret[i, :] = dst
    return ret


def compute_pairwise_distances(x, y):
    """Calculate the Euclidean distance between two sample sets"""
    x_norm = (x ** 2).sum(1).unsqueeze(1)
    y_norm = (y ** 2).sum(1).unsqueeze(0)
    dist = x_norm + y_norm - 2.0 * torch.mm(x, y.t())
    return dist


def gaussian_kernel_matrix(x, y, sigmas):
    """Calculate Gaussian kernel matrix"""
    pairwise_dists = compute_pairwise_distances(x, y)
    beta = 1. / (2. * sigmas.unsqueeze(1))
    s = beta @ pairwise_dists.view(1, -1)
    kernel = torch.exp(-s).sum(0).view_as(pairwise_dists)
    return kernel


def maximum_mean_discrepancy(x, y, sigmas):
    """Calculate MMD loss"""
    kernel_xx = gaussian_kernel_matrix(x, x, sigmas).mean()
    kernel_yy = gaussian_kernel_matrix(y, y, sigmas).mean()
    kernel_xy = gaussian_kernel_matrix(x, y, sigmas).mean()
    return kernel_xx + kernel_yy - 2 * kernel_xy


def mmd_loss(source_samples, target_samples, weight):
    sigmas = torch.tensor([
        1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 5, 10, 15, 20, 25, 30, 35, 100,
        1e3, 1e4, 1e5, 1e6
    ], dtype=torch.float32)

    loss_value = maximum_mean_discrepancy(source_samples, target_samples, sigmas)
    return torch.max(torch.tensor(1e-4), loss_value) * weight



