import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import numpy as np
import math
import os
from typing import Tuple, Dict, List, Union

# --- 导入 scikit-learn ---
try:
    from sklearn.metrics import accuracy_score
except ImportError:
    print("scikit-learn not found. Accuracy will not be calculated in the DRO classes.")
    print("Please run: pip install scikit-learn")
    accuracy_score = None

# --- 导入绘图库 ---
import matplotlib.pyplot as plt
import seaborn as sns


# --- 1. 全局设置 ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")


# --- 2. 对抗攻击 (PGD Attack) ---
def pgd_attack(model, features, labels, epsilon, alpha, num_iter):
    """PGD 对抗攻击 (l2 范数)"""
    perturbed_features = features.clone().detach().to(DEVICE)
    perturbed_features.requires_grad = True
    original_features = features.clone().detach().to(DEVICE)
    labels = labels.to(DEVICE)

    criterion = nn.CrossEntropyLoss()

    for _ in range(num_iter):
        model.zero_grad()
        outputs = model(perturbed_features)
        loss = criterion(outputs, labels)
        loss.backward()

        grad = perturbed_features.grad.detach()
        perturbed_features.data = perturbed_features.data + alpha * grad.sign()
        
        perturbation = perturbed_features.data - original_features.data
        norm = torch.linalg.norm(perturbation.view(perturbation.shape[0], -1), dim=1)
        factor = epsilon / (norm + 1e-12)
        factor = torch.min(torch.ones_like(norm), factor)
        
        perturbation = perturbation * factor.view(-1, 1)
        perturbed_features.data = original_features.data + perturbation

    return perturbed_features.detach()


# --- 3. 模型定义 ---

# 用于 SAA 的标准逻辑回归模型
class LogisticRegression(nn.Module):
    def __init__(self, input_dim=512, num_classes=10):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(input_dim, num_classes)
    def forward(self, x):
        return self.linear(x)

# 基础 DRO 框架相关的类
class DROError(Exception): pass

class LinearModel(nn.Module):
    def __init__(self, input_dim: int, output_dim: int, bias: bool = True):
        super().__init__()
        self.linear = nn.Linear(input_dim, output_dim, bias=bias)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear(x)

class BaseLinearDRO:
    def __init__(self, input_dim: int, num_classes: int, fit_intercept: bool):
        self.input_dim, self.num_classes, self.fit_intercept = input_dim, num_classes, fit_intercept
        self.device = torch.device("cpu")
        self.model: nn.Module
    def _to_tensor(self, data: np.ndarray) -> torch.Tensor:
        return torch.as_tensor(data, dtype=torch.float32, device=self.device)
    def _validate_inputs(self, X: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        if X.ndim == 1: X = X.reshape(-1, self.input_dim)
        if y.ndim > 1: y = y.flatten()
        if X.shape[0] != y.shape[0]: raise DROError(f"X and y must have the same number of samples. Got X: {X.shape[0]}, y: {y.shape[0]}")
        if X.shape[1] != self.input_dim: raise DROError(f"Expected input_dim={self.input_dim} features for X, got {X.shape[1]}")
        return X, y
    def _create_dataloader(self, X: np.ndarray, y: np.ndarray, batch_size: int) -> DataLoader:
        dataset = TensorDataset(self._to_tensor(X), torch.as_tensor(y, dtype=torch.long, device=self.device))
        return DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False)
    def predict(self, X: np.ndarray) -> np.ndarray:
        X_val, _ = self._validate_inputs(X, np.zeros(X.shape[0]))
        self.model.eval()
        with torch.no_grad():
            return self.model(self._to_tensor(X_val)).cpu().numpy()
    def score(self, X: np.ndarray, y: np.ndarray) -> float:
        if accuracy_score is None: return -1.0
        return accuracy_score(y.flatten(), np.argmax(self.predict(X), axis=1))

