import os
import yaml
import json
import datetime
import argparse

from ProtLig_GPCRclassA.envyaml import EnvYAML

# -----------------------------------------------
# This script will be run if ACTION is precompute
# -----------------------------------------------

def main_precompute_script(params):

    if 'precompute' not in params['ACTION']:
        raise ValueError('Invoking precompute script but "precompute" not in ACTION')
    elif 'TRAIN_CSV_NAME' in params.keys() or 'VALID_CSV_NAME' in params.keys() or 'PREDICT_CSV_NAME' in params.keys():
        raise ValueError('precompute is not supported while TRAIN_CSV_NAME or VALID_CSV_NAME or PREDICT_CSV_NAME are set in params.')
    elif 'DATA_FILE' not in params.keys():
        raise ValueError('no DATA_FILE privided for precompute.')

    if 'HUGGINGFACE_CACHE_DIR' not in params.keys():
        params['HUGGINGFACE_CACHE_DIR'] = None # Set to HuggingFace default, ~/.cache/huggingface/

    if 'SEQ_MODEL_PATH' not in params.keys() or params['SEQ_MODEL_PATH'] is None:
        if params['SEQ_MODEL_NAME'] in ['esm2_t33_650M_UR50D', 'esm2_t48_15B_UR50D']:
            params['SEQ_MODEL_PATH'] = "facebook/" + params['SEQ_MODEL_NAME']
        elif params['SEQ_MODEL_NAME'] == 'ProtBERT':
            params['SEQ_MODEL_PATH'] = "Rostlab/prot_bert"

    if 'SEQ_MODEL_CONFIG_PATH' not in params.keys() or params['SEQ_MODEL_CONFIG_PATH'] is None:
        if params['SEQ_MODEL_NAME'] in ['esm2_t33_650M_UR50D', 'esm2_t48_15B_UR50D']:
            params['SEQ_MODEL_CONFIG_PATH'] = "facebook/" + params['SEQ_MODEL_NAME']
        elif params['SEQ_MODEL_NAME'] == 'ProtBERT':
            params['SEQ_MODEL_CONFIG_PATH'] = "Rostlab/prot_bert"

    if 'SEQ_MODEL_TOKENIZER_PATH' not in params.keys() or params['SEQ_MODEL_TOKENIZER_PATH'] is None:
        if params['SEQ_MODEL_NAME'] in ['esm2_t33_650M_UR50D', 'esm2_t48_15B_UR50D']:
            params['SEQ_MODEL_TOKENIZER_PATH'] = "facebook/" + params['SEQ_MODEL_NAME']
        elif params['SEQ_MODEL_NAME'] == 'ProtBERT':
            params['SEQ_MODEL_TOKENIZER_PATH'] = "Rostlab/prot_bert"

    if 'CACHE_SEQ_LOOKUP' in params.keys():
        if params['CACHE_SEQ_LOOKUP'] and not params['CACHE']:
            print('-----> WARNING: CACHE_SEQ_LOOKUP is True but CACHE is False.')


    if params['CASE'] == 'amino_GNN':
        if params['ACTION'] == 'precompute':
            from ProtLig_GPCRclassA.amino_GNN.precompute.main_precompute import main_precompute
            output = main_precompute(params)
        else:
            raise ValueError('Unknown action {} for case {}. Available options: {}'.format(params['ACTION'], params['CASE'], ['precompute']))
    print('Finished...')

    # TODO: Handle outputs from precompute scripts. Currently assumed to be None (26.09.2023).