from ProtLig_GPCRclassA.amino_GNN.precompute.ESM2 import PrecomputeESM2
from ProtLig_GPCRclassA.amino_GNN.precompute.prot_bert import PrecomputeProtBERT

def main_precompute(hparams):
    if hparams['MODEL_NAME'] == 'ProtBERT':
        from transformers import BertTokenizer, BertConfig, FlaxBertModel
        tokenizer = BertTokenizer.from_pretrained(hparams['SEQ_MODEL_TOKENIZER_PATH'], do_lower_case=False, cache_dir = hparams['HUGGINGFACE_CACHE_DIR'])
        config = BertConfig.from_pretrained(hparams['SEQ_MODEL_CONFIG_PATH'], output_hidden_states=True, output_attentions=False, cache_dir = hparams['HUGGINGFACE_CACHE_DIR'])
        bert_model = FlaxBertModel.from_pretrained(hparams['SEQ_MODEL_PATH'], from_pt = True, config = config, cache_dir = hparams['HUGGINGFACE_CACHE_DIR'])

        precomuteBERT = PrecomputeProtBERT(data_file = hparams['DATA_FILE'],
                                             save_dir = hparams['SAVE_DIR'],
                                             mode = hparams['MODE'],
                                             dbname = hparams['DBNAME'],
                                             id_col = hparams['ID_COL'],
                                             seq_col = hparams['SEQ_COL'],
                                             batch_size = hparams['BATCH_SIZE'],
                                             bert_model = bert_model,
                                             tokenizer = tokenizer,
                                             hidden_states_shape = hparams['HIDDEN_STATES_SHAPE'],
                                             max_length = hparams['MAX_LENGTH'],
                                             )

        precomuteBERT.precompute_and_save()
        precomuteBERT.h5file.close()


    elif hparams['MODEL_NAME'] == 'esm2_t33_650M_UR50D':
        from transformers import EsmTokenizer, EsmConfig, EsmModel
        tokenizer = EsmTokenizer.from_pretrained(hparams['SEQ_MODEL_TOKENIZER_PATH'], cache_dir = hparams['HUGGINGFACE_CACHE_DIR'])
        config = EsmConfig.from_pretrained(hparams['SEQ_MODEL_CONFIG_PATH'], cache_dir = hparams['HUGGINGFACE_CACHE_DIR'])
        model = EsmModel.from_pretrained(hparams['SEQ_MODEL_PATH'], config = config, add_pooling_layer = False, cache_dir = hparams['HUGGINGFACE_CACHE_DIR'])

        precomuteESM2 = PrecomputeESM2(data_file = hparams['DATA_FILE'],
                                        save_dir = hparams['SAVE_DIR'],
                                        mode = hparams['MODE'],
                                        dbname = hparams['DBNAME'],
                                        id_col = hparams['ID_COL'],
                                        seq_col = hparams['SEQ_COL'],
                                        batch_size = hparams['BATCH_SIZE'],
                                        model = model,
                                        tokenizer = tokenizer,
                                        hidden_states_shape = hparams['HIDDEN_STATES_SHAPE'],
                                        max_length = hparams['MAX_LENGTH'],
                                        )
        precomuteESM2.precompute_and_save()
        precomuteESM2.h5file.close()


    elif hparams['MODEL_NAME'] == 'esm2_t48_15B_UR50D':
        from transformers import EsmTokenizer, EsmConfig, EsmModel
        tokenizer = EsmTokenizer.from_pretrained(hparams['SEQ_MODEL_TOKENIZER_PATH'], cache_dir = hparams['HUGGINGFACE_CACHE_DIR'])
        config = EsmConfig.from_pretrained(hparams['SEQ_MODEL_CONFIG_PATH'], cache_dir = hparams['HUGGINGFACE_CACHE_DIR'])
        model = EsmModel.from_pretrained(hparams['SEQ_MODEL_PATH'], config = config, add_pooling_layer = False, cache_dir = hparams['HUGGINGFACE_CACHE_DIR'])
 
        precomuteESM2 = PrecomputeESM2(data_file = hparams['DATA_FILE'],
                                        save_dir = hparams['SAVE_DIR'],
                                        mode = hparams['MODE'],
                                        dbname = hparams['DBNAME'],
                                        id_col = hparams['ID_COL'],
                                        seq_col = hparams['SEQ_COL'],
                                        batch_size = hparams['BATCH_SIZE'],
                                        model = model,
                                        tokenizer = tokenizer,
                                        hidden_states_shape = hparams['HIDDEN_STATES_SHAPE'],
                                        max_length = hparams['MAX_LENGTH'],
                                        )
        precomuteESM2.precompute_and_save()
        precomuteESM2.h5file.close()

    else:
        raise ValueError('Unknown MODEL_NAME: \t {}'.format(hparams['MODEL_NAME']))

    return None