# ==============================================================================
# START: 性能优化后的 SinkhornBaseDRO 类
# ==============================================================================
class SinkhornBaseDRO(BaseLinearDRO):
    def __init__(self, input_dim: int, num_classes: int, fit_intercept: bool = True, epsilon: float = 1.0,
                 lambda_param: float = 1.0, max_iter: int = 100, learning_rate: float = 1e-3,
                 num_samples: int = 10, batch_size: int = 64, noise: float = 0.1, device: str = "cpu"):
        super().__init__(input_dim, num_classes, fit_intercept)
        self.epsilon = epsilon
        self.lambda_param = lambda_param
        self.max_iter = max_iter
        self.learning_rate = learning_rate
        self.num_samples = num_samples
        self.batch_size = batch_size
        self.noise = noise
        self.device = torch.device(device if device == "cuda" and torch.cuda.is_available() else "cpu")
        self.model = LinearModel(input_dim, output_dim=num_classes, bias=fit_intercept).to(self.device)

    def fit(self, X: np.ndarray, y: np.ndarray) -> List[float]:
        X, y = self._validate_inputs(X, y)
        dataloader = self._create_dataloader(X, y, batch_size=self.batch_size)
        optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
        
        self.model.train()
        loss_history = []
        
        for epoch in range(self.max_iter):
            epoch_loss = 0.0
            pbar = tqdm(dataloader, desc=f"Training SinkhornBaseDRO Epoch {epoch+1}/{self.max_iter}")
            
            for zeta_batch, target_batch in pbar:
                optimizer.zero_grad()
                
                batch_size = zeta_batch.size(0)
                
                # --- 向量化改造核心 ---
                # 1. 扩展原始数据以匹配样本数量
                zeta_batch_expanded = zeta_batch.repeat_interleave(self.num_samples, dim=0).to(self.device)
                target_batch_expanded = target_batch.repeat_interleave(self.num_samples, dim=0).to(self.device)

                # 2. 一次性为整个扩展后的批次生成噪声
                noise_tensor = torch.normal(mean=0.0, std=self.noise, size=zeta_batch_expanded.shape, device=self.device)
                xi_batch = zeta_batch_expanded + noise_tensor
                
                # 3. 一次性获得所有噪声样本的预测
                predictions = self.model(xi_batch)
                
                # 4. 向量化计算损失
                # a. 计算 loss_value (交叉熵)
                ce_loss_fn = nn.CrossEntropyLoss(reduction='none')
                loss_values = ce_loss_fn(predictions, target_batch_expanded)
                
                # b. 计算 norm_diff (L2 距离)
                norm_diffs = torch.norm(xi_batch - zeta_batch_expanded, p=2, dim=1)
                
                # c. 计算完整的内部项
                inner_terms = (loss_values - self.lambda_param * norm_diffs) / (self.lambda_param * self.epsilon)
                exp_terms = torch.exp(inner_terms)
                
                # d. 重塑并计算期望
                exp_terms_reshaped = exp_terms.view(batch_size, self.num_samples)
                inner_expectations = torch.mean(exp_terms_reshaped, dim=1)
                
                # e. 计算最终批次损失
                batch_dro_loss = torch.log(inner_expectations).mean()
                final_loss = (self.lambda_param * self.epsilon) * batch_dro_loss

                # --- 优化步骤 ---
                final_loss.backward()
                optimizer.step()
                
                epoch_loss += final_loss.item()
                pbar.set_postfix(loss=final_loss.item())
                
            avg_epoch_loss = epoch_loss / len(dataloader)
            loss_history.append(avg_epoch_loss)
            
        return loss_history
# ==============================================================================
# END: 性能优化完成
# ==============================================================================


