'''
Laminate Open Hole Compression Test experiment with parallel processing capabilities.
'''
import os
import sys
sys.path.append('.')
import multiprocessing as mp
import time
import numpy as np
import pandas as pd

from common import *

from baseline.bayesian_last_layer_networks import create_bll_model, train_bll_model
from baseline.monte_carlo_dropout import MCDropout, train_mc_dropout_model
from baseline.probabilistic_neural_networks import ProbabilisticNeuralNetwork, train_pnn_model
from baseline.stochastic_weight_averaging_gaussian import create_swag_model, train_swag_model
from baseline.deterministic_variational_inference import DVI, train_dvi_model
from baseline.deep_gaussian_process_regression import DeepGPRegression, train_deep_gp_model
from baseline.mixture_density_network import MixtureDensityNetwork, train_mdn_model
from hvbll.vbll import VBLL, train_vbll_model
from hvbll.hvbll import HVBLL, train_hvbll_model

N_SEED = 10
NUM_PARALLEL_PROCESSES = 64  # Number of processes to run in parallel


# Common parameters for all models
dim_latent = 32
dim_hidden = 64
n_hidden_layers = 3
learning_rate = 0.01
lr_step_size = 500

MODELS = ['HVBLL', 'VBLL', 'BLL', 'MC-Dropout', 'Deep-GP', 'PNN', 'SWAG', 'DVI', 'MDN']
MODELS = ['MC-Dropout']
N_SAMPLES = [500, 1000, 4000]


