import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

plt.rcParams["font.family"] = "serif"
plt.rcParams["font.serif"] = ["Times", "DejaVu Serif"]


# ==========================================
# 0. 数据集定义 (5维版)
# =========================================
class KillerDatasetHighDim:
    def __init__(self, n_samples=300, dim=10):
        self.n_samples = n_samples
        self.dim = dim
        self.b_true = 1.0

        # === 定义高维真实权重 ===
        # w_true 方向随机，但模长固定为 2.0 (保持信号强度一致)
        raw_w = np.random.randn(dim)
        self.w_true = (raw_w / np.linalg.norm(raw_w)) * 2.0

        # === 1. 基础背景噪声 (Center Normal) ===
        # 代表大部分正常数据
        n_center = int(0.7 * n_samples)
        self.x_norm = np.random.uniform(-2, 2, (n_center, dim))
        self.y_norm = self.x_norm @ self.w_true + self.b_true + np.random.normal(0, 1.0, n_center)

        # === 2. 远端高杠杆点 (Informative High Leverage) ===
        # 几何特征：沿 w 方向延伸很远 (距离~10)
        n_lev = int(0.1 * n_samples)
        direction = self.w_true / np.linalg.norm(self.w_true)
        center_lev = direction * 10.0
        # 在中心点附近添加小高斯噪声
        self.x_lev = center_lev + np.random.normal(0, 0.2, (n_lev, dim))
        self.y_lev = self.x_lev @ self.w_true + self.b_true + np.random.normal(0, 0.1, n_lev)

        # === 3. 近端离群点 (Proximal/Low-Leverage Outliers) ===
        # 几何特征：在中心附近
        # 统计特征：反转的斜率 (-w) 和巨大的 Bias (+10.0)
        n_out = int(0.2 * n_samples)
        self.x_out = np.random.normal(0, 1.0, (n_out, dim))

        # 保持扰动逻辑不变：权重反转
        w_outlier = -1.0 * self.w_true
        self.y_out = self.x_out @ w_outlier + self.b_true + 10.0 + np.random.normal(0, 0.5, n_out)

        # === 训练数据整合 ===
        self.X = np.concatenate([self.x_norm, self.x_lev, self.x_out]).astype(np.float32)
        # Y 保持 (N, 1)
        self.Y = np.concatenate([self.y_norm, self.y_lev, self.y_out]).astype(np.float32).reshape(-1, 1)

        self.X_t = torch.from_numpy(self.X)
        self.Y_t = torch.from_numpy(self.Y)

    def get_data(self):
        return self.X_t, self.Y_t

    def get_clean_data(self):
        # 仅合并 Normal 和 Leverage 数据
        X_clean = np.concatenate([self.x_norm, self.x_lev])
        Y_clean = np.concatenate([self.y_norm, self.y_lev])
        return X_clean, Y_clean


# ==========================================
# Excess Risk 计算函数 (支持高维)
# ==========================================
def calc_excess_risk(w_hat, b_hat, w_true, b_true, X_clean, Y_clean):
    """
    计算 Excess Risk (L1/MAE 版)
    公式: Mean(|y_hat - y|) - Mean(|y_star - y|)
    """
    # 确保 w_hat 是 numpy array
    if isinstance(w_hat, torch.Tensor):
        w_hat = w_hat.detach().cpu().numpy()

    # 将 w_hat reshape 为 (dim,)
    w_hat = np.array(w_hat).reshape(-1)

    # 1. 当前模型的风险
    # X_clean: (N, dim), w_hat: (dim,)
    pred_hat = X_clean @ w_hat + b_hat
    # Y_clean: (N,) or (N, 1)
    risk_hat = np.mean(np.abs(pred_hat - Y_clean.reshape(-1)))

    # 2. 真实模型的风险 (Oracle)
    pred_star = X_clean @ w_true + b_true
    risk_star = np.mean(np.abs(pred_star - Y_clean.reshape(-1)))

    return risk_hat - risk_star


