# add clean + nested DRO
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
plt.rcParams["font.family"] = "serif"
plt.rcParams["font.serif"] = ["Times", "DejaVu Serif"]

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, datasets
import torchvision.transforms.functional as TF
import numpy as np
import random
from tqdm import tqdm
import os
# --- New Dependencies ---
from sklearn.manifold import TSNE
import matplotlib.patches as mpatches


# ==========================================
# 0. Basic Configuration
# ==========================================
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🔥 Running on Device: {device}")


# ==========================================
# 1. Model: ResNet-18 for CIFAR-10
# ==========================================
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet18_CIFAR(nn.Module):
    def __init__(self, num_classes=10):
        super(ResNet18_CIFAR, self).__init__()
        self.in_planes = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(BasicBlock, 64, 2, stride=1)
        self.layer2 = self._make_layer(BasicBlock, 128, 2, stride=2)
        self.layer3 = self._make_layer(BasicBlock, 256, 2, stride=2)
        self.layer4 = self._make_layer(BasicBlock, 512, 2, stride=2)
        self.linear = nn.Linear(512 * BasicBlock.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

    # --- [New] Feature Extraction Method ---
    def get_features(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        feature = out.view(out.size(0), -1)  # Flatten (Batch, 512)
        return feature


# ==========================================
# 2. Dataset: Compound Trap CIFAR-10
# ==========================================
class CompoundTrapCIFAR10(Dataset):
    def __init__(self, root='./data', mode='dirty_train'):
        self.mode = mode
        train = 'train' in mode
        self.base_dataset = datasets.CIFAR10(root=root, train=train, download=True)
        self.data = self.base_dataset.data
        self.targets = np.array(self.base_dataset.targets)

        self.to_tensor = transforms.ToTensor()
        self.normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))

        self.geo_indices = set()
        self.noise_indices = set()

        if train:
            num_samples = len(self.data)
            indices = np.arange(num_samples)
            rng = np.random.default_rng(42)
            rng.shuffle(indices)

            n_geo = int(num_samples * 0.1)  # 10% Hard Geo (rotated)
            n_noise = int(num_samples * 0.3)  # 30% Noise candidates (but we select 20% total ratio in dirty mode)

            self.geo_indices = set(indices[:n_geo])
            self.noise_indices = set(indices[n_geo: n_geo + n_noise])

            if mode == 'dirty_train':
                print(f"💀 [Dataset] Injecting 20% Label Noise into Group C...")
                for idx in self.noise_indices:
                    original = self.targets[idx]
                    new_label = rng.integers(0, 10)
                    while new_label == original:
                        new_label = rng.integers(0, 10)
                    self.targets[idx] = new_label
            else:
                print(f"✨ [Dataset] Oracle Mode: No Label Noise injected.")

            print(f"📐 [Dataset] Group B (Hard Geo - Rotate 90°): {n_geo} samples")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_arr, target = self.data[idx], self.targets[idx]
        img = transforms.ToPILImage()(img_arr)

        if 'train' in self.mode:
            if idx in self.geo_indices:
                img = TF.rotate(img, 90)
        elif self.mode == 'hard_test':
            img = TF.rotate(img, 90)

        img = self.to_tensor(img)
        img = self.normalize(img)
        # [Modified] Return idx for t-SNE to identify samples
        return img, target, idx


# ===========================================
# Global Hyperparameters
# ===========================================
brho = 1.0
t_r = 0.12  # r_target
min_eps_val = 0.001  # Set a non-zero floor to prevent calculation explosion
max_epsilon_val = 1.0  # Upper limit can be set higher as it's automatically controlled by the formula


# ==========================================
# 3. Dynamic Epsilon Scheduler
# ==========================================
class DirectEpsilonScheduler:
    def __init__(self, target_ratio=0.4, min_eps=0.1, max_eps=1.0, delta=1e-8):
        self.target_ratio = target_ratio
        self.min_eps = min_eps
        self.max_eps = max_eps
        self.delta = delta
        self.epsilon = min_eps  # Initial value

    def step(self, mean_loss, var_loss):
        """
        Implementation of the formula:
        eps_t = (r_target * E[psi])^2 / (2 * Var(psi) + delta)
        """
        # Ensure variance is non-negative
        var_loss = max(var_loss, 0.0)

        # Calculate Numerator: (r * E)^2
        numerator = (self.target_ratio * mean_loss) ** 2

        # Calculate Denominator: 2 * Var + delta
        denominator = 2 * var_loss + self.delta

        # Compute Epsilon
        calculated_eps = numerator / denominator

        # Apply Bounds [min_eps, max_eps]
        self.epsilon = max(min(calculated_eps, self.max_eps), self.min_eps)

        return self.epsilon


