import os
# os.environ['TF_FORCE_UNIFIED_MEMORY'] = '1'
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import sys
import functools
import pickle
import time
import datetime
import json
import numpy as np
import jax
from jax import numpy as jnp
import flax
from flax import serialization

from ProtLig_GPCRclassA.amino_GNN.dataset import AminoDatasetPrecompute
from ProtLig_GPCRclassA.amino_GNN.collate import AminoCollatePrecompute
from ProtLig_GPCRclassA.amino_GNN.loader import AminoLoader

from ProtLig_GPCRclassA.amino_GNN.base.make_init import make_init_model, get_tf_specs
from ProtLig_GPCRclassA.amino_GNN.make_create_optimizer import make_create_optimizer
from ProtLig_GPCRclassA.amino_GNN.base.predict.make_predict_epoch import make_predict_epoch

from ProtLig_GPCRclassA.amino_GNN.select_model import get_model_by_name

import logging



def main_predict(hparams):
    """
    """
    model_class = get_model_by_name(hparams['MODEL_NAME'])
    model = model_class(atom_features = hparams['ATOM_FEATURES'], bond_features = hparams['BOND_FEATURES'], out_features = hparams['OUT_FEATURES'])

    if hparams['SELF_LOOPS']:
        hparams['PADDING_N_EDGE'] = hparams['PADDING_N_EDGE'] + hparams['PADDING_N_NODE'] # NOTE: Because of self_loops
        if len(hparams['BOND_FEATURES']) > 0:
            raise ValueError('Can not have both bond features and self_loops.')

    logger = logging.getLogger('main_predict')
    # logger.setLevel(logging.DEBUG)
    logger.setLevel(logging.INFO)
    logger_stdout_handler = logging.StreamHandler(sys.stdout)
    logger.addHandler(logger_stdout_handler)
    
    # ---------
    # Datasets:
    # ---------
    import tables
    h5file = tables.open_file(hparams['H5FILE'], mode = 'r', title=hparams['H5FILE_TITLE'])
    # h5_table = h5file.root.amino.table
    h5_table = h5file.root.amino.table

    collate = AminoCollatePrecompute(h5_table, 
                                    padding_n_node = hparams['PADDING_N_NODE'], 
                                    padding_n_edge = hparams['PADDING_N_EDGE'],
                                    n_partitions = hparams['N_PARTITIONS'],
                                    from_disk = hparams['PYTABLE_FROM_DISK'],
                                    line_graph = hparams['LINE_GRAPH'])

    if not hparams['PYTABLE_FROM_DISK']:
        h5file.close()
        print('Table closed...')


    predict_dataset = AminoDatasetPrecompute(data_csv = hparams['PREDICT_CSV_PATH'],
                        mols_csv = hparams['MOLS_CSV_PATH'],
                        mol_id_col = hparams['MOL_ID_COL'], 
                        mol_col = hparams['MOL_COL'],
                        seq_id_col = hparams['SEQ_ID_COL'], # Gene is only sequence id.
                        label_col = hparams['LABEL_COL'],
                        weight_col = None,
                        atom_features = model.atom_features, # ['AtomicNum', 'ChiralTag', 'Hybridization', 'FormalCharge', 
                                # 'NumImplicitHs', 'ExplicitValence', 'Mass', 'IsAromatic'],
                        bond_features = model.bond_features, # ['BondType', 'IsAromatic'],
                        line_graph_max_size = hparams['LINE_GRAPH_MAX_SIZE_MULTIPLIER'] * collate.padding_n_node,
                        self_loops = hparams['SELF_LOOPS'],
                        line_graph = hparams['LINE_GRAPH'],
                        mol_global_cols = hparams['MOL_GLOBAL_COLS'],
                        seq_global_cols = hparams['SEQ_GLOBAL_COLS'],
                        )

    _predict_loader = AminoLoader(predict_dataset, 
                        batch_size = hparams['BATCH_SIZE'],
                        collate_fn = collate.make_collate(),
                        shuffle = False,
                        rng = jax.random.PRNGKey(int(time.time())),
                        drop_last = False,
                        n_partitions = hparams['N_PARTITIONS'])


    if hparams['LOADER_OUTPUT_TYPE'] == 'jax':
        predict_loader = _predict_loader

    elif hparams['LOADER_OUTPUT_TYPE'] == 'tf':
        raise Exception('F U C K ... For some reason tensorflow is creating problem with RuntimeError: jaxlib/gpu/solver_kernels.cc:45: operation gpusolverDnCreate(&handle) failed: cuSolver internal error')
        predict_loader = _predict_loader.tf_Dataset_by_example(n_partitions = hparams['N_PARTITIONS'])
        if hparams['CACHE']:
            predict_loader = predict_loader.cache()
            predict_loader = predict_loader.shuffle(buffer_size = len(_predict_loader)) # This tries to load all the data
        # predict_loader = predict_loader.shuffle(buffer_size = 16) # This tries to load all the data
        predict_loader = predict_loader.prefetch(buffer_size = 32)

    # ----------------
    # Initializations:
    # ----------------
    key1, key2 = jax.random.split(jax.random.PRNGKey(int(time.time())), 2)
    key_params, _key_num_steps, key_num_steps, key_dropout = jax.random.split(key1, 4)

    logger.info('jax_version = {}'.format(jax.__version__))
    logger.info('flax_version = {}'.format(flax.__version__))
    logger.info('from_disk = {}'.format(hparams['PYTABLE_FROM_DISK']))
    logger.info('model_name = {}'.format(hparams['MODEL_NAME']))
    logger.info('loader_output_type = {}'.format(hparams['LOADER_OUTPUT_TYPE']))

    # Initializations:
    start = time.time()
    logger.info('Initializing...')
    init_model = make_init_model(model, 
                                batch_size = hparams['BATCH_SIZE'], 
                                seq_embedding_size = hparams['SEQ_EMBEDDING_SIZE'], 
                                num_node_features = len(hparams['ATOM_FEATURES']), 
                                num_edge_features = len(hparams['BOND_FEATURES']), 
                                self_loops = hparams['SELF_LOOPS'], 
                                line_graph = hparams['LINE_GRAPH'],
                                seq_max_length = hparams['SEQ_MAX_LENGTH'],
                                padding_n_node = hparams['PADDING_N_NODE'], 
                                padding_n_edge = hparams['PADDING_N_EDGE']) # 768)
    params = init_model(rngs = {'params' : key_params, 'dropout' : key_dropout, 'num_steps' : _key_num_steps}) # jax.random.split(key1, jax.device_count()))
    end = time.time()
    logger.info('TIME: init_model: {}'.format(end - start))

    transition_steps = 1000 # NOTE: This is a dummy value.
    create_optimizer = make_create_optimizer(model, option = hparams['OPTIMIZATION']['OPTION'], warmup_steps = hparams['OPTIMIZATION']['WARMUP_STEPS'], transition_steps = transition_steps)
    init_state, scheduler = create_optimizer(params, rngs = {'dropout' : key_dropout, 'num_steps' : key_num_steps}, learning_rate = hparams['LEARNING_RATE'])

    # Restore params:
    restore_file = hparams['RESTORE_FILE']
    if restore_file is not None:
        logger.info('Restoring parameters from {}'.format(restore_file))
        with open(restore_file, 'rb') as pklfile:
            bytes_output = pickle.load(pklfile)
        state = serialization.from_bytes(init_state, bytes_output)
        logger.info('Parameters restored...')
    else:
        state = init_state    

    if hparams['N_PARTITIONS'] > 0:
        raise NotImplementedError('pmap is not implemented yet...')
    else:
        predict_epoch = make_predict_epoch(model, return_intermediates = hparams['RETURN_INTERMEDIATES'], num_classes = hparams['OUT_FEATURES'], loader_output_type = hparams['LOADER_OUTPUT_TYPE'])


    # --------
    # PREDICT:
    # --------
    start = time.time()
    predict_outputs = predict_epoch(state.params, predict_loader)
    end = time.time()
    logger.info('TIME: predict_epoch: {}'.format(end - start))
    # valid_metrics_np = jax.device_get(valid_metrics)

    outputs = dict()
    for col in predict_outputs[0].keys():
        _batch_col = [x[col] for x in predict_outputs]
        outputs[col] = jnp.concatenate(_batch_col, axis = 0)

    return outputs