import sys
import time
import pickle
import jax
import flax
from flax import serialization

import tables

from ProtLig_GPCRclassA.amino_GNN.base.make_init import make_init_model
from ProtLig_GPCRclassA.amino_GNN.make_create_optimizer import make_create_optimizer
from ProtLig_GPCRclassA.amino_GNN.select_model import get_model_by_name
from ProtLig_GPCRclassA.amino_GNN.base.predict.make_predict_single import make_predict_single_apply
from ProtLig_GPCRclassA.amino_GNN.seq_embedding.select_seq_embedding import select_seq_embedding

import logging

def main_predict_single_apply(hparams):
    """
    """
    model_class = get_model_by_name(hparams['MODEL_NAME'])
    model = model_class(atom_features = hparams['ATOM_FEATURES'], bond_features = hparams['BOND_FEATURES'], out_features = hparams['OUT_FEATURES'])

    if hparams['SELF_LOOPS']:
        hparams['PADDING_N_EDGE'] = hparams['PADDING_N_EDGE'] + hparams['PADDING_N_NODE'] # NOTE: Because of self_loops
        if len(hparams['BOND_FEATURES']) > 0:
            raise ValueError('Can not have both bond features and self_loops.')

    logger = logging.getLogger('main_predict_apply')
    # logger.setLevel(logging.DEBUG)
    logger.setLevel(logging.INFO)
    logger_stdout_handler = logging.StreamHandler(sys.stdout)

    # ----------------
    # Initializations:
    # ----------------
    key1, key2 = jax.random.split(jax.random.PRNGKey(int(time.time())), 2)
    key_params, _key_num_steps, key_num_steps, key_dropout = jax.random.split(key1, 4)

    logger.info('jax_version = {}'.format(jax.__version__))
    logger.info('flax_version = {}'.format(flax.__version__))
    logger.info('from_disk = {}'.format(hparams['PYTABLE_FROM_DISK']))
    logger.info('model_name = {}'.format(hparams['MODEL_NAME']))

    # Initializations:
    start = time.time()
    logger.info('Initializing...')
    init_model = make_init_model(model, 
                                batch_size = hparams['BATCH_SIZE'], 
                                seq_embedding_size = hparams['SEQ_EMBEDDING_SIZE'], 
                                num_node_features = len(hparams['ATOM_FEATURES']), 
                                num_edge_features = len(hparams['BOND_FEATURES']), 
                                self_loops = hparams['SELF_LOOPS'], 
                                line_graph = hparams['LINE_GRAPH'],
                                seq_max_length = hparams['SEQ_MAX_LENGTH'],
                                padding_n_node = hparams['PADDING_N_NODE'], 
                                padding_n_edge = hparams['PADDING_N_EDGE'])
    params = init_model(rngs = {'params' : key_params, 'dropout' : key_dropout, 'num_steps' : _key_num_steps}) # jax.random.split(key1, jax.device_count()))
    end = time.time()
    logger.info('TIME: init_model: {}'.format(end - start))

    # Restore params:
    transition_steps = 1000 # NOTE: This is a dummy value.
    create_optimizer = make_create_optimizer(model, option = hparams['OPTIMIZATION']['OPTION'], warmup_steps = hparams['OPTIMIZATION']['WARMUP_STEPS'], transition_steps = transition_steps)
    init_state, scheduler = create_optimizer(params, rngs = {'dropout' : key_dropout, 'num_steps' : key_num_steps}, learning_rate = hparams['LEARNING_RATE'])

    restore_file = hparams['RESTORE_FILE']
    if restore_file is not None:
        logger.info('Restoring parameters from {}'.format(restore_file))
        with open(restore_file, 'rb') as pklfile:
            bytes_output = pickle.load(pklfile)
        state = serialization.from_bytes(init_state, bytes_output)
        logger.info('Parameters restored...')
    else:
        state = init_state

    # --------
    # PREDICT:
    # --------
    apply_seqs_model = select_seq_embedding(hparams) # SEQ_MODEL_NAME needs to be in hparams
    predict_single = make_predict_single_apply(model,
                                            apply_seqs_model = apply_seqs_model,
                                            padding_n_node = hparams['PADDING_N_NODE'], 
                                            padding_n_edge = hparams['PADDING_N_EDGE'], 
                                            return_intermediates = hparams['RETURN_INTERMEDIATES'], 
                                            num_classes = hparams['OUT_FEATURES'],
                                            self_loops = hparams['SELF_LOOPS'])

    output = predict_single(state.params, hparams['LIST_SMILES'], hparams['LIST_SEQS'])
    logger.info('Finished...')
    return output