# ==========================================
# 4. Algorithm: Nested WDRO
# ==========================================
class NestedWDROTrainer:
    def __init__(self, model, optimizer, device):
        self.model = model
        self.optimizer = optimizer
        self.device = device
        self.rho = brho
        self.log_lambda = nn.Parameter(torch.tensor(0.0).to(device))
        self.criterion_none = nn.CrossEntropyLoss(reduction='none')

        # Use new direct calculation scheduler
        self.eps_scheduler = DirectEpsilonScheduler(
            target_ratio=t_r,
            min_eps=min_eps_val,
            max_eps=max_epsilon_val
        )

        # Discrete Group Definition
        self.group_actions = [0, 1, 2, 3]
        self.group_costs = [0.0, 1.0, 2.0, 1.0]

    def train_step(self, X, Y):
        self.log_lambda.data = torch.clamp(self.log_lambda.data, min=-2.0, max=5.0)
        lam = torch.exp(self.log_lambda)
        batch_size = X.shape[0]

        # Phase 1: Search Worst Transformation
        objectives_all = []
        self.model.eval()
        with torch.no_grad():
            for i, k in enumerate(self.group_actions):
                X_rot = torch.rot90(X, k, [2, 3])
                cost_val = self.group_costs[i]
                pred = self.model(X_rot)
                loss = self.criterion_none(pred, Y)
                obj = loss - lam * cost_val
                objectives_all.append(obj.unsqueeze(1))

        objectives_tensor = torch.cat(objectives_all, dim=1)
        _, best_indices = torch.max(objectives_tensor, dim=1)

        # Phase 2: Compute Gradients
        self.model.train()
        self.optimizer.zero_grad()
        final_losses = torch.zeros(batch_size, device=self.device)
        final_costs = torch.zeros(batch_size, device=self.device)

        for i, k in enumerate(self.group_actions):
            mask = (best_indices == i)
            if mask.sum() > 0:
                X_rot_sub = torch.rot90(X[mask], k, [2, 3])
                pred_sub = self.model(X_rot_sub)
                loss_sub = self.criterion_none(pred_sub, Y[mask])
                final_losses[mask] = loss_sub
                final_costs[mask] = self.group_costs[i]

        # Robust Surrogate Loss
        psi = final_losses - lam * final_costs
        mean_psi = psi.mean()
        var_psi = psi.var(unbiased=True)

        # --- Dynamic Epsilon Calculation (Based on Formula) ---
        current_epsilon = self.eps_scheduler.step(mean_psi.item(), var_psi.item())

        # Variance Regularization Term
        std_term = torch.sqrt(2 * current_epsilon * torch.clamp(var_psi, min=1e-8))

        # Final Objective
        total_loss = lam * self.rho + mean_psi - std_term

        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        self.optimizer.step()

        with torch.no_grad():
            if self.log_lambda.grad is not None:
                self.log_lambda.data -= 0.01 * self.log_lambda.grad
                self.log_lambda.grad.zero_()

        return total_loss.item(), std_term.item(), current_epsilon