# ==========================================
# 1. Nested DRO
# ==========================================
def train_nested_dro(X, Y, dim, epochs=800, lr=0.01):
    model = nn.Linear(dim, 1)
    with torch.no_grad():
        model.weight.fill_(0.0)
        model.bias.fill_(0.0)

    log_lambda = nn.Parameter(torch.tensor(-1.0))
    optimizer = optim.SGD(model.parameters(), lr=lr)

    # 针对高维稍作调整的 epsilon，或者保持 0.23
    epsilon = 0.1
    rho = 1.0
    K_inner = 5
    eta_pgd = 0.1

    history = {'loss': [], 'w': [], 'b': []}
    loss_fn = nn.HuberLoss(delta=1.0, reduction='none')

    for epoch in range(epochs):
        lam = torch.exp(log_lambda)

        # === Phase 1: Inner Maximization (PGD) ===
        X_adv = X.clone().detach().requires_grad_(True)
        for _ in range(K_inner):
            pred = model(X_adv)
            loss = loss_fn(pred, Y).flatten()
            # 高维欧氏距离平方: sum(dim=1)
            cost = ((X_adv - X) ** 2).sum(dim=1)
            obj = (loss - lam * cost).sum()
            grad_x = torch.autograd.grad(obj, X_adv)[0]
            with torch.no_grad():
                X_adv += eta_pgd * grad_x
        X_adv = X_adv.detach()

        # === Phase 2: Outer Minimization ===
        optimizer.zero_grad()
        pred_adv = model(X_adv)
        psi = loss_fn(pred_adv, Y).flatten()
        mean_psi = psi.mean()
        var_psi = psi.var(unbiased=True)
        std_term = torch.sqrt(2 * epsilon * var_psi + 1e-8)
        total_loss = lam * rho + mean_psi - std_term

        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()

        with torch.no_grad():
            if log_lambda.grad is not None:
                log_lambda.data -= 0.01 * log_lambda.grad
                log_lambda.grad.zero_()

        history['loss'].append(total_loss.item())

    history['w'] = model.weight.detach().cpu().numpy()
    history['b'] = model.bias.detach().cpu().item()
    return history


