import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn import linear_model
from tqdm import tqdm
import multiprocessing as mp
from functools import partial
import os
import gc

from dataset import load_dataset

# Configuration
USE_PARALLEL = True  # Set to False to use sequential version
MAX_PROCESSES = (
    1000  # Maximum number of parallel processes (adjust based on GPU memory)
)
OUTPUT_DIR = "outputs"

# Set device for GPU acceleration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


class PinballLoss(nn.Module):
    """Pinball loss for quantile regression"""

    def __init__(self, quantile):
        super(PinballLoss, self).__init__()
        self.quantile = quantile

    def forward(self, pred, target):
        """
        Compute pinball loss
        pred: predicted values
        target: true values
        quantile: quantile level (e.g., 0.05 for 5th percentile)
        """
        error = target - pred
        loss = torch.where(
            error >= 0, self.quantile * error, (self.quantile - 1) * error
        )
        return torch.mean(loss)


class QuantileLinearModel(nn.Module):
    """Linear model for quantile regression"""

    def __init__(self, input_dim):
        super(QuantileLinearModel, self).__init__()
        self.linear = nn.Linear(input_dim, 1)

    def forward(self, x):
        return self.linear(x).squeeze()


def train_quantile_model(
    X_train,
    y_train,
    quantile,
    epochs=1000,
    lr=0.01,
    batch_size=64,
    patience=50,
    device=None,
):
    """
    Train a PyTorch linear model using pinball loss for a specific quantile

    Args:
        X_train: training features
        y_train: training targets
        quantile: quantile level (e.g., 0.05 for 5th percentile)
        epochs: maximum number of training epochs
        lr: learning rate
        patience: early stopping patience

    Returns:
        trained model
    """
    # Initialize device if not provided
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Convert to PyTorch tensors and move to device
    X_tensor = torch.FloatTensor(X_train).to(device)
    y_tensor = torch.FloatTensor(y_train).to(device)

    # Initialize model and move to device
    input_dim = X_train.shape[1]
    model = QuantileLinearModel(input_dim).to(device)

    # Loss function and optimizer
    criterion = PinballLoss(quantile)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Training loop
    best_loss = float("inf")
    patience_counter = 0

    for epoch in range(epochs):
        for i in range(0, len(X_tensor), batch_size):
            X_batch = X_tensor[i : i + batch_size]
            y_batch = y_tensor[i : i + batch_size]

            # Forward pass
            optimizer.zero_grad()
            predictions = model(X_batch)
            loss = criterion(predictions, y_batch)

            # Backward pass
            loss.backward()
            optimizer.step()

            # Early stopping
            if loss.item() < best_loss:
                best_loss = loss.item()
                patience_counter = 0
            else:
                patience_counter += 1

            if patience_counter >= patience:
                # print(f"Early stopping at epoch {epoch}")
                break

    # Clear GPU memory
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    return model, loss.item()


def _transform_dataset(
    X_train, y_train, X_test, y_test, train_ratio, calibration_ratio, seed=None
):
    # reshape the data
    X_train = np.asarray(X_train)
    y_train = np.asarray(y_train)
    X_test = np.asarray(X_test)
    y_test = np.asarray(y_test)

    # input dimensions
    n_train_original = X_train.shape[0]
    in_shape = X_train.shape[1]

    # set seed for splitting the data into proper train and calibration
    if seed is not None:
        np.random.seed(seed)
    idx = np.random.permutation(n_train_original)

    # divide the data into proper training set and calibration set
    n_train = int(np.floor(n_train_original * train_ratio))
    n_calibration = int(np.floor(n_train_original * calibration_ratio))

    # Ensure non-overlapping splits
    idx_train = idx[:n_train]
    idx_cal = idx[-n_calibration:]

    # zero mean and unit variance scaling of the train and test features
    scalerX = StandardScaler()
    scalerX = scalerX.fit(X_train[idx_train])
    X_train = scalerX.transform(X_train)
    X_test = scalerX.transform(X_test)

    # scale the labels by dividing each by the mean absolute response
    mean_ytrain = np.mean(np.abs(y_train[idx_train]))
    y_train = np.squeeze(y_train) / mean_ytrain
    y_test = np.squeeze(y_test) / mean_ytrain

    X_train = X_train.astype(np.float32)
    X_test = X_test.astype(np.float32)
    y_train = y_train.astype(np.float32)
    y_test = y_test.astype(np.float32)
    idx_train = idx_train.astype(np.int32)
    # idx_cal = idx_cal.astype(np.int32)
    return X_train, X_test, y_train, y_test, idx_train, idx_cal


