'''
UCI Dataset experiment with parallel processing capabilities
Based on the original run_models_165.py but configured for parallel execution.
'''
import os
import sys
sys.path.append('.')
import multiprocessing as mp
from ucimlrepo import dotdict, fetch_ucirepo 
import time
import numpy as np
import pandas as pd

from common import *

from baseline.bayesian_last_layer_networks import create_bll_model, train_bll_model
from baseline.monte_carlo_dropout import MCDropout, train_mc_dropout_model
from baseline.probabilistic_neural_networks import ProbabilisticNeuralNetwork, train_pnn_model
from baseline.stochastic_weight_averaging_gaussian import create_swag_model, train_swag_model
from baseline.deterministic_variational_inference import DVI, train_dvi_model
from baseline.deep_gaussian_process_regression import DeepGPRegression, train_deep_gp_model
from baseline.mixture_density_network import MixtureDensityNetwork, train_mdn_model
from hvbll.vbll import VBLL, train_vbll_model
from hvbll.hvbll import HVBLL, train_hvbll_model


ID_UCI = 291
GPU_ID = 0  # Default GPU ID (will be overridden when parallelizing)
N_SEED = 10
NUM_PARALLEL_PROCESSES = 150  # Number of processes to run in parallel


# Common parameters for all models
dim_latent = 32
dim_hidden = 64
n_hidden_layers = 3
learning_rate = 0.01
lr_step_size = 500