def run_single_case(args):
    """
    Run a single experiment case
    
    Args:
        args: Tuple containing (i_seed, gpu_id, model_name, max_samples)
    
    Returns:
        None, results are written to files
    """
    i_seed, gpu_id, model_name, max_samples = args
    
    # Log which process is handling this task
    process_id = mp.current_process().name
    print(f"Process {process_id}: Assigned to GPU {gpu_id} for {model_name} (seed={i_seed})")
    
    # Check if this case has already been run
    _n_sample = int(max_samples*0.8)
    if case_already_run(model_name, i_seed, _n_sample):
        print(f"Process {process_id}: Skipping for {model_name} - already run for seed={i_seed}")
        return
    
    # Reassign GPU based on current memory availability
    # gpu_id = assign_gpu_by_memory()
    
    # Set the random seed for reproducibility
    set_seed(i_seed)
    
    # Prepare the case
    dict_problem = prepare_case(batch_size=512, seed=i_seed, GPU_ID=gpu_id, max_samples=max_samples)

    dim_input = dict_problem['dim_input']
    dim_output = dict_problem['dim_output']
    X_train_tensor = dict_problem['X_train_tensor']
    y_train_tensor = dict_problem['y_train_tensor']
    X_test_tensor = dict_problem['X_test_tensor']
    y_test_tensor = dict_problem['y_test_tensor']
    dataloader = dict_problem['dataloader']
    n_train_sample = len(X_train_tensor)
    if _n_sample != n_train_sample:
        raise ValueError(f"n_train_sample={n_train_sample} does not match _n_sample={_n_sample}")

    # Dictionary to store model configurations
    models = {
        'HVBLL': {
            'create_fn': HVBLL,
            'params': {
                'dim_input': dim_input,
                'dim_output': dim_output,
                'dim_latent': dim_latent,
                'dim_hidden': dim_hidden,
                'dim_hidden_noise': 16,
                'n_hidden_layers': n_hidden_layers,
                'n_noise_layers': 1,
                'reg_weight_latent': 0.1,
                'reg_weight_noise': 0.1,
                'covariance_type': 'dense',
                'prior_scale': 1.0,
                'wishart_scale': 1.0,
                'dof': 1.0
            },
            'train_fn': train_hvbll_model,
            'color': 'gray',
            'num_epochs': 5000
        },
        'VBLL': {
            'create_fn': VBLL,
            'params': {
                'dim_input': dim_input,
                'dim_output': dim_output,
                'dim_latent': dim_latent,
                'dim_hidden': dim_hidden,
                'n_hidden_layers': n_hidden_layers,
                'reg_weight_latent': 0.1,
                'reg_weight_noise': 0.1,
                'covariance_type': 'dense',
                'prior_scale': 1.0,
                'wishart_scale': 1.0,
                'dof': 1.0
            },
            'train_fn': train_vbll_model,
            'color': 'pink',
            'num_epochs': 5000
        },
        'BLL': {
            'create_fn': create_bll_model,
            'params': {
                'dim_input': dim_input,
                'dim_hidden': dim_hidden,
                'dim_latent': 32,    # Dimension of the feature space
                'n_hidden_layers': n_hidden_layers,
                'kernel_type': 'rbf',
                'use_derivatives': False,  # Set to True for LDGBLL, False for GBLL
                'prior_log_obs_var': 3.0
            },
            'train_fn': train_bll_model,
            'color': 'blue',
            'num_epochs': 5000
        },
        'MC-Dropout': {
            'create_fn': MCDropout,
            'params': {
                'dim_input': dim_input,
                'dim_output': dim_output,
                'dim_hidden': 128,      # The default causes extremely bad results
                'n_hidden_layers': 1,   # The default causes extremely bad results
                'dropout_rate': 0.1
            },
            'train_fn': train_mc_dropout_model,
            'color': 'red',
            'num_samples': 30,  # This is used in get_prediction, not in the constructor
            'num_epochs': 100
        },
        'Deep-GP': {
            'create_fn': DeepGPRegression,
            'params': {
                'dim_input': dim_input,
                'dim_output': dim_output,
                'dim_hidden': dim_hidden,
                'num_hidden_layers': n_hidden_layers,
                'num_inducing': 64,
                'num_samples': 50
            },
            'train_fn': train_deep_gp_model,
            'color': 'cyan',
            'num_epochs': 5000
        },
        'PNN': {
            'create_fn': ProbabilisticNeuralNetwork,
            'params': {
                'dim_input': dim_input,
                'dim_output': dim_output,
                'dim_hidden': dim_hidden,
                'n_hidden_layers': n_hidden_layers,
                'beta': 0.5  # Beta parameter for beta-NLL loss
            },
            'train_fn': train_pnn_model,
            'color': 'purple',
            'num_epochs': 5000
        },
        'SWAG': {
            'create_fn': create_swag_model,
            'params': {
                'dim_input': dim_input,
                'dim_output': dim_output,
                'dim_hidden': dim_hidden,
                'n_hidden_layers': n_hidden_layers,
                'max_models': 20,
                'swa_start': 3000,  # Epoch to start SWA
                'swa_lr': 1e-3,
                'var_clamp': 0.001,
                'full_cov': True,  # Use full covariance (diagonal + low-rank)
                'prior_log_obs_var': None
            },
            'train_fn': train_swag_model,
            'color': 'orange',
            'num_epochs': 5000
        },
        'DVI': {
            'create_fn': DVI,
            'params': {
                'dim_input': dim_input,
                'dim_output': dim_output,
                'dim_hidden': dim_hidden,
                'n_hidden_layers': n_hidden_layers,
                'ratio_kl': 1.0,
                'n_components': 5,  # Number of Gaussian components
                'prior_mean': 0.0,
                'prior_var': 0.1,
                'prior_log_obs_var': 0.0
            },
            'train_fn': train_dvi_model,
            'color': 'brown',
            'num_epochs': 1000
        },
        'MDN': {
            'create_fn': MixtureDensityNetwork,
            'params': {
                'dim_input': dim_input,
                'dim_output': dim_output,
                'dim_hidden': dim_hidden,
                'num_hidden_layers': n_hidden_layers,
                'num_components': 5,
            },
            'train_fn': train_mdn_model,
            'color': 'green',
            'num_epochs': 3000
        },
    }

    # Train and evaluate each model
    model_config = models[model_name]

    # Create model
    model = model_config['create_fn'](**model_config['params'])
    if torch.cuda.is_available():
        model.cuda(gpu_id)
    
    # Train model
    try:
        train_results = model_config['train_fn'](
            model, dataloader, X_train_tensor, y_train_tensor, X_test_tensor, y_test_tensor, 
            model_config['num_epochs'], learning_rate, lr_step_size
        )
        
    except Exception as e:
        print(f"Process {process_id}: Error training {model_name}, seed={i_seed}: {e}")
        return
    
    # Write results to file
    fname = fname_summary(model_name)
    
    # Write metrics to file - use file locking to avoid corruption
    with open(fname, 'a') as f:
        # Write header if file is empty
        if os.path.getsize(fname) == 0:
            # Construct comprehensive header with all available metrics
            header_parts = ["n_sample", "seed"]
            
            # Basic train metrics
            header_parts.extend([
                "train_mse", "train_nll", "train_mae", "train_crps", 
                "train_coverage_95", "train_width_95", "train_ace"
            ])
            
            # Train coverage at different confidence levels
            for conf in [10, 20, 30, 40, 50, 60, 70, 80, 90]:
                header_parts.append(f"train_coverage_{conf}")
            
            # Train width at different confidence levels
            for conf in [10, 20, 30, 40, 50, 60, 70, 80, 90]:
                header_parts.append(f"train_width_{conf}")
            
            # Basic test metrics
            header_parts.extend([
                "test_mse", "test_nll", "test_mae", "test_crps",
                "test_coverage_95", "test_width_95", "test_ace"
            ])
            
            # Test coverage at different confidence levels
            for conf in [10, 20, 30, 40, 50, 60, 70, 80, 90]:
                header_parts.append(f"test_coverage_{conf}")
            
            # Test width at different confidence levels
            for conf in [10, 20, 30, 40, 50, 60, 70, 80, 90]:
                header_parts.append(f"test_width_{conf}")
            
            # Training metadata
            header_parts.extend(["training_time", "epoch"])
            
            f.write(", ".join(header_parts) + "\n")
            
        # Write the metrics for this run
        data_parts = [str(n_train_sample), str(i_seed)]
        
        # Basic train metrics
        data_parts.extend([
            f"{train_results.get('train_mse', 0.0):.6f}",
            f"{train_results.get('train_nll', 0.0):.6f}",
            f"{train_results.get('train_mae', 0.0):.6f}",
            f"{train_results.get('train_crps', 0.0):.6f}",
            f"{train_results.get('train_coverage_95', 0.0):.6f}",
            f"{train_results.get('train_width_95', 0.0):.6f}",
            f"{train_results.get('train_ace', 0.0):.6f}"
        ])
        
        # Train coverage at different confidence levels
        for conf in [10, 20, 30, 40, 50, 60, 70, 80, 90]:
            data_parts.append(f"{train_results.get(f'train_coverage_{conf}', 0.0):.6f}")
        
        # Train width at different confidence levels
        for conf in [10, 20, 30, 40, 50, 60, 70, 80, 90]:
            data_parts.append(f"{train_results.get(f'train_width_{conf}', 0.0):.6f}")
        
        # Basic test metrics
        data_parts.extend([
            f"{train_results.get('test_mse', 0.0):.6f}",
            f"{train_results.get('test_nll', 0.0):.6f}",
            f"{train_results.get('test_mae', 0.0):.6f}",
            f"{train_results.get('test_crps', 0.0):.6f}",
            f"{train_results.get('test_coverage_95', 0.0):.6f}",
            f"{train_results.get('test_width_95', 0.0):.6f}",
            f"{train_results.get('test_ace', 0.0):.6f}"
        ])
        
        # Test coverage at different confidence levels
        for conf in [10, 20, 30, 40, 50, 60, 70, 80, 90]:
            data_parts.append(f"{train_results.get(f'test_coverage_{conf}', 0.0):.6f}")
        
        # Test width at different confidence levels
        for conf in [10, 20, 30, 40, 50, 60, 70, 80, 90]:
            data_parts.append(f"{train_results.get(f'test_width_{conf}', 0.0):.6f}")
        
        # Training metadata
        data_parts.extend([
            f"{train_results.get('training_time', 0.0):.6f}",
            f"{train_results.get('epoch', 0.0):.1f}"
        ])
        
        f.write(", ".join(data_parts) + "\n")
            
        # Print summary of results
        print(f"\nProcess {process_id}: Results for {model_name} (seed={i_seed}, GPU={gpu_id}):")
        print(f"Training Time: {train_results['training_time']:.2f} seconds")
        print(f"Train MAE: {train_results['train_mae']:.6f}")
        print(f"Train NLL: {train_results['train_nll']:.6f}")
        print(f"Test MAE: {train_results['test_mae']:.6f}")
        print(f"Test NLL: {train_results['test_nll']:.6f}")


