```markdown
# Nested Wasserstein DRO on Compound Trap CIFAR-10

This repository contains a PyTorch implementation of **Nested Wasserstein Distributionally Robust Optimization (Nested WDRO)**. It is designed to evaluate model robustness against **Geometric Shifts** and **Label Noise** using a custom "Compound Trap" variant of the CIFAR-10 dataset.

## Overview

The project compares standard Empirical Risk Minimization (ERM) against the proposed Nested WDRO method. The core innovation lies in using a **Dynamic Epsilon Scheduler** and **Variance Regularization** to robustly handle data corruption without sacrificing clean accuracy.

### Key Features
* **Compound Trap Dataset**: A custom CIFAR-10 loader that simultaneously introduces:
    * **Geometric Traps**: 10% of samples are rotated by 90° (semantically valid but distributionally shifted).
    * **Label Noise**: 20% of samples have randomized labels (invalid data).
* **Nested WDRO Algorithm**: Implements a robust surrogate loss that accounts for variance in the loss distribution.
* **Dynamic Epsilon**: Automatically adjusts the robustness radius ($\epsilon$) based on the signal-to-noise ratio of the loss.
* **Visualizations**: Includes automated generation of t-SNE plots, accuracy evolution curves, and comparative bar charts.

## Dependencies

```bash
pip install torch torchvision numpy matplotlib scikit-learn tqdm

```

* **Python**: >= 3.6
* **PyTorch**: Compatible with your CUDA version (or CPU).
* **Scikit-Learn**: Required for t-SNE visualization.
* **Matplotlib**: For plotting metrics and embeddings.

## Usage

To run the full suite of experiments, simply execute the main script:

```bash
python cifar.py

```

The script is self-contained and will automatically:

1. Download CIFAR-10 (if not present).
2. Run experiments sequentially:
* **Clean+ERM**: Baseline on clean data.
* **Dirty+ERM**: Baseline on noisy/rotated data.
* **Dirty+NestedDRO**: The proposed robust method on noisy/rotated data.
* **Clean+NestedDRO**: Sanity check on clean data.


3. Save all results and plots to the `results_v6/` directory.

## Code Structure

### 1. Data Loading (`CompoundTrapCIFAR10`)

* **Modes**:
* `oracle_train`: Clean data (no noise, no rotations).
* `dirty_train`: Injects 20% label noise and tracks 10% geometric shifts.
* `hard_test`: All test images are rotated 90° to evaluate robustness.
* `clean_test`: Standard CIFAR-10 test set.



### 2. Model (`ResNet18_CIFAR`)

A modified ResNet-18 architecture optimized for CIFAR-10 image sizes, including a feature extractor method (`get_features`) used for t-SNE analysis.

### 3. Optimization Algorithms

* **ERM**: Standard Cross-Entropy Loss.
* **Nested WDRO (`NestedWDROTrainer`)**:
* **Step 1**: Finds the worst-case geometric transformation (Group Action).
* **Step 2**: Optimizes the robust objective:
$$ \mathcal{L} = \lambda \rho + \mathbb{E}[\psi] - \sqrt{2\epsilon \cdot \text{Var}(\psi)} $$
* **Step 3**: Dynamically updates  using the `DirectEpsilonScheduler`.



## Outputs & Visualization

All artifacts are generated in the `results/` directory.

### 1. Performance Metrics

* **`{ExpName}_metrics.png`**: Plots showing Total Loss, Clean Accuracy, Hard (Rotated) Accuracy, and Epsilon evolution over epochs.
* **`final_best_acc_chart.pdf`**: A summary bar chart comparing the best accuracies across all methods.
* **`global_evolution_comparison.png`**: Combined curves comparing all methods side-by-side.

### 2. t-SNE Feature Embeddings

The code generates t-SNE scatter plots (`{ExpName}_tsne.png`) to visualize how the model separates data:

* **Blue**: Normal Samples (Clean)
* **Green**: Hard Geometric Samples (Valid, Rotated 90°)
* **Red**: Label Noise (Invalid, should be isolated)

## Configuration

Global hyperparameters can be tweaked at the top of `cifar.py`:

```python
brho = 1.0           # Robustness radius constraint
t_r = 0.12           # Target ratio for epsilon scheduler
min_eps_val = 0.001  # Minimum epsilon floor

```

Batch size and Epoch count are defined in the `__main__` block:

```python
BATCH_SIZE = 128
EPOCHS = 80

```

## License

This project is provided for educational and research purposes.

```

```