def run_single_case(args):
    """
    Run a single experiment case
    
    Args:
        args: Tuple containing (i_case_partial_x, i_case_sample, i_seed, gpu_id, old_dataset)
    
    Returns:
        None, results are written to files
    """
    i_case_partial_x, i_case_sample, i_seed, gpu_id, old_dataset = args
    
    # Set the random seed for reproducibility
    set_seed(i_seed)
    
    # Log which process is handling this task
    process_id = mp.current_process().name

    # Prepare the case
    dict_problem = prepare_case(ID_UCI, 
                    i_case_partial_x=i_case_partial_x,
                    i_case_sample=i_case_sample,
                    seed=i_seed,
                    GPU_ID=gpu_id,
                    old_dataset=old_dataset
                    )

    dim_input = dict_problem['dim_input']
    dim_output = dict_problem['dim_output']
    X_train_tensor = dict_problem['X_train_tensor']
    y_train_tensor = dict_problem['y_train_tensor']
    X_test_tensor = dict_problem['X_test_tensor']
    y_test_tensor = dict_problem['y_test_tensor']
    dataloader = dict_problem['dataloader']
    n_train_sample = len(X_train_tensor)

    # Dictionary to store model configurations
    models = {
        'HVBLL': {
            'create_fn': HVBLL,
            'params': {
                'dim_input': dim_input,
                'dim_output': dim_output,
                'dim_latent': dim_latent,
                'dim_hidden': dim_hidden,
                'dim_hidden_noise': 64,
                'n_hidden_layers': n_hidden_layers,
                'n_noise_layers': 1,
                'reg_weight_latent': 1.0,
                'reg_weight_noise': 1.0,
                'covariance_type': 'dense',
                'prior_scale': 1.0,
                'wishart_scale': 10.0,
                'dof': 1.0
            },
            'train_fn': train_hvbll_model,
            'color': 'gray',
            'num_epochs': 3000
        },
        'VBLL': {
            'create_fn': VBLL,
            'params': {
                'dim_input': dim_input,
                'dim_output': dim_output,
                'dim_latent': dim_latent,
                'dim_hidden': dim_hidden,
                'n_hidden_layers': n_hidden_layers,
                'reg_weight_latent': 1.0,
                'reg_weight_noise': 1.0,
                'covariance_type': 'dense',
                'prior_scale': 1.0,
                'wishart_scale': 1.0,
                'dof': 1.0
            },
            'train_fn': train_vbll_model,
            'color': 'pink',
            'num_epochs': 3000
        },
        'BLL': {
            'create_fn': create_bll_model,
            'params': {
                'dim_input': dim_input,
                'dim_hidden': dim_hidden,
                'dim_latent': 8,    # Dimension of the feature space
                'n_hidden_layers': n_hidden_layers,
                'kernel_type': 'rbf',
                'use_derivatives': False,  # Set to True for LDGBLL, False for GBLL
                'prior_log_obs_var': 2.0
            },
            'train_fn': train_bll_model,
            'color': 'blue',
            'num_epochs': 3000
        },
        'MC-Dropout': {
            'create_fn': MCDropout,
            'params': {
                'dim_input': dim_input,
                'dim_output': dim_output,
                'dim_hidden': dim_hidden,
                'n_hidden_layers': n_hidden_layers,
                'dropout_rate': 0.1
            },
            'train_fn': train_mc_dropout_model,
            'color': 'red',
            'num_samples': 30,  # This is used in get_prediction, not in the constructor
            'num_epochs': 3000
        },
        'Deep-GP': {
            'create_fn': DeepGPRegression,
            'params': {
                'dim_input': dim_input,
                'dim_output': dim_output,
                'dim_hidden': dim_hidden,
                'num_hidden_layers': n_hidden_layers,
                'num_inducing': 64,
                'num_samples': 50
            },
            'train_fn': train_deep_gp_model,
            'color': 'cyan',
            'num_epochs': 3000
        },
        'PNN': {
            'create_fn': ProbabilisticNeuralNetwork,
            'params': {
                'dim_input': dim_input,
                'dim_output': dim_output,
                'dim_hidden': dim_hidden,
                'n_hidden_layers': n_hidden_layers,
                'beta': 0.5  # Beta parameter for beta-NLL loss
            },
            'train_fn': train_pnn_model,
            'color': 'purple',
            'num_epochs': 3000
        },
        'SWAG': {
            'create_fn': create_swag_model,
            'params': {
                'dim_input': dim_input,
                'dim_output': dim_output,
                'dim_hidden': dim_hidden,
                'n_hidden_layers': n_hidden_layers,
                'max_models': 20,
                'swa_start': 2000,  # Epoch to start SWA
                'swa_lr': 1e-4,
                'var_clamp': 1e-6,
                'full_cov': True,  # Use full covariance (diagonal + low-rank)
                'prior_log_obs_var': None
            },
            'train_fn': train_swag_model,
            'color': 'orange',
            'num_epochs': 3000
        },
        'DVI': {
            'create_fn': DVI,
            'params': {
                'dim_input': dim_input,
                'dim_output': dim_output,
                'dim_hidden': dim_hidden,
                'n_hidden_layers': n_hidden_layers,
                'ratio_kl': 1.0,
                'n_components': 5,  # Number of Gaussian components
                'prior_mean': 0.0,
                'prior_var': 0.1,
                'prior_log_obs_var': 0.0
            },
            'train_fn': train_dvi_model,
            'color': 'brown',
            'num_epochs': 1000
        },
        'MDN': {
            'create_fn': MixtureDensityNetwork,
            'params': {
                'dim_input': dim_input,
                'dim_output': dim_output,
                'dim_hidden': dim_hidden,
                'num_hidden_layers': n_hidden_layers,
                'num_components': 5,
            },
            'train_fn': train_mdn_model,
            'color': 'green',
            'num_epochs': 3000
        },
    }

    #! only run VBLL and HVBLL
    models = {k: v for k, v in models.items() if k in ['VBLL']}

    #! skip BBB
    # models = {k: v for k, v in models.items() if k not in ['BBB']}


    # Train and evaluate each model
    for model_name, model_config in models.items():

        # Check if this case has already been run
        if case_already_run(model_name, ID_UCI, i_case_partial_x, n_train_sample, i_seed):
            print(f"Process {process_id}: Skipping for {model_name} - already run for partial_x={i_case_partial_x}, sample={i_case_sample}, seed={i_seed}")
            continue

        # Create model
        model = model_config['create_fn'](**model_config['params'])
        if torch.cuda.is_available():
            model.cuda(gpu_id)
        
        # Train model
        try:
            train_results = model_config['train_fn'](
                model, dataloader, X_train_tensor, y_train_tensor, X_test_tensor, y_test_tensor, 
                model_config['num_epochs'], learning_rate, lr_step_size
            )
            
        except Exception as e:
            print(f"Process {process_id}: Error training {model_name}, partial_x={i_case_partial_x}, sample={i_case_sample}: {e}")
            continue
        
        # Write results to file
        fname = fname_summary(model_name, ID_UCI, i_case_partial_x)
        
        # Write metrics to file - use file locking to avoid corruption
        with open(fname, 'a') as f:
            # Write header if file is empty
            if os.path.getsize(fname) == 0:
                # Construct comprehensive header with all available metrics
                header_parts = ["n_sample", "seed"]
                
                # Basic train metrics
                header_parts.extend([
                    "train_mse", "train_nll", "train_mae", "train_crps", 
                    "train_coverage_95", "train_width_95", "train_ace"
                ])
                
                # Train coverage at different confidence levels
                for conf in [10, 20, 30, 40, 50, 60, 70, 80, 90]:
                    header_parts.append(f"train_coverage_{conf}")
                
                # Train width at different confidence levels
                for conf in [10, 20, 30, 40, 50, 60, 70, 80, 90]:
                    header_parts.append(f"train_width_{conf}")
                
                # Basic test metrics
                header_parts.extend([
                    "test_mse", "test_nll", "test_mae", "test_crps",
                    "test_coverage_95", "test_width_95", "test_ace"
                ])
                
                # Test coverage at different confidence levels
                for conf in [10, 20, 30, 40, 50, 60, 70, 80, 90]:
                    header_parts.append(f"test_coverage_{conf}")
                
                # Test width at different confidence levels
                for conf in [10, 20, 30, 40, 50, 60, 70, 80, 90]:
                    header_parts.append(f"test_width_{conf}")
                
                # Training metadata
                header_parts.extend(["training_time", "epoch"])
                
                f.write(", ".join(header_parts) + "\n")
                
            # Write the metrics for this run
            data_parts = [str(n_train_sample), str(i_seed)]
            
            # Basic train metrics
            data_parts.extend([
                f"{train_results.get('train_mse', 0.0):.6f}",
                f"{train_results.get('train_nll', 0.0):.6f}",
                f"{train_results.get('train_mae', 0.0):.6f}",
                f"{train_results.get('train_crps', 0.0):.6f}",
                f"{train_results.get('train_coverage_95', 0.0):.6f}",
                f"{train_results.get('train_width_95', 0.0):.6f}",
                f"{train_results.get('train_ace', 0.0):.6f}"
            ])
            
            # Train coverage at different confidence levels
            for conf in [10, 20, 30, 40, 50, 60, 70, 80, 90]:
                data_parts.append(f"{train_results.get(f'train_coverage_{conf}', 0.0):.6f}")
            
            # Train width at different confidence levels
            for conf in [10, 20, 30, 40, 50, 60, 70, 80, 90]:
                data_parts.append(f"{train_results.get(f'train_width_{conf}', 0.0):.6f}")
            
            # Basic test metrics
            data_parts.extend([
                f"{train_results.get('test_mse', 0.0):.6f}",
                f"{train_results.get('test_nll', 0.0):.6f}",
                f"{train_results.get('test_mae', 0.0):.6f}",
                f"{train_results.get('test_crps', 0.0):.6f}",
                f"{train_results.get('test_coverage_95', 0.0):.6f}",
                f"{train_results.get('test_width_95', 0.0):.6f}",
                f"{train_results.get('test_ace', 0.0):.6f}"
            ])
            
            # Test coverage at different confidence levels
            for conf in [10, 20, 30, 40, 50, 60, 70, 80, 90]:
                data_parts.append(f"{train_results.get(f'test_coverage_{conf}', 0.0):.6f}")
            
            # Test width at different confidence levels
            for conf in [10, 20, 30, 40, 50, 60, 70, 80, 90]:
                data_parts.append(f"{train_results.get(f'test_width_{conf}', 0.0):.6f}")
            
            # Training metadata
            data_parts.extend([
                f"{train_results.get('training_time', 0.0):.6f}",
                f"{train_results.get('epoch', 0.0):.1f}"
            ])
            
            f.write(", ".join(data_parts) + "\n")
            
        # Print summary of results
        print(f"\nProcess {process_id}: Results for {model_name} (partial_x={i_case_partial_x}, sample={i_case_sample}, seed={i_seed}, GPU={gpu_id}):")
        print(f"Training Time: {train_results['training_time']:.2f} seconds")
        print(f"Train MSE: {train_results['train_mse']:.6f}")
        print(f"Train NLL: {train_results['train_nll']:.6f}")
        print(f"Test MSE: {train_results['test_mse']:.6f}")
        print(f"Test NLL: {train_results['test_nll']:.6f}")


