# Implement Scalable Membership Inference Attacks via Quantile Regression(https://arxiv.org/abs/2307.03694)

import os
import sys
sys.path.insert(0, './')
import json
import time
import pickle
import numpy as np
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import trunc_normal_, DropPath

# 尝试导入torchvision，并检查它是否支持ConvNeXt
def is_convnext_available():
    try:
        import torchvision.models as tvm
        # 检查ConvNeXt是否可用
        if hasattr(tvm, 'convnext_tiny'):
            return True, tvm
        else:
            print("警告: 当前torchvision版本不支持ConvNeXt模型")
            return False, None
    except ImportError:
        print("警告: 无法导入torchvision")
        return False, None

# 导入torchvision (如果可用)
convnext_available, tvm = is_convnext_available()

from MIA.MIA import MIA

# Import ConvNeXt components or define them here
class LayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError 
        self.normalized_shape = (normalized_shape, )
    
    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None] * x + self.bias[:, None, None]
            return x

class Block(nn.Module):
    def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
        super().__init__()
        self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) 
        self.norm = LayerNorm(dim, eps=1e-6)
        self.pwconv1 = nn.Linear(dim, 4 * dim)
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(4 * dim, dim)
        self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 
                                  requires_grad=True) if layer_scale_init_value > 0 else None
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        input = x
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        if self.gamma is not None:
            x = self.gamma * x
        x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)

        x = input + self.drop_path(x)
        return x

def to_onehot(labels, num_classes):
    """Convert index labels to one-hot encoding"""
    if isinstance(labels, np.ndarray):
        labels = torch.from_numpy(labels)
    device = labels.device
    one_hot = torch.zeros(labels.size(0), num_classes, device=device)
    one_hot.scatter_(1, labels.unsqueeze(1), 1)
    return one_hot

# Pin_ball loss
def pinball_loss_fn(score, target, quantile):
    target = target.reshape([-1, 1])
    assert (
        score.ndim == 2
    ), "score has the wrong shape, expected 2d input but got {}".format(score.shape)
    delta_score = target - score
    loss = torch.nn.functional.relu(delta_score) * quantile + torch.nn.functional.relu(
        -delta_score
    ) * (1.0 - quantile)
    return loss

# 确保预测的分位数单调递增
def rearrange_quantile_fn(test_preds, all_quantiles, target_quantiles=None):
    """Produce monotonic quantiles
    Parameters
    ----------
    test_preds : array of predicted quantile (nXq)
    all_quantiles : array (q), grid of quantile levels in the range (0,1)
    target_quantiles: array (q'), grid of target quantile levels in the range (0,1)

    Returns
    -------
    q_fixed : array (nXq'), containing the rearranged estimates of the
              desired low and high quantile
    """
    if not target_quantiles:
        target_quantiles = all_quantiles

    scaling = all_quantiles[-1] - all_quantiles[0]
    rescaled_target_qs = (target_quantiles - all_quantiles[0]) / scaling
    q_fixed = torch.quantile(
        test_preds, rescaled_target_qs, interpolation="linear", dim=-1
    ).T
    assert (
        q_fixed.shape[0] == test_preds.shape[0] and q_fixed.ndim == test_preds.ndim
    ), "fixed quantiles have the wrong shape, {}".format(q_fixed.shape)
    return q_fixed


def label_logit_and_hinge_scoring_fn(samples, label, base_model):
    # z_y(x)-max_{y'\neq y} z_{y'}(x)
    base_model.eval()
    with torch.no_grad():
        logits = base_model(samples)

        oh_label = to_onehot(label, logits.shape[-1]).bool()
        score = logits[oh_label]
        score -= torch.max(logits[~oh_label].view(logits.shape[0], -1), dim=1)[0]
        assert (
            score.ndim == 1
        ), "hinge loss score should be 1-dimensional, got {}".format(score.shape)
    return score, logits


