import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import numpy as np
import math
import os
from typing import Tuple, Dict, List, Union

# --- 导入 scikit-learn ---
try:
    from sklearn.metrics import accuracy_score
except ImportError:
    print("scikit-learn not found. Accuracy will not be calculated in the DRO classes.")
    print("Please run: pip install scikit-learn")
    accuracy_score = None

# --- 导入绘图库 ---
import matplotlib.pyplot as plt
import seaborn as sns


# --- 1. 全局设置 ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")


# --- 2. 对抗攻击 (PGD Attack) ---
def pgd_attack(model, features, labels, epsilon, alpha, num_iter):
    """PGD 对抗攻击 (l2 范数)"""
    perturbed_features = features.clone().detach().to(DEVICE)
    perturbed_features.requires_grad = True
    original_features = features.clone().detach().to(DEVICE)
    labels = labels.to(DEVICE)

    criterion = nn.CrossEntropyLoss()

    for _ in range(num_iter):
        model.zero_grad()
        outputs = model(perturbed_features)
        loss = criterion(outputs, labels)
        loss.backward()

        grad = perturbed_features.grad.detach()
        perturbed_features.data = perturbed_features.data + alpha * grad.sign()
        
        perturbation = perturbed_features.data - original_features.data
        norm = torch.linalg.norm(perturbation.view(perturbation.shape[0], -1), dim=1)
        factor = epsilon / (norm + 1e-12)
        factor = torch.min(torch.ones_like(norm), factor)
        
        perturbation = perturbation * factor.view(-1, 1)
        perturbed_features.data = original_features.data + perturbation

    return perturbed_features.detach()


# --- 3. 模型定义 ---

# 用于 SAA 的标准逻辑回归模型
class LogisticRegression(nn.Module):
    def __init__(self, input_dim=512, num_classes=10):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(input_dim, num_classes)
    def forward(self, x):
        return self.linear(x)

# 基础 DRO 框架相关的类
class DROError(Exception): pass

class LinearModel(nn.Module):
    def __init__(self, input_dim: int, output_dim: int, bias: bool = True):
        super().__init__()
        self.linear = nn.Linear(input_dim, output_dim, bias=bias)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear(x)

class BaseLinearDRO:
    def __init__(self, input_dim: int, num_classes: int, fit_intercept: bool):
        self.input_dim, self.num_classes, self.fit_intercept = input_dim, num_classes, fit_intercept
        self.device = torch.device("cpu")
        self.model: nn.Module
    def _to_tensor(self, data: np.ndarray) -> torch.Tensor:
        return torch.as_tensor(data, dtype=torch.float32, device=self.device)
    def _validate_inputs(self, X: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        if X.ndim == 1: X = X.reshape(-1, self.input_dim)
        if y.ndim > 1: y = y.flatten()
        if X.shape[0] != y.shape[0]: raise DROError(f"X and y must have the same number of samples. Got X: {X.shape[0]}, y: {y.shape[0]}")
        if X.shape[1] != self.input_dim: raise DROError(f"Expected input_dim={self.input_dim} features for X, got {X.shape[1]}")
        return X, y
    def _create_dataloader(self, X: np.ndarray, y: np.ndarray, batch_size: int) -> DataLoader:
        dataset = TensorDataset(self._to_tensor(X), torch.as_tensor(y, dtype=torch.long, device=self.device))
        return DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False)
    def predict(self, X: np.ndarray) -> np.ndarray:
        X_val, _ = self._validate_inputs(X, np.zeros(X.shape[0]))
        self.model.eval()
        with torch.no_grad():
            return self.model(self._to_tensor(X_val)).cpu().numpy()
    def score(self, X: np.ndarray, y: np.ndarray) -> float:
        if accuracy_score is None: return -1.0
        return accuracy_score(y.flatten(), np.argmax(self.predict(X), axis=1))

