import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn import linear_model
from tqdm import tqdm
import multiprocessing as mp
from functools import partial
import os
import gc
import click
import json

from dataset import load_dataset

# Configuration
USE_PARALLEL = True  # Set to False to use sequential version
MAX_PROCESSES = (
    1000  # Maximum number of parallel processes (adjust based on GPU memory)
)
OUTPUT_DIR = "outputs"

# Set device for GPU acceleration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


class PinballLoss(nn.Module):
    """Pinball loss for quantile regression"""

    def __init__(self, quantile):
        super(PinballLoss, self).__init__()
        self.quantile = quantile

    def forward(self, pred, target):
        """
        Compute pinball loss
        pred: predicted values
        target: true values
        quantile: quantile level (e.g., 0.05 for 5th percentile)
        """
        error = target - pred
        loss = torch.where(
            error >= 0, self.quantile * error, (self.quantile - 1) * error
        )
        return torch.mean(loss)


class QuantileLinearModel(nn.Module):
    """Linear model for quantile regression"""

    def __init__(self, input_dim):
        super(QuantileLinearModel, self).__init__()
        self.linear = nn.Linear(input_dim, 1)

    def forward(self, x):
        return self.linear(x).squeeze()


class QuantileTwoLayerModel(nn.Module):
    """Linear model for quantile regression"""

    def __init__(self, input_dim):
        super(QuantileTwoLayerModel, self).__init__()
        self.linear = nn.Linear(input_dim, 10)
        self.linear2 = nn.Linear(10, 1)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.linear(x)
        x = self.relu(x)
        x = self.linear2(x).squeeze()
        return x