# ==========================================
# 5. Algorithm: Standard OT-DRO
# ==========================================
class StandardWDROTrainer:
    def __init__(self, model, optimizer, device):
        self.model = model
        self.optimizer = optimizer
        self.device = device
        self.rho = brho
        self.log_lambda = nn.Parameter(torch.tensor(0.0).to(device))
        self.criterion_none = nn.CrossEntropyLoss(reduction='none')
        self.group_actions = [0, 1, 2, 3]
        self.group_costs = [0.0, 1.0, 2.0, 1.0]

    def train_step(self, X, Y):
        self.log_lambda.data = torch.clamp(self.log_lambda.data, min=-2.0, max=5.0)
        lam = torch.exp(self.log_lambda)
        batch_size = X.shape[0]

        # Phase 1
        objectives_all = []
        self.model.eval()
        with torch.no_grad():
            for i, k in enumerate(self.group_actions):
                X_rot = torch.rot90(X, k, [2, 3])
                cost_val = self.group_costs[i]
                pred = self.model(X_rot)
                loss = self.criterion_none(pred, Y)
                obj = loss - lam * cost_val
                objectives_all.append(obj.unsqueeze(1))

        objectives_tensor = torch.cat(objectives_all, dim=1)
        _, best_indices = torch.max(objectives_tensor, dim=1)

        # Phase 2
        self.model.train()
        self.optimizer.zero_grad()
        final_losses = torch.zeros(batch_size, device=self.device)
        final_costs = torch.zeros(batch_size, device=self.device)

        for i, k in enumerate(self.group_actions):
            mask = (best_indices == i)
            if mask.sum() > 0:
                X_rot_sub = torch.rot90(X[mask], k, [2, 3])
                loss_sub = self.criterion_none(self.model(X_rot_sub), Y[mask])
                final_losses[mask] = loss_sub
                final_costs[mask] = self.group_costs[i]

        psi = final_losses - lam * final_costs
        mean_psi = psi.mean()

        # No Variance Regularization
        total_loss = lam * self.rho + mean_psi

        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        self.optimizer.step()

        with torch.no_grad():
            if self.log_lambda.grad is not None:
                self.log_lambda.data -= 0.01 * self.log_lambda.grad
                self.log_lambda.grad.zero_()

        return total_loss.item(), 0.0, 0.0


# ==========================================
# 6. Visualization
# ==========================================
def plot_metrics(history, save_dir, method_name):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    epochs = range(1, len(history['loss']) + 1)

    has_eps = 'eps' in history and len(history['eps']) > 0 and sum(history['eps']) > 0

    n_plots = 4 if has_eps else 3
    fig, axs = plt.subplots(n_plots, 1, figsize=(10, 4 * n_plots), sharex=True)
    if n_plots == 3:
        ax1, ax2, ax3 = axs
        ax4 = None
    else:
        ax1, ax2, ax3, ax4 = axs

    ax1.plot(epochs, history['loss'], 'b-', label='Total Loss')
    ax1.set_ylabel('Loss', fontsize=20)
    ax1.set_title(f'{method_name}')
    ax1.legend()

    ax2.plot(epochs, history['test_clean_acc'], 'g-', label='Clean Test Acc')
    ax2.set_ylabel('Acc (%)', fontsize=20)
    ax2.grid(True, alpha=0.3)
    ax2.legend()

    ax3.plot(epochs, history['test_hard_acc'], 'r-', label='Hard (Rotated) Acc')
    ax3.set_ylabel('Acc (%)', fontsize=20)
    ax3.grid(True, alpha=0.3)
    ax3.legend()

    if ax4:
        ax4.plot(epochs, history['eps'], 'm--', label='Dynamic Epsilon')
        ax4.set_ylabel('Epsilon')
        ax4.set_xlabel('Epoch')
        ax4.grid(True, alpha=0.3)
        ax4.legend()
    else:
        ax3.set_xlabel('Epoch')

    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f'{method_name}_metrics.png'))
    plt.close()


def plot_final_comparison(results, save_dir='results'):
    methods = list(results.keys())
    # Extract BEST accuracy from results dict
    clean_accs = [results[m]['best_clean'] for m in methods]
    hard_accs = [results[m]['best_hard'] for m in methods]

    x = np.arange(len(methods))
    width = 0.35

    color_clean = '#467879'
    color_hard = '#C66447'

    fig, ax = plt.subplots(figsize=(10, 7))
    rects1 = ax.bar(x - width / 2, clean_accs, width, label='Best Clean Test', color=color_clean, edgecolor='black',
                    alpha=0.9)
    rects2 = ax.bar(x + width / 2, hard_accs, width, label='Best Hard (Rot) Test', color=color_hard, edgecolor='black',
                    alpha=0.9)

    ax.set_ylabel('Best Accuracy (%)', fontsize=20)
    ax.set_xticks(x)
    methods1 = ['Clean+\nERM', 'Dirty+\nERM', 'Dirty+\nNestedDRO', 'Clean+\nNestedDRO']
    ax.set_xticklabels(methods1, fontsize=20)
    ax.legend(fontsize=16)
    ax.set_ylim(0, 100)
    for spine in ax.spines.values():
        spine.set_linewidth(2.0)

    def autolabel(rects):
        for rect in rects:
            height = rect.get_height()
            ax.annotate(f'{height:.1f}%',
                        xy=(rect.get_x() + rect.get_width() / 2, height),
                        xytext=(0, 3),
                        textcoords="offset points",
                        ha='center', va='bottom', fontweight='bold')

    autolabel(rects1)
    autolabel(rects2)

    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'final_best_acc_chart.pdf'), format='pdf', bbox_inches='tight')
    plt.close()