# sinkhorn_base 模型
class SinkhornBaseDRO(BaseLinearDRO):
    def __init__(self, input_dim: int, num_classes: int, fit_intercept: bool = True, epsilon: float = 1.0,
                 lambda_param: float = 1.0, max_iter: int = 100, learning_rate: float = 1e-3,
                 num_samples: int = 10, batch_size: int = 64, noise: float = 0.1, device: str = "cpu"):
        super().__init__(input_dim, num_classes, fit_intercept)
        self.epsilon, self.lambda_param, self.max_iter, self.learning_rate = epsilon, lambda_param, max_iter, learning_rate
        self.num_samples, self.batch_size, self.noise = num_samples, batch_size, noise
        self.device = torch.device(device if device == "cuda" and torch.cuda.is_available() else "cpu")
        self.model = LinearModel(input_dim, output_dim=num_classes, bias=fit_intercept).to(self.device)
    def _distance_metric(self, xi: torch.Tensor, zeta: torch.Tensor) -> torch.Tensor:
        return torch.norm(xi - zeta, p=2)
    def _cross_entropy_metric(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        return nn.CrossEntropyLoss()(predictions, targets)
    def _generate_corrupt_data(self, per_zeta: torch.Tensor, per_target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        per_zeta, per_target = per_zeta.to(self.device), per_target.to(self.device)
        n = per_zeta.size(0)
        per_zeta_repeated = per_zeta.reshape((1, n)).repeat(self.num_samples, 1)
        per_target_repeated = per_target.repeat(self.num_samples)
        distribution_shift = torch.normal(mean=0.0, std=torch.tensor(self.noise), size=(self.num_samples, n), device=self.device)
        return per_zeta_repeated + distribution_shift, per_target_repeated
    def _baseline_sinkDRO_loss(self, predictions: torch.Tensor, per_zeta: torch.Tensor, generated_xi: torch.Tensor, generated_target: torch.Tensor) -> torch.Tensor:
        inner_values = torch.empty((self.num_samples,), device=self.device)
        for i in range(self.num_samples):
            xi, target, prediction = generated_xi[i:i+1], generated_target[i:i+1], predictions[i:i+1]
            norm_diff = self._distance_metric(per_zeta.unsqueeze(0), xi)
            loss_value = self._cross_entropy_metric(prediction, target)
            inner_term = (loss_value - self.lambda_param * norm_diff) / (self.lambda_param * self.epsilon)
            inner_values[i] = torch.exp(inner_term)
        return torch.log(inner_values.mean())
    def fit(self, X: np.ndarray, y: np.ndarray) -> List[float]:
        X, y = self._validate_inputs(X, y)
        dataloader = self._create_dataloader(X, y, batch_size=self.batch_size)
        optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
        self.model.train()
        loss_history = []
        for epoch in range(self.max_iter):
            epoch_loss = 0.0
            pbar = tqdm(dataloader, desc=f"Training SinkhornBaseDRO Epoch {epoch+1}/{self.max_iter}")
            for data_batch, target_batch in pbar:
                optimizer.zero_grad()
                batch_loss = 0.0
                for i in range(data_batch.size(0)):
                    zeta, target = data_batch[i], target_batch[i]
                    xi_samples, target_repeated = self._generate_corrupt_data(zeta, target)
                    predictions = self.model(xi_samples)
                    loss_for_one_sample = self._baseline_sinkDRO_loss(predictions, zeta.to(self.device), xi_samples, target_repeated)
                    batch_loss += loss_for_one_sample
                final_loss = (self.lambda_param * self.epsilon) * (batch_loss / data_batch.size(0))
                final_loss.backward()
                optimizer.step()
                epoch_loss += final_loss.item()
                pbar.set_postfix(loss=final_loss.item())
            avg_epoch_loss = epoch_loss / len(dataloader)
            loss_history.append(avg_epoch_loss)
        return loss_history

# RGO 模型
class SinkhornDROLogisticRGO(BaseLinearDRO):
    def __init__(self, input_dim: int, num_classes: int, fit_intercept: bool = True, epsilon: float = 0.1,
                 lambda_param: float = 1.0, rgo_inner_lr: float = 0.01, rgo_inner_steps: int = 20,
                 num_samples: int = 10, max_iter: int = 30, learning_rate: float = 0.01,
                 batch_size: int = 64, device: str = "cpu"):
        self.device = torch.device(device if device == "cuda" and torch.cuda.is_available() else "cpu")
        super().__init__(input_dim, num_classes, fit_intercept)
        self.epsilon, self.lambda_param = epsilon, lambda_param
        self.rgo_inner_lr, self.rgo_inner_steps = rgo_inner_lr, rgo_inner_steps
        self.num_samples, self.max_iter, self.learning_rate, self.batch_size = num_samples, max_iter, learning_rate, batch_size
        self.rgo_vectorized_max_trials = 100 
        self.model = LinearModel(self.input_dim, output_dim=self.num_classes, bias=self.fit_intercept).to(self.device)
    def _get_model_loss_value_batched(self, x_features_batch: torch.Tensor, y_target_batch: torch.Tensor, model_instance: nn.Module) -> torch.Tensor:
        return nn.CrossEntropyLoss(reduction='none')(model_instance(x_features_batch), y_target_batch)
    def _rgo_sampler_vectorized(self, x_original_batch: torch.Tensor, y_original_batch: torch.Tensor, current_model_state: nn.Module, num_samples_to_generate: int, epoch: int) -> torch.Tensor:
        batch_size = x_original_batch.size(0)
        x_orig_detached_batch = x_original_batch.detach()
        x_pert_batch = x_orig_detached_batch.clone()
        lr_inner = self.rgo_inner_lr
        inner_steps = int(min(5, self.rgo_inner_steps * (epoch + 1) / self.max_iter))
        for _ in range(inner_steps):
            x_pert_batch.requires_grad_(True)
            per_sample_losses = self._get_model_loss_value_batched(x_pert_batch, y_original_batch, current_model_state)
            per_sample_grads, = torch.autograd.grad(outputs=per_sample_losses, inputs=x_pert_batch, grad_outputs=torch.ones_like(per_sample_losses))
            x_pert_batch = x_pert_batch.detach()
            grad_total = -per_sample_grads / self.lambda_param + 2 * (x_pert_batch - x_orig_detached_batch)
            x_pert_batch -= lr_inner * grad_total
        x_opt_star_batch = x_pert_batch
        var_rgo = self.epsilon
        if var_rgo <= 1e-12: return x_opt_star_batch.repeat_interleave(num_samples_to_generate, dim=0)
        std_rgo = math.sqrt(var_rgo)
        f_model_loss_opt_star = self._get_model_loss_value_batched(x_opt_star_batch, y_original_batch, current_model_state)
        norm_sq_opt_star = torch.sum((x_opt_star_batch - x_orig_detached_batch) ** 2, dim=1)
        f_L_xi_opt_star = (-f_model_loss_opt_star / (self.lambda_param * self.epsilon)) + (norm_sq_opt_star / self.epsilon)
        x_opt_star_3d, x_original_3d, f_L_xi_opt_star_3d = x_opt_star_batch.unsqueeze(1), x_orig_detached_batch.unsqueeze(1), f_L_xi_opt_star.unsqueeze(1)
        final_accepted_perturbations = torch.zeros((batch_size, num_samples_to_generate, self.input_dim), device=self.device)
        active_flags = torch.ones((batch_size, num_samples_to_generate), dtype=torch.bool, device=self.device)
        for _ in range(self.rgo_vectorized_max_trials):
            if not active_flags.any(): break
            pert_proposals = torch.randn_like(final_accepted_perturbations) * std_rgo
            x_candidates = x_opt_star_3d + pert_proposals
            x_candidates_flat = x_candidates.view(-1, self.input_dim)
            y_repeated = y_original_batch.repeat_interleave(num_samples_to_generate, dim=0)
            f_model_loss_candidates = self._get_model_loss_value_batched(x_candidates_flat, y_repeated, current_model_state).view(batch_size, num_samples_to_generate)
            norm_sq_candidates = torch.sum((x_candidates - x_original_3d) ** 2, dim=2)
            f_L_xi_candidates = (-f_model_loss_candidates / (self.lambda_param * self.epsilon)) + (norm_sq_candidates / self.epsilon)
            diff_cand_opt_norm_sq = torch.sum(pert_proposals**2, dim=2)
            exponent_term3 = diff_cand_opt_norm_sq / (2 * var_rgo)
            acceptance_probs = torch.exp(torch.clamp(-f_L_xi_candidates + f_L_xi_opt_star_3d + exponent_term3, max=10))
            newly_accepted_mask = (torch.rand_like(acceptance_probs) < acceptance_probs) & active_flags
            final_accepted_perturbations[newly_accepted_mask] = pert_proposals[newly_accepted_mask]
            active_flags[newly_accepted_mask] = False
        return (x_opt_star_3d + final_accepted_perturbations).view(-1, self.input_dim)
    def fit(self, X: np.ndarray, y: np.ndarray) -> List[float]:
        X, y = self._validate_inputs(X, y)
        dataloader = self._create_dataloader(X, y, batch_size=self.batch_size)
        optimizer_theta = optim.Adam(self.model.parameters(), lr=self.learning_rate, weight_decay=1e-4)
        loss_history = []
        for epoch in range(self.max_iter):
            epoch_loss = 0.0
            pbar = tqdm(dataloader, desc=f"Training RGO Epoch {epoch+1}/{self.max_iter}")
            for x_original_batch, y_original_batch in pbar:
                self.model.eval()
                x_rgo_batch = self._rgo_sampler_vectorized(x_original_batch, y_original_batch, self.model, self.num_samples, epoch)
                y_repeated_batch = y_original_batch.repeat_interleave(self.num_samples, dim=0)
                self.model.train()
                predictions_logits_batch = self.model(x_rgo_batch)
                loss = nn.CrossEntropyLoss()(predictions_logits_batch, y_repeated_batch)
                optimizer_theta.zero_grad()
                loss.backward()
                optimizer_theta.step()
                epoch_loss += loss.item()
                pbar.set_postfix(loss=loss.item())
            avg_epoch_loss = epoch_loss / len(dataloader)
            loss_history.append(avg_epoch_loss)
        return loss_history


# --- 4. 训练和评估函数 ---
def train_saa(model, train_loader, epochs=30, learning_rate=0.001):
    model.to(DEVICE).train()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    loss_history = []
    for epoch in range(epochs):
        epoch_loss = 0.0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        for features, labels in pbar:
            features, labels = features.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(features)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            pbar.set_postfix(loss=loss.item())
        avg_epoch_loss = epoch_loss / len(train_loader)
        loss_history.append(avg_epoch_loss)
    return loss_history

def evaluate_model(model_object, test_features_np, test_labels_np, attack_fn, epsilon_list):
    pytorch_model = model_object.model.to(DEVICE).eval()
    results = {0.0: model_object.score(test_features_np, test_labels_np) * 100}
    print(f"Accuracy on clean test set: {results[0.0]:.2f}%")
    test_loader = DataLoader(TensorDataset(torch.from_numpy(test_features_np), torch.from_numpy(test_labels_np)), batch_size=128)
    for epsilon in epsilon_list:
        correct, total = 0, 0
        for features, labels in tqdm(test_loader, desc=f"Evaluating (epsilon={epsilon:.4f})"):
            perturbed_features = attack_fn(model=pytorch_model, features=features, labels=labels, epsilon=epsilon, alpha=epsilon / 8, num_iter=10)
            with torch.no_grad():
                outputs = pytorch_model(perturbed_features)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels.to(DEVICE)).sum().item()
        results[epsilon] = 100 * correct / total
        print(f"Accuracy under PGD attack (epsilon={epsilon:.4f}): {results[epsilon]:.2f}%")
    return results

# --- 5. 绘图函数 ---
def plot_training_loss(loss_dict):
    """绘制不同模型的训练损失曲线。"""
    sns.set_theme(style="whitegrid")
    plt.figure(figsize=(10, 7))
    for model_name, loss_history in loss_dict.items():
        plt.plot(range(1, len(loss_history) + 1), loss_history, marker='o', linestyle='--', label=model_name)
    plt.xlabel("Epoch")
    plt.ylabel("Average Training Loss")
    plt.title("Training Loss Comparison")
    plt.legend(title="Model")
    plt.grid(True, which='both', linestyle='-')
    plt.savefig("training_loss_comparison.png")
    print("\nTraining loss plot saved as 'training_loss_comparison.png'")

def plot_robustness_results(results_dict, perturbation_levels):
    """绘制模型的鲁棒性（错误率）曲线。"""
    sns.set_theme(style="whitegrid")
    plt.figure(figsize=(10, 7))
    for model_name, results in results_dict.items():
        sorted_epsilons = sorted(results.keys())
        y_values = [100.0 - results[eps] for eps in sorted_epsilons]
        plt.plot(perturbation_levels, y_values, marker='o', linestyle='--', label=model_name)
    plt.xlabel("Level of Perturbations / Norm of Data")
    plt.ylabel("Mis-classification Rate (%)")
    plt.title("Model Robustness under PGD Attack on CIFAR-10 Features")
    plt.legend(title="Model")
    plt.grid(True, which='both', linestyle='-')
    plt.ylim(bottom=0)
    plt.xticks(perturbation_levels)
    plt.savefig("robustness_comparison_plot.png")
    print("\nRobustness plot saved as 'robustness_comparison_plot.png'")

# --- 6. 主执行流程 ---
if __name__ == '__main__':
    print("--- Loading Pre-extracted Features ---")
    file_path = "cifar10_resnet50_features.pth"
    SEED = 2022
    torch.manual_seed(SEED)
    
    # 检查文件是否存在
    if not os.path.exists(file_path):
        print(f"Error: Feature file '{file_path}' not found.")
        print("Please run the feature extraction script first.")
    else:
        data = torch.load(file_path)
        train_features = data["train_features"]
        train_labels = data["train_labels"]
        test_features = data["test_features"]
        test_labels = data["test_labels"]
        
        train_features_np, train_labels_np = train_features.numpy(), train_labels.numpy()
        test_features_np, test_labels_np = test_features.numpy(), test_labels.numpy()
        
        INPUT_DIM, NUM_CLASSES = train_features_np.shape[1], 10
        DRO_BATCH_SIZE = 128
        lam, eps, itr = 0.8, 1, 10
        
        all_training_losses = {}

        print("\n--- Training SAA Model ---")
        saa_model_wrapper = BaseLinearDRO(INPUT_DIM, NUM_CLASSES, True)
        saa_model_wrapper.model = LogisticRegression(INPUT_DIM, NUM_CLASSES).to(DEVICE)
        feature_train_loader = DataLoader(TensorDataset(train_features, train_labels), batch_size=128, shuffle=True)
        saa_loss_history = train_saa(saa_model_wrapper.model, feature_train_loader, epochs=itr)
        all_training_losses["SAA"] = saa_loss_history

        print("\n--- Training SinkhornBaseDRO (Wang et al. formulation) ---")
        base_dro_model = SinkhornBaseDRO(INPUT_DIM, NUM_CLASSES, device=DEVICE.type, lambda_param=lam, epsilon=eps, max_iter=itr, learning_rate=1e-4, batch_size=DRO_BATCH_SIZE, num_samples=10, noise=0.15)
        base_loss_history = base_dro_model.fit(train_features_np, train_labels_np)
        all_training_losses["Base"] = base_loss_history

        print("\n--- Training SinkhornDROLogisticRGO Model ---")
        rgo_dro_model = SinkhornDROLogisticRGO(INPUT_DIM, NUM_CLASSES, device=DEVICE.type, lambda_param=lam, epsilon=eps, max_iter=itr, learning_rate=1e-3, batch_size=DRO_BATCH_SIZE)
        rgo_loss_history = rgo_dro_model.fit(train_features_np, train_labels_np)
        all_training_losses["RGO"] = rgo_loss_history
        
        # 绘制训练损失图
        plot_training_loss(all_training_losses)

        print("\n--- Evaluating Models ---")
        avg_feature_norm = np.mean(np.linalg.norm(test_features_np, axis=1))
        print(f"Average L2 norm of test set features: {avg_feature_norm:.2f}")
        perturbation_levels = [0.0, 0.005, 0.01, 0.015, 0.02, 0.025, 0.03]
        epsilon_values = [level * avg_feature_norm for level in perturbation_levels]
        
        all_results = {}
        models_to_evaluate = {
            "SAA": saa_model_wrapper,
            "Base": base_dro_model,
            "RGO": rgo_dro_model
        }

        for name, model_obj in models_to_evaluate.items():
            print(f"\n--- Evaluating {name} ---")
            results = evaluate_model(model_obj, test_features_np, test_labels_np, pgd_attack, epsilon_values[1:])
            all_results[name] = results

        plot_robustness_results(all_results, perturbation_levels)