import os
import yaml
import json
import datetime
import argparse

from ProtLig_GPCRclassA.envyaml import EnvYAML

# ------------------------------------------
# This script will be run if ACTION is train
# ------------------------------------------

def main_train_script(params):
    # if args.job_array:
    #     params['SLURM_JOB_ARRAY'] = True

    if 'train' not in params['ACTION']:
        raise ValueError('Invoking training script but "train" not in ACTION')
    elif 'VALID_CSV_NAME' in params.keys() or 'PREDICT_CSV_NAME' in params.keys():
        raise ValueError('Training is not supported while VALID_CSV_NAME or PREDICT_CSV_NAME are set in params.')
    elif 'TRAIN_CSV_NAME' not in params.keys():
        raise ValueError('no TRAIN_CSV_NAME privided for training.')

    if 'HUGGINGFACE_CACHE_DIR' not in params.keys():
        params['HUGGINGFACE_CACHE_DIR'] = None # Set to HuggingFace default, ~/.cache/huggingface/

    if 'SEQ_MODEL_PATH' not in params.keys() or params['SEQ_MODEL_PATH'] is None:
        if params['SEQ_MODEL_NAME'] in ['esm2_t33_650M_UR50D', 'esm2_t48_15B_UR50D']:
            params['SEQ_MODEL_PATH'] = "facebook/" + params['SEQ_MODEL_NAME']
        elif params['SEQ_MODEL_NAME'] == 'ProtBERT':
            params['SEQ_MODEL_PATH'] = "Rostlab/prot_bert"

    if 'SEQ_MODEL_CONFIG_PATH' not in params.keys() or params['SEQ_MODEL_CONFIG_PATH'] is None:
        if params['SEQ_MODEL_NAME'] in ['esm2_t33_650M_UR50D', 'esm2_t48_15B_UR50D']:
            params['SEQ_MODEL_CONFIG_PATH'] = "facebook/" + params['SEQ_MODEL_NAME']
        elif params['SEQ_MODEL_NAME'] == 'ProtBERT':
            params['SEQ_MODEL_CONFIG_PATH'] = "Rostlab/prot_bert"

    if 'SEQ_MODEL_TOKENIZER_PATH' not in params.keys() or params['SEQ_MODEL_TOKENIZER_PATH'] is None:
        if params['SEQ_MODEL_NAME'] in ['esm2_t33_650M_UR50D', 'esm2_t48_15B_UR50D']:
            params['SEQ_MODEL_TOKENIZER_PATH'] = "facebook/" + params['SEQ_MODEL_NAME']
        elif params['SEQ_MODEL_NAME'] == 'ProtBERT':
            params['SEQ_MODEL_TOKENIZER_PATH'] = "Rostlab/prot_bert"

    if 'CACHE_SEQ_LOOKUP' in params.keys():
        if params['CACHE_SEQ_LOOKUP'] and not params['CACHE']:
            print('-----> WARNING: CACHE_SEQ_LOOKUP is True but CACHE is False.')

    # Cast ATOM_FEATURES and BOND_FEATURES to tuple.
    params['ATOM_FEATURES'] = tuple(params['ATOM_FEATURES'])
    params['BOND_FEATURES'] = tuple(params['BOND_FEATURES'])

    if 'MOL_GLOBAL_COLS' not in params:
        params['MOL_GLOBAL_COLS'] = []
    elif params['MOL_GLOBAL_COLS'] is None:
        params['MOL_GLOBAL_COLS'] = []
    elif not isinstance(params['MOL_GLOBAL_COLS'], list):
        params['MOL_GLOBAL_COLS'] = [params['MOL_GLOBAL_COLS']]

    if 'SEQ_GLOBAL_COLS' not in params:
        params['SEQ_GLOBAL_COLS'] = []
    if params['SEQ_GLOBAL_COLS'] is None:
        params['SEQ_GLOBAL_COLS'] = []
    elif not isinstance(params['SEQ_GLOBAL_COLS'], list):
        params['SEQ_GLOBAL_COLS'] = [params['SEQ_GLOBAL_COLS']]

    # Consistency checks:
    if params['SELF_LOOPS']:
        if not params['MODEL_NAME'] in ['Simple_GAT_model', 'Transformer_GAT_model']:
            raise ValueError('Not supported model for self loops: {}'.format(params['MODEL_NAME']))
        if len(params['BOND_FEATURES']) > 0:
            raise ValueError('Non-empty bond features while self loop is True.')
        for i in range(len(params['PADDING_N_EDGE'])):
            params['PADDING_N_EDGE'][i] = params['PADDING_N_EDGE'][i] + params['PADDING_N_NODE'][i] # NOTE: Because of self_loops

    # Prepare data with multiple cuts and perform series of concistency checks:
    params_test_list = {key : isinstance(params[key], list) for key in ['MOLS_CSV', 'SEQS_CSV', 'PADDING_N_NODE', 'PADDING_N_EDGE', 'BATCH_SIZE', 'N_EPOCH']}
    check_groups = {
        'MOLS_CSV' : ['MOLS_CSV', 'PADDING_N_NODE', 'PADDING_N_EDGE', 'BATCH_SIZE', 'N_EPOCH'],
        'SEQS_CSV' : ['SEQS_CSV', 'BATCH_SIZE', 'N_EPOCH'],
        }
    params_test_list.update({'TRAIN_CSV_NAME' : isinstance(params['TRAIN_CSV_NAME'], list)})
    check_groups.update({
        'TRAIN_CSV_NAME' : ['TRAIN_CSV_NAME', 'BATCH_SIZE', 'N_EPOCH']
        })
    if any(params_test_list.values()):
        # Check consistency:
        for key in check_groups.keys():
            if params_test_list[key]:
                if not all([params_test_list[k] for k in check_groups[key]]):
                    raise ValueError('{} is list while some of others ({}) are not.'.format(key, ', '.join(check_groups[key])))
        if sorted(params['N_EPOCH']) != params['N_EPOCH']:
            raise ValueError('N_EPOCH is not sorted.')

        # Check list lengths:
        n_cuts = len(params['N_EPOCH'])
        
        for key in params_test_list.keys():
            if params_test_list[key] and len(params[key]) != n_cuts:
                raise ValueError('Parameter {} has different length than N_EPOCH'.format(key))
    
        # Construct lists from non-list params:
        for key in params_test_list.keys():
            if not params_test_list[key]:
                params[key] = [params[key]]*n_cuts
    
        cuts = []
        for i in range(n_cuts):
            cuts.append((params['MOLS_CSV'][i], params['SEQS_CSV'][i], params['TRAIN_CSV_NAME'][i]))

        if len(cuts) > len(set(cuts)):
            raise ValueError('There are at least two cuts that are identical.')

    if 'AUXILIARY_LABEL_COLS' not in params:
        params['AUXILIARY_LABEL_COLS'] = []
    elif params['AUXILIARY_LABEL_COLS'] is None:
        params['AUXILIARY_LABEL_COLS'] = []
    elif not isinstance(params['AUXILIARY_LABEL_COLS'], list):
        params['AUXILIARY_LABEL_COLS'] = [params['AUXILIARY_LABEL_COLS']]

    if 'AUXILIARY_WEIGHT_COLS' not in params:
        params['AUXILIARY_WEIGHT_COLS'] = []
    elif params['AUXILIARY_WEIGHT_COLS'] is None:
        params['AUXILIARY_WEIGHT_COLS'] = []
    elif not isinstance(params['AUXILIARY_WEIGHT_COLS'], list):
        params['AUXILIARY_WEIGHT_COLS'] = [params['AUXILIARY_WEIGHT_COLS']]

    if 'AUXILIARY_LOSS_OPTION' not in params:
        params['AUXILIARY_LOSS_OPTION'] = None

    if 'SAMPLE_BY_LABEL_DISTRIBUTION' not in params:
        params['SAMPLE_BY_LABEL_DISTRIBUTION'] = False

    # Slurm job array:
    if 'SLURM_JOB_ARRAY' not in params.keys():
        params['SLURM_JOB_ARRAY'] = False
    
    if params['SLURM_JOB_ARRAY'] and params['RESTORE_FILE'] is not None:
        raise NotImplementedError('SLURM_JOB_ARRAY with non-empty RESTORE_FILE is not supported.')
    
    # Get restore file if using Slurm job array:
    if params['SLURM_JOB_ARRAY']:
        params['SLURM_ARRAY_JOB_ID'] = os.environ['SLURM_ARRAY_JOB_ID']
        if params['CASE'] == 'amino_GNN':
            from ProtLig_GPCRclassA.amino_GNN.select_model import get_model_by_name
            from ProtLig_GPCRclassA.utils import get_last_restore_file_and_state
            
            model_class = get_model_by_name(params['MODEL_NAME'])
            restore_dir = os.path.join(params['LOGGING_PARENT_DIR'], params['DATACASE'], model_class.__name__, params['SLURM_ARRAY_JOB_ID'])
            print('restoring file for SLURM JOB ARRAY: {}'.format(restore_dir))

            if os.path.exists(restore_dir):
                _restore_file, _restore_state = get_last_restore_file_and_state(restore_dir)
                if _restore_file is not None:
                    params['RESTORE_FILE'] = _restore_file
                    print('latest restore file: {}'.format(_restore_file))

                    if int(_restore_state.replace('state_e', '').replace('.pkl', '')) >= params['N_EPOCH'][-1]:
                        os.environ["PROTLIG_SLURM_ARRAY_EXIT_SIGNAL"] = "1"
                        os.environ["PROTLIG_SLURM_ARRAY_EXIT_SLURM_ARRAY_JOB_ID"] = os.environ['SLURM_ARRAY_JOB_ID']
                        raise Exception('Slurm array ends...')
                    else:
                        os.environ["PROTLIG_SLURM_ARRAY_EXIT_SIGNAL"] = "0"
                        

    if params['CASE'] == 'amino_GNN':
        # Training:
        if params['ACTION'] == 'train':
            from ProtLig_GPCRclassA.amino_GNN.base.train.main_train import main_train
            output = main_train(params)
        elif params['ACTION'] == 'train_pmap':
            from ProtLig_GPCRclassA.amino_GNN.base.train.main_train_pmap import main_train_pmap
            output = main_train_pmap(params)
        elif params['ACTION'] == 'train_masked':
            from ProtLig_GPCRclassA.amino_GNN.base.train.main_train_masked import main_train_masked
            output = main_train_masked(params)
        elif params['ACTION'] == 'train_masked_pmap':
            from ProtLig_GPCRclassA.amino_GNN.base.train.main_train_masked_pmap import main_train_masked_pmap
            output = main_train_masked_pmap(params)
        else:
            raise ValueError('Unknown action {} for case {}. Available options: {}'.format(params['ACTION'], params['CASE'], ['train', 'eval', 'predict', 'predict_single', 'precompute']))
    print('Finished...')

    # TODO: Handle outputs from training scripts. Currently assumed to be None (26.09.2023).