import os
# os.environ['TF_FORCE_UNIFIED_MEMORY'] = '1'
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import sys
import functools
import pickle
import time
import datetime
import json
import numpy as np
import jax
from jax import numpy as jnp
import flax
from flax import serialization
from objax.jaxboard import SummaryWriter, Summary, Image, DelayedScalar

from ProtLig_GPCRclassA.metrics import *

# from ProtLig_GPCRclassA.amino_GNN.loader import AminoDatasetPrecompute, AminoLoader, AminoCollatePrecompute
from ProtLig_GPCRclassA.amino_GNN.dataset import AminoDatasetPrecompute
from ProtLig_GPCRclassA.amino_GNN.collate import AminoCollatePrecompute
from ProtLig_GPCRclassA.amino_GNN.element import AminoElementPrecompute
from ProtLig_GPCRclassA.amino_GNN.loader import AminoLoader, get_tf_loader

from ProtLig_GPCRclassA.utils import _serialize_hparam, prefetch_to_device

from ProtLig_GPCRclassA.amino_GNN.base.make_init import make_init_model, get_tf_specs
from ProtLig_GPCRclassA.amino_GNN.make_create_optimizer import make_create_optimizer
from ProtLig_GPCRclassA.amino_GNN.make_compute_metrics import log_metrics_from_epoch
# from ProtLig_GPCRclassA.amino_GNN.base.train.make_valid_epoch import make_valid_epoch

from ProtLig_GPCRclassA.amino_GNN.make_regularization_loss import make_regularization_loss

from ProtLig_GPCRclassA.amino_GNN.select_model import get_model_by_name

import logging

# jax.config.update('jax_platform_name', 'cpu')
# print(jnp.ones(3).device_buffer.device())



