import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import itertools


# 取10与50做平均
# ==========================================
# 0. 数据集定义 (Killer Dataset)
# ==========================================
class KillerDatasetHighDim:
    def __init__(self, n_samples=500, dim=10):
        self.n_samples = n_samples
        self.dim = dim
        self.b_true = 1.0

        # === 定义高维真实权重 ===
        raw_w = np.random.randn(dim)
        self.w_true = (raw_w / np.linalg.norm(raw_w)) * 2.0

        # === 1. 基础背景噪声 (70%) ===
        n_center = int(0.7 * n_samples)
        self.x_norm = np.random.uniform(-2, 2, (n_center, dim))
        noise_norm = np.random.normal(0, 1.0, n_center)
        self.y_norm = self.x_norm @ self.w_true + self.b_true + noise_norm

        # === 2. 远端高杠杆点 (20%) ===
        n_lev = int(0.1 * n_samples)
        direction = self.w_true / np.linalg.norm(self.w_true)
        center_lev = direction * 10.0
        self.x_lev = center_lev + np.random.normal(0, 0.2, (n_lev, dim))
        noise_lev = np.random.normal(0, 0.1, n_lev)
        self.y_lev = self.x_lev @ self.w_true + self.b_true + noise_lev

        # === 3. 近端离群点 (10%) ===
        n_out = int(0.2 * n_samples)
        self.x_out = np.random.normal(0, 1.0, (n_out, dim))
        w_outlier = -1.0 * self.w_true
        noise_out = np.random.normal(0, 0.5, n_out)
        self.y_out = self.x_out @ w_outlier + self.b_true + 10.0 + noise_out

        # === 数据整合 ===
        self.X = np.concatenate([self.x_norm, self.x_lev, self.x_out]).astype(np.float32)
        self.Y = np.concatenate([self.y_norm, self.y_lev, self.y_out]).astype(np.float32).reshape(-1, 1)

        self.X_t = torch.from_numpy(self.X)
        self.Y_t = torch.from_numpy(self.Y)

    def get_data(self):
        return self.X_t, self.Y_t


# ==========================================
# Excess Risk 计算工具
# ==========================================
def calc_excess_risk_high_dim(w_hat, b_hat, w_true, b_true, dim, n_test=20000):
    if isinstance(w_hat, torch.Tensor):
        w_hat = w_hat.detach().cpu().numpy()
    w_hat = np.array(w_hat).reshape(-1)
    if isinstance(b_hat, torch.Tensor):
        b_hat = b_hat.detach().cpu().item()

    n_norm = int(n_test * (7 / 9))
    n_lev = n_test - n_norm

    x_test_norm = np.random.uniform(-2, 2, (n_norm, dim))
    y_test_norm = x_test_norm @ w_true + b_true + np.random.normal(0, 1.0, n_norm)

    direction = w_true / np.linalg.norm(w_true)
    center_lev = direction * 10.0
    x_test_lev = center_lev + np.random.normal(0, 0.2, (n_lev, dim))
    y_test_lev = x_test_lev @ w_true + b_true + np.random.normal(0, 0.1, n_lev)

    X_test = np.concatenate([x_test_norm, x_test_lev])
    Y_test = np.concatenate([y_test_norm, y_test_lev])

    y_pred = X_test @ w_hat + b_hat
    y_star = X_test @ w_true + b_true

    loss_model = np.mean(np.abs(y_pred - Y_test))
    loss_oracle = np.mean(np.abs(y_star - Y_test))

    return loss_model - loss_oracle