class QMIAModel(nn.Module):
    def __init__(self, input_size, hidden_dims, num_quantiles):
        super(QMIAModel, self).__init__()
        
        # 对于1D输入(如score)，使用MLP，参考论文中的hidden_dims设置
        if input_size == 1 or not convnext_available:
            # 构建与论文中相同的隐藏层结构
            layers = []
            prev_size = input_size
            
            # 添加隐藏层
            for dim in hidden_dims:
                layers.append(nn.Linear(prev_size, dim))
                layers.append(nn.ReLU())
                # 添加dropout以提高泛化能力
                layers.append(nn.Dropout(0.2))
                prev_size = dim
            
            # 添加输出层
            layers.append(nn.Linear(prev_size, num_quantiles))
            
            self.model = nn.Sequential(*layers)
            self.use_convnext = False
            
            if not convnext_available and input_size > 1:
                print("ConvNeXt不可用，即使输入维度>1也使用MLP")
        else:
            # 对于多维特征输入，尝试使用预训练的ConvNeXt
            try:
                # 尝试加载预训练的ConvNeXt-tiny
                model_fn = tvm.convnext_tiny
                model_weights = tvm.ConvNeXt_Tiny_Weights.IMAGENET1K_V1
                self.base_model = model_fn(weights=model_weights)
                
                # 替换分类器层
                prev_size = 768  # ConvNeXt-tiny的特征维度
                mlp_list = []
                for hd in hidden_dims:
                    mlp_list.append(nn.Linear(prev_size, hd))
                    mlp_list.append(nn.ReLU())
                    mlp_list.append(nn.Dropout(0.2))  # 添加dropout
                    prev_size = hd
                mlp_list.append(nn.Linear(prev_size, num_quantiles))
                self.base_model.classifier = nn.Sequential(*mlp_list)
                
                # 输入适配器 - 将特征转换为适合ConvNeXt的形状
                self.input_adapter = nn.Linear(input_size, 3*224*224)
                self.use_convnext = True
                print("成功初始化ConvNeXt模型")
                
                # 确保model也存在，用于错误回退
                layers = []
                prev_size = input_size
                
                # 添加隐藏层
                for dim in hidden_dims:
                    layers.append(nn.Linear(prev_size, dim))
                    layers.append(nn.ReLU())
                    layers.append(nn.Dropout(0.2))
                    prev_size = dim
                
                # 添加输出层
                layers.append(nn.Linear(prev_size, num_quantiles))
                
                self.model = nn.Sequential(*layers)
            except Exception as e:
                print(f"无法加载ConvNeXt模型，回退到MLP: {e}")
                # 回退到MLP
                layers = []
                prev_size = input_size
                
                # 添加隐藏层
                for dim in hidden_dims:
                    layers.append(nn.Linear(prev_size, dim))
                    layers.append(nn.ReLU())
                    layers.append(nn.Dropout(0.2))
                    prev_size = dim
                
                # 添加输出层
                layers.append(nn.Linear(prev_size, num_quantiles))
                
                self.model = nn.Sequential(*layers)
                self.use_convnext = False
    
    def forward(self, x):
        # 处理1D输入
        if x.dim() == 1:
            x = x.unsqueeze(1)
        
        if self.use_convnext:
            try:
                batch_size = x.shape[0]
                # 将特征转换为图像形状
                x = self.input_adapter(x)
                x = x.view(batch_size, 3, 224, 224)
                # 通过ConvNeXt模型
                return self.base_model(x)
            except Exception as e:
                print(f"ConvNeXt前向传播错误，回退到MLP: {e}")
                # 回退到MLP
                return self.model(x)
        else:
            return self.model(x)

