import os
import yaml
import json
import datetime
import argparse

from ProtLig_GPCRclassA.envyaml import EnvYAML

# ------------------------------------------
# This script will be run if ACTION is train
# ------------------------------------------

def main_eval_conc_script(params):
    # if args.job_array:
    #     params['SLURM_JOB_ARRAY'] = True

    if 'eval_conc' not in params['ACTION']:
        raise ValueError('Invoking validation concentration script but "eval_conc" not in ACTION')
    elif 'PREDICT_CSV_NAME' in params.keys():
        raise ValueError('Validation concentration is not supported while PREDICT_CSV_NAME are set in params.')
    elif 'VALID_CSV_NAME' not in params.keys():
        raise ValueError('no VALID_CSV_NAME privided for validation concentration.')

    if 'HUGGINGFACE_CACHE_DIR' not in params.keys():
        params['HUGGINGFACE_CACHE_DIR'] = None # Set to HuggingFace default, ~/.cache/huggingface/

    if 'SEQ_MODEL_PATH' not in params.keys() or params['SEQ_MODEL_PATH'] is None:
        if params['SEQ_MODEL_NAME'] in ['esm2_t33_650M_UR50D', 'esm2_t48_15B_UR50D']:
            params['SEQ_MODEL_PATH'] = "facebook/" + params['SEQ_MODEL_NAME']
        elif params['SEQ_MODEL_NAME'] == 'ProtBERT':
            params['SEQ_MODEL_PATH'] = "Rostlab/prot_bert"

    if 'SEQ_MODEL_CONFIG_PATH' not in params.keys() or params['SEQ_MODEL_CONFIG_PATH'] is None:
        if params['SEQ_MODEL_NAME'] in ['esm2_t33_650M_UR50D', 'esm2_t48_15B_UR50D']:
            params['SEQ_MODEL_CONFIG_PATH'] = "facebook/" + params['SEQ_MODEL_NAME']
        elif params['SEQ_MODEL_NAME'] == 'ProtBERT':
            params['SEQ_MODEL_CONFIG_PATH'] = "Rostlab/prot_bert"

    if 'SEQ_MODEL_TOKENIZER_PATH' not in params.keys() or params['SEQ_MODEL_TOKENIZER_PATH'] is None:
        if params['SEQ_MODEL_NAME'] in ['esm2_t33_650M_UR50D', 'esm2_t48_15B_UR50D']:
            params['SEQ_MODEL_TOKENIZER_PATH'] = "facebook/" + params['SEQ_MODEL_NAME']
        elif params['SEQ_MODEL_NAME'] == 'ProtBERT':
            params['SEQ_MODEL_TOKENIZER_PATH'] = "Rostlab/prot_bert"

    if 'CACHE_SEQ_LOOKUP' in params.keys():
        if params['CACHE_SEQ_LOOKUP'] and not params['CACHE']:
            print('-----> WARNING: CACHE_SEQ_LOOKUP is True but CACHE is False.')

    # Cast ATOM_FEATURES and BOND_FEATURES to tuple.
    params['ATOM_FEATURES'] = tuple(params['ATOM_FEATURES'])
    params['BOND_FEATURES'] = tuple(params['BOND_FEATURES'])

    if 'MOL_GLOBAL_COLS' not in params:
        params['MOL_GLOBAL_COLS'] = []
    elif params['MOL_GLOBAL_COLS'] is None:
        params['MOL_GLOBAL_COLS'] = []
    elif not isinstance(params['MOL_GLOBAL_COLS'], list):
        params['MOL_GLOBAL_COLS'] = [params['MOL_GLOBAL_COLS']]

    if 'SEQ_GLOBAL_COLS' not in params:
        params['SEQ_GLOBAL_COLS'] = []
    if params['SEQ_GLOBAL_COLS'] is None:
        params['SEQ_GLOBAL_COLS'] = []
    elif not isinstance(params['SEQ_GLOBAL_COLS'], list):
        params['SEQ_GLOBAL_COLS'] = [params['SEQ_GLOBAL_COLS']]

    # Consistency checks:
    if params['SELF_LOOPS']:
        if not params['MODEL_NAME'] in ['Simple_GAT_model', 'Transformer_GAT_model']:
            raise ValueError('Not supported model for self loops: {}'.format(params['MODEL_NAME']))
        if len(params['BOND_FEATURES']) > 0:
            raise ValueError('Non-empty bond features while self loop is True.')
        for i in range(len(params['PADDING_N_EDGE'])):
            params['PADDING_N_EDGE'][i] = params['PADDING_N_EDGE'][i] + params['PADDING_N_NODE'][i] # NOTE: Because of self_loops

    if 'AUXILIARY_LABEL_COLS' not in params:
        params['AUXILIARY_LABEL_COLS'] = []
    elif params['AUXILIARY_LABEL_COLS'] is None:
        params['AUXILIARY_LABEL_COLS'] = []
    elif not isinstance(params['AUXILIARY_LABEL_COLS'], list):
        params['AUXILIARY_LABEL_COLS'] = [params['AUXILIARY_LABEL_COLS']]

    if 'AUXILIARY_WEIGHT_COLS' not in params:
        params['AUXILIARY_WEIGHT_COLS'] = []
    elif params['AUXILIARY_WEIGHT_COLS'] is None:
        params['AUXILIARY_WEIGHT_COLS'] = []
    elif not isinstance(params['AUXILIARY_WEIGHT_COLS'], list):
        params['AUXILIARY_WEIGHT_COLS'] = [params['AUXILIARY_WEIGHT_COLS']]

    if 'AUXILIARY_LOSS_OPTION' not in params:
        params['AUXILIARY_LOSS_OPTION'] = None    

    # Concentration sampling and training parameters:
    if 'MONOTONICITY_SLOPE_POS' not in params:
        params['MONOTONICITY_SLOPE_POS'] = 0.0
    if 'MONOTONICITY_SLOPE_NEG' not in params:
        params['MONOTONICITY_SLOPE_NEG'] = 0.0

    if params['CASE'] == 'amino_GNN':
        # Evaluation:
        if params['ACTION'] == 'eval_conc_masked_ckpts':
            from ProtLig_GPCRclassA.amino_GNN.concentration.eval.main_eval_conc_masked_ckpts import main_eval_conc_masked_ckpts
            output = main_eval_conc_masked_ckpts(params)
        elif params['ACTION'] == 'eval_conc_masked_params_ckpts':
            from ProtLig_GPCRclassA.amino_GNN.concentration.eval.main_eval_conc_masked_params_ckpts import main_eval_conc_masked_params_ckpts
            output = main_eval_conc_masked_params_ckpts(params)
        else:
            raise ValueError('Unknown action {} for case {}. Available options for validation: {}'.format(params['ACTION'], params['CASE'], ['train', 'train_conc', 'eval', 'eval_conc', 'predict', 'predict_single', 'precompute']))
    else:
        raise ValueError('Unknown case: {}'.format(params['CASE']))
    print('Finished...')

    # TODO: Handle outputs from validation scripts. Currently assumed to be None (26.09.2023).