# ==========================================
# 核心算法: UOT-DRO (参数化版本)
# ==========================================
def train_uot_dro_tuned(X, Y, dim, epochs=1000, lr=0.01, lam=10.0, beta=10.0, lam2=10.0):
    """
    UOT-DRO 训练函数，接收三个核心超参数：
    - lam:  Global regularization / Radius control
    - beta: Softmax temperature
    - lam2: Prior/Transport cost weight
    """
    if Y.dim() == 1:
        Y = Y.unsqueeze(1)
    n_samples, n_features = X.shape
    model = nn.Linear(n_features, 1)

    # 统一初始化，保证公平
    with torch.no_grad():
        model.weight.fill_(0.0)
        model.bias.fill_(0.0)

    optimizer = optim.SGD(model.parameters(), lr=lr)

    # 预计算 Prior (特征距离)
    with torch.no_grad():
        center_x = torch.median(X, dim=0)[0]
        center_y = torch.median(Y, dim=0)[0]
        diff_x = X - center_x
        diff_y = Y - center_y
        diff_concat = torch.cat([diff_x, diff_y], dim=1)
        norm_prior = torch.norm(diff_concat, p=2, dim=1, keepdim=True)

    for epoch in range(epochs):
        optimizer.zero_grad()
        pred = model(X)
        abs_error = torch.abs(pred - Y)

        # UOT-DRO Loss 公式
        exponent_numerator = abs_error - lam2 * norm_prior
        exponent_term = exponent_numerator / (lam * beta)
        max_val = exponent_term.max().detach()  # Log-Sum-Exp 稳定性技巧
        loss = lam * beta * (torch.log(torch.sum(torch.exp(exponent_term - max_val))) + max_val)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()

    return {
        'w': model.weight.detach().cpu().numpy(),
        'b': model.bias.detach().cpu().item()
    }


# ==========================================
# 通用网格搜索 (Universal Grid Search)
# ==========================================
def grid_search_uot_dro_universal():
    print("=== 开始 UOT-DRO 通用参数搜索 (Universal Grid Search) ===")

    # 1. 准备不同维度的数据集
    # 我们将在 dim=10 和 dim=50 上同时验证，取平均 Risk
    test_dims = [10, 50]
    n_samples = 500
    seed_for_data = 42

    datasets = {}
    print("生成固定验证数据集...")
    for d in test_dims:
        np.random.seed(seed_for_data)
        torch.manual_seed(seed_for_data)
        ds = KillerDatasetHighDim(n_samples=n_samples, dim=d)
        datasets[d] = {
            'X': ds.get_data()[0],
            'Y': ds.get_data()[1],
            'w_true': ds.w_true,
            'b_true': ds.b_true
        }
        print(f"  -> Dim={d} 生成完毕")

    # 2. 定义搜索网格
    # 针对 Killer Dataset 的特性设计范围
    param_grid = {
        'lam': [1.0, 5.0, 10.0, 20.0],  # 半径控制
        'beta': [1.0, 10.0],  # 温度 (通常 10.0 比较稳定)
        'lam2': [1.0, 10.0, 20.0, 50.0]  # 对离群点/高杠杆的惩罚权重
    }

    keys, values = zip(*param_grid.items())
    combinations = [dict(zip(keys, v)) for v in itertools.product(*values)]

    print("-" * 85)
    print(f"待测组合总数: {len(combinations)}")
    print(f"{'lam':<6} {'beta':<6} {'lam2':<6} | {'Avg Risk':<12} | {'Risk(d=10)':<12} {'Risk(d=50)':<12}")
    print("-" * 85)

    best_avg_risk = float('inf')
    best_params = {}

    # 3. 遍历搜索
    for params in combinations:
        risks = []

        # 对每个维度分别训练并评估
        for d in test_dims:
            # 重置随机种子，确保训练过程可复现
            torch.manual_seed(seed_for_data)

            data_pack = datasets[d]

            # 训练
            result = train_uot_dro_tuned(
                data_pack['X'], data_pack['Y'], d,
                epochs=1000,  # 保证收敛
                lr=0.01,
                lam=params['lam'],
                beta=params['beta'],
                lam2=params['lam2']
            )

            # 评估 (Excess Risk)
            risk = calc_excess_risk_high_dim(
                result['w'], result['b'],
                data_pack['w_true'], data_pack['b_true'],
                d
            )
            risks.append(risk)

        # 计算平均表现
        avg_risk = np.mean(risks)

        # 格式化输出
        r_str = [f"{r:.4f}" for r in risks]
        print(
            f"{params['lam']:<6} {params['beta']:<6} {params['lam2']:<6} | {avg_risk:.4f}       | {r_str[0]:<12} {r_str[1]:<12}")

        # 更新最优解
        if avg_risk < best_avg_risk:
            best_avg_risk = avg_risk
            best_params = params

    print("-" * 85)
    print(f"【最优通用参数】: {best_params}")
    print(f"【最低平均 Risk】: {best_avg_risk:.4f}")

    return best_params


# ==========================================
# 主程序入口
# ==========================================
if __name__ == "__main__":
    best_config = grid_search_uot_dro_universal()