import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import itertools
from tqdm import tqdm


# ==========================================
# 0. 数据集 (保持不变)
# ==========================================
class KillerDatasetHighDim:
    def __init__(self, n_samples=500, dim=5):
        self.n_samples = n_samples
        self.dim = dim
        self.b_true = 1.0

        raw_w = np.random.randn(dim)
        self.w_true = (raw_w / np.linalg.norm(raw_w)) * 2.0

        n_center = int(0.7 * n_samples)
        self.x_norm = np.random.uniform(-2, 2, (n_center, dim))
        self.y_norm = self.x_norm @ self.w_true + self.b_true + np.random.normal(0, 1.0, n_center)

        n_lev = int(0.1 * n_samples)
        direction = self.w_true / np.linalg.norm(self.w_true)
        center_lev = direction * 10.0
        self.x_lev = center_lev + np.random.normal(0, 0.2, (n_lev, dim))
        self.y_lev = self.x_lev @ self.w_true + self.b_true + np.random.normal(0, 0.1, n_lev)

        n_out = int(0.2 * n_samples)
        self.x_out = np.random.normal(0, 1.0, (n_out, dim))
        w_outlier = -1.0 * self.w_true
        self.y_out = self.x_out @ w_outlier + self.b_true + 10.0 + np.random.normal(0, 0.5, n_out)

        self.X = np.concatenate([self.x_norm, self.x_lev, self.x_out]).astype(np.float32)
        self.Y = np.concatenate([self.y_norm, self.y_lev, self.y_out]).astype(np.float32).reshape(-1, 1)

        self.X_t = torch.from_numpy(self.X)
        self.Y_t = torch.from_numpy(self.Y)

    def get_data(self):
        return self.X_t, self.Y_t

    def get_clean_data(self):
        X_clean = np.concatenate([self.x_norm, self.x_lev])
        Y_clean = np.concatenate([self.y_norm, self.y_lev])
        return X_clean, Y_clean


