import os
import yaml
import json
import datetime
import argparse

from ProtLig_GPCRclassA.envyaml import EnvYAML

# ------------------------------------------
# This script will be run if ACTION is train
# ------------------------------------------

ALLOWED_OUTPUT_FILE_EXT = ['.json', '.csv', '.pkl']

def main_predict_script(params):

    if 'predict' not in params['ACTION']:
        raise ValueError('Invoking predicting script but "predict" not in ACTION')
    elif 'TRAIN_CSV_NAME' in params.keys() or 'VALID_CSV_NAME' in params.keys():
        raise ValueError('Predicting is not supported while TRAIN_CSV_NAME or VALID_CSV_NAME are set in params.')
    elif 'PREDICT_CSV_PATH' not in params.keys():
        raise ValueError('no PREDICT_CSV_PATH privided for training.')
    
    if 'OUTPUT_FILE' not in params.keys() or params['OUTPUT_FILE'] is None:
        raise ValueError('No output file path provided for the output.')
    else:
        output_dir, output_filename = os.path.split(params['OUTPUT_FILE'])
        _, output_file_ext = os.path.splitext(output_filename)
        params['_OUTPUT_DIR'] = output_dir
        params['_OUTPUT_FILENAME'] = output_filename
        params['_OUTPUT_FILE_EXT'] = output_file_ext
        if output_file_ext not in ALLOWED_OUTPUT_FILE_EXT:
            raise ValueError('Unknown output_file extension {}. Available options: {}'.format(output_file_ext, ALLOWED_OUTPUT_FILE_EXT)) 

    if 'HUGGINGFACE_CACHE_DIR' not in params.keys():
        params['HUGGINGFACE_CACHE_DIR'] = None # Set to HuggingFace default, ~/.cache/huggingface/

    if 'SEQ_MODEL_PATH' not in params.keys() or params['SEQ_MODEL_PATH'] is None:
        if params['SEQ_MODEL_NAME'] in ['esm2_t33_650M_UR50D', 'esm2_t48_15B_UR50D']:
            params['SEQ_MODEL_PATH'] = "facebook/" + params['SEQ_MODEL_NAME']
        elif params['SEQ_MODEL_NAME'] == 'ProtBERT':
            params['SEQ_MODEL_PATH'] = "Rostlab/prot_bert"

    if 'SEQ_MODEL_CONFIG_PATH' not in params.keys() or params['SEQ_MODEL_CONFIG_PATH'] is None:
        if params['SEQ_MODEL_NAME'] in ['esm2_t33_650M_UR50D', 'esm2_t48_15B_UR50D']:
            params['SEQ_MODEL_CONFIG_PATH'] = "facebook/" + params['SEQ_MODEL_NAME']
        elif params['SEQ_MODEL_NAME'] == 'ProtBERT':
            params['SEQ_MODEL_CONFIG_PATH'] = "Rostlab/prot_bert"

    if 'SEQ_MODEL_TOKENIZER_PATH' not in params.keys() or params['SEQ_MODEL_TOKENIZER_PATH'] is None:
        if params['SEQ_MODEL_NAME'] in ['esm2_t33_650M_UR50D', 'esm2_t48_15B_UR50D']:
            params['SEQ_MODEL_TOKENIZER_PATH'] = "facebook/" + params['SEQ_MODEL_NAME']
        elif params['SEQ_MODEL_NAME'] == 'ProtBERT':
            params['SEQ_MODEL_TOKENIZER_PATH'] = "Rostlab/prot_bert"

    if 'CACHE_SEQ_LOOKUP' in params.keys():
        if params['CACHE_SEQ_LOOKUP'] and not params['CACHE']:
            print('-----> WARNING: CACHE_SEQ_LOOKUP is True but CACHE is False.')

    # Cast ATOM_FEATURES and BOND_FEATURES to tuple.
    params['ATOM_FEATURES'] = tuple(params['ATOM_FEATURES'])
    params['BOND_FEATURES'] = tuple(params['BOND_FEATURES'])

    if 'MOL_GLOBAL_COLS' not in params:
        params['MOL_GLOBAL_COLS'] = []
    elif params['MOL_GLOBAL_COLS'] is None:
        params['MOL_GLOBAL_COLS'] = []
    elif not isinstance(params['MOL_GLOBAL_COLS'], list):
        params['MOL_GLOBAL_COLS'] = [params['MOL_GLOBAL_COLS']]

    if 'SEQ_GLOBAL_COLS' not in params:
        params['SEQ_GLOBAL_COLS'] = []
    if params['SEQ_GLOBAL_COLS'] is None:
        params['SEQ_GLOBAL_COLS'] = []
    elif not isinstance(params['SEQ_GLOBAL_COLS'], list):
        params['SEQ_GLOBAL_COLS'] = [params['SEQ_GLOBAL_COLS']]

    # Consistency checks:
    if params['SELF_LOOPS']:
        if not params['MODEL_NAME'] in ['Simple_GAT_model', 'Transformer_GAT_model']:
            raise ValueError('Not supported model for self loops: {}'.format(params['MODEL_NAME']))
        if len(params['BOND_FEATURES']) > 0:
            raise ValueError('Non-empty bond features while self loop is True.')
        for i in range(len(params['PADDING_N_EDGE'])):
            params['PADDING_N_EDGE'][i] = params['PADDING_N_EDGE'][i] + params['PADDING_N_NODE'][i] # NOTE: Because of self_loops




    if params['CASE'] == 'amino_GNN':
        if params['ACTION'] == 'predict':
            from ProtLig_GPCRclassA.amino_GNN.base.predict.main_predict import main_predict
            output = main_predict(params)
        elif params['ACTION'] == 'predict_conc':
            from ProtLig_GPCRclassA.amino_GNN.concentration.predict.main_predict_conc import main_predict_conc
            output = main_predict_conc(params)
        elif params['ACTION'] == 'predict_conc_range':
            from ProtLig_GPCRclassA.amino_GNN.concentration.predict.main_predict_conc_range import main_predict_conc_range
            output = main_predict_conc_range(params)
        elif params['ACTION'] == 'predict_single_precompute':
            from ProtLig_GPCRclassA.amino_GNN.base.predict.main_predict_single_precompute import main_predict_single_precompute
            output = main_predict_single_precompute(params)
        elif params['ACTION'] == 'predict_single_apply':
            from ProtLig_GPCRclassA.amino_GNN.base.predict.main_predict_single_apply import main_predict_single_apply
            output = main_predict_single_apply(params)
        elif params['ACTION'] == 'predict_batch_precompute':
            from ProtLig_GPCRclassA.amino_GNN.base.predict.main_predict_batch_precompute import main_predict_batch_precompute
            output = main_predict_batch_precompute(params)
        else:
            raise ValueError('Unknown action {} for case {}. Available options: {}'.format(params['ACTION'], params['CASE'], ['predict', 'predict_conc', 'predict_single_precompute', 'predict_single_apply']))
    print('Finished...')

    if output is not None:
        _datetime = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        output_path = os.path.join(params['_OUTPUT_DIR'], params['ACTION'] + '_output_' + params['MODEL_NAME'] + '_' + _datetime + '_' + params['_OUTPUT_FILENAME'])
        if not os.path.exists(params['_OUTPUT_DIR']):
            os.makedirs(params['_OUTPUT_DIR'])
            
        if params['_OUTPUT_FILE_EXT'] == '.json':
            with open(output_path, 'w+') as jsonfile:
                json.dump(output, jsonfile)
        elif params['_OUTPUT_FILE_EXT'] == '.csv':
            print(output)
            raise Exception('Bla...')
            
        elif params['_OUTPUT_FILE_EXT'] == '.pkl':
            import pickle
            import numpy
            import jax
            with open(output_path, 'wb') as pklfile:
                pickle.dump(jax.tree_map(lambda x: numpy.array(x), output), pklfile)
        else:
            raise ValueError('Unknown output_file extension {}. Available options: {}'.format(output_file_ext, ALLOWED_OUTPUT_FILE_EXT)) 
        return output_path