def case_already_run(model_name, id_dataset, i_case_partial_x, n_train_sample, i_seed):
    """
    Check if a specific case has already been run by looking at the summary file
    
    Args:
        model_name: Name of the model
        id_dataset: Dataset ID
        i_case_partial_x: Index for partial feature case
        n_train_sample: Training sample size
        i_seed: Seed index
        
    Returns:
        Boolean indicating if the case has been run
    """
    fname = fname_summary(model_name, id_dataset, i_case_partial_x)
    
    if not os.path.exists(fname):
        return False
    
    try:
        # Check if file is empty
        if os.path.getsize(fname) == 0:
            return False
            
        # Read the file as text first to check content
        with open(fname, 'r') as f:
            content = f.read().strip()
            if not content or ',' not in content:  # If file is empty or doesn't have CSV format
                return False
        
        # Use pandas to read the CSV safely with specific settings
        # Note: 'error_bad_lines' was renamed to 'on_bad_lines' in newer pandas versions
        try:
            # First try with the newer parameter name
            df = pd.read_csv(fname, sep=',', skipinitialspace=True, comment='#', 
                         skip_blank_lines=True, on_bad_lines='warn')
        except TypeError:
            # Fall back to older pandas versions
            df = pd.read_csv(fname, sep=',', skipinitialspace=True, comment='#', 
                         skip_blank_lines=True, error_bad_lines=False)
        
        # Check if required columns exist
        if 'n_sample' not in df.columns or 'seed' not in df.columns:
            print(f"Warning: Missing required columns in {fname}")
            return False
        
        # Check if this specific case exists in the results
        mask = (df['n_sample'] == n_train_sample) & (df['seed'] == i_seed)
        case_exists = mask.any()
        
        if case_exists:
            print(f"Found existing run for {model_name}, n_sample={n_train_sample}, seed={i_seed}")
            
        return case_exists
        
    except Exception as e:
        print(f"Error checking if case already run: {e} (file: {fname})")
        # If there's any error reading the file, assume the case hasn't been run
        return False