# ==========================================
# 1. OR-WDRO 组件 (修改版)
# ==========================================
class OR_WDRO_LinearRegression(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.linear = nn.Linear(dim, 1)
        with torch.no_grad():
            self.linear.weight.fill_(0.0)
            self.linear.bias.fill_(0.0)
        self.pre_lambda1 = nn.Parameter(torch.tensor(0.55))
        self.pre_lambda2 = nn.Parameter(torch.tensor(0.55))
        self.alpha = nn.Parameter(torch.tensor(0.0))


def get_robust_estimates_high_dim(X, Y, epsilon, var_mult):
    """
    修改点：增加 var_mult 参数，替代硬编码的 4.0
    """
    data = torch.cat([X, Y], dim=1)
    N, d_total = data.shape

    trim_ratio = 2 * epsilon
    n_trim = int(N * trim_ratio)

    sorted_data, _ = torch.sort(data, dim=0)
    if n_trim > 0 and n_trim < N // 2:
        valid_data = sorted_data[n_trim:-n_trim, :]
    else:
        valid_data = sorted_data

    z0 = torch.mean(valid_data, dim=0)
    diff = valid_data - z0
    second_moment = torch.mean(torch.sum(diff ** 2, dim=1))

    # === 关键修改 ===
    # sigma_sq = second_moment * 4.0  <-- 原代码
    sigma_sq = second_moment * var_mult  # <-- 现在通过参数传入

    sigma_sq = torch.max(sigma_sq, torch.tensor(1.0))

    return z0[:-1].unsqueeze(0), z0[-1].unsqueeze(0), sigma_sq


def train_or_wdro_tuned(X, Y, dim, epsilon, var_mult, epochs=1000, lr=0.01):
    # 固定 rho = 1.0
    rho = 1.0

    # 传入 var_mult
    z0_x, z0_y, sigma_sq = get_robust_estimates_high_dim(X, Y, epsilon, var_mult)
    z0_x, z0_y, sigma_sq = z0_x.detach(), z0_y.detach(), sigma_sq.detach()

    model = OR_WDRO_LinearRegression(dim)
    optimizer = optim.SGD(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=int(epochs / 2), gamma=0.1)

    for epoch in range(epochs):
        lam1 = torch.nn.functional.softplus(model.pre_lambda1)
        lam2 = torch.nn.functional.softplus(model.pre_lambda2)
        gamma = lam1 + lam2 + 1e-6

        optimizer.zero_grad()
        bar_X = (lam1 * z0_x + lam2 * X) / gamma
        bar_Y = (lam1 * z0_y + lam2 * Y) / gamma

        theta = model.linear.weight
        bias = model.linear.bias

        linear_val = (torch.matmul(bar_X, theta.T) + bias) - bar_Y

        w_norm_sq = torch.sum(theta ** 2) + 1.0
        conjugate_term = w_norm_sq / (4.0 * gamma)

        dist_sq = torch.sum((X - z0_x) ** 2, dim=1, keepdim=True) + (Y - z0_y) ** 2

        correction_term = (lam1 * lam2 / gamma) * dist_sq
        sup_val = torch.abs(linear_val) + conjugate_term - correction_term
        relu_part = torch.relu(sup_val - model.alpha)

        cvar_loss = model.alpha + torch.mean(relu_part) / (1.0 - epsilon + 1e-6)

        total_loss = lam1 * sigma_sq + lam2 * (rho ** 2) + cvar_loss

        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

    return model.linear.weight.detach().cpu().numpy(), model.linear.bias.detach().cpu().item()


def calc_excess_risk(w_hat, b_hat, w_true, b_true, X_clean, Y_clean):
    w_hat = np.array(w_hat).reshape(-1)
    pred_hat = X_clean @ w_hat + b_hat
    risk_hat = np.mean(np.abs(pred_hat - Y_clean.reshape(-1)))
    pred_star = X_clean @ w_true + b_true
    risk_star = np.mean(np.abs(pred_star - Y_clean.reshape(-1)))
    return risk_hat - risk_star


# ==========================================
# 2. 网格搜索 (Grid Search)
# ==========================================
if __name__ == '__main__':
    seed = 42
    np.random.seed(seed)
    torch.manual_seed(seed)
    dim = 5
    n_samples = 800

    print(f"🔄 Generating Dataset (N={n_samples}, Dim={dim})...")
    dataset = KillerDatasetHighDim(n_samples=n_samples, dim=dim)
    X_t, Y_t = dataset.get_data()
    X_clean, Y_clean = dataset.get_clean_data()

    # === 参数搜索空间 ===
    # epsilon: 真实污染率是 0.2，所以重点搜索 0.2 附近
    # var_mult: 默认是 4.0，我们尝试更紧(1.0)或更松(8.0)的界
    param_grid = {
        'epsilon': [0.0, 0.05, 0.1, 0.2, 0.4],
        'var_mult': [1.0, 2.0, 4.0, 8.0]
    }

    keys, values = zip(*param_grid.items())
    combinations = [dict(zip(keys, v)) for v in itertools.product(*values)]

    print(f"🔍 Starting Grid Search with {len(combinations)} combinations...")
    print(f"ℹ️  Fixed Rho = 1.0")
    print("-" * 60)
    print(f"{'epsilon':<10} | {'var_mult':<10} | {'Excess Risk':<15}")
    print("-" * 60)

    best_risk = float('inf')
    best_params = {}

    for params in tqdm(combinations, leave=False):
        try:
            w_hat, b_hat = train_or_wdro_tuned(
                X_t, Y_t, dim,
                epsilon=params['epsilon'],
                var_mult=params['var_mult'],  # 传入缩放因子
                epochs=1500
            )

            risk = calc_excess_risk(w_hat, b_hat, dataset.w_true, dataset.b_true, X_clean, Y_clean)

            # print(f"{params['epsilon']:<10} | {params['var_mult']:<10} | {risk:.4f}")

            if risk < best_risk:
                best_risk = risk
                best_params = params

        except Exception as e:
            continue

    print("-" * 60)
    print(f"✅ Best Parameters Found:")
    print(f"   epsilon  = {best_params['epsilon']}")
    print(f"   var_mult = {best_params['var_mult']}")