'''
ERA5 Dataset experiment with parallel processing capabilities
Based on the original run_models_165.py but configured for parallel execution.
'''
import os
import sys
sys.path.append('.')
import multiprocessing as mp
import time
import numpy as np
import pandas as pd

from common import *

from baseline.bayesian_last_layer_networks import create_bll_model, train_bll_model
from baseline.monte_carlo_dropout import MCDropout, train_mc_dropout_model
from baseline.probabilistic_neural_networks import ProbabilisticNeuralNetwork, train_pnn_model
from baseline.stochastic_weight_averaging_gaussian import create_swag_model, train_swag_model
from baseline.deterministic_variational_inference import DVI, train_dvi_model
from baseline.deep_gaussian_process_regression import DeepGPRegression, train_deep_gp_model
from baseline.mixture_density_network import MixtureDensityNetwork, train_mdn_model
from hvbll.vbll import VBLL, train_vbll_model
from hvbll.hvbll import HVBLL, train_hvbll_model

N_SEED = 10
NUM_PARALLEL_PROCESSES = 32  # Number of processes to run in parallel


# Common parameters for all models
dim_latent = 64
dim_hidden = 128
n_hidden_layers = 4
learning_rate = 0.01
lr_step_size = 500

MODELS = ['HVBLL', 'VBLL', 'BLL', 'MC-Dropout', 'Deep-GP', 'PNN', 'SWAG', 'DVI', 'MDN']
MODELS = ['MDN']
N_SAMPLES = [500, 4000, 20000, int(5e6)]
N_SAMPLES = [500, 4000, 20000]


