import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import itertools
from tqdm import tqdm


# ==========================================
# 0. Dataset Definition (Reusing KillerDatasetHighDim)
# ==========================================
class KillerDatasetHighDim:
    def __init__(self, n_samples=500, dim=5):  # Default N=500 for hyperparameter tuning
        self.n_samples = n_samples
        self.dim = dim
        self.b_true = 1.0

        # === Define High-Dimensional True Weights ===
        raw_w = np.random.randn(dim)
        self.w_true = (raw_w / np.linalg.norm(raw_w)) * 2.0

        # === 1. Background Noise (Majority, 70%) ===
        n_center = int(0.7 * n_samples)
        self.x_norm = np.random.uniform(-2, 2, (n_center, dim))
        self.y_norm = self.x_norm @ self.w_true + self.b_true + np.random.normal(0, 1.0, n_center)

        # === 2. Distant High-Leverage Points (10%) ===
        n_lev = int(0.1 * n_samples)
        direction = self.w_true / np.linalg.norm(self.w_true)
        center_lev = direction * 10.0
        self.x_lev = center_lev + np.random.normal(0, 0.2, (n_lev, dim))
        self.y_lev = self.x_lev @ self.w_true + self.b_true + np.random.normal(0, 0.1, n_lev)

        # === 3. Proximal Outliers (20% - Critical Contamination) ===
        n_out = int(0.2 * n_samples)
        self.x_out = np.random.normal(0, 1.0, (n_out, dim))
        w_outlier = -1.0 * self.w_true
        self.y_out = self.x_out @ w_outlier + self.b_true + 10.0 + np.random.normal(0, 0.5, n_out)

        # === Data Integration ===
        self.X = np.concatenate([self.x_norm, self.x_lev, self.x_out]).astype(np.float32)
        self.Y = np.concatenate([self.y_norm, self.y_lev, self.y_out]).astype(np.float32).reshape(-1, 1)

        self.X_t = torch.from_numpy(self.X)
        self.Y_t = torch.from_numpy(self.Y)

    def get_data(self):
        return self.X_t, self.Y_t

    def get_clean_data(self):
        # For validation: Contains only normal data and high-leverage points
        X_clean = np.concatenate([self.x_norm, self.x_lev])
        Y_clean = np.concatenate([self.y_norm, self.y_lev])
        return X_clean, Y_clean


# ==========================================
# 1. UOT-DRO Function for Tuning
# ==========================================
def train_uot_dro_tuned(X, Y, dim, lam, beta, lam2, epochs=1000, lr=0.01):
    if Y.dim() == 1:
        Y = Y.unsqueeze(1)

    model = nn.Linear(dim, 1)
    # Initialize weights to zero for fair comparison
    with torch.no_grad():
        model.weight.fill_(0.0)
        model.bias.fill_(0.0)

    optimizer = optim.SGD(model.parameters(), lr=lr)

    # Pre-compute marginal prior distances
    with torch.no_grad():
        center_x = torch.median(X, dim=0)[0]
        center_y = torch.median(Y, dim=0)[0]
        diff_x = X - center_x
        diff_y = Y - center_y
        diff_concat = torch.cat([diff_x, diff_y], dim=1)
        norm_prior = torch.norm(diff_concat, p=2, dim=1, keepdim=True)

    for epoch in range(epochs):
        optimizer.zero_grad()
        pred = model(X)
        abs_error = torch.abs(pred - Y)

        # Core UOT formulation
        exponent_numerator = abs_error - lam2 * norm_prior
        exponent_term = exponent_numerator / (lam * beta)
        max_val = exponent_term.max().detach()
        loss = lam * beta * (torch.log(torch.sum(torch.exp(exponent_term - max_val))) + max_val)

        loss.backward()
        # Gradient clipping to prevent NaN
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()

    return model.weight.detach().cpu().numpy(), model.bias.detach().cpu().item()


def calc_excess_risk(w_hat, b_hat, w_true, b_true, X_clean, Y_clean):
    w_hat = np.array(w_hat).reshape(-1)
    pred_hat = X_clean @ w_hat + b_hat
    risk_hat = np.mean(np.abs(pred_hat - Y_clean.reshape(-1)))

    pred_star = X_clean @ w_true + b_true
    risk_star = np.mean(np.abs(pred_star - Y_clean.reshape(-1)))
    return risk_hat - risk_star


# ==========================================
# 2. Grid Search
# ==========================================
if __name__ == '__main__':
    # 1. Set fixed environment
    seed = 42
    np.random.seed(seed)
    torch.manual_seed(seed)
    dim = 5
    n_samples = 800  # Use a moderate sample size for tuning

    print(f"🔄 Generating Dataset (N={n_samples}, Dim={dim})...")
    dataset = KillerDatasetHighDim(n_samples=n_samples, dim=dim)
    X_t, Y_t = dataset.get_data()
    X_clean, Y_clean = dataset.get_clean_data()

    # 2. Define parameter search space
    # Note: UOT is sensitive to numerical scales
    param_grid = {
        'lam': [1.0, 5.0, 10.0, 20.0],  # KL divergence penalty strength
        'beta': [0.1, 1.0, 5.0, 10.0],  # Temperature coefficient
        'lam2': [1.0, 5.0, 10.0, 50.0]  # Transport cost weight
    }

    keys, values = zip(*param_grid.items())
    combinations = [dict(zip(keys, v)) for v in itertools.product(*values)]

    print(f"🔍 Starting Grid Search with {len(combinations)} combinations...")
    print("-" * 60)
    print(f"{'lam':<8} | {'beta':<8} | {'lam2':<8} | {'Excess Risk':<15}")
    print("-" * 60)

    best_risk = float('inf')
    best_params = {}

    for params in tqdm(combinations, leave=False):
        try:
            w_hat, b_hat = train_uot_dro_tuned(
                X_t, Y_t, dim,
                lam=params['lam'],
                beta=params['beta'],
                lam2=params['lam2'],
                epochs=1500  # Consistent with main experiments
            )

            risk = calc_excess_risk(w_hat, b_hat, dataset.w_true, dataset.b_true, X_clean, Y_clean)

            # Simple filter to print only stable results
            if risk < 100:
                # print(f"{params['lam']:<8} | {params['beta']:<8} | {params['lam2']:<8} | {risk:.4f}")
                pass

            if risk < best_risk:
                best_risk = risk
                best_params = params

        except Exception as e:
            continue

    print("-" * 60)
    print(f"✅ Best Parameters Found:")
    print(f"   lam  = {best_params['lam']}")
    print(f"   beta = {best_params['beta']}")
    print(f"   lam2 = {best_params['lam2']}")