import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import itertools


# Average over dimensions 10 and 50
# ==========================================
# 0. Dataset and Evaluation Function
# ==========================================
class KillerDatasetHighDim:
    def __init__(self, n_samples=500, dim=10):
        self.n_samples = n_samples
        self.dim = dim
        self.b_true = 1.0
        # ... (Keep the original data generation logic unchanged) ...
        raw_w = np.random.randn(dim)
        self.w_true = (raw_w / np.linalg.norm(raw_w)) * 2.0

        n_center = int(0.7 * n_samples)
        self.x_norm = np.random.uniform(-2, 2, (n_center, dim))
        noise_norm = np.random.normal(0, 1.0, n_center)
        self.y_norm = self.x_norm @ self.w_true + self.b_true + noise_norm

        n_lev = int(0.1 * n_samples)
        direction = self.w_true / np.linalg.norm(self.w_true)
        center_lev = direction * 10.0
        self.x_lev = center_lev + np.random.normal(0, 0.2, (n_lev, dim))
        noise_lev = np.random.normal(0, 0.1, n_lev)
        self.y_lev = self.x_lev @ self.w_true + self.b_true + noise_lev

        n_out = int(0.2 * n_samples)
        self.x_out = np.random.normal(0, 1.0, (n_out, dim))
        w_outlier = -1.0 * self.w_true
        noise_out = np.random.normal(0, 0.5, n_out)
        self.y_out = self.x_out @ w_outlier + self.b_true + 10.0 + noise_out

        self.X = np.concatenate([self.x_norm, self.x_lev, self.x_out]).astype(np.float32)
        self.Y = np.concatenate([self.y_norm, self.y_lev, self.y_out]).astype(np.float32).reshape(-1, 1)
        self.X_t = torch.from_numpy(self.X)
        self.Y_t = torch.from_numpy(self.Y)

    def get_data(self):
        return self.X_t, self.Y_t


def calc_excess_risk_high_dim(w_hat, b_hat, w_true, b_true, dim, n_test=20000):
    if isinstance(w_hat, torch.Tensor):
        w_hat = w_hat.detach().cpu().numpy()
    w_hat = np.array(w_hat).reshape(-1)
    if isinstance(b_hat, torch.Tensor):
        b_hat = b_hat.detach().cpu().item()

    n_norm = int(n_test * (7 / 9))
    n_lev = n_test - n_norm
    x_test_norm = np.random.uniform(-2, 2, (n_norm, dim))
    y_test_norm = x_test_norm @ w_true + b_true + np.random.normal(0, 1.0, n_norm)
    direction = w_true / np.linalg.norm(w_true)
    center_lev = direction * 10.0
    x_test_lev = center_lev + np.random.normal(0, 0.2, (n_lev, dim))
    y_test_lev = x_test_lev @ w_true + b_true + np.random.normal(0, 0.1, n_lev)

    X_test = np.concatenate([x_test_norm, x_test_lev])
    Y_test = np.concatenate([y_test_norm, y_test_lev])

    y_pred = X_test @ w_hat + b_hat
    y_star = X_test @ w_true + b_true
    return np.mean(np.abs(y_pred - Y_test)) - np.mean(np.abs(y_star - Y_test))