def plot_all_methods_evolution(all_results, save_dir='results'):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

    # Color palette for different methods
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd']

    # 1. Plot Clean Accuracy Evolution
    ax1.set_title('Clean Accuracy Evolution')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Accuracy (%)')
    ax1.grid(True, alpha=0.3)

    # 2. Plot Hard Accuracy Evolution
    ax2.set_title('Hard (Rotated) Accuracy Evolution')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.grid(True, alpha=0.3)

    for i, (method_name, data) in enumerate(all_results.items()):
        color = colors[i % len(colors)]
        epochs = range(1, len(data['history_clean']) + 1)

        ax1.plot(epochs, data['history_clean'], label=method_name, color=color, linewidth=2)
        ax2.plot(epochs, data['history_hard'], label=method_name, color=color, linewidth=2)

    ax1.legend()
    ax2.legend()

    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'global_evolution_comparison.png'), dpi=300)
    plt.close()


# ==========================================
# t-SNE Visualization Function
# ==========================================
def run_tsne_visualization(model, dataset, save_dir, exp_name, max_samples=2000):
    """
    Run t-SNE and plot scatter plot.
    - Blue: Normal
    - Green: Hard Geo (Valid but rotated)
    - Red: Label Noise (Invalid)
    """
    print(f"🎨 [Viz] Generating t-SNE for {exp_name}...")
    model.eval()

    # Create a non-shuffled Loader to track Index
    loader = DataLoader(dataset, batch_size=128, shuffle=False, num_workers=2)

    features = []
    groups = []  # 0: Normal, 1: Geo, 2: Noise

    # Collect data
    collected_count = 0
    indices_to_plot = []

    # For speed, randomly sample max_samples, but ensure class presence
    all_indices = np.arange(len(dataset))
    np.random.shuffle(all_indices)
    target_indices = set(all_indices[:max_samples])

    with torch.no_grad():
        for imgs, _, indices in loader:
            # Only process selected samples
            batch_mask = [i.item() in target_indices for i in indices]
            if not any(batch_mask):
                continue

            imgs = imgs[batch_mask].to(device)
            current_indices = indices[batch_mask].numpy()

            # Extract features
            feats = model.get_features(imgs)  # (B, 512)
            features.append(feats.cpu().numpy())

            # Determine groups
            for idx in current_indices:
                if idx in dataset.geo_indices:
                    groups.append(1)  # Green
                elif idx in dataset.noise_indices:
                    groups.append(2)  # Red
                else:
                    groups.append(0)  # Blue

            collected_count += len(current_indices)
            if collected_count >= max_samples:
                break

    if len(features) == 0:
        print("Warning: No features collected for t-SNE.")
        return

    X = np.concatenate(features, axis=0)
    y = np.array(groups)

    print(f"   Running t-SNE on {X.shape[0]} samples...")
    tsne = TSNE(n_components=2, random_state=42, perplexity=30, init='pca', learning_rate='auto')
    X_embedded = tsne.fit_transform(X)

    # Plotting
    plt.figure(figsize=(10, 8))

    # Plot Normal
    idx_0 = (y == 0)
    plt.scatter(X_embedded[idx_0, 0], X_embedded[idx_0, 1], c='#1f77b4', alpha=0.5, s=20, label='Normal')

    # Plot Geo (Valid Hard)
    idx_1 = (y == 1)
    plt.scatter(X_embedded[idx_1, 0], X_embedded[idx_1, 1], c='#2ca02c', alpha=0.8, s=30, marker='^',
                label='Hard Geo (Rotated)')

    # Plot Noise (Invalid)
    idx_2 = (y == 2)
    plt.scatter(X_embedded[idx_2, 0], X_embedded[idx_2, 1], c='#d62728', alpha=0.8, s=30, marker='x',
                label='Label Noise')

    plt.title(f't-SNE Feature Visualization: {exp_name}', fontsize=16)
    plt.legend(fontsize=12)
    plt.grid(True, alpha=0.2)
    plt.tight_layout()

    save_path = os.path.join(save_dir, f'{exp_name}_tsne.png')
    plt.savefig(save_path, dpi=300)
    plt.close()
    print(f"✅ t-SNE saved to {save_path}")