def run_single_case(args):
    """
    Run a single experiment case
    
    Args:
        args: Tuple containing (i_seed, gpu_id, model_name, max_samples)
    
    Returns:
        None, results are written to files
    """
    i_seed, gpu_id, model_name, max_samples = args
    
    # Log which process is handling this task
    process_id = mp.current_process().name
    print(f"Process {process_id}: Assigned to GPU {gpu_id} for {model_name} (seed={i_seed})")
    
    # Check if this case has already been run
    _n_sample = int(max_samples*0.8)
    if case_already_run(model_name, i_seed, _n_sample):
        print(f"Process {process_id}: Skipping for {model_name} - already run for seed={i_seed}")
        return
    
    # Reassign GPU based on current memory availability
    # gpu_id = assign_gpu_by_memory(gpu_id)
    
    # Set the random seed for reproducibility
    set_seed(i_seed)
    
    # Prepare the case
    dict_problem = prepare_case(batch_size=512, GPU_ID=gpu_id, max_samples=max_samples)

    dim_input = dict_problem['dim_input']
    dim_output = dict_problem['dim_output']
    X_train_tensor = dict_problem['X_train_tensor']
    y_train_tensor = dict_problem['y_train_tensor']
    X_test_tensor = dict_problem['X_test_tensor']
    y_test_tensor = dict_problem['y_test_tensor']
    dataloader = dict_problem['dataloader']
    n_train_sample = len(X_train_tensor)
    if _n_sample != n_train_sample:
        raise ValueError(f"n_train_sample={n_train_sample} does not match _n_sample={_n_sample}")

    # Dictionary to store model configurations
    models = {
        'HVBLL': {
            'create_fn': HVBLL,
            'params': {
                'dim_input': dim_input,
                'dim_output': dim_output,
                'dim_latent': dim_latent,
                'dim_hidden': dim_hidden,
                'dim_hidden_noise': 16,
                'n_hidden_layers': n_hidden_layers,
                'n_noise_layers': 1,
                'reg_weight_latent': 0.1,
                'reg_weight_noise': 0.1,
                'covariance_type': 'dense',
                'prior_scale': 1.0,
                'wishart_scale': 1.0,
                'dof': 1.0
            },
            'train_fn': train_hvbll_model,
            'color': 'gray',
            'num_epochs': 5000
        },
        'VBLL': {
            'create_fn': VBLL,
            'params': {
                'dim_input': dim_input,
                'dim_output': dim_output,
                'dim_latent': dim_latent,
                'dim_hidden': dim_hidden,
                'n_hidden_layers': n_hidden_layers,
                'reg_weight_latent': 0.1,
                'reg_weight_noise': 0.1,
                'covariance_type': 'dense',
                'prior_scale': 1.0,
                'wishart_scale': 1.0,
                'dof': 1.0
            },
            'train_fn': train_vbll_model,
            'color': 'pink',
            'num_epochs': 5000
        },
        'BLL': {
            'create_fn': create_bll_model,
            'params': {
                'dim_input': dim_input,
                'dim_hidden': dim_hidden,
                'dim_latent': 32,    # Dimension of the feature space
                'n_hidden_layers': n_hidden_layers,
                'kernel_type': 'rbf',
                'use_derivatives': False,  # Set to True for LDGBLL, False for GBLL
                'prior_log_obs_var': 3.0
            },
            'train_fn': train_bll_model,
            'color': 'blue',
            'num_epochs': 5000
        },
        'MC-Dropout': {
            'create_fn': MCDropout,
            'params': {
                'dim_input': dim_input,
                'dim_output': dim_output,
                'dim_hidden': dim_hidden,
                'n_hidden_layers': n_hidden_layers,
                'dropout_rate': 0.1
            },
            'train_fn': train_mc_dropout_model,
            'color': 'red',
            'num_samples': 30,  # This is used in get_prediction, not in the constructor
            'num_epochs': 5000
        },
        'Deep-GP': {
            'create_fn': DeepGPRegression,
            'params': {
                'dim_input': dim_input,
                'dim_output': dim_output,
                'dim_hidden': dim_hidden,
                'num_hidden_layers': n_hidden_layers,
                'num_inducing': 64,
                'num_samples': 50
            },
            'train_fn': train_deep_gp_model,
            'color': 'cyan',
            'num_epochs': 5000
        },
        'PNN': {
            'create_fn': ProbabilisticNeuralNetwork,
            'params': {
                'dim_input': dim_input,
                'dim_output': dim_output,
                'dim_hidden': dim_hidden,
                'n_hidden_layers': n_hidden_layers,
                'beta': 0.5  # Beta parameter for beta-NLL loss
            },
            'train_fn': train_pnn_model,
            'color': 'purple',
            'num_epochs': 5000
        },
        'SWAG': {
            'create_fn': create_swag_model,
            'params': {
                'dim_input': dim_input,
                'dim_output': dim_output,
                'dim_hidden': dim_hidden,
                'n_hidden_layers': n_hidden_layers,
                'max_models': 20,
                'swa_start': 3000,  # Epoch to start SWA
                'swa_lr': 1e-3,
                'var_clamp': 0.001,
                'full_cov': True,  # Use full covariance (diagonal + low-rank)
                'prior_log_obs_var': None
            },
            'train_fn': train_swag_model,
            'color': 'orange',
            'num_epochs': 5000
        },
        'DVI': {
            'create_fn': DVI,
            'params': {
                'dim_input': dim_input,
                'dim_output': dim_output,
                'dim_hidden': dim_hidden,
                'n_hidden_layers': n_hidden_layers,
                'ratio_kl': 1.0,
                'n_components': 5,  # Number of Gaussian components
                'prior_mean': 0.0,
                'prior_var': 0.1,
                'prior_log_obs_var': 0.0
            },
            'train_fn': train_dvi_model,
            'color': 'brown',
            'num_epochs': 1000
        },
        'MDN': {
            'create_fn': MixtureDensityNetwork,
            'params': {
                'dim_input': dim_input,
                'dim_output': dim_output,
                'dim_hidden': dim_hidden,
                'num_hidden_layers': n_hidden_layers,
                'num_components': 3,
            },
            'train_fn': train_mdn_model,
            'color': 'green',
            'num_epochs': 10000
        },
    }

    # Train and evaluate each model
    model_config = models[model_name]

    # Create model
    model = model_config['create_fn'](**model_config['params'])
    if torch.cuda.is_available():
        model.cuda(gpu_id)
    
    # Train model
    try:
        train_results = model_config['train_fn'](
            model, dataloader, X_train_tensor, y_train_tensor, X_test_tensor, y_test_tensor, 
            model_config['num_epochs'], learning_rate, lr_step_size
        )
        
    except Exception as e:
        print(f"Process {process_id}: Error training {model_name}, seed={i_seed}: {e}")
        return
    
    # Write results to file
    fname = fname_summary(model_name)
    
    # Write metrics to file - use file locking to avoid corruption
    with open(fname, 'a') as f:
        # Write header if file is empty
        if os.path.getsize(fname) == 0:
            # Construct comprehensive header with all available metrics
            header_parts = ["n_sample", "seed"]
            
            # Basic train metrics
            header_parts.extend([
                "train_mse", "train_nll", "train_mae", "train_crps", 
                "train_coverage_95", "train_width_95", "train_ace"
            ])
            
            # Train coverage at different confidence levels
            for conf in [10, 20, 30, 40, 50, 60, 70, 80, 90]:
                header_parts.append(f"train_coverage_{conf}")
            
            # Train width at different confidence levels
            for conf in [10, 20, 30, 40, 50, 60, 70, 80, 90]:
                header_parts.append(f"train_width_{conf}")
            
            # Basic test metrics
            header_parts.extend([
                "test_mse", "test_nll", "test_mae", "test_crps",
                "test_coverage_95", "test_width_95", "test_ace"
            ])
            
            # Test coverage at different confidence levels
            for conf in [10, 20, 30, 40, 50, 60, 70, 80, 90]:
                header_parts.append(f"test_coverage_{conf}")
            
            # Test width at different confidence levels
            for conf in [10, 20, 30, 40, 50, 60, 70, 80, 90]:
                header_parts.append(f"test_width_{conf}")
            
            # Training metadata
            header_parts.extend(["training_time", "epoch"])
            
            f.write(", ".join(header_parts) + "\n")
            
        # Write the metrics for this run
        data_parts = [str(n_train_sample), str(i_seed)]
        
        # Basic train metrics
        data_parts.extend([
            f"{train_results.get('train_mse', 0.0):.6f}",
            f"{train_results.get('train_nll', 0.0):.6f}",
            f"{train_results.get('train_mae', 0.0):.6f}",
            f"{train_results.get('train_crps', 0.0):.6f}",
            f"{train_results.get('train_coverage_95', 0.0):.6f}",
            f"{train_results.get('train_width_95', 0.0):.6f}",
            f"{train_results.get('train_ace', 0.0):.6f}"
        ])
        
        # Train coverage at different confidence levels
        for conf in [10, 20, 30, 40, 50, 60, 70, 80, 90]:
            data_parts.append(f"{train_results.get(f'train_coverage_{conf}', 0.0):.6f}")
        
        # Train width at different confidence levels
        for conf in [10, 20, 30, 40, 50, 60, 70, 80, 90]:
            data_parts.append(f"{train_results.get(f'train_width_{conf}', 0.0):.6f}")
        
        # Basic test metrics
        data_parts.extend([
            f"{train_results.get('test_mse', 0.0):.6f}",
            f"{train_results.get('test_nll', 0.0):.6f}",
            f"{train_results.get('test_mae', 0.0):.6f}",
            f"{train_results.get('test_crps', 0.0):.6f}",
            f"{train_results.get('test_coverage_95', 0.0):.6f}",
            f"{train_results.get('test_width_95', 0.0):.6f}",
            f"{train_results.get('test_ace', 0.0):.6f}"
        ])
        
        # Test coverage at different confidence levels
        for conf in [10, 20, 30, 40, 50, 60, 70, 80, 90]:
            data_parts.append(f"{train_results.get(f'test_coverage_{conf}', 0.0):.6f}")
        
        # Test width at different confidence levels
        for conf in [10, 20, 30, 40, 50, 60, 70, 80, 90]:
            data_parts.append(f"{train_results.get(f'test_width_{conf}', 0.0):.6f}")
        
        # Training metadata
        data_parts.extend([
            f"{train_results.get('training_time', 0.0):.6f}",
            f"{train_results.get('epoch', 0.0):.1f}"
        ])
        
        f.write(", ".join(data_parts) + "\n")
            
        # Print summary of results
        print(f"\nProcess {process_id}: Results for {model_name} (seed={i_seed}, GPU={gpu_id}):")
        print(f"Training Time: {train_results['training_time']:.2f} seconds")
        print(f"Train MAE: {train_results['train_mae']:.6f}")
        print(f"Train NLL: {train_results['train_nll']:.6f}")
        print(f"Test MAE: {train_results['test_mae']:.6f}")
        print(f"Test NLL: {train_results['test_nll']:.6f}")