if __name__ == "__main__":
    
    # Get number of available GPUs
    num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
    print(f"Number of available GPUs: {num_gpus}")
    
    
    # Generate all combinations of parameters
    all_cases = []
    for max_samples in N_SAMPLES:
        for model_name in MODELS:
            for i_seed in range(N_SEED):
                # Use round-robin for initial assignment, will be reassigned based on memory in each process
                gpu_id = assign_gpu(len(all_cases), num_gpus)
                all_cases.append((i_seed, gpu_id, model_name, max_samples))
    
    total_cases = len(all_cases)
    print(f"Total number of cases to run: {total_cases}")
    
    # Run cases in parallel
    start_time = time.time()
    
    if NUM_PARALLEL_PROCESSES > 1:
        print(f"Running {NUM_PARALLEL_PROCESSES} processes in parallel")
        # Use 'spawn' for better compatibility with CUDA
        mp.set_start_method('spawn', force=True)
        with mp.Pool(processes=NUM_PARALLEL_PROCESSES) as pool:
            # Use imap to get results as they complete
            for _ in pool.imap_unordered(run_single_case, all_cases):
                pass
    else:
        print("Running sequentially")
        for i, case in enumerate(all_cases):
            print(f"Running case {i+1}/{total_cases}")
            run_single_case(case)
    
    end_time = time.time()
    total_time = end_time - start_time
    print(f"Total execution time: {total_time:.2f} seconds")
    print(f"Average time per case: {total_time/total_cases:.2f} seconds")