if __name__ == "__main__":
    
    # Get number of available GPUs
    num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
    print(f"Number of available GPUs: {num_gpus}")
    
    old_dataset = fetch_ucirepo(id=ID_UCI)
    
    # Generate all combinations of parameters
    all_cases = []
    for i_case_partial_x in range(len(TEST_CASES[ID_UCI]['index_delete_features'])):
        for i_case_sample in range(len(TEST_CASES[ID_UCI]['n_samples'])):
            for i_seed in range(N_SEED):
                # Assign a GPU to this job in round-robin fashion if multiple GPUs available
                gpu_id = assign_gpu(len(all_cases), num_gpus)
                all_cases.append((i_case_partial_x, i_case_sample, i_seed, gpu_id, old_dataset))
    
    total_cases = len(all_cases)
    print(f"Total number of cases to run: {total_cases}")
    
    # Run cases in parallel
    start_time = time.time()
    
    if NUM_PARALLEL_PROCESSES > 1:
        print(f"Running {NUM_PARALLEL_PROCESSES} processes in parallel")
        # Use 'spawn' for better compatibility with CUDA
        mp.set_start_method('spawn', force=True)
        with mp.Pool(processes=NUM_PARALLEL_PROCESSES) as pool:
            # Use imap to get results as they complete
            for _ in pool.imap_unordered(run_single_case, all_cases):
                pass
    else:
        print("Running sequentially")
        for i, case in enumerate(all_cases):
            print(f"Running case {i+1}/{total_cases}")
            run_single_case(case)
    
    end_time = time.time()
    total_time = end_time - start_time
    print(f"Total execution time: {total_time:.2f} seconds")
    print(f"Average time per case: {total_time/total_cases:.2f} seconds")