# ==========================================
# 1. Modified OR-WDRO Component
# ==========================================
class OR_WDRO_LinearRegression(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.linear = nn.Linear(dim, 1)
        with torch.no_grad():
            self.linear.weight.fill_(0.0)
            self.linear.bias.fill_(0.0)
        self.pre_lambda1 = nn.Parameter(torch.tensor(0.55))
        self.pre_lambda2 = nn.Parameter(torch.tensor(0.55))
        self.alpha = nn.Parameter(torch.tensor(0.0))

    def forward(self, x):
        return self.linear(x)


def get_robust_estimates_tuned(X, Y, epsilon, sigma_multiplier):
    """
    Modification: Accept sigma_multiplier instead of hardcoded 8.0
    """
    data = torch.cat([X, Y], dim=1)
    N, d_total = data.shape

    # Trim ratio depends on epsilon
    trim_ratio = 2 * epsilon
    n_trim = int(N * trim_ratio)

    sorted_data, _ = torch.sort(data, dim=0)
    if n_trim > 0 and n_trim < N // 2:
        valid_data = sorted_data[n_trim:-n_trim, :]
    else:
        valid_data = sorted_data

    z0 = torch.mean(valid_data, dim=0)
    diff = valid_data - z0
    second_moment = torch.mean(torch.sum(diff ** 2, dim=1))

    # [Core Modification] Use parameter to control the multiplier
    sigma_sq = second_moment * sigma_multiplier
    sigma_sq = torch.max(sigma_sq, torch.tensor(1.0))

    return z0[:-1].unsqueeze(0), z0[-1].unsqueeze(0), sigma_sq


def train_or_wdro_tuned(X, Y, dim, epochs=1000, lr=0.01, epsilon=0.4, sigma_multiplier=8.0, rho=1.0):
    """
    Modification: Pass in epsilon and sigma_multiplier
    """
    # Get robust estimates
    z0_x, z0_y, sigma_sq = get_robust_estimates_tuned(X, Y, epsilon, sigma_multiplier)
    z0_x = z0_x.detach()
    z0_y = z0_y.detach()
    sigma_sq = sigma_sq.detach()

    model = OR_WDRO_LinearRegression(dim)
    optimizer = optim.SGD(model.parameters(), lr=lr)
    # Accelerate convergence slightly for search
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=int(epochs / 2), gamma=0.1)

    for epoch in range(epochs):
        lam1 = torch.nn.functional.softplus(model.pre_lambda1)
        lam2 = torch.nn.functional.softplus(model.pre_lambda2)
        gamma = lam1 + lam2 + 1e-6

        optimizer.zero_grad()
        bar_X = (lam1 * z0_x + lam2 * X) / gamma
        bar_Y = (lam1 * z0_y + lam2 * Y) / gamma

        theta = model.linear.weight
        bias = model.linear.bias

        linear_val = (torch.matmul(bar_X, theta.T) + bias) - bar_Y
        w_norm_sq = torch.sum(theta ** 2) + 1.0
        conjugate_term = w_norm_sq / (4.0 * gamma)

        dist_sq = torch.sum((X - z0_x) ** 2, dim=1, keepdim=True) + (Y - z0_y) ** 2
        correction_term = (lam1 * lam2 / gamma) * dist_sq

        sup_val = torch.abs(linear_val) + conjugate_term - correction_term
        relu_part = torch.relu(sup_val - model.alpha)

        # epsilon also plays a role here
        cvar_loss = model.alpha + torch.mean(relu_part) / (1.0 - epsilon)

        total_loss = lam1 * sigma_sq + lam2 * (rho ** 2) + cvar_loss

        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

    return {
        'w': model.linear.weight.detach().cpu().numpy(),
        'b': model.linear.bias.detach().cpu().item()
    }


def grid_search_or_wdro_universal():
    print("=== Starting Universal Grid Search across dimensions ===")

    # 1. Select representative dimensions for validation
    # No need to run all [10, 20, 30, 40, 50]; typically ends [10, 50] cover the range
    test_dims = [10, 50]
    n_samples = 500
    seed_for_data = 42

    # Pre-generate data to avoid regeneration in the loop
    datasets = {}
    for d in test_dims:
        np.random.seed(seed_for_data)
        torch.manual_seed(seed_for_data)
        ds = KillerDatasetHighDim(n_samples=n_samples, dim=d)
        datasets[d] = {
            'X': ds.get_data()[0],
            'Y': ds.get_data()[1],
            'w_true': ds.w_true,
            'b_true': ds.b_true
        }

    # 2. Define search space
    # Epsilon: Must cover contamination rate (around 0.3)
    epsilon_list = [0.2, 0.3, 0.4]
    # Sigma Multiplier: Controls variance estimation
    multiplier_list = [4.0, 8.0, 16.0]

    param_grid = list(itertools.product(epsilon_list, multiplier_list))

    print(f"Validation Dimensions: {test_dims}")
    print(f"Search Combinations: {len(param_grid)}")
    print(f"{'Epsilon':<10} {'Mult':<10} | {'Avg Risk':<15} | {'Risk(d=10)':<12} {'Risk(d=50)':<12}")
    print("-" * 75)

    best_avg_risk = float('inf')
    best_params = {}

    # 3. Iterate over parameters
    for eps, mult in param_grid:
        risks = []

        # For each parameter combination, run on dim=10 and dim=50 respectively
        for d in test_dims:
            # Reset random seed for model training
            torch.manual_seed(seed_for_data)

            data_pack = datasets[d]

            # Training
            result = train_or_wdro_tuned(
                data_pack['X'], data_pack['Y'], d,
                epochs=1000,
                lr=0.01,
                epsilon=eps,
                sigma_multiplier=mult,
                rho=1.0
            )

            # Evaluation
            risk = calc_excess_risk_high_dim(
                result['w'], result['b'],
                data_pack['w_true'], data_pack['b_true'],
                d
            )
            risks.append(risk)

        # Calculate average risk
        avg_risk = np.mean(risks)

        # Print results
        r_str = [f"{r:.4f}" for r in risks]
        print(f"{eps:<10} {mult:<10} | {avg_risk:.4f}          | {r_str[0]:<12} {r_str[1]:<12}")

        if avg_risk < best_avg_risk:
            best_avg_risk = avg_risk
            best_params = {'epsilon': eps, 'sigma_multiplier': mult}

    print("-" * 75)
    print(f"Best Universal Parameters: {best_params}")
    print(f"Lowest Average Risk: {best_avg_risk:.4f}")

    return best_params

# Run search
best_config = grid_search_or_wdro_universal()