'''
Variance estimation using MC Dropout
'''
import os
import sys
sys.path.append('.')
path0 = os.path.dirname(sys.argv[0])

import numpy as np
import multiprocessing as mp
import pandas as pd
import time

import torch
from torch.utils.data import Dataset, DataLoader
from hvbll.toy_functions import *
from baseline.monte_carlo_dropout import MCDropout, train_mc_dropout_model


# Define a custom dataset class
class TensorDataset(Dataset):
    """Dataset wrapping tensors.
    
    Args:
        X_tensor (Tensor): contains features
        y_tensor (Tensor): contains labels
    """
    def __init__(self, X_tensor, y_tensor):
        self.X = X_tensor
        self.y = y_tensor
        
    def __getitem__(self, index):
        return self.X[index], self.y[index]
        
    def __len__(self):
        return len(self.X)

#* Constants
N_SEED = 10
NUM_PARALLEL_PROCESSES = 60  # Number of processes to run in parallel

NOISE_LEVEL = 0.1
NOISE_LEVEL_SLOPE = 1.0
NOISE_LEVEL_OMEGA = 2 * np.pi

# (dim_input, num_samples)
list_setting = [
    (1, 20), (1, 100), (1, 1000),
    (10, 50), (10, 1000), (10, 10000),
    (100, 100), (100, 1000), (100, 10000),
]

path_summary = os.path.join(path0, 'summary')
os.makedirs(path_summary, exist_ok=True)

# Common parameters for all models
dim_output = 1
dim_latent = 64
dim_hidden = 128
n_hidden_layers = 3
learning_rate = 0.01
lr_step_size = 500
num_epochs = 5000


def set_seed(seed: int) -> None:
    '''
    Set the random seed for reproducibility.
    '''
    np.random.seed(seed)
    
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.enabled = False

def fname_summary(i_function: int, dim_input: int) -> str:
    return os.path.join(path_summary, 'summary-F%d-X%d.txt'%(i_function, dim_input))

def assign_gpu(idx: int, num_gpus: int) -> int:
    """Assign a GPU ID based on the job index for round-robin assignment"""
    if num_gpus == 0:  # No GPUs available
        return 0
    return idx % num_gpus

def run_single_case(args):
    """
    Run a single experiment case
    
    Args:
        args: Tuple containing (i_function, dim_input, num_samples, seed, gpu_id)
    
    Returns:
        None, results are written to files
    """
    i_function, dim_input, num_samples, seed, gpu_id = args
    
    # Log which process is handling this task
    process_id = mp.current_process().name
    print(f"Process {process_id} running case: {i_function}, sample={num_samples}, seed={seed}, GPU={gpu_id}")
    
    #* Prepare the case
    
    if i_function in [1, 3]:    # sine noise
        if dim_input == 10:
            dim_input = 2
        elif dim_input == 100:
            dim_input = 4
    
    if i_function == 0:
        dataset = ToyFn_Lin_Noise_Lin(num_samples, dim_input=dim_input, seed=seed,
                        noise_level=NOISE_LEVEL, noise_level_slope=NOISE_LEVEL_SLOPE,
                        gpu_id=gpu_id)
    elif i_function == 1:
        dataset = ToyFn_Lin_Noise_Sin(num_samples, dim_input=dim_input, seed=seed,
                        noise_level=NOISE_LEVEL, noise_level_omega=NOISE_LEVEL_OMEGA,
                        gpu_id=gpu_id)
    elif i_function == 2:
        dataset = ToyFn_Sin_Noise_Lin(num_samples, dim_input=dim_input, seed=seed,
                        noise_level=NOISE_LEVEL, noise_level_slope=NOISE_LEVEL_SLOPE,
                        gpu_id=gpu_id)
    elif i_function == 3:
        dataset = ToyFn_Sin_Noise_Sin(num_samples, dim_input=dim_input, seed=seed,
                        noise_level=NOISE_LEVEL, noise_level_omega=NOISE_LEVEL_OMEGA,
                        gpu_id=gpu_id)

    X_train_tensor = dataset.X
    y_train_tensor = dataset.Y
    X_test_tensor = X_train_tensor[:10]
    y_test_tensor = y_train_tensor[:10]
    
    # Fix: Use the custom TensorDataset instead of direct Dataset instantiation
    train_set = TensorDataset(X_train_tensor, y_train_tensor)
    dataloader = DataLoader(train_set, batch_size=1000, shuffle=True, drop_last=False)

    #* Check if this case has already been run
    if case_already_run(i_function, dim_input, num_samples, seed):
        print(f"Process {process_id}: Skipping - Function {i_function}, Nx {dim_input}, Ns {num_samples}, Seed {seed}")
        return

    #* Create model
    model = MCDropout(dim_input, dim_output, dim_hidden, n_hidden_layers, dropout_rate=0.2)
    if torch.cuda.is_available():
        model.cuda(gpu_id)
    
    #* Train model
    try:
        train_results = train_mc_dropout_model(
            model, dataloader, X_train_tensor, y_train_tensor, X_test_tensor, y_test_tensor, 
            num_epochs, learning_rate, lr_step_size
        )
        
        train_var = train_results['train_var']

        E_noise = torch.mean(train_var, dim=0).detach().cpu().numpy()[0]
        
    except Exception as e:
        print(f"Process {process_id} Error: {e}")
        return
    
    # Write results to file
    fname = fname_summary(i_function, dim_input)
    
    # Write metrics to file - use file locking to avoid corruption
    with open(fname, 'a') as f:
        # Write header if file is empty
        if os.path.getsize(fname) == 0:
            f.write("n_sample, seed, E_noise, train_mse, train_nll, training_time, epoch\n")
            
        # Write the metrics for this run
        f.write(f"{num_samples}, {seed}, "
                f"{E_noise:.6f}, "
                f"{train_results['train_mse']:.6f}, "
                f"{train_results['train_nll']:.6f}, "
                f"{train_results['training_time']:.6f}, "
                f"{train_results['epoch']:.1f} \n")
        
