import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn import linear_model
from tqdm import tqdm
import multiprocessing as mp
from functools import partial
import os
import gc

from dataset import load_dataset

# Configuration
USE_PARALLEL = True  # Set to False to use sequential version
MAX_PROCESSES = 1000  # Maximum number of parallel processes (adjust based on GPU memory)
OUTPUT_DIR = "outputs"

# Set device for GPU acceleration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


class PinballLoss(nn.Module):
    """Pinball loss for quantile regression"""
    def __init__(self, quantile):
        super(PinballLoss, self).__init__()
        self.quantile = quantile
    
    def forward(self, pred, target):
        """
        Compute pinball loss
        pred: predicted values
        target: true values
        quantile: quantile level (e.g., 0.05 for 5th percentile)
        """
        error = target - pred
        loss = torch.where(error >= 0, 
                          self.quantile * error, 
                          (self.quantile - 1) * error)
        return torch.mean(loss)


class QuantileLinearModel(nn.Module):
    """Linear model for quantile regression"""
    def __init__(self, input_dim):
        super(QuantileLinearModel, self).__init__()
        self.linear = nn.Linear(input_dim, 1)
        
    def forward(self, x):
        return self.linear(x).squeeze()


def train_quantile_model(X_train, y_train, quantile, epochs=1000, lr=0.01, batch_size=64, patience=50, device=None):
    """
    Train a PyTorch linear model using pinball loss for a specific quantile
    
    Args:
        X_train: training features
        y_train: training targets
        quantile: quantile level (e.g., 0.05 for 5th percentile)
        epochs: maximum number of training epochs
        lr: learning rate
        patience: early stopping patience
    
    Returns:
        trained model
    """
    # Initialize device if not provided
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Convert to PyTorch tensors and move to device
    X_tensor = torch.FloatTensor(X_train).to(device)
    y_tensor = torch.FloatTensor(y_train).to(device)
    
    # Initialize model and move to device
    input_dim = X_train.shape[1]
    model = QuantileLinearModel(input_dim).to(device)
    
    # Loss function and optimizer
    criterion = PinballLoss(quantile)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    # Training loop
    best_loss = float('inf')
    patience_counter = 0
    
    for epoch in range(epochs):
        for i in range(0, len(X_tensor), batch_size):
            X_batch = X_tensor[i:i+batch_size]
            y_batch = y_tensor[i:i+batch_size]

            # Forward pass
            optimizer.zero_grad()
            predictions = model(X_batch)
            loss = criterion(predictions, y_batch)
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            # Early stopping
            if loss.item() < best_loss:
                best_loss = loss.item()
                patience_counter = 0
            else:
                patience_counter += 1
                
            if patience_counter >= patience:
                break
            
    
    # Clear GPU memory
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    return model, loss.item()

def _transform_dataset(X_train, y_train, X_test, y_test, train_ratio, calibration_ratio, seed=None):

    # reshape the data
    X_train = np.asarray(X_train)
    y_train = np.asarray(y_train)
    X_test = np.asarray(X_test)
    y_test = np.asarray(y_test)

    # input dimensions
    n_train_original = X_train.shape[0]
    in_shape = X_train.shape[1]

    # set seed for splitting the data into proper train and calibration
    if seed is not None:
        np.random.seed(seed)
    idx = np.random.permutation(n_train_original)

    # divide the data into proper training set and calibration set
    n_train = int(np.floor(n_train_original * train_ratio))
    n_calibration = int(np.floor(n_train_original * calibration_ratio))
    
    # Ensure non-overlapping splits
    idx_train = idx[:n_train]
    idx_cal = idx[-n_calibration:]

    # zero mean and unit variance scaling of the train and test features
    scalerX = StandardScaler()
    scalerX = scalerX.fit(X_train[idx_train])
    X_train = scalerX.transform(X_train)
    X_test = scalerX.transform(X_test)
    
    # scale the labels by dividing each by the mean absolute response
    mean_ytrain = np.mean(np.abs(y_train[idx_train]))
    y_train = np.squeeze(y_train) / mean_ytrain
    y_test = np.squeeze(y_test) / mean_ytrain

    X_train = X_train.astype(np.float32)
    X_test = X_test.astype(np.float32)
    y_train = y_train.astype(np.float32)
    y_test = y_test.astype(np.float32)
    idx_train = idx_train.astype(np.int32)
    # idx_cal = idx_cal.astype(np.int32)
    return X_train, X_test, y_train, y_test, idx_train, idx_cal