# ==========================================
# 7. Experiment Workflow
# ==========================================
def run_experiment(exp_name, train_loader, test_clean_loader, test_hard_loader, method='erm', epochs=100):
    print(f"\n🚀 [Start] {exp_name} | Method: {method}")
    set_seed(42)

    sanitized_name = exp_name.replace(" ", "_").replace("(", "").replace(")", "").replace("+", "_")
    save_dir = os.path.join("results", sanitized_name)
    os.makedirs(save_dir, exist_ok=True)

    model = ResNet18_CIFAR(num_classes=10).to(device)
    optimizer = optim.SGD(model.parameters(), lr=0.05, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    criterion = nn.CrossEntropyLoss()

    trainer = None
    if method == 'nested_dro':
        trainer = NestedWDROTrainer(model, optimizer, device)
    elif method == 'standard_dro':
        trainer = StandardWDROTrainer(model, optimizer, device)

    metrics = {'loss': [], 'std_term': [], 'eps': [], 'test_clean_acc': [], 'test_hard_acc': []}

    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0
        epoch_std = 0.0
        # Use a list to temporarily store all eps for this epoch, then record the average
        batch_eps_list = []

        pbar = tqdm(train_loader, desc=f"Ep {epoch + 1}/{epochs}", leave=False, ncols=110)

        for x, y, _ in pbar:  # [Modified] Receive idx but temporarily unused
            x, y = x.to(device), y.to(device)

            if method == 'nested_dro':
                loss, std, eps = trainer.train_step(x, y)
                epoch_std += std
                batch_eps_list.append(eps)
                # [Critical] Real-time display of current Epsilon
                pbar.set_postfix({'L': f"{loss:.3f}", 'Std': f"{std:.3f}", 'Eps': f"{eps:.4f}"})
                epoch_loss += loss
            elif method == 'standard_dro':
                loss, _, _ = trainer.train_step(x, y)
                pbar.set_postfix({'L': f"{loss:.3f}", 'Type': 'StdDRO'})
                epoch_loss += loss
            else:
                optimizer.zero_grad()
                out = model(x)
                loss_tensor = criterion(out, y)
                loss_tensor.backward()
                optimizer.step()
                epoch_loss += loss_tensor.item()
                pbar.set_postfix({'L': f"{loss_tensor.item():.3f}"})

        scheduler.step()

        # Record Metrics
        avg_loss = epoch_loss / len(train_loader)
        metrics['loss'].append(avg_loss)

        if method == 'nested_dro':
            metrics['std_term'].append(epoch_std / len(train_loader))
            avg_eps = sum(batch_eps_list) / len(batch_eps_list)
            metrics['eps'].append(avg_eps)

        # Eval
        model.eval()
        correct_clean = 0;
        total_clean = 0
        with torch.no_grad():
            for x, y, _ in test_clean_loader:
                x, y = x.to(device), y.to(device)
                out = model(x)
                correct_clean += (out.argmax(1) == y).sum().item()
                total_clean += y.size(0)
        acc_clean = 100 * correct_clean / total_clean
        metrics['test_clean_acc'].append(acc_clean)

        correct_hard = 0;
        total_hard = 0
        with torch.no_grad():
            for x, y, _ in test_hard_loader:
                x, y = x.to(device), y.to(device)
                out = model(x)
                correct_hard += (out.argmax(1) == y).sum().item()
                total_hard += y.size(0)
        acc_hard = 100 * correct_hard / total_hard
        metrics['test_hard_acc'].append(acc_hard)

        print(f"  Ep {epoch + 1} | Loss: {avg_loss:.3f} | Clean: {acc_clean:.2f}% | Hard: {acc_hard:.2f}%")

    plot_metrics(metrics, save_dir, sanitized_name)
    best_clean = max(metrics['test_clean_acc'])
    best_hard = max(metrics['test_hard_acc'])
    last_clean = metrics['test_clean_acc'][-1]
    last_hard = metrics['test_hard_acc'][-1]

    print(f"✅ Finished {exp_name}. Peak Hard: {best_hard:.2f}%, Last Hard: {last_hard:.2f}%")

    # Return dictionary with full history and best stats
    return {
        'best_clean': best_clean,
        'best_hard': best_hard,
        'history_clean': metrics['test_clean_acc'],
        'history_hard': metrics['test_hard_acc'],
        'model': model  # Return model for subsequent t-SNE
    }


# ==========================================
# 8. Main
# ==========================================
if __name__ == '__main__':
    BATCH_SIZE = 128
    EPOCHS = 80

    print("Preparing Data...")
    ds_dirty_train = CompoundTrapCIFAR10(mode='dirty_train')
    dl_dirty_train = DataLoader(ds_dirty_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

    ds_oracle_train = CompoundTrapCIFAR10(mode='oracle_train')
    dl_oracle_train = DataLoader(ds_oracle_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

    ds_clean_test = CompoundTrapCIFAR10(mode='clean_test')
    dl_clean_test = DataLoader(ds_clean_test, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    ds_hard_test = CompoundTrapCIFAR10(mode='hard_test')
    dl_hard_test = DataLoader(ds_hard_test, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    final_results = {}

    # 1. Clean + ERM
    res_clean_erm = run_experiment(
        "Clean+ERM", dl_oracle_train, dl_clean_test, dl_hard_test, method='erm', epochs=EPOCHS
    )
    final_results['Clean+ERM'] = res_clean_erm
    # Optional: Perform t-SNE on Clean+ERM (but it has no noise, colors might be monotonous)

    # 2. Dirty + ERM (Optional, but good for t-SNE comparison)
    res_dirty_erm = run_experiment(
        "Dirty+ERM", dl_dirty_train, dl_clean_test, dl_hard_test, method='erm', epochs=EPOCHS
    )
    final_results['Dirty+ERM'] = res_dirty_erm
    # [New] Plot t-SNE for Dirty+ERM
    run_tsne_visualization(res_dirty_erm['model'], ds_dirty_train, "results/Dirty_ERM", "Dirty_ERM")

    # 3. Dirty + NestedDRO (The Hero)
    res_nested = run_experiment(
        "Dirty+NestedDRO", dl_dirty_train, dl_clean_test, dl_hard_test, method='nested_dro', epochs=EPOCHS
    )
    final_results['Dirty+NestedDRO'] = res_nested
    # [New] Plot t-SNE for Nested DRO
    run_tsne_visualization(res_nested['model'], ds_dirty_train, "results/Dirty_NestedDRO", "Dirty_NestedDRO")

    # 4. [New] Clean + NestedDRO (Sanity Check / Verify Harmlessness)
    print("\n🧐 [Experiment] Verifying Clean + NestedDRO (Sanity Check)...")
    res_clean_nested = run_experiment(
        "Clean+NestedDRO", dl_oracle_train, dl_clean_test, dl_hard_test, method='nested_dro', epochs=EPOCHS
    )
    final_results['Clean+NestedDRO'] = res_clean_nested

    # 5. Standard DRO (Baseline) - Still recommended to keep
    # res_std = run_experiment(
    #    "Dirty+StandardDRO", dl_dirty_train, dl_clean_test, dl_hard_test, method='standard_dro', epochs=EPOCHS
    # )
    # final_results['Dirty+StandardDRO'] = res_std
    # run_tsne_visualization(res_std['model'], ds_dirty_train, "results/Dirty_StandardDRO", "Dirty_StandardDRO")

    print("\n" + "=" * 70)
    print("📊 FINAL REPORT (Best Epoch Accuracy)")
    print("=" * 70)
    for name, res in final_results.items():
        print(f"{name:<20} | Best Clean: {res['best_clean']:.2f}% | Best Hard: {res['best_hard']:.2f}%")

    # 1. Plot Bar Chart (Best Acc) with custom colors
    plot_final_comparison(final_results, save_dir='results')

    # 2. Plot Curve Comparison (Acc vs Epoch) for all methods
    plot_all_methods_evolution(final_results, save_dir='results')