import os
# os.environ['TF_FORCE_UNIFIED_MEMORY'] = '1'
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# print('\nWARNING: XLA_PYTHON_CLIENT_PREALLOCATE set to False...\n')
# os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = '0'
import sys
import pickle
import time
import pandas
import jax
from jax import numpy as jnp
import flax
from flax import serialization


from ProtLig_GPCRclassA.amino_GNN.concentration.dataset_conc import AminoConcentrationDatasetPrecomputePredict
from ProtLig_GPCRclassA.amino_GNN.collate import AminoCollatePrecompute
from ProtLig_GPCRclassA.amino_GNN.concentration.element_conc import AminoConcentrationElementPrecompute
from ProtLig_GPCRclassA.amino_GNN.loader import AminoLoader, get_tf_loader

from ProtLig_GPCRclassA.amino_GNN.concentration.make_init_conc import make_init_model
from ProtLig_GPCRclassA.amino_GNN.make_create_optimizer import make_create_optimizer
from ProtLig_GPCRclassA.amino_GNN.concentration.predict.make_predict_conc_range_epoch import make_predict_conc_range_epoch

from ProtLig_GPCRclassA.amino_GNN.select_model import get_model_by_name

import logging

# jax.config.update('jax_platform_name', 'cpu')
# print(jnp.ones(3).device_buffer.device())