# ==========================================
# 2. OR-WDRO
# ==========================================
class OR_WDRO_LinearRegression(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.linear = nn.Linear(dim, 1)
        with torch.no_grad():
            self.linear.weight.fill_(0.0)
            self.linear.bias.fill_(0.0)
        self.pre_lambda1 = nn.Parameter(torch.tensor(0.55))
        self.pre_lambda2 = nn.Parameter(torch.tensor(0.55))
        self.alpha = nn.Parameter(torch.tensor(0.0))

    def forward(self, x):
        return self.linear(x)


def get_robust_estimates_high_dim(X, Y, epsilon):
    # X: (N, d), Y: (N, 1)
    data = torch.cat([X, Y], dim=1)
    N, d_total = data.shape  # d_total = dim + 1

    trim_ratio = 2 * epsilon
    n_trim = int(N * trim_ratio)

    # 简单的 Trim 估计
    sorted_data, _ = torch.sort(data, dim=0)
    if n_trim > 0 and n_trim < N // 2:
        valid_data = sorted_data[n_trim:-n_trim, :]
    else:
        valid_data = sorted_data

    z0 = torch.mean(valid_data, dim=0)
    diff = valid_data - z0
    # 对所有维度求误差平方和
    second_moment = torch.mean(torch.sum(diff ** 2, dim=1))
    sigma_sq = second_moment * 8.0
    sigma_sq = torch.max(sigma_sq, torch.tensor(1.0))

    # z0_x: 前 dim 维, z0_y: 最后一维
    return z0[:-1].unsqueeze(0), z0[-1].unsqueeze(0), sigma_sq


def train_or_wdro(X, Y, dim, epochs=1000, lr=0.01, epsilon=0.4, rho=1.0):
    z0_x, z0_y, sigma_sq = get_robust_estimates_high_dim(X, Y, epsilon)
    z0_x = z0_x.detach()
    z0_y = z0_y.detach()
    sigma_sq = sigma_sq.detach()

    model = OR_WDRO_LinearRegression(dim)
    optimizer = optim.SGD(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=int(epochs / 2), gamma=0.1)

    history = {'loss': [], 'w': [], 'b': []}

    for epoch in range(epochs):
        lam1 = torch.nn.functional.softplus(model.pre_lambda1)
        lam2 = torch.nn.functional.softplus(model.pre_lambda2)
        gamma = lam1 + lam2 + 1e-6

        optimizer.zero_grad()
        # 广播机制处理 (1, dim) 和 (N, dim)
        bar_X = (lam1 * z0_x + lam2 * X) / gamma
        bar_Y = (lam1 * z0_y + lam2 * Y) / gamma

        theta = model.linear.weight  # (1, dim)
        bias = model.linear.bias

        # Linear val: (N, dim) * (dim, 1) -> (N, 1)
        linear_val = (torch.matmul(bar_X, theta.T) + bias) - bar_Y

        w_norm_sq = torch.sum(theta ** 2) + 1.0
        conjugate_term = w_norm_sq / (4.0 * gamma)

        # dist_sq: (N, 1)
        dist_sq = torch.sum((X - z0_x) ** 2, dim=1, keepdim=True) + (Y - z0_y) ** 2

        correction_term = (lam1 * lam2 / gamma) * dist_sq
        sup_val = torch.abs(linear_val) + conjugate_term - correction_term
        relu_part = torch.relu(sup_val - model.alpha)
        cvar_loss = model.alpha + torch.mean(relu_part) / (1.0 - epsilon)

        total_loss = lam1 * sigma_sq + lam2 * (rho ** 2) + cvar_loss

        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

        history['loss'].append(total_loss.item())

    history['w'] = model.linear.weight.detach().cpu().numpy()
    history['b'] = model.linear.bias.detach().cpu().item()
    return history


# ==========================================
# 3. UOT-DRO (Updated for Dim=5)
# ==========================================
def train_uot_dro(X, Y, dim, epochs=300, lr=0.01):
    if Y.dim() == 1:
        Y = Y.unsqueeze(1)
    n_samples, n_features = X.shape
    model = nn.Linear(n_features, 1)

    # 网格搜索来的最优参数
    lam = 10.0
    beta = 10.0
    lam2 = 10.0
    optimizer = optim.SGD(model.parameters(), lr=lr)

    with torch.no_grad():
        # 高维中位数
        center_x = torch.median(X, dim=0)[0]
        center_y = torch.median(Y, dim=0)[0]

        diff_x = X - center_x
        diff_y = Y - center_y
        diff_concat = torch.cat([diff_x, diff_y], dim=1)
        # 高维 Norm
        norm_prior = torch.norm(diff_concat, p=2, dim=1, keepdim=True)

    history = {'loss': [], 'w': [], 'b': []}

    for epoch in range(epochs):
        optimizer.zero_grad()
        pred = model(X)
        abs_error = torch.abs(pred - Y)

        exponent_numerator = abs_error - lam2 * norm_prior
        exponent_term = exponent_numerator / (lam * beta)
        max_val = exponent_term.max().detach()
        loss = lam * beta * (torch.log(torch.sum(torch.exp(exponent_term - max_val))) + max_val)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()
        history['loss'].append(loss.item())

    history['w'] = model.weight.detach().cpu().numpy()
    history['b'] = model.bias.detach().cpu().item()
    return history


# ==========================================
# 4. Standard DRO
# ==========================================
def train_standard_dro(X, Y, dim, epochs=800, lr=0.01):
    model = nn.Linear(dim, 1)
    with torch.no_grad():
        model.weight.fill_(0.0)
        model.bias.fill_(0.0)
    log_lambda = nn.Parameter(torch.tensor(-1.0))
    optimizer = optim.SGD(model.parameters(), lr=lr)
    rho = 1.0
    K_inner = 5
    eta_pgd = 0.1
    loss_fn = nn.HuberLoss(delta=1.0, reduction='none')
    history = {'loss': [], 'w': [], 'b': []}

    for epoch in range(epochs):
        lam = torch.exp(log_lambda)
        X_adv = X.clone().detach().requires_grad_(True)
        for _ in range(K_inner):
            pred = model(X_adv)
            loss = loss_fn(pred, Y).flatten()
            cost = ((X_adv - X) ** 2).sum(dim=1)
            obj = (loss - lam * cost).sum()
            grad_x = torch.autograd.grad(obj, X_adv)[0]
            with torch.no_grad():
                X_adv += eta_pgd * grad_x
        X_adv = X_adv.detach()
        optimizer.zero_grad()
        pred_adv = model(X_adv)
        psi = loss_fn(pred_adv, Y).flatten()
        mean_psi = psi.mean()
        total_loss = lam * rho + mean_psi
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()
        with torch.no_grad():
            if log_lambda.grad is not None:
                log_lambda.data -= 0.01 * log_lambda.grad
                log_lambda.grad.zero_()
        history['loss'].append(total_loss.item())

    history['w'] = model.weight.detach().cpu().numpy()
    history['b'] = model.bias.detach().cpu().item()
    return history


# ==========================================
# Main Execution: Multi-Seed & Multi-Sample Experiment
# ==========================================

DIM = 5
n_epoch_per_run = 1500
seeds = range(33, 43)  # 10 个随机种子
n_samples_list = [500, 600, 700, 800, 900, 1000]
algo_names = ['Nested DRO', 'OR-WDRO', 'UOT-DRO', 'Standard DRO']

results = {algo: {'risk': {n: [] for n in n_samples_list}} for algo in algo_names}

print(f"Starting 5D experiments for sample sizes: {n_samples_list}")

for n_idx, n_samples in enumerate(n_samples_list):
    print(f"\n=== Processing n_samples = {n_samples} ===")

    for seed in seeds:
        print(f"[Seed {seed}]", end=" ", flush=True)

        # 1. 设置 Seed & 生成 5D 数据
        np.random.seed(seed)
        torch.manual_seed(seed)
        dataset = KillerDatasetHighDim(n_samples=n_samples, dim=DIM)
        X_t, Y_t = dataset.get_data()
        X_clean, Y_clean = dataset.get_clean_data()

        # 2. 运行算法 (传入 dim 参数)

        # --- Nested DRO ---
        h1 = train_nested_dro(X_t, Y_t, DIM, epochs=n_epoch_per_run)
        risk1 = calc_excess_risk(h1['w'], h1['b'], dataset.w_true, dataset.b_true, X_clean, Y_clean)
        results['Nested DRO']['risk'][n_samples].append(risk1)

        # --- OR-WDRO ---
        h2 = train_or_wdro(X_t, Y_t, DIM, epochs=n_epoch_per_run)
        risk2 = calc_excess_risk(h2['w'], h2['b'], dataset.w_true, dataset.b_true, X_clean, Y_clean)
        results['OR-WDRO']['risk'][n_samples].append(risk2)

        # --- UOT-DRO ---
        h3 = train_uot_dro(X_t, Y_t, DIM, epochs=n_epoch_per_run)
        risk3 = calc_excess_risk(h3['w'], h3['b'], dataset.w_true, dataset.b_true, X_clean, Y_clean)
        results['UOT-DRO']['risk'][n_samples].append(risk3)

        # --- Standard DRO ---
        h4 = train_standard_dro(X_t, Y_t, DIM, epochs=n_epoch_per_run)
        risk4 = calc_excess_risk(h4['w'], h4['b'], dataset.w_true, dataset.b_true, X_clean, Y_clean)
        results['Standard DRO']['risk'][n_samples].append(risk4)

    print("Done.")

print("All experiments finished. Plotting...")


def plot_metric_vs_samples(ax, metric_key, results_dict, y_label, log_scale=False):
    colors = {
        'Nested DRO': 'blue',
        'OR-WDRO': 'orange',
        'UOT-DRO': 'green',
        'Standard DRO': 'tab:blue'
    }

    offsets = {'Nested DRO': -5, 'OR-WDRO': -2, 'UOT-DRO': 2, 'Standard DRO': 5}

    for algo in algo_names:
        means = []
        stds = []

        for n in n_samples_list:
            data_points = results_dict[algo][metric_key][n]
            means.append(np.mean(data_points))
            # Standard Error
            stds.append(np.std(data_points) / np.sqrt(len(seeds)))

        x_vals = np.array(n_samples_list) + offsets.get(algo, 0)

        if algo == 'Standard DRO':
            ax.errorbar(x_vals, means, yerr=stds, label=algo,
                        marker='o', capsize=5, linestyle='-', color=colors[algo], alpha=0.8, linewidth=3.2,)
        else:
            ax.errorbar(x_vals, means, yerr=stds, label=algo,
                        marker='o', capsize=5, linestyle='-', color=colors[algo], alpha=0.8, linewidth=3.2,)

    ax.set_xlabel("Number of Samples", fontsize=24)
    ax.set_ylabel(y_label, fontsize=24)
    ax.tick_params(axis='both', labelsize=14)
    ax.grid(True, alpha=0.3)
    ax.set_xticks(n_samples_list)
    for spine in ax.spines.values():
        spine.set_linewidth(1.8)

    if log_scale:
        ax.set_yscale('log')

    ax.legend(fontsize=20)


# --- 单独绘制 Excess Risk 图 ---
fig, ax = plt.subplots(figsize=(10, 7))

plot_metric_vs_samples(ax, 'risk', results,
                       y_label="Excess Risk",
                       log_scale=False)

plt.tight_layout()
plt.savefig('excess_risk_5D_plot.pdf', format='pdf', bbox_inches='tight')
plt.show()