def case_already_run(model_name, i_seed, n_train_sample):
    """
    Check if a specific case has already been run by looking at the summary file
    
    Args:
        model_name: Name of the model
        i_seed: Seed index
        
    Returns:
        Boolean indicating if the case has been run
    """
    fname = fname_summary(model_name)
    
    if not os.path.exists(fname):
        return False
    
    try:
        # Check if file is empty
        if os.path.getsize(fname) == 0:
            return False
            
        # Read the file as text first to check content
        with open(fname, 'r') as f:
            content = f.read().strip()
            if not content or ',' not in content:  # If file is empty or doesn't have CSV format
                return False
        
        # Use pandas to read the CSV safely with specific settings
        # Note: 'error_bad_lines' was renamed to 'on_bad_lines' in newer pandas versions
        try:
            # First try with the newer parameter name
            df = pd.read_csv(fname, sep=',', skipinitialspace=True, comment='#', 
                         skip_blank_lines=True, on_bad_lines='warn')
        except TypeError:
            # Fall back to older pandas versions
            df = pd.read_csv(fname, sep=',', skipinitialspace=True, comment='#', 
                         skip_blank_lines=True, error_bad_lines=False)
        
        # Check if required columns exist
        if 'n_sample' not in df.columns or 'seed' not in df.columns:
            print(f"Warning: Missing required columns in {fname}")
            return False
        
        # Check if this specific case exists in the results
        mask = (df['seed'] == i_seed) & (df['n_sample'] == n_train_sample)
        case_exists = mask.any()
        
        if case_exists:
            print(f"Found existing run for {model_name}, seed={i_seed}")
            
        return case_exists
        
    except Exception as e:
        print(f"Error checking if case already run: {e} (file: {fname})")
        # If there's any error reading the file, assume the case hasn't been run
        return False