class QuantileThreeLayerModel(nn.Module):
    """Linear model for quantile regression"""

    def __init__(self, input_dim):
        super(QuantileThreeLayerModel, self).__init__()
        self.linear = nn.Linear(input_dim, 10)
        self.linear2 = nn.Linear(10, 10)
        self.linear3 = nn.Linear(10, 1)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.linear(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.relu(x)
        x = self.linear3(x)
        return x.squeeze()


def train_quantile_model(
    X_train,
    y_train,
    quantile,
    epochs=1000,
    lr=0.01,
    batch_size=64,
    patience=50,
    device=None,
    kwargs={},
):
    """
    Train a PyTorch linear model using pinball loss for a specific quantile

    Args:
        X_train: training features
        y_train: training targets
        quantile: quantile level (e.g., 0.05 for 5th percentile)
        epochs: maximum number of training epochs
        lr: learning rate
        patience: early stopping patience

    Returns:
        trained model
    """
    # Initialize device if not provided
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Convert to PyTorch tensors and move to device
    X_tensor = torch.FloatTensor(X_train).to(device)
    y_tensor = torch.FloatTensor(y_train).to(device)

    # Initialize model and move to device
    input_dim = X_train.shape[1]
    if "model" in kwargs:
        if kwargs["model"] == "linear":
            model = QuantileLinearModel(input_dim)
        elif kwargs["model"] == "two_layer":
            model = QuantileTwoLayerModel(input_dim)
        elif kwargs["model"] == "three_layer":
            model = QuantileThreeLayerModel(input_dim)
        else:
            raise ValueError(f"Invalid model: {kwargs['model']}")
    else:
        model = QuantileLinearModel(input_dim)

    model = model.to(device)

    # Loss function and optimizer
    criterion = PinballLoss(quantile)
    if "optim" in kwargs:
        if kwargs["optim"] == "SGD":
            optimizer = optim.SGD(model.parameters(), lr=lr)
        elif kwargs["optim"] == "MomentumSGD":
            optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
        elif kwargs["optim"] == "Adam":
            optimizer = optim.Adam(model.parameters(), lr=lr)
        elif kwargs["optim"] == "AdamW":
            optimizer = optim.AdamW(model.parameters(), lr=lr)
        elif kwargs["optim"] == "VR":
            # variance reduction
            # optimizer = VR(model.parameters(), lr=lr)
            raise NotImplementedError("VR optimizer not implemented")
        else:
            raise ValueError(f"Invalid optimizer: {kwargs['optim']}")
    else:
        optimizer = optim.Adam(model.parameters(), lr=lr)

    # Training loop
    best_loss = float("inf")
    patience_counter = 0

    for epoch in range(epochs):
        for i in range(0, len(X_tensor), batch_size):
            X_batch = X_tensor[i : i + batch_size]
            y_batch = y_tensor[i : i + batch_size]

            # Forward pass
            optimizer.zero_grad()
            predictions = model(X_batch)
            loss = criterion(predictions, y_batch)

            # Backward pass
            loss.backward()
            optimizer.step()

            # Early stopping
            if loss.item() < best_loss:
                best_loss = loss.item()
                patience_counter = 0
            else:
                patience_counter += 1

            if patience_counter >= patience:
                # print(f"Early stopping at epoch {epoch}")
                break

    # Clear GPU memory
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    return model, loss.item()


def _transform_dataset(
    X_train, y_train, X_test, y_test, train_ratio, calibration_ratio, seed=None
):
    # reshape the data
    X_train = np.asarray(X_train)
    y_train = np.asarray(y_train)
    X_test = np.asarray(X_test)
    y_test = np.asarray(y_test)

    # input dimensions
    n_train_original = X_train.shape[0]
    in_shape = X_train.shape[1]

    # set seed for splitting the data into proper train and calibration
    if seed is not None:
        np.random.seed(seed)
    idx = np.random.permutation(n_train_original)

    # divide the data into proper training set and calibration set
    n_train = int(np.floor(n_train_original * train_ratio))
    n_calibration = int(np.floor(n_train_original * calibration_ratio))

    # Ensure non-overlapping splits
    idx_train = idx[:n_train]
    idx_cal = idx[-n_calibration:]

    # zero mean and unit variance scaling of the train and test features
    scalerX = StandardScaler()
    scalerX = scalerX.fit(X_train[idx_train])
    X_train = scalerX.transform(X_train)
    X_test = scalerX.transform(X_test)

    # scale the labels by dividing each by the mean absolute response
    mean_ytrain = np.mean(np.abs(y_train[idx_train]))
    y_train = np.squeeze(y_train) / mean_ytrain
    y_test = np.squeeze(y_test) / mean_ytrain

    X_train = X_train.astype(np.float32)
    X_test = X_test.astype(np.float32)
    y_train = y_train.astype(np.float32)
    y_test = y_test.astype(np.float32)
    idx_train = idx_train.astype(np.int32)
    # idx_cal = idx_cal.astype(np.int32)
    return X_train, X_test, y_train, y_test, idx_train, idx_cal


def successive_halving_lr0_tuning(
    X_train,
    y_train,
    gamma,
    epochs=1,
    lr0_candidates=None,
    n_iterations=3,
    eta=2,
    min_budget=1,
    debug=False,
    device=None,
    kwargs={},
):
    """
    Successive Halving algorithm for tuning lr0 hyperparameter.

    Parameters:
    -----------
    X_train, y_train : training data
    gamma : quantile level
    rng : random number generator
    epochs : number of training epochs
    lr0_candidates : list of learning rates to try (if None, will generate automatically)
    n_iterations : number of successive halving iterations
    eta : elimination factor (how many candidates to eliminate each round)
    min_budget : minimum number of candidates to keep

    Returns:
    --------
    best_lr0 : best learning rate found
    best_loss : corresponding loss value
    tuning_history : list of (iteration, candidates, losses) for each round
    """

    def pprint(text):
        if debug:
            print(text)

    # Generate learning rate candidates if not provided
    if lr0_candidates is None:
        # Use log-uniform distribution for learning rates
        lr0_candidates = np.logspace(-5, 0, 20)  # From 1e-5 to 1e-1

    tuning_history = []
    current_candidates = lr0_candidates.copy()

    pprint(f"Starting Successive Halving with {len(current_candidates)} candidates")
    pprint(
        f"Learning rate range: [{min(current_candidates):.2e}, {max(current_candidates):.2e}]"
    )

    for iteration in range(n_iterations):
        pprint(f"\n--- Iteration {iteration + 1} ---")
        pprint(f"Evaluating {len(current_candidates)} candidates")

        # Evaluate all current candidates
        candidate_losses = []

        for i, lr0 in enumerate(current_candidates):
            # Train model with current lr0
            model, final_loss = train_quantile_model(
                X_train,
                y_train,
                quantile=gamma,
                epochs=epochs,
                lr=lr0,
                device=device,
                kwargs=kwargs,
            )

            # Use the last (final) loss value as evaluation metric
            # final_loss = model.loss_list[-1]
            candidate_losses.append(final_loss)

            pprint(f"  lr0={lr0:.2e}: final_loss={final_loss:.6f}")

        # Store history
        tuning_history.append(
            {
                "iteration": iteration + 1,
                "candidates": current_candidates.copy(),
                "lr": lr0,
                "losses": candidate_losses.copy(),
            }
        )

        # Select top candidates for next round
        if iteration < n_iterations - 1:  # Don't eliminate on last iteration
            # Sort by loss (lower is better for quantile regression)
            sorted_indices = np.argsort(candidate_losses)

            # Keep top candidates (eliminate worst ones)
            n_keep = max(min_budget, len(current_candidates) // eta)
            top_indices = sorted_indices[:n_keep]

            current_candidates = current_candidates[top_indices]
            pprint(f"Kept top {len(current_candidates)} candidates for next round")
        else:
            # Final iteration - select best
            best_idx = np.argmin(candidate_losses)
            best_lr0 = current_candidates[best_idx]
            best_loss = candidate_losses[best_idx]

            pprint(f"\nFinal best: lr0={best_lr0:.2e}, loss={best_loss:.6f}")

    model, _ = train_quantile_model(
        X_train, y_train, quantile=gamma, epochs=epochs, lr=best_lr0, device=device
    )

    return model, best_lr0, best_loss, tuning_history


def run_experiment(
    dataset_name,
    alpha=0.1,
    calibration_ratio=0.2,
    test_ratio=0.2,
    train_ratio=0.5,
    n_iterations=3,
    lr0_candidates=np.logspace(-5, 0, 20),
    kwargs={},
    seed=42,
    device=None,
):
    kwargs = dict(kwargs)
    # Initialize device if not provided
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    quantiles = [0.5]
    data = {
        "seed": seed,
        "dataset_name": dataset_name,
        "alpha": alpha,
        "calibration_ratio": calibration_ratio,
        "train_ratio": train_ratio,
        **kwargs,
    }

    ## Fix seed
    data_seed = 42
    random.seed(data_seed)
    np.random.seed(data_seed)
    torch.manual_seed(data_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(data_seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    X, y = load_dataset(dataset_name)
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_ratio, random_state=data_seed
    )
    X_train, X_test, y_train, y_test, idx_train, idx_cal = _transform_dataset(
        X_train,
        y_train,
        X_test,
        y_test,
        train_ratio=train_ratio,
        calibration_ratio=calibration_ratio,
        seed=seed,
    )

    ## Fix seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    models = {}
    predictions_calib = {}
    predictions_test = {}
    best_losses_dict = {}
    for quantile in quantiles:
        model, _, best_loss, _ = successive_halving_lr0_tuning(
            X_train[idx_train],
            y_train[idx_train],
            gamma=quantile,
            epochs=1,
            device=device,
            # debug=True,
            n_iterations=n_iterations,
            lr0_candidates=lr0_candidates,
            kwargs=kwargs,
        )

        models[quantile] = model
        best_losses_dict[quantile] = best_loss

        # Make predictions
        with torch.no_grad():
            X_test_tensor = torch.FloatTensor(X_test).to(device)
            pred = model(X_test_tensor).cpu().numpy()
            predictions_test[quantile] = pred

            X_cal_tensor = torch.FloatTensor(X_train[idx_cal]).to(device)
            pred_cal = model(X_cal_tensor).cpu().numpy()
            predictions_calib[quantile] = pred_cal

    # Calculate prediction intervals
    t_mid = predictions_calib[quantiles[0]]

    def compute_cmr_nonconformity_score(t_mid, y):
        return np.abs([t_mid - y])

    cmr_nc = compute_cmr_nonconformity_score(t_mid, y_train[idx_cal])

    q = np.quantile(cmr_nc, 1 - alpha, method="higher")

    # Coverage (percentage of test points within prediction interval)
    t_mid_test = predictions_test[quantiles[0]]
    coverage = np.mean((y_test >= t_mid_test - q) & (y_test <= t_mid_test + q))
    interval_width = q * 2

    data["coverage"] = coverage.item()
    data["q"] = q.item()
    data["quantile_interval"] = interval_width.item()
    data["n_iterations"] = n_iterations
    data["len(lr0_candidates)"] = len(lr0_candidates)
    data["best_losses"] = json.dumps(best_losses_dict)

    # Clear GPU memory and run garbage collection
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

    return data


def run_single_experiment(args):
    """Wrapper function for parallel execution"""
    (
        dataset_name,
        alpha,
        calibration_ratio,
        train_ratio,
        n_iterations,
        lr0_candidates,
        kwargs,
        seed,
    ) = args

    try:
        # Initialize CUDA device in subprocess
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        result = run_experiment(
            dataset_name=dataset_name,
            alpha=alpha,
            calibration_ratio=calibration_ratio,
            test_ratio=0.2,
            train_ratio=train_ratio,
            n_iterations=n_iterations,
            lr0_candidates=lr0_candidates,
            kwargs=kwargs,
            seed=seed,
            device=device,
        )
        return result
    except Exception as e:
        print(f"Error in experiment {args}: {e}")
        raise e
        return None


@click.group()
def cli():
    pass


################
@cli.command()
def grid_budget_exploration():
    mp.set_start_method("spawn", force=True)

    # Create output directory if it doesn't exist
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    # Define experiment parameters
    seeds = list(range(20, 50))
    calibration_ratios = [0.05, 0.1, 0.15, 0.2]
    train_ratios = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
    alphas = [0.01, 0.05, 0.1, 0.2]

    filename = "grid_budget_exploration"  # NOTE:
    output_file = f"{OUTPUT_DIR}/{filename}.csv"

    experiment_args = []
    for dataset_name in ["meps_20", "meps_19", "cal_housing"]:
        # Create all experiment combinations
        for seed in seeds:
            for calibration_ratio in calibration_ratios:
                for train_ratio in train_ratios:
                    for alpha in alphas:
                        for n_iterations, lr0_candidates in [
                            (20, np.logspace(-6, 0, 80)),
                        ]:
                            experiment_args.append(
                                (
                                    dataset_name,
                                    alpha,
                                    calibration_ratio,
                                    train_ratio,
                                    n_iterations,
                                    lr0_candidates,
                                    seed,
                                )
                            )

    print(f"Total experiments to run: {len(experiment_args)}")

    # Determine number of processes (use CPU count, but limit for GPU memory)
    num_processes = min(mp.cpu_count(), MAX_PROCESSES)
    print(f"Using {num_processes} processes")

    # Run experiments in parallel
    with mp.Pool(processes=num_processes) as pool:
        results = list(
            tqdm(
                pool.imap(run_single_experiment, experiment_args),
                total=len(experiment_args),
                desc="Running experiments",
            )
        )

    # Filter out None results (failed experiments)
    results = [r for r in results if r is not None]

    # Convert to DataFrame and save
    df = pd.DataFrame(results)
    df.to_csv(output_file, index=False)

    print(f"Completed {len(results)} experiments successfully")
    print(f"Results saved to {output_file}")
    return df


@cli.command()
def different_optimizers_datasets():
    mp.set_start_method("spawn", force=True)

    # Create output directory if it doesn't exist
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    # Define experiment parameters
    seeds = list(range(20))
    calibration_ratios = [0.2]
    train_ratios = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
    alphas = [0.01, 0.05, 0.1, 0.2]
    taskname = "different_optimizers"  # NOTE:
    os.makedirs(f"{OUTPUT_DIR}/{taskname}", exist_ok=True)
    # output_file = f"{OUTPUT_DIR}/{taskname}/result.csv"
    # output_file = f"{OUTPUT_DIR}/{taskname}/result_extended_dataset.csv"
    output_file = f"{OUTPUT_DIR}/{taskname}/result_extended_optimizers.csv"

    # Create all experiment combinations
    experiment_args = []
    for seed in seeds:
        # for dataset_name in ["meps_19", "meps_20", "cal_housing"]:
        # for dataset_name in ["cal_housing", "abalone", "cpu_small"]:
        for dataset_name in ["meps_19", "meps_20", "cal_housing", "abalone", "cpu_small"]:
            for calibration_ratio in calibration_ratios:
                for train_ratio in train_ratios:
                    for alpha in alphas:
                        # for optim in ["SGD", "MomentumSGD", "Adam", "AdamW"]:
                        for optim in ["AdamW"]:
                            n_iterations = 10
                            lr0_candidates = np.logspace(-6, 0, 20)
                            kwargs = [("optim", optim)]
                            experiment_args.append(
                                (
                                    dataset_name,
                                    alpha,
                                    calibration_ratio,
                                    train_ratio,
                                    n_iterations,
                                    lr0_candidates,
                                    kwargs,
                                    seed,
                                )
                            )

    print(f"Total experiments to run: {len(experiment_args)}")

    # Determine number of processes (use CPU count, but limit for GPU memory)
    num_processes = min(mp.cpu_count(), MAX_PROCESSES)
    print(f"Using {num_processes} processes")

    # Run experiments in parallel
    with mp.Pool(processes=num_processes) as pool:
        results = list(
            tqdm(
                pool.imap_unordered(run_single_experiment, experiment_args),
                total=len(experiment_args),
                desc="Running experiments",
            )
        )

    # Filter out None results (failed experiments)
    results = [r for r in results if r is not None]

    # Convert to DataFrame and save
    df = pd.DataFrame(results)
    df.to_csv(output_file, index=False)

    print(f"Completed {len(results)} experiments successfully")
    print(f"Results saved to {output_file}")

    return df


@cli.command()
def different_models_datasets():
    mp.set_start_method("spawn", force=True)

    # Create output directory if it doesn't exist
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    # Define experiment parameters
    seeds = list(range(20))
    calibration_ratios = [0.2]
    train_ratios = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
    alphas = [0.01, 0.05, 0.1, 0.2]
    # seeds = list(range(1))
    # calibration_ratios = [0.2]
    # train_ratios = [0.2]
    # alphas = [0.05]
    taskname = "different_models"  # NOTE:
    os.makedirs(f"{OUTPUT_DIR}/{taskname}", exist_ok=True)
    output_file = f"{OUTPUT_DIR}/{taskname}/result_sgd.csv"

    # Create all experiment combinations
    experiment_args = []
    for seed in seeds:
        # for dataset_name in ["meps_19", "meps_20", "cal_housing"]:
        for dataset_name in ["cal_housing", "abalone", "cpu_small"]:
            for calibration_ratio in calibration_ratios:
                for train_ratio in train_ratios:
                    for alpha in alphas:
                        for model in ["linear", "two_layer", "three_layer"]:
                            n_iterations = 10
                            lr0_candidates = np.logspace(-6, 0, 20)
                            kwargs = [("model", model), ("optim", "SGD")]
                            experiment_args.append(
                                (
                                    dataset_name,
                                    alpha,
                                    calibration_ratio,
                                    train_ratio,
                                    n_iterations,
                                    lr0_candidates,
                                    kwargs,
                                    seed,
                                )
                            )

    print(f"Total experiments to run: {len(experiment_args)}")

    # Determine number of processes (use CPU count, but limit for GPU memory)
    num_processes = min(mp.cpu_count(), MAX_PROCESSES)
    print(f"Using {num_processes} processes")

    # Run experiments in parallel
    with mp.Pool(processes=num_processes) as pool:
        results = list(
            tqdm(
                pool.imap_unordered(run_single_experiment, experiment_args),
                total=len(experiment_args),
                desc="Running experiments",
            )
        )

    # Filter out None results (failed experiments)
    results = [r for r in results if r is not None]

    # Convert to DataFrame and save
    df = pd.DataFrame(results)
    df.to_csv(output_file, index=False)

    print(f"Completed {len(results)} experiments successfully")
    print(f"Results saved to {output_file}")

    return df


@cli.command()
def data_allocation_guidance():
    mp.set_start_method("spawn", force=True)
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    seeds = list(range(20))
    # calibration_ratios = [0.2]
    # train_ratios = [0.8]
    # calibration_ratios = [0.8]
    # train_ratios = [0.01]
    train_ratios = [0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99]
    # alphas = [0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, 0.15, 0.2]
    alphas = [0.001, 0.002, 0.003, 0.004, 0.005, 0.006, 0.007, 0.008, 0.009]
    taskname = "data_allocation_guidance"  # NOTE:
    os.makedirs(f"{OUTPUT_DIR}/{taskname}", exist_ok=True)
    output_file = f"{OUTPUT_DIR}/{taskname}/result_var_train_1.csv"
    # Create all experiment combinations
    experiment_args = []
    for seed in seeds:
        # for dataset_name in ["meps_19", "meps_20", "cal_housing"]:
        for dataset_name in ["cpu_small"]:
            for train_ratio in train_ratios:
                calibration_ratio = 1 - train_ratio
                for alpha in alphas:
                    n_iterations = 10
                    lr0_candidates = np.logspace(-6, 0, 20)
                    kwargs = [("model", "linear")]
                    experiment_args.append(
                        (
                            dataset_name,
                            alpha,
                            calibration_ratio,
                            train_ratio,
                            n_iterations,
                            lr0_candidates,
                            kwargs,
                            seed,
                        )
                    )

    print(f"Total experiments to run: {len(experiment_args)}")

    # Determine number of processes (use CPU count, but limit for GPU memory)
    num_processes = min(mp.cpu_count(), MAX_PROCESSES)
    print(f"Using {num_processes} processes")

    # Run experiments in parallel
    with mp.Pool(processes=num_processes) as pool:
        results = list(
            tqdm(
                pool.imap_unordered(run_single_experiment, experiment_args),
                total=len(experiment_args),
                desc="Running experiments",
            )
        )

    # Filter out None results (failed experiments)
    results = [r for r in results if r is not None]

    # Convert to DataFrame and save
    df = pd.DataFrame(results)
    df.to_csv(output_file, index=False)

    print(f"Completed {len(results)} experiments successfully")
    print(f"Results saved to {output_file}")

    return df

if __name__ == "__main__":
    cli()