def case_already_run(i_function: int, dim_input: int, num_samples: int, seed: int) -> bool:
    """
    Check if a specific case has already been run by looking at the summary file
    
    Args:
        i_function: Toy function index
        dim_input: Dimension of the input
        num_samples: Number of samples
        seed: Random seed used for the case
        
    Returns:
        Boolean indicating if the case has been run
    """
    fname = fname_summary(i_function, dim_input)
    
    if not os.path.exists(fname):
        return False
    
    try:
        # Check if file is empty
        if os.path.getsize(fname) == 0:
            return False
            
        # Read the file as text first to check content
        with open(fname, 'r') as f:
            content = f.read().strip()
            if not content or ',' not in content:  # If file is empty or doesn't have CSV format
                return False
        
        # Use pandas to read the CSV safely with specific settings
        # Note: 'error_bad_lines' was renamed to 'on_bad_lines' in newer pandas versions
        try:
            # First try with the newer parameter name
            df = pd.read_csv(fname, sep=',', skipinitialspace=True, comment='#', 
                         skip_blank_lines=True, on_bad_lines='warn')
        except TypeError:
            # Fall back to older pandas versions
            df = pd.read_csv(fname, sep=',', skipinitialspace=True, comment='#', 
                         skip_blank_lines=True, error_bad_lines=False)
        
        # Check if required columns exist
        if 'n_sample' not in df.columns or 'seed' not in df.columns:
            print(f"Warning: Missing required columns in {fname}")
            return False
        
        # Check if this specific case exists in the results
        mask = (df['n_sample'] == num_samples) & (df['seed'] == seed)
        case_exists = mask.any()
        
        return case_exists
        
    except Exception as e:
        print(f"Error checking if case already run: {e} (file: {fname})")
        # If there's any error reading the file, assume the case hasn't been run
        return False


if __name__ == '__main__':


    #* Run the cases in parallel
    if False:

        # Get number of available GPUs
        num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
        print(f"Number of available GPUs: {num_gpus}")
        
        # Generate all combinations of parameters
        all_cases = []
        for i_function in range(4):
            for dim_input, num_samples in list_setting:
                for seed in range(N_SEED):
                    # Assign a GPU to this job in round-robin fashion if multiple GPUs available
                    gpu_id = assign_gpu(len(all_cases), num_gpus)
                    all_cases.append((i_function, dim_input, num_samples, seed, gpu_id))
        
        total_cases = len(all_cases)
        print(f"Total number of cases to run: {total_cases}")
        
        # Run cases in parallel
        start_time = time.time()
        
        if NUM_PARALLEL_PROCESSES > 1:
            print(f"Running {NUM_PARALLEL_PROCESSES} processes in parallel")
            # Use 'spawn' for better compatibility with CUDA
            mp.set_start_method('spawn', force=True)
            with mp.Pool(processes=NUM_PARALLEL_PROCESSES) as pool:
                # Use imap to get results as they complete
                for _ in pool.imap_unordered(run_single_case, all_cases):
                    pass
        else:
            print("Running sequentially")
            for i, case in enumerate(all_cases):
                print(f"Running case {i+1}/{total_cases}")
                run_single_case(case)
        
        end_time = time.time()
        total_time = end_time - start_time
        print(f"Total execution time: {total_time:.2f} seconds")
        print(f"Average time per case: {total_time/total_cases:.2f} seconds")


    #* Collect results

    with open(os.path.join(path0, 'variance-mc-dropout.dat'), 'w') as f:
        
        f.write('Variables= %9s %9s %19s %19s \n' % ('i_func', 'dim_input', 'num_samples', 'E_noise'))

        for i_function in range(4):
            for dim_input, num_samples in list_setting:
                
                if i_function in [1, 3]:    # sine noise
                    if dim_input == 10:
                        dim_input = 2
                    elif dim_input == 100:
                        dim_input = 4
                
                E_noise = []
                
                fname = fname_summary(i_function, dim_input)
                with open(fname, 'r') as f2:
                    lines = f2.readlines()
                    for line in lines[1:]:
                        
                        parts = line.strip().split(',')

                        if num_samples == int(parts[0].strip()):
                            # Check if the line has enough parts
                            if len(parts) >= 3:
                                # Append the E_noise value to the list
                                # Convert to float and strip whitespace
                                # E_noise.append(float(parts[2].strip()))
                                # Check if the value is a valid float
                                try:
                                    E_noise.append(float(parts[2].strip()))
                                except ValueError:
                                    print(f"Warning: Invalid float value in line: {line.strip()}")
                            else:
                                print(f"Warning: Not enough parts in line: {line.strip()}")
                        else:
                            # If num_samples doesn't match, skip this line
                            continue

                text = '%20d %9d %19d %19.6e \n'%(i_function, dim_input, num_samples, np.mean(E_noise))
                
                f.write(text)
                
    print('Results saved to variance-mc-dropout.dat')