def successive_halving_lr0_tuning(
    X_train,
    y_train,
    gamma,
    epochs=1,
    lr0_candidates=None,
    n_iterations=3,
    eta=2,
    min_budget=1,
    debug=False,
    device=None,
):
    """
    Successive Halving algorithm for tuning lr0 hyperparameter.

    Parameters:
    -----------
    X_train, y_train : training data
    gamma : quantile level
    rng : random number generator
    epochs : number of training epochs
    lr0_candidates : list of learning rates to try (if None, will generate automatically)
    n_iterations : number of successive halving iterations
    eta : elimination factor (how many candidates to eliminate each round)
    min_budget : minimum number of candidates to keep

    Returns:
    --------
    best_lr0 : best learning rate found
    best_loss : corresponding loss value
    tuning_history : list of (iteration, candidates, losses) for each round
    """

    def pprint(text):
        if debug:
            print(text)

    # Generate learning rate candidates if not provided
    if lr0_candidates is None:
        # Use log-uniform distribution for learning rates
        lr0_candidates = np.logspace(-5, 0, 20)  # From 1e-5 to 1e-1

    tuning_history = []
    current_candidates = lr0_candidates.copy()

    pprint(f"Starting Successive Halving with {len(current_candidates)} candidates")
    pprint(
        f"Learning rate range: [{min(current_candidates):.2e}, {max(current_candidates):.2e}]"
    )

    for iteration in range(n_iterations):
        pprint(f"\n--- Iteration {iteration + 1} ---")
        pprint(f"Evaluating {len(current_candidates)} candidates")

        # Evaluate all current candidates
        candidate_losses = []

        for i, lr0 in enumerate(current_candidates):
            # Train model with current lr0
            model, final_loss = train_quantile_model(
                X_train, y_train, quantile=gamma, epochs=1, lr=lr0, device=device
            )

            # Use the last (final) loss value as evaluation metric
            # final_loss = model.loss_list[-1]
            candidate_losses.append(final_loss)

            pprint(f"  lr0={lr0:.2e}: final_loss={final_loss:.6f}")

        # Store history
        tuning_history.append(
            {
                "iteration": iteration + 1,
                "candidates": current_candidates.copy(),
                "lr": lr0,
                "losses": candidate_losses.copy(),
            }
        )

        # Select top candidates for next round
        if iteration < n_iterations - 1:  # Don't eliminate on last iteration
            # Sort by loss (lower is better for quantile regression)
            sorted_indices = np.argsort(candidate_losses)

            # Keep top candidates (eliminate worst ones)
            n_keep = max(min_budget, len(current_candidates) // eta)
            top_indices = sorted_indices[:n_keep]

            current_candidates = current_candidates[top_indices]
            pprint(f"Kept top {len(current_candidates)} candidates for next round")
        else:
            # Final iteration - select best
            best_idx = np.argmin(candidate_losses)
            best_lr0 = current_candidates[best_idx]
            best_loss = candidate_losses[best_idx]

            pprint(f"\nFinal best: lr0={best_lr0:.2e}, loss={best_loss:.6f}")

    return model, best_lr0, best_loss, tuning_history


def run_experiment(
    dataset_name,
    alpha=0.1,
    calibration_ratio=0.2,
    test_ratio=0.2,
    train_ratio=0.5,
    seed=42,
    device=None,
):
    # Initialize device if not provided
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    quantiles = [0.5]
    data = {
        "seed": seed,
        "dataset_name": dataset_name,
        "alpha": alpha,
        "calibration_ratio": calibration_ratio,
        "train_ratio": train_ratio,
    }

    ## Fix seed
    data_seed = 42
    random.seed(data_seed)
    np.random.seed(data_seed)
    torch.manual_seed(data_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(data_seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    X, y = load_dataset(dataset_name)
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_ratio, random_state=data_seed
    )
    X_train, X_test, y_train, y_test, idx_train, idx_cal = _transform_dataset(
        X_train,
        y_train,
        X_test,
        y_test,
        train_ratio=train_ratio,
        calibration_ratio=calibration_ratio,
        seed=data_seed,
    )

    ## Fix seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    models = {}
    predictions_calib = {}
    predictions_test = {}
    for quantile in quantiles:
        model, _, _, _ = successive_halving_lr0_tuning(
            X_train[idx_train],
            y_train[idx_train],
            gamma=quantile,
            epochs=1,
            device=device,
        )

        models[quantile] = model

        # Make predictions
        with torch.no_grad():
            X_test_tensor = torch.FloatTensor(X_test).to(device)
            pred = model(X_test_tensor).cpu().numpy()
            predictions_test[quantile] = pred

            X_cal_tensor = torch.FloatTensor(X_train[idx_cal]).to(device)
            pred_cal = model(X_cal_tensor).cpu().numpy()
            predictions_calib[quantile] = pred_cal

    # Calculate prediction intervals
    t_mid = predictions_calib[quantiles[0]]

    def compute_cmr_nonconformity_score(t_mid, y):
        return np.abs([t_mid - y])

    cmr_nc = compute_cmr_nonconformity_score(t_mid, y_train[idx_cal])

    q = np.quantile(cmr_nc, 1 - alpha, method="higher")

    # Coverage (percentage of test points within prediction interval)
    t_mid_test = predictions_test[quantiles[0]]
    coverage = np.mean((y_test >= t_mid_test - q) & (y_test <= t_mid_test + q))
    interval_width = q * 2

    data["coverage"] = coverage.item()
    data["q"] = q.item()
    data["quantile_interval"] = interval_width.item()

    # Clear GPU memory and run garbage collection
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

    return data


def run_single_experiment(args):
    """Wrapper function for parallel execution"""
    dataset_name, alpha, calibration_ratio, train_ratio, seed = args

    try:
        # Initialize CUDA device in subprocess
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        result = run_experiment(
            dataset_name=dataset_name,
            alpha=alpha,
            calibration_ratio=calibration_ratio,
            test_ratio=0.2,
            train_ratio=train_ratio,
            seed=seed,
            device=device,
        )
        return result
    except Exception as e:
        print(f"Error in experiment {args}: {e}")
        raise e
        return None


def run_parallel_experiments():
    """Run experiments in parallel"""
    # Create output directory if it doesn't exist
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    # Define experiment parameters
    datasets = ["meps_20"]
    seeds = list(range(20))
    calibration_ratios = [0.05, 0.1, 0.15, 0.2]
    train_ratios = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
    alphas = [0.01, 0.05, 0.1, 0.2]

    # Create all experiment combinations
    experiment_args = []
    for dataset_name in datasets:
        for seed in seeds:
            for calibration_ratio in calibration_ratios:
                for train_ratio in train_ratios:
                    for alpha in alphas:
                        experiment_args.append(
                            (dataset_name, alpha, calibration_ratio, train_ratio, seed)
                        )

    print(f"Total experiments to run: {len(experiment_args)}")

    # Determine number of processes (use CPU count, but limit for GPU memory)
    num_processes = min(mp.cpu_count(), MAX_PROCESSES)
    print(f"Using {num_processes} processes")

    # Run experiments in parallel
    with mp.Pool(processes=num_processes) as pool:
        results = list(
            tqdm(
                pool.imap(run_single_experiment, experiment_args),
                total=len(experiment_args),
                desc="Running experiments",
            )
        )

    # Filter out None results (failed experiments)
    results = [r for r in results if r is not None]

    # Convert to DataFrame and save
    df = pd.DataFrame(results)
    output_file = f"{OUTPUT_DIR}/meps_20_cmr.csv"
    df.to_csv(output_file, index=False)

    print(f"Completed {len(results)} experiments successfully")
    print(f"Results saved to {output_file}")

    return df


if __name__ == "__main__":
    # Set multiprocessing start method to 'spawn' for CUDA compatibility
    # This is required because CUDA contexts cannot be shared between forked processes
    # 'spawn' creates fresh processes instead of forking, allowing each process to initialize CUDA independently
    mp.set_start_method("spawn", force=True)

    print("Running experiments in parallel...")
    df = run_parallel_experiments()
