import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import itertools
import matplotlib.pyplot as plt

plt.rcParams["font.family"] = "serif"
plt.rcParams["font.serif"] = ["Times", "DejaVu Serif"]


# ==========================================
# 0. High-Dimensional Dataset Definition
# =========================================
class KillerDatasetHighDim:
    def __init__(self, n_samples=500, dim=10):
        self.n_samples = n_samples
        self.dim = dim
        self.b_true = 1.0

        # === Define High-Dimensional True Weights ===
        # w_true direction is random, but magnitude is fixed at 2.0
        raw_w = np.random.randn(dim)
        self.w_true = (raw_w / np.linalg.norm(raw_w)) * 2.0

        # === 1. Basic Background Noise (Center Normal) 70% ===
        n_center = int(0.7 * n_samples)
        self.x_norm = np.random.uniform(-2, 2, (n_center, dim))
        noise_norm = np.random.normal(0, 1.0, n_center)
        self.y_norm = self.x_norm @ self.w_true + self.b_true + noise_norm

        # === 2. Distant High Leverage Points (Valid High Leverage) 10% ===
        n_lev = int(0.1 * n_samples)
        direction = self.w_true / np.linalg.norm(self.w_true)
        center_lev = direction * 10.0
        self.x_lev = center_lev + np.random.normal(0, 0.2, (n_lev, dim))
        noise_lev = np.random.normal(0, 0.1, n_lev)  # Low Noise
        self.y_lev = self.x_lev @ self.w_true + self.b_true + noise_lev

        # === 3. Proximal Outliers (Conflated Outliers) 20% ===
        n_out = int(0.2 * n_samples)
        self.x_out = np.random.normal(0, 1.0, (n_out, dim))
        w_outlier = -1.0 * self.w_true  # Adversarial Weight
        noise_out = np.random.normal(0, 0.5, n_out)
        self.y_out = self.x_out @ w_outlier + self.b_true + 10.0 + noise_out  # Huge Bias

        # === Data Integration ===
        self.X = np.concatenate([self.x_norm, self.x_lev, self.x_out]).astype(np.float32)
        self.Y = np.concatenate([self.y_norm, self.y_lev, self.y_out]).astype(np.float32).reshape(-1, 1)

        self.X_t = torch.from_numpy(self.X)
        self.Y_t = torch.from_numpy(self.Y)

    def get_data(self):
        return self.X_t, self.Y_t


# ==========================================
# High-Dimensional Excess Risk Calculation
# ==========================================
def calc_excess_risk_high_dim(w_hat, b_hat, w_true, b_true, dim, n_test=20000):
    """
    Calculate L1 Excess Risk:
    Risk = E_clean [ |y_hat - y| ] - E_clean [ |y_true - y| ]
    """
    # === 1. Unify Parameter Format ===
    if isinstance(w_hat, torch.Tensor):
        w_hat = w_hat.detach().cpu().numpy()
    w_hat = np.array(w_hat).reshape(-1)  # Flatten to (dim,)

    if isinstance(b_hat, torch.Tensor):
        b_hat = b_hat.detach().cpu().item()

    # === 2. Generate Test Data (Clean Population) ===
    n_norm = int(n_test * (7 / 9))
    n_lev = n_test - n_norm

    x_test_norm = np.random.uniform(-2, 2, (n_norm, dim))
    noise_norm = np.random.normal(0, 1.0, n_norm)
    y_test_norm = x_test_norm @ w_true + b_true + noise_norm

    direction = w_true / np.linalg.norm(w_true)
    center_lev = direction * 10.0
    x_test_lev = center_lev + np.random.normal(0, 0.2, (n_lev, dim))
    noise_lev = np.random.normal(0, 0.1, n_lev)
    y_test_lev = x_test_lev @ w_true + b_true + noise_lev

    X_test = np.concatenate([x_test_norm, x_test_lev])
    Y_test = np.concatenate([y_test_norm, y_test_lev])

    # === 3. Calculate Excess Risk ===
    y_pred = X_test @ w_hat + b_hat
    y_star = X_test @ w_true + b_true

    loss_model = np.mean(np.abs(y_pred - Y_test))
    loss_oracle = np.mean(np.abs(y_star - Y_test))

    excess_risk = loss_model - loss_oracle
    return excess_risk


