import os
import yaml
import json
import datetime
import argparse

from ProtLig_GPCRclassA.envyaml import EnvYAML

# -----------------------------------------
# This script will be run if ACTION is eval
# -----------------------------------------

def main_eval_script(params):
    if 'eval' not in params['ACTION']:
        raise ValueError('Invoking validation script by "train" not in ACTION')
    elif 'PREDICT_CSV_NAME' in params.keys():
        raise ValueError('validation is not supported while PREDICT_CSV_NAME are set in params.')
    elif 'VALID_CSV_NAME' not in params.keys():
        raise ValueError('no VALID_CSV_NAME privided for validation.')

    if 'HUGGINGFACE_CACHE_DIR' not in params.keys():
        params['HUGGINGFACE_CACHE_DIR'] = None # Set to HuggingFace default, ~/.cache/huggingface/

    if 'SEQ_MODEL_PATH' not in params.keys() or params['SEQ_MODEL_PATH'] is None:
        if params['SEQ_MODEL_NAME'] in ['esm2_t33_650M_UR50D', 'esm2_t48_15B_UR50D']:
            params['SEQ_MODEL_PATH'] = "facebook/" + params['SEQ_MODEL_NAME']
        elif params['SEQ_MODEL_NAME'] == 'ProtBERT':
            params['SEQ_MODEL_PATH'] = "Rostlab/prot_bert"

    if 'SEQ_MODEL_CONFIG_PATH' not in params.keys() or params['SEQ_MODEL_CONFIG_PATH'] is None:
        if params['SEQ_MODEL_NAME'] in ['esm2_t33_650M_UR50D', 'esm2_t48_15B_UR50D']:
            params['SEQ_MODEL_CONFIG_PATH'] = "facebook/" + params['SEQ_MODEL_NAME']
        elif params['SEQ_MODEL_NAME'] == 'ProtBERT':
            params['SEQ_MODEL_CONFIG_PATH'] = "Rostlab/prot_bert"

    if 'SEQ_MODEL_TOKENIZER_PATH' not in params.keys() or params['SEQ_MODEL_TOKENIZER_PATH'] is None:
        if params['SEQ_MODEL_NAME'] in ['esm2_t33_650M_UR50D', 'esm2_t48_15B_UR50D']:
            params['SEQ_MODEL_TOKENIZER_PATH'] = "facebook/" + params['SEQ_MODEL_NAME']
        elif params['SEQ_MODEL_NAME'] == 'ProtBERT':
            params['SEQ_MODEL_TOKENIZER_PATH'] = "Rostlab/prot_bert"

    if 'CACHE_SEQ_LOOKUP' in params.keys():
        if params['CACHE_SEQ_LOOKUP'] and not params['CACHE']:
            print('-----> WARNING: CACHE_SEQ_LOOKUP is True but CACHE is False.')

    # Cast ATOM_FEATURES and BOND_FEATURES to tuple.
    params['ATOM_FEATURES'] = tuple(params['ATOM_FEATURES'])
    params['BOND_FEATURES'] = tuple(params['BOND_FEATURES'])

    if 'MOL_GLOBAL_COLS' not in params:
        params['MOL_GLOBAL_COLS'] = []
    elif params['MOL_GLOBAL_COLS'] is None:
        params['MOL_GLOBAL_COLS'] = []
    elif not isinstance(params['MOL_GLOBAL_COLS'], list):
        params['MOL_GLOBAL_COLS'] = [params['MOL_GLOBAL_COLS']]

    if 'SEQ_GLOBAL_COLS' not in params:
        params['SEQ_GLOBAL_COLS'] = []
    if params['SEQ_GLOBAL_COLS'] is None:
        params['SEQ_GLOBAL_COLS'] = []
    elif not isinstance(params['SEQ_GLOBAL_COLS'], list):
        params['SEQ_GLOBAL_COLS'] = [params['SEQ_GLOBAL_COLS']]

    # Consistency checks:
    if params['SELF_LOOPS']:
        if not params['MODEL_NAME'] in ['Simple_GAT_model', 'Transformer_GAT_model']:
            raise ValueError('Not supported model for self loops: {}'.format(params['MODEL_NAME']))
        if len(params['BOND_FEATURES']) > 0:
            raise ValueError('Non-empty bond features while self loop is True.')
        for i in range(len(params['PADDING_N_EDGE'])):
            params['PADDING_N_EDGE'][i] = params['PADDING_N_EDGE'][i] + params['PADDING_N_NODE'][i] # NOTE: Because of self_loops

    if 'AUXILIARY_LABEL_COLS' not in params:
        params['AUXILIARY_LABEL_COLS'] = []
    elif params['AUXILIARY_LABEL_COLS'] is None:
        params['AUXILIARY_LABEL_COLS'] = []
    elif not isinstance(params['AUXILIARY_LABEL_COLS'], list):
        params['AUXILIARY_LABEL_COLS'] = [params['AUXILIARY_LABEL_COLS']]

    if 'AUXILIARY_WEIGHT_COLS' not in params:
        params['AUXILIARY_WEIGHT_COLS'] = []
    elif params['AUXILIARY_WEIGHT_COLS'] is None:
        params['AUXILIARY_WEIGHT_COLS'] = []
    elif not isinstance(params['AUXILIARY_WEIGHT_COLS'], list):
        params['AUXILIARY_WEIGHT_COLS'] = [params['AUXILIARY_WEIGHT_COLS']]

    if 'AUXILIARY_LOSS_OPTION' not in params:
        params['AUXILIARY_LOSS_OPTION'] = None    

    if params['CASE'] == 'amino_GNN':
        # Evaluation:
        if params['ACTION'] == 'eval':
            from ProtLig_GPCRclassA.amino_GNN.base.eval.main_eval import main_eval
            output = main_eval(params)
        elif params['ACTION'] == 'eval_ckpts':
            from ProtLig_GPCRclassA.amino_GNN.base.eval.main_eval_ckpts import main_eval_ckpts
            output = main_eval_ckpts(params)
        elif params['ACTION'] == 'eval_masked_ckpts':
            from ProtLig_GPCRclassA.amino_GNN.base.eval.main_eval_masked_ckpts import main_eval_masked_ckpts
            output = main_eval_masked_ckpts(params)
        else:
            raise ValueError('Unknown action {} for case {}. Available options for validation: {}'.format(params['ACTION'], params['CASE'], ['eval', 'eval_ckpts', 'eval_masked_ckpts']))
    print('Finished...')

    # TODO: Handle outputs from validation scripts. Currently assumed to be None (26.09.2023).