def successive_halving_lr0_tuning(X_train, y_train, gamma, epochs=1, 
                                 lr0_candidates=None, n_iterations=3, 
                                 eta=2, min_budget=1, debug=False, device=None):
    """
    Successive Halving algorithm for tuning lr0 hyperparameter.
    
    Parameters:
    -----------
    X_train, y_train : training data
    gamma : quantile level
    rng : random number generator
    epochs : number of training epochs
    lr0_candidates : list of learning rates to try (if None, will generate automatically)
    n_iterations : number of successive halving iterations
    eta : elimination factor (how many candidates to eliminate each round)
    min_budget : minimum number of candidates to keep
    
    Returns:
    --------
    best_lr0 : best learning rate found
    best_loss : corresponding loss value
    tuning_history : list of (iteration, candidates, losses) for each round
    """
    def pprint(text):
        if debug:
            print(text)
    
    # Generate learning rate candidates if not provided
    if lr0_candidates is None:
        # Use log-uniform distribution for learning rates
        lr0_candidates = np.logspace(-5, 0, 20)  # From 1e-5 to 1e-1
    
    tuning_history = []
    current_candidates = lr0_candidates.copy()
    
    pprint(f"Starting Successive Halving with {len(current_candidates)} candidates")
    pprint(f"Learning rate range: [{min(current_candidates):.2e}, {max(current_candidates):.2e}]")
    
    for iteration in range(n_iterations):
        pprint(f"\n--- Iteration {iteration + 1} ---")
        pprint(f"Evaluating {len(current_candidates)} candidates")
        
        # Evaluate all current candidates
        candidate_losses = []
        
        for i, lr0 in enumerate(current_candidates):
            # Train model with current lr0
            model, final_loss = train_quantile_model(
            X_train, 
            y_train, 
            quantile=gamma,
            epochs=1,
            lr=lr0,
            device=device
        )
            
            # Use the last (final) loss value as evaluation metric
            # final_loss = model.loss_list[-1]
            candidate_losses.append(final_loss)
            
            pprint(f"  lr0={lr0:.2e}: final_loss={final_loss:.6f}")
        
        # Store history
        tuning_history.append({
            'iteration': iteration + 1,
            'candidates': current_candidates.copy(),
            'lr': lr0,
            'losses': candidate_losses.copy()
        })
        
        # Select top candidates for next round
        if iteration < n_iterations - 1:  # Don't eliminate on last iteration
            # Sort by loss (lower is better for quantile regression)
            sorted_indices = np.argsort(candidate_losses)
            
            # Keep top candidates (eliminate worst ones)
            n_keep = max(min_budget, len(current_candidates) // eta)
            top_indices = sorted_indices[:n_keep]
            
            current_candidates = current_candidates[top_indices]
            pprint(f"Kept top {len(current_candidates)} candidates for next round")
        else:
            # Final iteration - select best
            best_idx = np.argmin(candidate_losses)
            best_lr0 = current_candidates[best_idx]
            best_loss = candidate_losses[best_idx]
            
            pprint(f"\nFinal best: lr0={best_lr0:.2e}, loss={best_loss:.6f}")
    
    return model, best_lr0, best_loss, tuning_history


def run_experiment(
    dataset_name, 
    alpha=0.1, 
    calibration_ratio=0.2,
    test_ratio=0.2,
    train_ratio=0.5,
    seed=42,
    device=None
):
    # Initialize device if not provided
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    quantiles = [alpha / 2, 1 - alpha / 2]
    data = {"seed": seed, "dataset_name": dataset_name, "alpha": alpha, "calibration_ratio": calibration_ratio, "train_ratio": train_ratio}
    
    ## Fix seed
    data_seed = 42
    random.seed(data_seed)
    np.random.seed(data_seed)
    torch.manual_seed(data_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(data_seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    X, y = load_dataset(dataset_name)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_ratio, random_state=data_seed)
    X_train, X_test, y_train, y_test, idx_train, idx_cal = _transform_dataset(X_train, y_train, X_test, y_test, train_ratio=train_ratio, calibration_ratio=calibration_ratio, seed=data_seed)

    ## Fix seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # print(f"\nTraining PyTorch linear models for quantiles: {quantiles}")

    models = {}
    predictions_calib = {}
    predictions_test = {}
    for quantile in quantiles:
        model, _, _, _ = successive_halving_lr0_tuning(
            X_train[idx_train], 
            y_train[idx_train], 
            gamma=quantile,
            epochs=1,
            device=device
        )
        
        models[quantile] = model
        
        # Make predictions
        with torch.no_grad():
            X_test_tensor = torch.FloatTensor(X_test).to(device)
            pred = model(X_test_tensor).cpu().numpy()
            predictions_test[quantile] = pred

            X_cal_tensor = torch.FloatTensor(X_train[idx_cal]).to(device)
            pred_cal = model(X_cal_tensor).cpu().numpy()
            predictions_calib[quantile] = pred_cal

    # Calculate prediction intervals
    t_lo = predictions_calib[quantiles[0]]
    t_hi = predictions_calib[quantiles[1]]

    def compute_cqr_nonconformity_score(t_lo, t_hi, y):
        return np.max(np.stack([t_lo - y, y - t_hi], axis=0), axis=0)

    cqr_nc = compute_cqr_nonconformity_score(t_lo, t_hi, y_train[idx_cal])
    q = np.quantile(cqr_nc, 1 - alpha, method="higher")
    # print(f"1-alpha: {1-alpha}, q: {q} median = {np.median(cqr_nc)}")

    # Coverage (percentage of test points within prediction interval)
    t_lo_test = predictions_test[quantiles[0]]
    t_hi_test = predictions_test[quantiles[1]]

    data['crossing_mean'] = np.mean(t_hi_test > t_lo_test).item()

    t_lo_test, t_hi_test = np.min(np.c_[t_lo_test, t_hi_test], axis=1), np.max(np.c_[t_lo_test, t_hi_test], axis=1)

    coverage = np.mean((y_test >= t_lo_test - q) & (y_test <= t_hi_test + q))
    data['coverage'] = coverage.item()
    data['interval_width'] = (t_hi_test - t_lo_test).mean().item()
    data['predicted_interval_width'] = (t_hi_test - t_lo_test + 2 * q).mean().item()
    data['delta'] = t_hi_test - t_lo_test
    data['q'] = q.item()

    
    # Clear GPU memory and run garbage collection
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()
    
    return data



def run_single_experiment(args):
    """Wrapper function for parallel execution"""
    dataset_name, alpha, calibration_ratio, train_ratio, seed = args
    
    try:
        # Initialize CUDA device in subprocess
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        result = run_experiment(
            dataset_name=dataset_name,
            alpha=alpha,
            calibration_ratio=calibration_ratio,
            test_ratio=0.2,
            train_ratio=train_ratio,
            seed=seed,
            device=device
        )
        return result
    except Exception as e:
        print(f"Error in experiment {args}: {e}")
        raise e
        return None


def run_parallel_experiments():
    """Run experiments in parallel"""
    # Create output directory if it doesn't exist
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    
    # Define experiment parameters
    datasets = ["meps_20"]
    seeds = list(range(5, 20))
    calibration_ratios = [0.05, 0.1, 0.15, 0.2]
    train_ratios = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
    alphas = [0.01, 0.05, 0.1, 0.2]

    # Create all experiment combinations
    experiment_args = []
    for dataset_name in datasets:
        for seed in seeds:
            for calibration_ratio in calibration_ratios:
                for train_ratio in train_ratios:
                    for alpha in alphas:
                        experiment_args.append((dataset_name, alpha, calibration_ratio, train_ratio, seed))
    
    print(f"Total experiments to run: {len(experiment_args)}")
    
    # Determine number of processes (use CPU count, but limit for GPU memory)
    num_processes = min(mp.cpu_count(), MAX_PROCESSES)
    print(f"Using {num_processes} processes")
    
    # Run experiments in parallel
    with mp.Pool(processes=num_processes) as pool:
        results = list(tqdm(
            pool.imap(run_single_experiment, experiment_args),
            total=len(experiment_args),
            desc="Running experiments"
        ))
    
    # Filter out None results (failed experiments)
    os.makedirs(f"{OUTPUT_DIR}/meps_cqr", exist_ok=True)
    results = [r for r in results if r is not None]
    delta_list = []
    with open(f"{OUTPUT_DIR}/meps_cqr/index.txt", 'w') as f:
        for r in results:
            delta_list.append(r['delta'])
            del r['delta']
            f.write("{dataset_name}_{alpha}_{calibration_ratio}_{train_ratio}_{seed}\n".format(
                dataset_name=r['dataset_name'],
                alpha=r['alpha'],
                calibration_ratio=r['calibration_ratio'],
                train_ratio=r['train_ratio'],
                seed=r['seed']
            ))
    
    delta_list = np.c_[delta_list]
    np.save(f"{OUTPUT_DIR}/meps_cqr/delta.npy", delta_list)
    
    # Convert to DataFrame and save
    df = pd.DataFrame(results)
    output_file = f"{OUTPUT_DIR}/meps_cqr.csv"
    df.to_csv(output_file, index=False)
    
    print(f"Completed {len(results)} experiments successfully")
    print(f"Results saved to {output_file}")
    
    return df


if __name__ == "__main__":
    mp.set_start_method('spawn', force=True)    
    print("Running experiments in parallel...")
    df = run_parallel_experiments()