class QuantileMIA(MIA):
    
    def __init__(self, name='QuantileMIA', threshold=0.5, metric=None, mia_mode="attack",
                 low_quantile=-4, high_quantile=0, n_quantile=41, use_logscale=True,
                 hidden_dims=[512, 512], learning_rate=1e-4, weight_decay=1e-4, 
                 num_epochs=30, batch_size=128, device=None, **kwargs):
        '''
        >>> name: name of this method
        >>> threshold: float, the threshold to identify member or non-member (quantile level)
        >>> metric: metric function for scoring samples (默认使用hinge score)
        >>> low_quantile: 最低分位数 (采用论文默认值-4，对数尺度下)
        >>> high_quantile: 最高分位数 (采用论文默认值0，对数尺度下)
        >>> n_quantile: 分位数数量 (采用论文默认值41)
        >>> use_logscale: 是否使用对数刻度的分位数 (采用论文默认值True)
        >>> hidden_dims: 隐藏层维度 (采用论文默认值[512, 512])
        >>> learning_rate: 学习率 (采用论文默认值1e-4)
        >>> weight_decay: 权重衰减 (采用论文默认值1e-4)
        >>> num_epochs: 训练轮数 (采用论文默认值30)
        >>> batch_size: 批次大小 (采用论文默认值128，实际论文用16但为效率考虑使用128)
        >>> device: 设备 (如果None会自动检测)
        '''
        super(QuantileMIA, self).__init__(name, threshold, metric, mia_mode)
        
        # 检测设备
        if device is None:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = device
            
        # 分位数设置
        if use_logscale:
            # 使用对数尺度，原始论文中使用了1-logspace形式
            # 注意：原始论文中的-4到0表示10^-4到10^0的范围
            # 这里我们使用相同的参数但需要转换
            self.quantiles = torch.sort(
                1 - torch.logspace(10**low_quantile if low_quantile < 0 else low_quantile, 
                                  10**high_quantile if high_quantile < 0 else high_quantile, 
                                  n_quantile)
            )[0].reshape([1, -1]).to(self.device)
        else:
            self.quantiles = torch.sort(
                torch.linspace(low_quantile, high_quantile, n_quantile)
            )[0].reshape([1, -1]).to(self.device)
            
        # 存储训练配置
        self.hidden_dims = hidden_dims
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.num_epochs = num_epochs
        self.batch_size = batch_size
        self.use_logscale = use_logscale
        self.n_quantile = n_quantile
        
        # 其他成员变量初始化
        self.q_model = None
        self.base_model = None
        self.metric_fn = label_logit_and_hinge_scoring_fn if metric is None else metric
        self.rearrange_on_predict = not use_logscale

    def fit(self, model, train_data_generator, shadow_data_generator, num_batches=100, **kwargs):
        '''
        >>> model: target model
        >>> train_data_generator: member-data generator
        >>> shadow_data_generator: nonmember-data generator
        >>> num_batches: batch size
        '''
        # 保存基础模型
        self.base_model = model
        self.base_model.eval()
        
        # 收集训练数据
        train_samples = []
        train_labels = []
        train_scores = []
        
        shadow_samples = []
        shadow_labels = []
        shadow_scores = []
        
        # 收集成员样本数据
        print("Collecting member data...")
        # 检查train_data_generator是否是DataLoader
        if hasattr(train_data_generator, '__iter__') and not hasattr(train_data_generator, '__next__'):
            train_data_iter = iter(train_data_generator)
            for i in range(num_batches):
                try:
                    data, label, _ = next(train_data_iter)  # DataLoader可能会返回额外的索引

                    data = data.to(self.device)
                    label = label.to(self.device)
                    
                    train_samples.append(data)
                    train_labels.append(label)
                    
                    score, _ = self.metric_fn(data, label, self.base_model)
                    train_scores.append(score)
                except StopIteration:
                    break
        else:  
            for i in range(num_batches):
                try:
                    data, label = next(train_data_generator)
                    if isinstance(data, np.ndarray):
                        data = torch.from_numpy(data).to(self.device)
                    if isinstance(label, np.ndarray):
                        label = torch.from_numpy(label).to(self.device)
                    else:

                        data = data.to(self.device)
                        label = label.to(self.device)
                    
                    train_samples.append(data)
                    train_labels.append(label)
                    
                    score, _ = self.metric_fn(data, label, self.base_model)
                    train_scores.append(score)
                except StopIteration:
                    break
                
        # 收集非成员样本数据
        print("Collecting non-member data...")
        # 检查shadow_data_generator是否是DataLoader
        if hasattr(shadow_data_generator, '__iter__') and not hasattr(shadow_data_generator, '__next__'):
            shadow_data_iter = iter(shadow_data_generator)
            for i in range(num_batches):
                try:
                    data, label, _ = next(shadow_data_iter)  # DataLoader可能会返回额外的索引
                    # 确保数据在正确的设备上
                    data = data.to(self.device)
                    label = label.to(self.device)
                    
                    shadow_samples.append(data)
                    shadow_labels.append(label)
                    
                    score, _ = self.metric_fn(data, label, self.base_model)
                    shadow_scores.append(score)
                except StopIteration:
                    break
        else:  # 原始代码路径，处理迭代器
            for i in range(num_batches):
                try:
                    data, label = next(shadow_data_generator)
                    if isinstance(data, np.ndarray):
                        data = torch.from_numpy(data).to(self.device)
                    if isinstance(label, np.ndarray):
                        label = torch.from_numpy(label).to(self.device)
                    else:
                        # 确保张量在正确的设备上
                        data = data.to(self.device)
                        label = label.to(self.device)
                        
                    shadow_samples.append(data)
                    shadow_labels.append(label)
                    
                    score, _ = self.metric_fn(data, label, self.base_model)
                    shadow_scores.append(score)
                except StopIteration:
                    break
        
        # 合并数据
        all_scores = torch.cat(train_scores + shadow_scores)
        
        # 创建成员标签 (1 表示成员, 0 表示非成员)
        is_member = torch.cat([
            torch.ones(sum(len(x) for x in train_scores), dtype=torch.float32),
            torch.zeros(sum(len(x) for x in shadow_scores), dtype=torch.float32)
        ]).to(self.device)
        
        # 提取特征作为输入
        # 简化：只使用分数作为特征
        features = all_scores.reshape(-1, 1)
        print(f"特征形状: {features.shape}")
        
        # 初始化Q函数模型
        input_size = features.shape[1] if features.dim() > 1 else 1
        self.q_model = QMIAModel(input_size, self.hidden_dims, self.n_quantile).to(self.device)
        
        # 创建数据集和数据加载器
        dataset = torch.utils.data.TensorDataset(features, all_scores, is_member)
        data_loader = torch.utils.data.DataLoader(
            dataset, batch_size=self.batch_size, shuffle=True
        )
        
        # 优化器
        optimizer = torch.optim.AdamW(
            self.q_model.parameters(), 
            lr=self.learning_rate, 
            weight_decay=self.weight_decay
        )
        
        # 训练Q函数模型
        print("Training Q-function model...")
        self.q_model.train()
        
        for epoch in range(self.num_epochs):
            epoch_loss = 0.0
            for batch_features, batch_scores, batch_is_member in data_loader:
                optimizer.zero_grad()
                
                # 前向传播
                predicted_quantiles = self.q_model(batch_features)
                
                # 计算pinball loss
                loss = pinball_loss_fn(
                    predicted_quantiles,
                    batch_scores,
                    self.quantiles.to(predicted_quantiles.device)
                ).mean()
                
                # 反向传播和优化
                loss.backward()
                optimizer.step()
                
                epoch_loss += loss.item()
            
            print(f"Epoch {epoch+1}/{self.num_epochs}, Loss: {epoch_loss/len(data_loader):.6f}")
        
        print("QuantileMIA training completed.")
        self.q_model.eval()
        return self

    def infer(self, model, data, label):
        '''
        >>> model: 被攻击的模型
        >>> data: 输入数据
        >>> label: 真实标签
        
        返回: 布尔数组，True表示预测为成员，False表示预测为非成员
        '''
        if self.q_model is None:
            raise RuntimeError("模型未训练，请先调用fit方法")
            
        if self.base_model != model:
            print("Warning: Inference model different from training model")
            self.base_model = model
            
        self.base_model.eval()
        self.q_model.eval()
        
        # 转换数据类型和设备
        if isinstance(data, np.ndarray):
            data = torch.from_numpy(data).to(self.device)
        if isinstance(label, np.ndarray):
            label = torch.from_numpy(label).to(self.device)
            
        # 计算样本分数
        with torch.no_grad():
            score, _ = self.metric_fn(data, label, self.base_model)
            
            # 使用得分作为特征输入
            features = score.reshape(-1, 1)
            
            # 使用Q函数模型预测分位数
            predicted_quantiles = self.q_model(features)
            
            # 根据需要对分位数进行重排
            if self.rearrange_on_predict:
                predicted_quantiles = rearrange_quantile_fn(
                    predicted_quantiles, self.quantiles.to(predicted_quantiles.device).flatten()
                )
            
            # 选择阈值对应的分位数预测
            threshold_idx = int(self.threshold * (self.n_quantile - 1))
            threshold_quantile = predicted_quantiles[:, threshold_idx]
            
            # 比较分数和阈值，确定成员身份
            is_member = score > threshold_quantile
            
        return is_member