# ==========================================
# 1. Nested DRO (Updated: Dynamic Ratio-based Epsilon)
# ==========================================
def train_nested_dro(X, Y, dim, epochs=800, lr=0.01, target_ratio=0.15):
    """
    Nested DRO with Dynamic Epsilon based on Target Penalty Ratio.
    Goal: maintain Penalty / Mean ≈ target_ratio
    Formula: eps_t = (target_ratio * mean_psi)^2 / (2 * var_psi)
    """
    model = nn.Linear(dim, 1)
    with torch.no_grad():
        model.weight.fill_(0.0)
        model.bias.fill_(0.0)

    log_lambda = nn.Parameter(torch.tensor(-1.0))
    optimizer = optim.SGD(model.parameters(), lr=lr)

    # Fixed Parameters
    rho = 1.0
    K_inner = 5
    eta_pgd = 0.1

    # Note: Epsilon is no longer fixed here, but dynamically calculated based on target_ratio

    history = {'loss': [], 'w': [], 'b': [], 'eps': []}
    loss_fn = nn.HuberLoss(delta=1.0, reduction='none')

    for epoch in range(epochs):
        lam = torch.exp(log_lambda)

        # --- Phase 1: Inner Maximization (PGD) ---
        X_adv = X.clone().detach().requires_grad_(True)
        for _ in range(K_inner):
            pred = model(X_adv)
            loss = loss_fn(pred, Y).flatten()
            cost = ((X_adv - X) ** 2).sum(dim=1)
            obj = (loss - lam * cost).sum()
            grad_x = torch.autograd.grad(obj, X_adv)[0]
            with torch.no_grad():
                X_adv += eta_pgd * grad_x
        X_adv = X_adv.detach()

        # --- Phase 2: Outer Minimization & Dynamic Epsilon ---
        optimizer.zero_grad()
        pred_adv = model(X_adv)
        psi = loss_fn(pred_adv, Y).flatten()

        # Calculate Statistics
        mean_psi = psi.mean()
        var_psi = psi.var(unbiased=True)

        # [Core Modification] Dynamic Epsilon Calculation
        # Formula Derivation:
        # r = sqrt(2 * eps * var) / mean = target_ratio
        # => 2 * eps * var = (target_ratio * mean)^2
        # => eps = (target_ratio * mean)^2 / (2 * var)

        # Use detach() to ensure epsilon is treated as a constant hyperparameter,
        # excluded from gradient backpropagation
        numerator = (target_ratio * mean_psi.detach()) ** 2
        denominator = 2 * var_psi.detach() + 1e-8
        epsilon_t = numerator / denominator

        # Limit epsilon range for numerical stability
        epsilon_t = torch.clamp(epsilon_t, min=1e-6, max=5.0)

        # Calculate Total Loss
        std_term = torch.sqrt(2 * epsilon_t * var_psi + 1e-8)
        total_loss = lam * rho + mean_psi - std_term

        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()

        with torch.no_grad():
            if log_lambda.grad is not None:
                log_lambda.data -= 0.01 * log_lambda.grad
                log_lambda.grad.zero_()

        history['loss'].append(total_loss.item())
        history['eps'].append(epsilon_t.item())

    history['w'] = model.weight.detach().cpu().numpy()
    history['b'] = model.bias.detach().cpu().item()
    return history