if __name__ == "__main__":
    
    # Get number of available GPUs
    num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
    print(f"Number of available GPUs: {num_gpus}")
    
    
    # Generate all combinations of parameters
    all_cases = []
    for max_samples in N_SAMPLES:
        for model_name in MODELS:
            for i_seed in range(N_SEED):
                # Use round-robin for initial assignment, will be reassigned based on memory in each process
                gpu_id = assign_gpu(len(all_cases), num_gpus)
                all_cases.append((i_seed, gpu_id, model_name, max_samples))
    
    total_cases = len(all_cases)
    print(f"Total number of cases to run: {total_cases}")
    
    # Run cases in parallel
    start_time = time.time()
    
    if NUM_PARALLEL_PROCESSES > 1:
        print(f"Running {NUM_PARALLEL_PROCESSES} processes in parallel")
        # Use 'spawn' for better compatibility with CUDA
        mp.set_start_method('spawn', force=True)
        with mp.Pool(processes=NUM_PARALLEL_PROCESSES) as pool:
            # Use imap to get results as they complete
            for _ in pool.imap_unordered(run_single_case, all_cases):
                pass
    else:
        print("Running sequentially")
        for i, case in enumerate(all_cases):
            print(f"Running case {i+1}/{total_cases}")
            run_single_case(case)
    
    end_time = time.time()
    total_time = end_time - start_time
    print(f"Total execution time: {total_time:.2f} seconds")
    print(f"Average time per case: {total_time/total_cases:.2f} seconds")