def main_eval(hparams):
    if hparams['SIZE_CUT_DIRNAME'] is None:
        hparams.update({'BIG_SWITCH_EPOCH' : hparams['N_EPOCH'] + 1})

    datadir = os.path.join(hparams['DATA_PARENT_DIR'], hparams['DATACASE'])
    if hparams['SIZE_CUT_DIRNAME'] is not None:
        data_valid_csv = os.path.join(datadir, hparams['SIZE_CUT_DIRNAME'], hparams['VALID_CSV_NAME'])
        big_data_valid_csv = os.path.join(datadir, hparams['SIZE_CUT_DIRNAME'], hparams['BIG_VALID_CSV_NAME'])
    else:
        data_valid_csv = os.path.join(datadir, hparams['VALID_CSV_NAME'])
        big_data_valid_csv = os.path.join(datadir, hparams['BIG_VALID_CSV_NAME'])


    print('\n\n----->\tWARNING: Runnung tests....\n\n')
    # logdir = os.path.join(hparams['LOGGING_PARENT_DIR'], hparams['DATACASE'])
    restore_file = hparams['RESTORE_FILE']

    model_class = get_model_by_name(hparams['MODEL_NAME'])
    model = model_class(atom_features = hparams['ATOM_FEATURES'], bond_features = hparams['BOND_FEATURES'], out_features = hparams['OUT_FEATURES'], seq_d_model = hparams['SEQ_EMBEDDING_SIZE'])

    if hparams['SELF_LOOPS']:
    # if isinstance(model, Simple_GAT_model) or isinstance(model, Transformer_GAT_model):
        hparams['PADDING_N_EDGE'] = hparams['PADDING_N_EDGE'] + hparams['PADDING_N_NODE'] # NOTE: Because of self_loops
        if hparams['SIZE_CUT_DIRNAME'] is not None:
            hparams['BIG_PADDING_N_EDGE'] = hparams['BIG_PADDING_N_EDGE'] + hparams['BIG_PADDING_N_NODE'] # NOTE: Because of self_loops
        if len(hparams['BOND_FEATURES']) > 0:
            raise ValueError('Can not have both bond features and self_loops.')

    logger = logging.getLogger('main_eval')
    # logger.setLevel(logging.DEBUG)
    logger.setLevel(logging.INFO)
    # _datetime = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    # logdir = os.path.join(logdir, model.__class__.__name__, _datetime)
    # os.makedirs(logdir)
    # os.mkdir(os.path.join(logdir, 'ckpts'))
    # logger_file_handler = logging.FileHandler(os.path.join(logdir, 'run.log'))
    logger_stdout_handler = logging.StreamHandler(sys.stdout)
    # logger.addHandler(logger_file_handler)
    logger.addHandler(logger_stdout_handler)
    

    # ---------
    # Datasets:
    # ---------
    import tables
    h5file = tables.open_file(hparams['H5FILE'], mode = 'r', title=hparams['H5FILE_TITLE'])
    # h5_table = h5file.root.amino.table
    h5_table = h5file.root.amino.table

    collate = AminoCollatePrecompute(h5_table, 
                                    padding_n_node = hparams['PADDING_N_NODE'], 
                                    padding_n_edge = hparams['PADDING_N_EDGE'],
                                    n_partitions = hparams['N_PARTITIONS'],
                                    from_disk = hparams['PYTABLE_FROM_DISK'],
                                    line_graph = hparams['LINE_GRAPH'])

    element = AminoElementPrecompute(bert_table = h5_table,
                                    padding_n_node = hparams['PADDING_N_NODE'], 
                                    padding_n_edge = hparams['PADDING_N_EDGE'],
                                    from_disk = hparams['PYTABLE_FROM_DISK'])

    if hparams['SIZE_CUT_DIRNAME'] is not None:
        big_collate = AminoCollatePrecompute(h5_table, 
                                                padding_n_node = hparams['BIG_PADDING_N_NODE'], 
                                                padding_n_edge = hparams['BIG_PADDING_N_EDGE'],
                                                n_partitions = hparams['N_PARTITIONS'],
                                                from_disk = hparams['PYTABLE_FROM_DISK'],
                                                line_graph = hparams['LINE_GRAPH'])

        big_element = AminoElementPrecompute(bert_table = h5_table,
                                    padding_n_node = hparams['BIG_PADDING_N_NODE'], 
                                    padding_n_edge = hparams['BIG_PADDING_N_EDGE'],
                                    from_disk = hparams['PYTABLE_FROM_DISK'])

    if not hparams['PYTABLE_FROM_DISK']:
        h5file.close()
        print('Table closed...')


    valid_dataset = AminoDatasetPrecompute(data_csv = data_valid_csv,
                        mols_csv = os.path.join(hparams['DATA_PARENT_DIR'], hparams['MOLS_CSV']),
                        mol_id_col = hparams['MOL_ID_COL'], 
                        mol_col = hparams['MOL_COL'],
                        seq_id_col = hparams['SEQ_ID_COL'], # Gene is only sequence id.
                        label_col = hparams['LABEL_COL'],
                        weight_col = hparams['VALID_WEIGHT_COL'],
                        atom_features = model.atom_features, # ['AtomicNum', 'ChiralTag', 'Hybridization', 'FormalCharge', 
                                # 'NumImplicitHs', 'ExplicitValence', 'Mass', 'IsAromatic'],
                        bond_features = model.bond_features, # ['BondType', 'IsAromatic'],
                        line_graph_max_size = hparams['LINE_GRAPH_MAX_SIZE_MULTIPLIER'] * collate.padding_n_node,
                        self_loops = hparams['SELF_LOOPS'],
                        line_graph = hparams['LINE_GRAPH'],
                        auxiliary_label_cols = hparams['AUXILIARY_LABEL_COLS'],
                        auxiliary_weight_cols = hparams['AUXILIARY_WEIGHT_COLS'],
                        mol_global_cols = hparams['MOL_GLOBAL_COLS'],
                        seq_global_cols = hparams['SEQ_GLOBAL_COLS'],
                        )

    valid_loader = AminoLoader(valid_dataset, 
                        batch_size = hparams['BATCH_SIZE'],
                        collate_fn = collate.make_collate(),
                        shuffle = False,  # NOTE: shuffle is redundant for tf.data.Dataset here.
                        rng = jax.random.PRNGKey(int(time.time())),
                        drop_last = False,
                        n_partitions = hparams['N_PARTITIONS'])

    if hparams['SIZE_CUT_DIRNAME'] is not None:
        _big_valid_dataset = AminoDatasetPrecompute(data_csv = big_data_valid_csv,
                        mols_csv = os.path.join(hparams['DATA_PARENT_DIR'], hparams['MOLS_CSV']),
                        mol_id_col = hparams['MOL_ID_COL'],
                        mol_col = hparams['MOL_COL'],
                        seq_id_col = hparams['SEQ_ID_COL'], # Gene is only sequence id.
                        label_col = hparams['LABEL_COL'],
                        weight_col = hparams['VALID_WEIGHT_COL'],
                        atom_features = model.atom_features, # ['AtomicNum', 'ChiralTag', 'Hybridization', 'FormalCharge', 
                                # 'NumImplicitHs', 'ExplicitValence', 'Mass', 'IsAromatic'],
                        bond_features = model.bond_features, # ['BondType', 'IsAromatic'],
                        line_graph_max_size = hparams['LINE_GRAPH_MAX_SIZE_MULTIPLIER'] * big_collate.padding_n_node,
                        self_loops = hparams['SELF_LOOPS'],
                        line_graph = hparams['LINE_GRAPH'],
                        auxiliary_label_cols = hparams['AUXILIARY_LABEL_COLS'],
                        auxiliary_weight_cols = hparams['AUXILIARY_WEIGHT_COLS'],
                        mol_global_cols = hparams['MOL_GLOBAL_COLS'],
                        seq_global_cols = hparams['SEQ_GLOBAL_COLS'],
                        )
        big_valid_dataset = valid_dataset + _big_valid_dataset

        big_valid_loader = AminoLoader(big_valid_dataset, 
                                batch_size = hparams['BIG_BATCH_SIZE'],
                                collate_fn = big_collate.make_collate(),
                                shuffle = False,  # NOTE: shuffle is redundant for tf.data.Dataset here.
                                rng = jax.random.PRNGKey(int(time.time())),
                                drop_last = False,
                                n_partitions = hparams['N_PARTITIONS'])

    # if hparams['LOADER_OUTPUT_TYPE'] == 'jax':
    #     valid_loader = _valid_loader
    #     if hparams['SIZE_CUT_DIRNAME'] is not None:
    #         big_valid_loader = _big_valid_loader

    if hparams['LOADER_OUTPUT_TYPE'] == 'tf':
        valid_dataset.element_preprocess = element.make_element_preprocess()

        valid_loader = get_tf_loader(valid_dataset, 
                                    batch_size = hparams['BATCH_SIZE'],
                                    use_cache = hparams['CACHE'],
                                    shuffle = True, 
                                    shuffle_buffer_size = hparams['SHUFFLE_BUFFER_SIZE'],
                                    drop_last = True)

        if hparams['SIZE_CUT_DIRNAME'] is not None:
            big_valid_dataset.element_preprocess = big_element.make_element_preprocess()

            big_valid_loader = get_tf_loader(big_valid_dataset, 
                                            batch_size = hparams['BIG_BATCH_SIZE'],
                                            use_cache = hparams['CACHE'],
                                            shuffle = True, 
                                            shuffle_buffer_size = hparams['SHUFFLE_BUFFER_SIZE'],
                                            drop_last = True)


    # ----------------
    # Initializations:
    # ----------------
    key1, key2 = jax.random.split(jax.random.PRNGKey(int(time.time())), 2)
    key_params, _key_num_steps, key_num_steps, key_dropout = jax.random.split(key1, 4)

    logger.info('jax_version = {}'.format(jax.__version__))
    logger.info('flax_version = {}'.format(flax.__version__))
    logger.info('from_disk = {}'.format(hparams['PYTABLE_FROM_DISK']))
    logger.info('model_name = {}'.format(hparams['MODEL_NAME']))
    logger.info('loader_output_type = {}'.format(hparams['LOADER_OUTPUT_TYPE']))

    # Initializations:
    start = time.time()
    logger.info('Initializing...')
    init_model = make_init_model(model, 
                                batch_size = hparams['BATCH_SIZE'], 
                                seq_embedding_size = hparams['SEQ_EMBEDDING_SIZE'], 
                                num_node_features = len(hparams['ATOM_FEATURES']), 
                                num_edge_features = len(hparams['BOND_FEATURES']), 
                                self_loops = hparams['SELF_LOOPS'], 
                                line_graph = hparams['LINE_GRAPH'],
                                seq_max_length = hparams['SEQ_MAX_LENGTH'],
                                padding_n_node = hparams['PADDING_N_NODE'], 
                                padding_n_edge = hparams['PADDING_N_EDGE']) # 768)
    params = init_model(rngs = {'params' : key_params, 'dropout' : key_dropout, 'num_steps' : _key_num_steps}) # jax.random.split(key1, jax.device_count()))
    end = time.time()
    logger.info('TIME: init_model: {}'.format(end - start))

    transition_steps = hparams['OPTIMIZATION']['TRANSITION_EPOCHS']*(len(valid_dataset)/hparams['BATCH_SIZE'])
    if hparams['SIZE_CUT_DIRNAME'] is not None:
        transition_steps += hparams['OPTIMIZATION']['TRANSITION_EPOCHS']*(len(_big_valid_dataset)/hparams['BIG_BATCH_SIZE'])
    create_optimizer = make_create_optimizer(model, option = hparams['OPTIMIZATION']['OPTION'], warmup_steps = hparams['OPTIMIZATION']['WARMUP_STEPS'], transition_steps = transition_steps)
    init_state, scheduler = create_optimizer(params, rngs = {'dropout' : key_dropout, 'num_steps' : key_num_steps}, learning_rate = hparams['LEARNING_RATE'])

    # Restore params:
    if restore_file is not None:
        logger.info('Restoring parameters from {}'.format(restore_file))
        with open(restore_file, 'rb') as pklfile:
            bytes_output = pickle.load(pklfile)
        state = serialization.from_bytes(init_state, bytes_output)
        logger.info('Parameters restored...')
    else:
        state = init_state    

    reg_loss_func_embed = make_regularization_loss(params_path = ['params/atomic_num_embed/node_embed/embedding',
                                                            'params/chiral_tag_embed/node_embed/embedding',
                                                            'params/hybridization_embed/node_embed/embedding',
                                                            'params/X_proj_non_embeded/kernel',
                                                            'params/X_proj_non_embeded/bias',
                                                            'params/bond_type_embed/edge_embed/embedding',
                                                            'params/E_proj_non_embeded/kernel',
                                                            'params/E_proj_non_embeded/bias',
                                                            'params/stereo_embed/edge_embed/embedding',
                                                            ], alpha = 0.01, option = 'l1')
    reg_loss_func_kernel = make_regularization_loss(params_path = ['kernel',
                                                            ], alpha = 0.01, option = 'l2')
    def reg_loss_func(params):
        return reg_loss_func_embed(params) + reg_loss_func_kernel(params)

    if hparams['N_PARTITIONS'] > 0:
        raise NotImplementedError('pmap is not implemented correctly.')
        # state = flax.jax_utils.replicate(state)
    else:
        # train_epoch = make_train_epoch(is_weighted = True, loss_option = hparams['LOSS_OPTION'], init_rngs = state.rngs, logger = logger, reg_loss_func = reg_loss_func, loader_output_type = hparams['LOADER_OUTPUT_TYPE'], num_classes = model.out_features)
        valid_epoch = make_valid_epoch(is_weighted = hparams['VALID_WEIGHT_COL'] is not None,
                                    loss_option = hparams['LOSS_OPTION'], 
                                    logger = logger,
                                    aux_loss_option = hparams['AUXILIARY_LOSS_OPTION'], 
                                    loader_output_type = hparams['LOADER_OUTPUT_TYPE'], 
                                    num_classes = model.out_features)

    # ------
    # VALID:
    # ------
    start = time.time()
    if hparams['SIZE_CUT_DIRNAME'] is not None:
        valid_metrics = valid_epoch(state, big_valid_loader)
    else:
        valid_metrics = valid_epoch(state, valid_loader)
    end = time.time()
    logger.info('TIME: valid_epoch: {}'.format(end - start))
    valid_metrics_np = jax.device_get(valid_metrics)

    summary = Summary()
    summary = log_metrics_from_epoch(name = 'eval',
                            epoch = 0,
                            hparams = hparams, 
                            metrics_np = valid_metrics_np, 
                            logger = logger, 
                            summary = summary)

    if hparams['SIZE_CUT_DIRNAME'] is not None:
        num_test_data = len(big_valid_dataset)
        data_csv = os.path.join(datadir, hparams['SIZE_CUT_DIRNAME'], hparams['VALID_CSV_NAME'] + '__' + hparams['BIG_VALID_CSV_NAME'])
    else:
        num_test_data = len(valid_dataset)
        data_csv = os.path.join(datadir, hparams['VALID_CSV_NAME'])
    result = {'DATA_CSV' : data_csv, 
                'RESTORE_FILE' : hparams['RESTORE_FILE'],
                'H5FILE' : hparams['H5FILE'],
                'NUM_TEST_DATA' : num_test_data
                }

    if summary is not None:
        for key in summary.keys():
            print(key)
            if isinstance(summary[key], Image):
                pass
            elif isinstance(summary[key], DelayedScalar):
                print(summary[key].values)
                result[key] = jax.tree_map(lambda x: float(x), summary[key].values)
            print('--------')

    print(result)
    logger.info('Finished...')
    return result