# RGO 模型
class SinkhornDROLogisticRGO(BaseLinearDRO):
    def __init__(self, input_dim: int, num_classes: int, fit_intercept: bool = True, epsilon: float = 0.1,
                 lambda_param: float = 1.0, rgo_inner_lr: float = 0.01, rgo_inner_steps: int = 20,
                 num_samples: int = 10, max_iter: int = 30, learning_rate: float = 0.01,
                 batch_size: int = 64, device: str = "cpu"):
        self.device = torch.device(device if device == "cuda" and torch.cuda.is_available() else "cpu")
        super().__init__(input_dim, num_classes, fit_intercept)
        self.epsilon, self.lambda_param = epsilon, lambda_param
        self.rgo_inner_lr, self.rgo_inner_steps = rgo_inner_lr, rgo_inner_steps
        self.num_samples, self.max_iter, self.learning_rate, self.batch_size = num_samples, max_iter, learning_rate, batch_size
        self.rgo_vectorized_max_trials = 100 
        self.model = LinearModel(self.input_dim, output_dim=self.num_classes, bias=self.fit_intercept).to(self.device)
    def _get_model_loss_value_batched(self, x_features_batch: torch.Tensor, y_target_batch: torch.Tensor, model_instance: nn.Module) -> torch.Tensor:
        return nn.CrossEntropyLoss(reduction='none')(model_instance(x_features_batch), y_target_batch)
    def _rgo_sampler_vectorized(self, x_original_batch: torch.Tensor, y_original_batch: torch.Tensor, current_model_state: nn.Module, num_samples_to_generate: int, epoch: int) -> torch.Tensor:
        batch_size = x_original_batch.size(0)
        x_orig_detached_batch = x_original_batch.detach()
        x_pert_batch = x_orig_detached_batch.clone()
        lr_inner = self.rgo_inner_lr
        inner_steps = int(min(5, self.rgo_inner_steps * (epoch + 1) / self.max_iter))
        for _ in range(inner_steps):
            x_pert_batch.requires_grad_(True)
            per_sample_losses = self._get_model_loss_value_batched(x_pert_batch, y_original_batch, current_model_state)
            per_sample_grads, = torch.autograd.grad(outputs=per_sample_losses, inputs=x_pert_batch, grad_outputs=torch.ones_like(per_sample_losses))
            x_pert_batch = x_pert_batch.detach()
            grad_total = -per_sample_grads / self.lambda_param + 2 * (x_pert_batch - x_orig_detached_batch)
            x_pert_batch -= lr_inner * grad_total
        x_opt_star_batch = x_pert_batch
        var_rgo = self.epsilon
        if var_rgo <= 1e-12: return x_opt_star_batch.repeat_interleave(num_samples_to_generate, dim=0)
        std_rgo = math.sqrt(var_rgo)
        f_model_loss_opt_star = self._get_model_loss_value_batched(x_opt_star_batch, y_original_batch, current_model_state)
        norm_sq_opt_star = torch.sum((x_opt_star_batch - x_orig_detached_batch) ** 2, dim=1)
        f_L_xi_opt_star = (-f_model_loss_opt_star / (self.lambda_param * self.epsilon)) + (norm_sq_opt_star / self.epsilon)
        x_opt_star_3d, x_original_3d, f_L_xi_opt_star_3d = x_opt_star_batch.unsqueeze(1), x_orig_detached_batch.unsqueeze(1), f_L_xi_opt_star.unsqueeze(1)
        final_accepted_perturbations = torch.zeros((batch_size, num_samples_to_generate, self.input_dim), device=self.device)
        active_flags = torch.ones((batch_size, num_samples_to_generate), dtype=torch.bool, device=self.device)
        for _ in range(self.rgo_vectorized_max_trials):
            if not active_flags.any(): break
            pert_proposals = torch.randn_like(final_accepted_perturbations) * std_rgo
            x_candidates = x_opt_star_3d + pert_proposals
            x_candidates_flat = x_candidates.view(-1, self.input_dim)
            y_repeated = y_original_batch.repeat_interleave(num_samples_to_generate, dim=0)
            f_model_loss_candidates = self._get_model_loss_value_batched(x_candidates_flat, y_repeated, current_model_state).view(batch_size, num_samples_to_generate)
            norm_sq_candidates = torch.sum((x_candidates - x_original_3d) ** 2, dim=2)
            f_L_xi_candidates = (-f_model_loss_candidates / (self.lambda_param * self.epsilon)) + (norm_sq_candidates / self.epsilon)
            diff_cand_opt_norm_sq = torch.sum(pert_proposals**2, dim=2)
            exponent_term3 = diff_cand_opt_norm_sq / (2 * var_rgo)
            acceptance_probs = torch.exp(torch.clamp(-f_L_xi_candidates + f_L_xi_opt_star_3d + exponent_term3, max=10))
            newly_accepted_mask = (torch.rand_like(acceptance_probs) < acceptance_probs) & active_flags
            final_accepted_perturbations[newly_accepted_mask] = pert_proposals[newly_accepted_mask]
            active_flags[newly_accepted_mask] = False
        return (x_opt_star_3d + final_accepted_perturbations).view(-1, self.input_dim)
    def fit(self, X: np.ndarray, y: np.ndarray) -> List[float]:
        X, y = self._validate_inputs(X, y)
        dataloader = self._create_dataloader(X, y, batch_size=self.batch_size)
        optimizer_theta = optim.Adam(self.model.parameters(), lr=self.learning_rate, weight_decay=1e-4)
        loss_history = []
        for epoch in range(self.max_iter):
            epoch_loss = 0.0
            pbar = tqdm(dataloader, desc=f"Training RGO Epoch {epoch+1}/{self.max_iter}")
            for x_original_batch, y_original_batch in pbar:
                self.model.eval()
                x_rgo_batch = self._rgo_sampler_vectorized(x_original_batch, y_original_batch, self.model, self.num_samples, epoch)
                y_repeated_batch = y_original_batch.repeat_interleave(self.num_samples, dim=0)
                self.model.train()
                predictions_logits_batch = self.model(x_rgo_batch)
                loss = nn.CrossEntropyLoss()(predictions_logits_batch, y_repeated_batch)
                optimizer_theta.zero_grad()
                loss.backward()
                optimizer_theta.step()
                epoch_loss += loss.item()
                pbar.set_postfix(loss=loss.item())
            avg_epoch_loss = epoch_loss / len(dataloader)
            loss_history.append(avg_epoch_loss)
        return loss_history