# ==========================================
# 2. OR-WDRO
# ==========================================
class OR_WDRO_LinearRegression(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.linear = nn.Linear(dim, 1)
        with torch.no_grad():
            self.linear.weight.fill_(0.0)
            self.linear.bias.fill_(0.0)
        self.pre_lambda1 = nn.Parameter(torch.tensor(0.55))
        self.pre_lambda2 = nn.Parameter(torch.tensor(0.55))
        self.alpha = nn.Parameter(torch.tensor(0.0))

    def forward(self, x):
        return self.linear(x)


def get_robust_estimates_high_dim(X, Y, epsilon):
    data = torch.cat([X, Y], dim=1)
    N, d_total = data.shape

    trim_ratio = 2 * epsilon
    n_trim = int(N * trim_ratio)

    sorted_data, _ = torch.sort(data, dim=0)
    if n_trim > 0 and n_trim < N // 2:
        valid_data = sorted_data[n_trim:-n_trim, :]
    else:
        valid_data = sorted_data

    z0 = torch.mean(valid_data, dim=0)
    diff = valid_data - z0
    second_moment = torch.mean(torch.sum(diff ** 2, dim=1))
    sigma_sq = second_moment * 4.0
    sigma_sq = torch.max(sigma_sq, torch.tensor(1.0))

    return z0[:-1].unsqueeze(0), z0[-1].unsqueeze(0), sigma_sq


def train_or_wdro(X, Y, dim, epochs=1000, lr=0.01, epsilon=0.4, rho=1.0):
    z0_x, z0_y, sigma_sq = get_robust_estimates_high_dim(X, Y, epsilon)
    z0_x = z0_x.detach()
    z0_y = z0_y.detach()
    sigma_sq = sigma_sq.detach()

    model = OR_WDRO_LinearRegression(dim)
    optimizer = optim.SGD(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=int(epochs / 2), gamma=0.1)

    history = {'loss': [], 'w': [], 'b': []}

    for epoch in range(epochs):
        lam1 = torch.nn.functional.softplus(model.pre_lambda1)
        lam2 = torch.nn.functional.softplus(model.pre_lambda2)
        gamma = lam1 + lam2 + 1e-6

        optimizer.zero_grad()
        bar_X = (lam1 * z0_x + lam2 * X) / gamma
        bar_Y = (lam1 * z0_y + lam2 * Y) / gamma

        theta = model.linear.weight
        bias = model.linear.bias

        linear_val = (torch.matmul(bar_X, theta.T) + bias) - bar_Y
        w_norm_sq = torch.sum(theta ** 2) + 1.0
        conjugate_term = w_norm_sq / (4.0 * gamma)

        dist_sq = torch.sum((X - z0_x) ** 2, dim=1, keepdim=True) + (Y - z0_y) ** 2
        correction_term = (lam1 * lam2 / gamma) * dist_sq

        sup_val = torch.abs(linear_val) + conjugate_term - correction_term
        relu_part = torch.relu(sup_val - model.alpha)
        cvar_loss = model.alpha + torch.mean(relu_part) / (1.0 - epsilon)

        total_loss = lam1 * sigma_sq + lam2 * (rho ** 2) + cvar_loss

        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

        history['loss'].append(total_loss.item())

    history['w'] = model.linear.weight.detach().cpu().numpy()
    history['b'] = model.linear.bias.detach().cpu().item()
    return history


# ==========================================
# 3. UOT-DRO
# ==========================================
def train_uot_dro(X, Y, dim, epochs=300, lr=0.01):
    if Y.dim() == 1:
        Y = Y.unsqueeze(1)
    n_samples, n_features = X.shape
    model = nn.Linear(n_features, 1)

    lam = 20.0
    beta = 10.0
    lam2 = 20.0
    optimizer = optim.SGD(model.parameters(), lr=lr)

    with torch.no_grad():
        center_x = torch.median(X, dim=0)[0]
        center_y = torch.median(Y, dim=0)[0]

        diff_x = X - center_x
        diff_y = Y - center_y
        diff_concat = torch.cat([diff_x, diff_y], dim=1)
        norm_prior = torch.norm(diff_concat, p=2, dim=1, keepdim=True)

    history = {'loss': [], 'w': [], 'b': []}

    for epoch in range(epochs):
        optimizer.zero_grad()
        pred = model(X)
        abs_error = torch.abs(pred - Y)

        exponent_numerator = abs_error - lam2 * norm_prior
        exponent_term = exponent_numerator / (lam * beta)
        max_val = exponent_term.max().detach()
        loss = lam * beta * (torch.log(torch.sum(torch.exp(exponent_term - max_val))) + max_val)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()
        history['loss'].append(loss.item())

    history['w'] = model.weight.detach().cpu().numpy()
    history['b'] = model.bias.detach().cpu().item()
    return history


# ==========================================
# 4. Standard DRO
# ==========================================
def train_standard_dro(X, Y, dim, epochs=800, lr=0.01):
    model = nn.Linear(dim, 1)
    with torch.no_grad():
        model.weight.fill_(0.0)
        model.bias.fill_(0.0)
    log_lambda = nn.Parameter(torch.tensor(-1.0))
    optimizer = optim.SGD(model.parameters(), lr=lr)
    rho = 1.0
    K_inner = 5
    eta_pgd = 0.1
    loss_fn = nn.HuberLoss(delta=1.0, reduction='none')
    history = {'loss': [], 'w': [], 'b': []}

    for epoch in range(epochs):
        lam = torch.exp(log_lambda)
        X_adv = X.clone().detach().requires_grad_(True)
        for _ in range(K_inner):
            pred = model(X_adv)
            loss = loss_fn(pred, Y).flatten()
            cost = ((X_adv - X) ** 2).sum(dim=1)
            obj = (loss - lam * cost).sum()
            grad_x = torch.autograd.grad(obj, X_adv)[0]
            with torch.no_grad():
                X_adv += eta_pgd * grad_x
        X_adv = X_adv.detach()
        optimizer.zero_grad()
        pred_adv = model(X_adv)
        psi = loss_fn(pred_adv, Y).flatten()
        mean_psi = psi.mean()
        total_loss = lam * rho + mean_psi
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()
        with torch.no_grad():
            if log_lambda.grad is not None:
                log_lambda.data -= 0.01 * log_lambda.grad
                log_lambda.grad.zero_()
        history['loss'].append(total_loss.item())

    history['w'] = model.weight.detach().cpu().numpy()
    history['b'] = model.bias.detach().cpu().item()
    return history


# ==========================================
# Main Execution: Multi-Dim Experiment
# ==========================================

n_epoch_per_run = 1500
seeds = range(33, 43)
dim_list = [10, 20, 30, 40, 50]
algo_names = ['Nested DRO', 'OR-WDRO', 'UOT-DRO', 'Standard DRO']
fixed_n_samples = 500

results = {algo: {'risk': {d: [] for d in dim_list}} for algo in algo_names}

print(f"Starting experiments for dimensions: {dim_list} with N={fixed_n_samples}")

for dim in dim_list:
    print(f"\n=== Processing Dimension = {dim} ===")

    for seed in seeds:
        print(f".", end="", flush=True)

        # 1. Set Seed & Generate High-Dimensional Data
        np.random.seed(seed)
        torch.manual_seed(seed)

        dataset = KillerDatasetHighDim(n_samples=fixed_n_samples, dim=dim)
        X_t, Y_t = dataset.get_data()

        # 2. Run Algorithms

        # --- Nested DRO (Using Dynamic Ratio Epsilon, target_ratio=0.4) ---
        h1 = train_nested_dro(X_t, Y_t, dim, epochs=n_epoch_per_run, target_ratio=0.4)
        risk1 = calc_excess_risk_high_dim(h1['w'], h1['b'], dataset.w_true, dataset.b_true, dim)
        results['Nested DRO']['risk'][dim].append(risk1)

        # --- OR-WDRO ---
        h2 = train_or_wdro(X_t, Y_t, dim, epochs=n_epoch_per_run)
        risk2 = calc_excess_risk_high_dim(h2['w'], h2['b'], dataset.w_true, dataset.b_true, dim)
        results['OR-WDRO']['risk'][dim].append(risk2)

        # --- UOT-DRO ---
        h3 = train_uot_dro(X_t, Y_t, dim, epochs=n_epoch_per_run)
        risk3 = calc_excess_risk_high_dim(h3['w'], h3['b'], dataset.w_true, dataset.b_true, dim)
        results['UOT-DRO']['risk'][dim].append(risk3)

        # --- Standard DRO ---
        h4 = train_standard_dro(X_t, Y_t, dim, epochs=n_epoch_per_run)
        risk4 = calc_excess_risk_high_dim(h4['w'], h4['b'], dataset.w_true, dataset.b_true, dim)
        results['Standard DRO']['risk'][dim].append(risk4)

    print(" Done.")

print("All experiments finished. Plotting...")


# ==========================================
# Plotting
# ==========================================
def plot_metric_vs_samples(ax, metric_key, results_dict, y_label, log_scale=False):
    colors = {
        'Nested DRO': 'blue',
        'OR-WDRO': 'orange',
        'UOT-DRO': 'green',
        'Standard DRO': 'tab:blue'
    }

    offsets = {'Nested DRO': -0.6, 'OR-WDRO': -0.2, 'UOT-DRO': 0.2, 'Standard DRO': 0.6}

    for algo in algo_names:
        means = []
        stds = []

        for d in dim_list:
            data_points = results_dict[algo][metric_key][d]
            means.append(np.mean(data_points))
            stds.append(np.std(data_points) / np.sqrt(len(seeds)))

        x_vals = np.array(dim_list) + offsets.get(algo, 0)

        if algo == 'Standard DRO':
            ax.errorbar(x_vals, means, yerr=stds, label=algo,
                        marker='o', capsize=5, linestyle='-', color=colors[algo], alpha=0.8, linewidth=3.2,)
        else:
            ax.errorbar(x_vals, means, yerr=stds, label=algo,
                        marker='o', capsize=5, linestyle='-', color=colors[algo], alpha=0.8, linewidth=3.2,)

    ax.set_xlabel("Dimension", fontsize=24)
    ax.set_ylabel(y_label, fontsize=24)
    ax.tick_params(axis='both', labelsize=14)
    ax.grid(True, alpha=0.3)
    ax.set_xticks(dim_list)
    for spine in ax.spines.values():
        spine.set_linewidth(2.0)

    if log_scale:
        ax.set_yscale('log')

    ax.legend(fontsize=20)


# --- Plot Excess Risk Separately ---
fig, ax = plt.subplots(figsize=(10, 7))

plot_metric_vs_samples(ax, 'risk', results,
                       y_label="Excess Risk ",
                       log_scale=False)

plt.tight_layout()
plt.savefig('excess_risk_dimension_dynamic_ratio.pdf', format='pdf', bbox_inches='tight')
plt.show()