'''
Shared code for the comparison of different methods on UCI datasets.
'''
import os
import sys
sys.path.append('.')

import numpy as np
import torch
from hvbll.uci_database import UCIDataset, get_UCI_datasets
from torch.utils.data import DataLoader


TEST_CASES = {
    
    165: {
        'name': 'Concrete Compressive Strength',
        'n_samples': [200, 500, 1000],      # total: 1030
        'index_delete_features': [
            [],                             # dim_input: 8
            [0, 3],                         # dim_input: 6
            [0, 3, 6],                      # dim_input: 5
            [0, 4, 6],                      # dim_input: 5
            [0, 5, 7],                      # dim_input: 5
        ],
        'batch_size': 500,
    },
    
    186: {
        'name': 'Wine Quality',
        'n_samples': [500, 1000, 4000],     # total: 4898
        'index_delete_features': [
            [],                             # dim_input: 11
            [1, 6, 10],                     # dim_input: 8
            [5, 7, 10],                     # dim_input: 8
            [1, 5, 6, 7, 10],               # dim_input: 6
        ],
        'batch_size': 1000,
    },
    291: {
        'name': 'Airfoil Self-Noise',
        'n_samples': [200, 500,  1000],     # total: 1503
        'index_delete_features': [
            [],                             # dim_input: 5
            [0],                            # dim_input: 4
            [1, 4],                         # dim_input: 3
            [2, 4],                         # dim_input: 3
            [2, 3, 4],                      # dim_input: 2
        ],
        'batch_size': 500,
    },
    294: {
        'name': 'Combined Cycle Power Plant',
        'n_samples': [500, 1000, 4000],     # total: 9568
        'index_delete_features': [
            [],                             # dim_input: 4
            [0],                            # dim_input: 3
            [0, 1],                         # dim_input: 2
            [0, 2],                         # dim_input: 2
            [0, 1, 2],                      # dim_input: 1
        ],
        'batch_size': 1000,
    },
    464: {
        'name': 'Superconductivity Data',
        'n_samples': [500, 1000, 4000],     # total: 21263
        'index_delete_features': [
            [],                             # dim_input: 81
            [i+5 for i in range(76)],       # dim_input: 5
        ],
        'batch_size': 1000,
    },
}


CHECK_NAN_GRAD = False
torch.autograd.set_detect_anomaly(CHECK_NAN_GRAD)


path0 = os.path.dirname(sys.argv[0])
path_summary = os.path.join(path0, 'summary')

def prepare_case( 
                id_UCI: int, 
                i_case_partial_x: int, 
                i_case_sample: int,
                seed: int = 0,
                GPU_ID: int = 0,
                old_dataset = None
                ) -> dict:
    '''
    Prepare the test case, including:
    - Create folders
    - Load the dataset
    - Split the dataset into train and test sets
    '''

    os.makedirs(path_summary, exist_ok=True)

    train_set, test_set, dataset = get_UCI_datasets(
            id_UCI=id_UCI, 
            num_total_samples=TEST_CASES[id_UCI]['n_samples'][i_case_sample], 
            ratio_train_samples=0.8,
            seed=seed, gpu_id=GPU_ID, scale_x=True, print_info=False,
            index_delete_features= TEST_CASES[id_UCI]['index_delete_features'][i_case_partial_x],
            dataset= old_dataset
            )

    dataloader = DataLoader(train_set, batch_size=TEST_CASES[id_UCI]['batch_size'], shuffle=True, drop_last=False)

    X_train_tensor = train_set.X
    y_train_tensor = train_set.Y
    X_test_tensor = test_set.X
    y_test_tensor = test_set.Y

    return {
        'train_set': train_set,
        'test_set': test_set,
        'dataloader': dataloader,
        'X_train_tensor': X_train_tensor,
        'y_train_tensor': y_train_tensor,
        'X_test_tensor': X_test_tensor,
        'y_test_tensor': y_test_tensor,
        'dim_input': train_set.dim_input,
        'dim_output': train_set.dim_output,
        'dataset': dataset
    }


def set_seed(seed: int):
    '''
    Set the random seed for reproducibility.
    '''
    np.random.seed(seed)
    
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.enabled = False


def fname_summary(model_name: str, id_UCI: int, i_case_partial_x: int) -> str:
    return os.path.join(path_summary, 'summary-%s-%d-%d.csv'%(model_name, id_UCI, i_case_partial_x))


def case_already_run(model_name: str, id_UCI: int, i_case_partial_x: int, 
                        n_train_sample: int, i_seed: int) -> bool:
    """
    Check if a specific case has already been run.
    
    Args:
        model_name: The name of the model
        id_UCI: The UCI dataset ID
        i_case_partial_x: The index of the feature deletion case
        n_train_sample: The number of training samples
        i_seed: The random seed
        
    Returns:
        bool: True if the case has already been run, False otherwise
    """
    fname = fname_summary(model_name, id_UCI, i_case_partial_x)
    
    # If the file doesn't exist, the case hasn't been run
    if not os.path.exists(fname):
        return False
    
    # If the file exists but is empty, the case hasn't been run
    if os.path.getsize(fname) == 0:
        return False
    
    # Read the file and check if the specific case has been run
    with open(fname, 'r') as f:
        lines = f.readlines()
        for line in lines[1:]:  # Skip header line
            parts = line.strip().split(',')
            if len(parts) >= 2:
                file_n_sample = int(parts[0].strip())
                file_seed = int(parts[1].strip())
                if file_n_sample == n_train_sample and file_seed == i_seed:
                    return True
    
    return False


def assign_gpu(idx, num_gpus):
    """Assign a GPU ID based on the job index for round-robin assignment"""
    if num_gpus == 0:  # No GPUs available
        return 0
    return idx % num_gpus