# --- 4. 训练和评估函数 ---
def train_saa(model, train_loader, epochs=30, learning_rate=0.001):
    model.to(DEVICE).train()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    loss_history = []
    for epoch in range(epochs):
        epoch_loss = 0.0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        for features, labels in pbar:
            features, labels = features.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(features)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            pbar.set_postfix(loss=loss.item())
        avg_epoch_loss = epoch_loss / len(train_loader)
        loss_history.append(avg_epoch_loss)
    return loss_history

def evaluate_model(model_object, test_features_np, test_labels_np, attack_fn, epsilon_list):
    pytorch_model = model_object.model.to(DEVICE).eval()
    results = {0.0: model_object.score(test_features_np, test_labels_np) * 100}
    print(f"Accuracy on clean test set: {results[0.0]:.2f}%")
    test_loader = DataLoader(TensorDataset(torch.from_numpy(test_features_np), torch.from_numpy(test_labels_np)), batch_size=128)
    for epsilon in epsilon_list:
        correct, total = 0, 0
        for features, labels in tqdm(test_loader, desc=f"Evaluating (epsilon={epsilon:.4f})"):
            perturbed_features = attack_fn(model=pytorch_model, features=features, labels=labels, epsilon=epsilon, alpha=epsilon / 8, num_iter=10)
            with torch.no_grad():
                outputs = pytorch_model(perturbed_features)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels.to(DEVICE)).sum().item()
        results[epsilon] = 100 * correct / total
        print(f"Accuracy under PGD attack (epsilon={epsilon:.4f}): {results[epsilon]:.2f}%")
    return results

# --- 5. 绘图函数 ---
def plot_training_loss(loss_dict):
    """绘制不同模型的训练损失曲线。"""
    sns.set_theme(style="whitegrid")
    plt.figure(figsize=(10, 7))
    for model_name, loss_history in loss_dict.items():
        if loss_history: # 确保损失历史不为空
            plt.plot(range(1, len(loss_history) + 1), loss_history, marker='o', linestyle='--', label=model_name)
    plt.xlabel("Epoch")
    plt.ylabel("Average Training Loss")
    plt.title("Training Loss Comparison")
    plt.legend(title="Model")
    plt.grid(True, which='both', linestyle='-')
    plt.savefig("training_loss_comparison.png")
    print("\nTraining loss plot saved as 'training_loss_comparison.png'")

def plot_robustness_results(results_dict, perturbation_levels):
    """绘制模型的鲁棒性（错误率）曲线。"""
    sns.set_theme(style="whitegrid")
    plt.figure(figsize=(10, 7))
    for model_name, results in results_dict.items():
        # 确保 results 字典中的键与 perturbation_levels 匹配
        y_values = [100.0 - results.get(p_level * avg_feature_norm, results.get(0.0)) for p_level in perturbation_levels]
        plt.plot(perturbation_levels, y_values, marker='o', linestyle='--', label=model_name)
    plt.xlabel("Level of Perturbations / Norm of Data")
    plt.ylabel("Mis-classification Rate (%)")
    plt.title("Model Robustness under PGD Attack on CIFAR-10 Features")
    plt.legend(title="Model")
    plt.grid(True, which='both', linestyle='-')
    plt.ylim(bottom=0)
    plt.xticks(perturbation_levels)
    plt.savefig("robustness_comparison_plot.png")
    print("\nRobustness plot saved as 'robustness_comparison_plot.png'")