def main_predict_conc_range(hparams):
    """
    """
    model_class = get_model_by_name(hparams['MODEL_NAME'])
    model = model_class(atom_features = hparams['ATOM_FEATURES'],
                        bond_features = hparams['BOND_FEATURES'],
                        out_features = hparams['OUT_FEATURES'])

    if hparams['SELF_LOOPS']:
    # if isinstance(model, Simple_GAT_model) or isinstance(model, Transformer_GAT_model):
        hparams['PADDING_N_EDGE'] = hparams['PADDING_N_EDGE'] + hparams['PADDING_N_NODE'] # NOTE: Because of self_loops
        if len(hparams['BOND_FEATURES']) > 0:
            raise ValueError('Can not have both bond features and self_loops.')

    logger = logging.getLogger('main_predict_conc')
    # logger.setLevel(logging.DEBUG)
    logger.setLevel(logging.INFO)
    logger_stdout_handler = logging.StreamHandler(sys.stdout)
    logger.addHandler(logger_stdout_handler)

    logger.info('jax_version = {}'.format(jax.__version__))
    logger.info('flax_version = {}'.format(flax.__version__))
    logger.info('from_disk = {}'.format(hparams['PYTABLE_FROM_DISK']))
    logger.info('model_name = {}'.format(hparams['MODEL_NAME']))
    logger.info('loader_output_type = {}'.format(hparams['LOADER_OUTPUT_TYPE']))
    # ---------
    # Datasets:
    # ---------
    import tables
    h5file = tables.open_file(hparams['H5FILE'], mode = 'r', title=hparams['H5FILE_TITLE'])
    h5_table = h5file.root.amino.table # h5file.root.bert.BERTtable

    collate = AminoCollatePrecompute(bert_table = h5_table, 
                                    padding_n_node = hparams['PADDING_N_NODE'], 
                                    padding_n_edge = hparams['PADDING_N_EDGE'],
                                    n_partitions = hparams['N_PARTITIONS'],
                                    from_disk = hparams['PYTABLE_FROM_DISK'],
                                    line_graph = hparams['LINE_GRAPH'])

    element = AminoConcentrationElementPrecompute(bert_table = h5_table,
                                    padding_n_node = hparams['PADDING_N_NODE'], 
                                    padding_n_edge = hparams['PADDING_N_EDGE'],
                                    from_disk = hparams['PYTABLE_FROM_DISK'])
    if not hparams['PYTABLE_FROM_DISK']:
        h5file.close()
        print('Table closed...')

    predict_dataset = AminoConcentrationDatasetPrecomputePredict(data_csv = hparams['PREDICT_CSV_PATH'],
                        mols_csv = hparams['MOLS_CSV'],
                        seqs_csv = hparams['SEQS_CSV'],
                        mol_id_col = hparams['MOL_ID_COL'],
                        mol_col = hparams['MOL_COL'],
                        seq_id_col = hparams['SEQ_ID_COL'],
                        # label_col = hparams['LABEL_COL'],
                        # weight_col = hparams['WEIGHT_COL'],
                        atom_features = model.atom_features,
                        bond_features = model.bond_features,
                        # class_alpha = hparams['CLASS_ALPHA'],
                        line_graph_max_size = hparams['LINE_GRAPH_MAX_SIZE_MULTIPLIER'] * collate.padding_n_node,
                        self_loops = hparams['SELF_LOOPS'],
                        line_graph = hparams['LINE_GRAPH'],
                        # auxiliary_label_cols = hparams['AUXILIARY_LABEL_COLS'],
                        # auxiliary_weight_cols = hparams['AUXILIARY_WEIGHT_COLS'],
                        mol_global_cols = hparams['MOL_GLOBAL_COLS'],
                        seq_global_cols = hparams['SEQ_GLOBAL_COLS'],
                        # parameter_col = hparams['CONC_PARAMETER_COL'],
                        # n_ec50_copies = 0,
                        conc_value_col = hparams['CONC_VALUE_COL'],
                        # conc_value_screen_col = hparams['CONC_VALUE_SCREEN_COL'],
                        )


    predict_loader = AminoLoader(predict_dataset, 
                        batch_size = hparams['BATCH_SIZE'],
                        collate_fn = collate.make_collate(),
                        shuffle = False,
                        rng = jax.random.PRNGKey(int(time.time())),
                        drop_last = False,
                        n_partitions = hparams['N_PARTITIONS'])


    if hparams['LOADER_OUTPUT_TYPE'] == 'tf':
        predict_dataset.element_preprocess = element.make_element_preprocess()

        predict_loader = get_tf_loader(predict_dataset,
                               batch_size = hparams['BATCH_SIZE'],
                               use_cache = hparams['CACHE'],
                               shuffle = False,
                               shuffle_buffer_size = hparams['SHUFFLE_BUFFER_SIZE'],
                               drop_last = False)

    # ----------------
    # Initializations:
    # ----------------
    # key1, key2 = jax.random.split(jax.random.PRNGKey(int(time.time())), 2)
    prng_key = jax.random.PRNGKey(int(time.time()))
    key_params, _key_num_steps, key_num_steps, key_dropout = jax.random.split(prng_key, 4)

    # Initializations:
    start = time.time()
    logger.info('Initializing...')
    init_model = make_init_model(model, 
                                batch_size = hparams['BATCH_SIZE'], 
                                seq_embedding_size = hparams['SEQ_EMBEDDING_SIZE'], 
                                num_node_features = len(hparams['ATOM_FEATURES']), 
                                num_edge_features = len(hparams['BOND_FEATURES']), 
                                self_loops = hparams['SELF_LOOPS'], 
                                line_graph = hparams['LINE_GRAPH'],
                                seq_max_length = hparams['SEQ_MAX_LENGTH'],
                                padding_n_node = hparams['PADDING_N_NODE'], 
                                padding_n_edge = hparams['PADDING_N_EDGE']) # 768)
    params = init_model(rngs = {'params' : key_params, 'dropout' : key_dropout, 'num_steps' : _key_num_steps})
    end = time.time()
    logger.info('TIME: init_model: {}'.format(end - start))

    transition_steps = 1000 # NOTE: This is a dummy value.
    create_optimizer = make_create_optimizer(model, option = hparams['OPTIMIZATION']['OPTION'], warmup_steps = hparams['OPTIMIZATION']['WARMUP_STEPS'], transition_steps = transition_steps)
    init_state, scheduler = create_optimizer(params, rngs = {'dropout' : key_dropout, 'num_steps' : key_num_steps}, learning_rate = hparams['LEARNING_RATE'])

    # Restore params:
    restore_file = hparams['RESTORE_FILE']
    if restore_file is not None:
        logger.info('Restoring parameters from {}'.format(restore_file))
        with open(restore_file, 'rb') as pklfile:
            bytes_output = pickle.load(pklfile)
        state = serialization.from_bytes(init_state, bytes_output)
        logger.info('Parameters restored...')
    else:
        state = init_state    

    if hparams['N_PARTITIONS'] > 0:
        raise NotImplementedError('pmap needs to be checked...')
    else:
        predict_epoch = make_predict_conc_range_epoch(model, 
                                                      min_conc_sample = hparams['MIN_CONC_SAMPLE'],
                                                      max_conc_sample = hparams['MAX_CONC_SAMPLE'],
                                                      step_conc_sample = hparams['STEP_CONC_SAMPLE'],
                                                      return_intermediates = hparams['RETURN_INTERMEDIATES'], 
                                                      num_classes = hparams['OUT_FEATURES'], 
                                                      loader_output_type = hparams['LOADER_OUTPUT_TYPE'])

    # --------
    # PREDICT:
    # --------
    start = time.time()
    predict_outputs = predict_epoch(state.params, predict_loader)
    end = time.time()
    logger.info('TIME: predict_epoch: {}'.format(end - start))

    outputs = jax.tree.map(lambda *x: jnp.concatenate(x, axis = 0), *predict_outputs)
    outputs = jax.tree.map(lambda x: jnp.squeeze(x, axis = -1), outputs)

    pair_ids = predict_dataset.data[[hparams['MOL_ID_COL'], hparams['SEQ_ID_COL']]].copy()
    idx = jnp.array(pair_ids.index)
    idx = jnp.broadcast_to(jnp.reshape(idx, (len(pair_ids), 1)), outputs['_conc'].shape)
    outputs.update({'_pair_id' : idx})
    flatten_outputs = jax.tree.map(lambda x: x.flatten(), outputs)
    df = pandas.DataFrame(flatten_outputs)
    df  =pandas.merge(pair_ids, df, left_index = True, right_on = '_pair_id', how = 'inner')
    df.drop(columns = ['_pair_id'], inplace = True)
    return df.to_dict()
    # return outputs