import os
# os.environ['TF_FORCE_UNIFIED_MEMORY'] = '1'
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import sys
import functools
import pickle
import time
import datetime
import json
import numpy as np
import jax
from jax import numpy as jnp
import flax
from flax import serialization

from ProtLig_GPCRclassA.amino_GNN.dataset import AminoDatasetPrecompute
from ProtLig_GPCRclassA.amino_GNN.collate import AminoCollatePrecompute
from ProtLig_GPCRclassA.amino_GNN.element import AminoElementPrecompute
from ProtLig_GPCRclassA.amino_GNN.loader import AminoLoader, get_tf_loader

from ProtLig_GPCRclassA.amino_GNN.base.make_init import make_init_model, get_tf_specs
from ProtLig_GPCRclassA.amino_GNN.make_create_optimizer import make_create_optimizer
from ProtLig_GPCRclassA.amino_GNN.base.predict.make_predict_epoch import make_predict_epoch

from ProtLig_GPCRclassA.amino_GNN.select_model import get_model_by_name

import logging


def main_predict_batch_precompute(hparams):
    """
    """
    model_class = get_model_by_name(hparams['MODEL_NAME'])
    model = model_class(atom_features = hparams['ATOM_FEATURES'],
                        bond_features = hparams['BOND_FEATURES'],
                        out_features = hparams['OUT_FEATURES'])

    if hparams['SELF_LOOPS']:
    # if isinstance(model, Simple_GAT_model) or isinstance(model, Transformer_GAT_model):
        hparams['PADDING_N_EDGE'] = hparams['PADDING_N_EDGE'] + hparams['PADDING_N_NODE'] # NOTE: Because of self_loops
        if len(hparams['BOND_FEATURES']) > 0:
            raise ValueError('Can not have both bond features and self_loops.')

    logger = logging.getLogger('main_predict_batch_precompute')
    # logger.setLevel(logging.DEBUG)
    logger.setLevel(logging.INFO)
    logger_stdout_handler = logging.StreamHandler(sys.stdout)
    logger.addHandler(logger_stdout_handler)

    logger.info('jax_version = {}'.format(jax.__version__))
    logger.info('flax_version = {}'.format(flax.__version__))
    logger.info('from_disk = {}'.format(hparams['PYTABLE_FROM_DISK']))
    logger.info('model_name = {}'.format(hparams['MODEL_NAME']))
    logger.info('loader_output_type = {}'.format(hparams['LOADER_OUTPUT_TYPE']))
    # ---------
    # Datasets:
    # ---------
    import tables
    h5file = tables.open_file(hparams['H5FILE'], mode = 'r', title=hparams['H5FILE_TITLE'])
    h5_table = h5file.root.amino.table # h5file.root.bert.BERTtable

    collate = AminoCollatePrecompute(bert_table = h5_table, 
                                    padding_n_node = hparams['PADDING_N_NODE'], 
                                    padding_n_edge = hparams['PADDING_N_EDGE'],
                                    n_partitions = hparams['N_PARTITIONS'],
                                    from_disk = hparams['PYTABLE_FROM_DISK'],
                                    line_graph = hparams['LINE_GRAPH'])

    element = AminoElementPrecompute(bert_table = h5_table,
                                    padding_n_node = hparams['PADDING_N_NODE'], 
                                    padding_n_edge = hparams['PADDING_N_EDGE'],
                                    from_disk = hparams['PYTABLE_FROM_DISK'])
    if not hparams['PYTABLE_FROM_DISK']:
        h5file.close()
        print('Table closed...')

    predict_dataset = AminoDatasetPrecompute(data_csv = hparams['PREDICT_CSV_PATH'],
                        mols_csv = hparams['MOLS_CSV'],
                        seqs_csv = hparams['SEQS_CSV'],
                        mol_id_col = hparams['MOL_ID_COL'],
                        mol_col = hparams['MOL_COL'],
                        seq_id_col = hparams['SEQ_ID_COL'],
                        label_col = hparams['LABEL_COL'],
                        weight_col = None,
                        atom_features = model.atom_features,
                        bond_features = model.bond_features,
                        # class_alpha = hparams['CLASS_ALPHA'],
                        line_graph_max_size = hparams['LINE_GRAPH_MAX_SIZE_MULTIPLIER'] * collate.padding_n_node,
                        self_loops = hparams['SELF_LOOPS'],
                        line_graph = hparams['LINE_GRAPH'],
                        # auxiliary_label_cols = hparams['AUXILIARY_LABEL_COLS'],
                        # auxiliary_weight_cols = hparams['AUXILIARY_WEIGHT_COLS'],
                        mol_global_cols = hparams['MOL_GLOBAL_COLS'],
                        seq_global_cols = hparams['SEQ_GLOBAL_COLS'],
                        )


    predict_loader = AminoLoader(predict_dataset, 
                        batch_size = hparams['BATCH_SIZE'],
                        collate_fn = collate.make_collate(),
                        shuffle = False,
                        rng = jax.random.PRNGKey(int(time.time())),
                        drop_last = False,
                        n_partitions = hparams['N_PARTITIONS'])


    if hparams['LOADER_OUTPUT_TYPE'] == 'tf':
        predict_dataset.element_preprocess = element.make_element_preprocess()

        predict_loader = get_tf_loader(predict_dataset,
                               batch_size = hparams['BATCH_SIZE'],
                               use_cache = hparams['CACHE'],
                               shuffle = False,
                               shuffle_buffer_size = hparams['SHUFFLE_BUFFER_SIZE'],
                               drop_last = False)

    # ----------------
    # Initializations:
    # ----------------
    # key1, key2 = jax.random.split(jax.random.PRNGKey(int(time.time())), 2)
    prng_key = jax.random.PRNGKey(int(time.time()))
    key_params, _key_num_steps, key_num_steps, key_dropout = jax.random.split(prng_key, 4)

    # Initializations:
    start = time.time()
    logger.info('Initializing...')
    init_model = make_init_model(model, 
                                batch_size = hparams['BATCH_SIZE'], 
                                seq_embedding_size = hparams['SEQ_EMBEDDING_SIZE'], 
                                num_node_features = len(hparams['ATOM_FEATURES']), 
                                num_edge_features = len(hparams['BOND_FEATURES']), 
                                self_loops = hparams['SELF_LOOPS'], 
                                line_graph = hparams['LINE_GRAPH'],
                                seq_max_length = hparams['SEQ_MAX_LENGTH'],
                                padding_n_node = hparams['PADDING_N_NODE'], 
                                padding_n_edge = hparams['PADDING_N_EDGE']) # 768)
    params = init_model(rngs = {'params' : key_params, 'dropout' : key_dropout, 'num_steps' : _key_num_steps})
    end = time.time()
    logger.info('TIME: init_model: {}'.format(end - start))

    transition_steps = 1000 # NOTE: This is a dummy value.
    create_optimizer = make_create_optimizer(model, option = hparams['OPTIMIZATION']['OPTION'], warmup_steps = hparams['OPTIMIZATION']['WARMUP_STEPS'], transition_steps = transition_steps)
    init_state, scheduler = create_optimizer(params, rngs = {'dropout' : key_dropout, 'num_steps' : key_num_steps}, learning_rate = hparams['LEARNING_RATE'])

    # Restore params:
    restore_file = hparams['RESTORE_FILE']
    if restore_file is not None:
        logger.info('Restoring parameters from {}'.format(restore_file))
        with open(restore_file, 'rb') as pklfile:
            bytes_output = pickle.load(pklfile)
        state = serialization.from_bytes(init_state, bytes_output)
        logger.info('Parameters restored...')
    else:
        state = init_state    

    if hparams['N_PARTITIONS'] > 0:
        raise NotImplementedError('pmap is not implemented yet...')
    else:
        predict_epoch = make_predict_epoch(model, return_intermediates = hparams['RETURN_INTERMEDIATES'], num_classes = hparams['OUT_FEATURES'], loader_output_type = hparams['LOADER_OUTPUT_TYPE'])


    # --------
    # PREDICT:
    # --------
    start = time.time()
    predict_outputs = predict_epoch(state.params, predict_loader)
    end = time.time()
    logger.info('TIME: predict_epoch: {}'.format(end - start))

    outputs = dict()
    for col in predict_outputs[0].keys():
        _batch_col = [x[col] for x in predict_outputs]
        outputs[col] = jnp.concatenate(_batch_col, axis = 0)

    df = predict_dataset.data[[hparams['MOL_ID_COL'], hparams['SEQ_ID_COL']]].copy()
    df['prediction'] = outputs['_main_label'][:, 0].tolist()
    return df.to_dict()
    # return outputs