# --- 6. 主执行流程 ---
if __name__ == '__main__':
    print("--- Loading Pre-extracted Features ---")
    file_path = "cifar10_resnet50_features.pth"
    SEED = 2022
    torch.manual_seed(SEED)
    
    if not os.path.exists(file_path):
        print(f"Error: Feature file '{file_path}' not found.")
        print("Please run the feature extraction script first.")
    else:
        data = torch.load(file_path)
        train_features, train_labels = data["train_features"], data["train_labels"]
        test_features, test_labels = data["test_features"], data["test_labels"]
        
        train_features_np, train_labels_np = train_features.numpy(), train_labels.numpy()
        test_features_np, test_labels_np = test_features.numpy(), test_labels.numpy()
        
        INPUT_DIM, NUM_CLASSES = train_features_np.shape[1], 10
        DRO_BATCH_SIZE = 128
        lam, eps, itr = 0.8, 1, 10
        
        all_training_losses = {}

        print("\n--- Training SAA Model ---")
        saa_model_wrapper = BaseLinearDRO(INPUT_DIM, NUM_CLASSES, True)
        saa_model_wrapper.model = LogisticRegression(INPUT_DIM, NUM_CLASSES).to(DEVICE)
        feature_train_loader = DataLoader(TensorDataset(train_features, train_labels), batch_size=128, shuffle=True)
        saa_loss_history = train_saa(saa_model_wrapper.model, feature_train_loader, epochs=itr)
        all_training_losses["SAA"] = saa_loss_history

        print("\n--- Training SinkhornBaseDRO (Vectorized) ---")
        base_dro_model = SinkhornBaseDRO(INPUT_DIM, NUM_CLASSES, device=DEVICE.type, lambda_param=lam, epsilon=eps, max_iter=itr, learning_rate=1e-4, batch_size=DRO_BATCH_SIZE, num_samples=10, noise=0.15)
        base_loss_history = base_dro_model.fit(train_features_np, train_labels_np)
        all_training_losses["Base"] = base_loss_history

        print("\n--- Training SinkhornDROLogisticRGO Model ---")
        rgo_dro_model = SinkhornDROLogisticRGO(INPUT_DIM, NUM_CLASSES, device=DEVICE.type, lambda_param=lam, epsilon=eps, max_iter=itr, learning_rate=1e-3, batch_size=DRO_BATCH_SIZE)
        rgo_loss_history = rgo_dro_model.fit(train_features_np, train_labels_np)
        all_training_losses["RGO"] = rgo_loss_history
        
        plot_training_loss(all_training_losses)

        print("\n--- Evaluating Models ---")
        avg_feature_norm = np.mean(np.linalg.norm(test_features_np, axis=1))
        print(f"Average L2 norm of test set features: {avg_feature_norm:.2f}")
        perturbation_levels = [0.0, 0.005, 0.01, 0.015, 0.02, 0.025, 0.03]
        epsilon_values = [level * avg_feature_norm for level in perturbation_levels]
        
        all_results = {}
        models_to_evaluate = {
            "SAA": saa_model_wrapper,
            "Base": base_dro_model,
            "RGO": rgo_dro_model
        }

        for name, model_obj in models_to_evaluate.items():
            print(f"\n--- Evaluating {name} ---")
            results = evaluate_model(model_obj, test_features_np, test_labels_np, pgd_attack, epsilon_values[1:])
            all_results[name] = results
            
        # 修复 evaluate_model 返回的字典键与 perturbation_levels 不匹配的问题
        # 将 all_results 的键从 epsilon 值映射回 perturbation level
        all_results_by_plevel = {}
        for name, results in all_results.items():
            mapped_results = {0.0: results[0.0]}
            for i, p_level in enumerate(perturbation_levels[1:]):
                mapped_results[p_level] = results[epsilon_values[i+1]]
            all_results_by_plevel[name] = mapped_results

        plot_robustness_results(all_results_by_plevel, perturbation_levels)