import os
# os.environ['TF_FORCE_UNIFIED_MEMORY'] = '1'
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import sys
import functools
import pickle
import time
import datetime
import json
import numpy as np
import jax
from jax import numpy as jnp
import flax
from flax import serialization
from objax.jaxboard import SummaryWriter, Summary

import tensorflow as tf

from ProtLig_GPCRclassA.metrics import log_confusion_matrix

# from ProtLig_GPCRclassA.amino_GNN.loader import AminoDatasetPrecompute, AminoLoader, AminoCollatePrecompute
from ProtLig_GPCRclassA.amino_GNN.dataset import AminoDatasetPrecompute
from ProtLig_GPCRclassA.amino_GNN.collate import AminoCollatePrecompute
from ProtLig_GPCRclassA.amino_GNN.element import AminoElementPrecomputeMasked
from ProtLig_GPCRclassA.amino_GNN.loader_NEW import AminoLoader, get_tf_loader, get_tf_loader_masked

from ProtLig_GPCRclassA.utils import _serialize_hparam, prefetch_to_device

from ProtLig_GPCRclassA.amino_GNN.make_loss_func import make_loss_func

from ProtLig_GPCRclassA.amino_GNN.base.make_init import make_init_model, get_tf_specs
from ProtLig_GPCRclassA.amino_GNN.make_create_optimizer import make_create_optimizer
from ProtLig_GPCRclassA.amino_GNN.make_compute_metrics import log_metrics_from_epoch, make_compute_metrics
# from ProtLig_GPCRclassA.amino_GNN.base.train.make_train_epoch import make_train_epoch
from ProtLig_GPCRclassA.amino_GNN.base.train.make_valid_masked_epoch import make_valid_masked_epoch

from ProtLig_GPCRclassA.amino_GNN.make_regularization_loss import make_regularization_loss

from ProtLig_GPCRclassA.amino_GNN.select_model import get_model_by_name

import logging

# jax.config.update('jax_platform_name', 'cpu')
# print(jnp.ones(3).device_buffer.device())


def main_eval_masked_ckpts(hparams):
    datadir = os.path.join(hparams['DATA_PARENT_DIR'], hparams['DATACASE'])

    if hparams['SEQ_MODEL_NAME'] in ['esm2_t33_650M_UR50D', 'esm2_t48_15B_UR50D']:
        from transformers import EsmTokenizer
        tokenizer = EsmTokenizer.from_pretrained(hparams['SEQ_MODEL_TOKENIZER_PATH'], cache_dir = hparams['HUGGINGFACE_CACHE_DIR'])
        # tokenizer = EsmTokenizer.from_pretrained("facebook/" + hparams['SEQ_MODEL_NAME'], cache_dir = hparams['HUGGINGFACE_CACHE_DIR'])
    elif hparams['SEQ_MODEL_NAME'] == 'ProtBERT':
        from transformers import BertTokenizer
        tokenizer = BertTokenizer.from_pretrained(hparams['SEQ_MODEL_TOKENIZER_PATH'], do_lower_case=False, cache_dir = hparams['HUGGINGFACE_CACHE_DIR'])
        # tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False, cache_dir = hparams['HUGGINGFACE_CACHE_DIR'])

    logdir = hparams['RESTORE_MODEL_DIR']
    restore_ckpts_dir = os.path.join(hparams['RESTORE_MODEL_DIR'], 'ckpts')

    model_class = get_model_by_name(hparams['MODEL_NAME'])
    model = model_class(atom_features = hparams['ATOM_FEATURES'],
                        bond_features = hparams['BOND_FEATURES'],
                        out_features = hparams['OUT_FEATURES'],
                        seq_d_model = hparams['SEQ_EMBEDDING_SIZE'],
                        vocab_size = len(tokenizer))


    logger = logging.getLogger('main_eval_masked_ckpts')
    # logger.setLevel(logging.DEBUG)
    logger.setLevel(logging.INFO)
    _datetime = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    logdir = os.path.join(logdir, 'eval_ckpts', _datetime)
    os.makedirs(logdir)
    # os.mkdir(os.path.join(logdir, 'ckpts'))
    logger_file_handler = logging.FileHandler(os.path.join(logdir, 'run_eval_masked_ckpts.log'))
    logger_stdout_handler = logging.StreamHandler(sys.stdout)
    logger.addHandler(logger_file_handler)
    logger.addHandler(logger_stdout_handler)

    logger.info('jax_version = {}'.format(jax.__version__))
    logger.info('flax_version = {}'.format(flax.__version__))
    logger.info('from_disk = {}'.format(hparams['PYTABLE_FROM_DISK']))
    logger.info('model_name = {}'.format(hparams['MODEL_NAME']))
    logger.info('loader_output_type = {}'.format(hparams['LOADER_OUTPUT_TYPE']))

    # ---------
    # Datasets:
    # ---------
    import tables
    h5file = tables.open_file(hparams['H5FILE'], mode = 'r', title=hparams['H5FILE_TITLE'])
    h5_table = h5file.root.amino.table # h5file.root.bert.BERTtable

    data_train_csv = os.path.join(datadir, hparams['TRAIN_CSV_NAME'])
    data_valid_csv = os.path.join(datadir, hparams['VALID_CSV_NAME'])
    mols_csv = os.path.join(hparams['DATA_PARENT_DIR'], hparams['MOLS_CSV'])
    seqs_csv = os.path.join(hparams['DATA_PARENT_DIR'], hparams['SEQS_CSV'])


    element = AminoElementPrecomputeMasked(tokenizer = tokenizer,
                                 seq_max_length = hparams['SEQ_MAX_LENGTH'],
                                 bert_table = h5_table,
                                 padding_n_node = hparams['PADDING_N_NODE'], 
                                 padding_n_edge = hparams['PADDING_N_EDGE'],
                                 from_disk = hparams['PYTABLE_FROM_DISK'],
                                 seq_lookup = hparams['CACHE_SEQ_LOOKUP'],
                                 seq_col = hparams['SEQ_COL'])
        
    dataset = AminoDatasetPrecompute(data_csv = data_train_csv,
                        mols_csv = mols_csv,
                        seqs_csv = seqs_csv,
                        mol_id_col = hparams['MOL_ID_COL'],
                        mol_col = hparams['MOL_COL'],
                        seq_id_col = hparams['SEQ_ID_COL'],
                        label_col = hparams['LABEL_COL'],
                        weight_col = hparams['WEIGHT_COL'],
                        atom_features = model.atom_features,
                        bond_features = model.bond_features,
                        class_alpha = hparams['CLASS_ALPHA'],
                        line_graph_max_size = hparams['LINE_GRAPH_MAX_SIZE_MULTIPLIER'] * hparams['PADDING_N_NODE'],
                        self_loops = hparams['SELF_LOOPS'],
                        line_graph = hparams['LINE_GRAPH'],
                        auxiliary_label_cols = hparams['AUXILIARY_LABEL_COLS'],
                        auxiliary_weight_cols = hparams['AUXILIARY_WEIGHT_COLS'],
                        mol_global_cols = hparams['MOL_GLOBAL_COLS'],
                        seq_global_cols = hparams['SEQ_GLOBAL_COLS'],
                        )

    valid_dataset = AminoDatasetPrecompute(data_csv = data_valid_csv,
                        mols_csv = mols_csv,
                        seqs_csv = seqs_csv,
                        mol_id_col = hparams['MOL_ID_COL'], 
                        mol_col = hparams['MOL_COL'],
                        seq_id_col = hparams['SEQ_ID_COL'],
                        label_col = hparams['LABEL_COL'],
                        weight_col = hparams['VALID_WEIGHT_COL'],
                        atom_features = model.atom_features,
                        bond_features = model.bond_features,
                        line_graph_max_size = hparams['LINE_GRAPH_MAX_SIZE_MULTIPLIER'] * hparams['PADDING_N_NODE'],
                        self_loops = hparams['SELF_LOOPS'],
                        line_graph = hparams['LINE_GRAPH'],
                        auxiliary_label_cols = hparams['AUXILIARY_LABEL_COLS'],
                        auxiliary_weight_cols = hparams['AUXILIARY_WEIGHT_COLS'],
                        mol_global_cols = hparams['MOL_GLOBAL_COLS'],
                        seq_global_cols = hparams['SEQ_GLOBAL_COLS'],
                        )

    if hparams['CACHE_SEQ_LOOKUP']:
        id_mapping_table, seq_embedding_lookup = element.create_seq_embedding_lookups()
    else:
        id_mapping_table, seq_embedding_lookup = None, None
    logger.info('Train dataset size: {}'.format(len(dataset)))
    logger.info('Valid dataset size: {}'.format(len(valid_dataset)))

    if hparams['LOADER_OUTPUT_TYPE'] == 'tf':
        dataset.element_preprocess = element.make_element_preprocess()
        valid_dataset.element_preprocess = element.make_element_preprocess()

        loader = get_tf_loader_masked(dataset,
                               batch_size = hparams['BATCH_SIZE'],
                               use_cache = hparams['CACHE'],
                               shuffle = False,
                               shuffle_buffer_size = hparams['SHUFFLE_BUFFER_SIZE'],
                               drop_last = False,
                               id_mapping_table = id_mapping_table,
                               seq_embedding_lookup = seq_embedding_lookup,
                               n_partitions = hparams['N_PARTITIONS'])
        valid_loader = get_tf_loader_masked(valid_dataset, 
                                    batch_size = hparams['BATCH_SIZE'],
                                    use_cache = hparams['CACHE'],
                                    shuffle = False,
                                    shuffle_buffer_size = hparams['SHUFFLE_BUFFER_SIZE'],
                                    drop_last = False,
                                    id_mapping_table = id_mapping_table,
                                    seq_embedding_lookup = seq_embedding_lookup,
                                    n_partitions = hparams['N_PARTITIONS'])
    else:
        raise NotImplementedError('jax loader is not supported.')

    if not hparams['PYTABLE_FROM_DISK']:
        h5file.close()
        logger.info('Table closed...: {}'.format(hparams['H5FILE']))

    # ----------------
    # Initializations:
    # ----------------
    prng_key = jax.random.PRNGKey(int(time.time()))
    key_params, _key_num_steps, key_num_steps, key_dropout, seq_mask_key = jax.random.split(prng_key, 5)

    # Initializations:
    start = time.time()
    logger.info('Initializing...')
    init_model = make_init_model(model, 
                                batch_size = hparams['BATCH_SIZE'], 
                                seq_embedding_size = hparams['SEQ_EMBEDDING_SIZE'], 
                                num_node_features = len(hparams['ATOM_FEATURES']), 
                                num_edge_features = len(hparams['BOND_FEATURES']), 
                                self_loops = hparams['SELF_LOOPS'], 
                                line_graph = hparams['LINE_GRAPH'],
                                seq_max_length = hparams['SEQ_MAX_LENGTH'],
                                padding_n_node = hparams['PADDING_N_NODE'], 
                                padding_n_edge = hparams['PADDING_N_EDGE']) # 768)
    params = init_model(rngs = {'params' : key_params, 'dropout' : key_dropout, 'num_steps' : _key_num_steps}) # jax.random.split(key1, jax.device_count()))
    end = time.time()
    logger.info('TIME: init_model: {}'.format(end - start))

    transition_steps = hparams['OPTIMIZATION']['TRANSITION_EPOCHS']*(len(dataset)/hparams['BATCH_SIZE'])
    # if hparams['SIZE_CUT_DIRNAME'] is not None:
    #     transition_steps += hparams['OPTIMIZATION']['TRANSITION_EPOCHS']*(len(_big_dataset)/hparams['BIG_BATCH_SIZE'])

    create_optimizer = make_create_optimizer(model, 
                                            option = hparams['OPTIMIZATION']['OPTION'],
                                            warmup_steps = hparams['OPTIMIZATION']['WARMUP_STEPS'],
                                            transition_steps = transition_steps)
    init_state, scheduler = create_optimizer(params, 
                                            rngs = {'dropout' : key_dropout, 'num_steps' : key_num_steps, 'seq_mask' : seq_mask_key},
                                            learning_rate = hparams['LEARNING_RATE'])

    reg_loss_func_embed = make_regularization_loss(params_path = ['params/atomic_num_embed/node_embed/embedding',
                                                            'params/chiral_tag_embed/node_embed/embedding',
                                                            'params/hybridization_embed/node_embed/embedding',
                                                            'params/X_proj_non_embeded/kernel',
                                                            'params/X_proj_non_embeded/bias',
                                                            'params/bond_type_embed/edge_embed/embedding',
                                                            'params/E_proj_non_embeded/kernel',
                                                            'params/E_proj_non_embeded/bias',
                                                            'params/stereo_embed/edge_embed/embedding',
                                                            ], alpha = 0.01, option = 'l1')
    reg_loss_func_kernel = make_regularization_loss(params_path = ['kernel',
                                                            ], alpha = 0.01, option = 'l2')
    def reg_loss_func(params):
        return reg_loss_func_embed(params) + reg_loss_func_kernel(params)

    if hparams['N_PARTITIONS'] > 0:
        raise NotImplementedError('pmap needs to be checked...')
    else:
        valid_epoch = make_valid_masked_epoch(loss_option = hparams['LOSS_OPTION'], 
                                    logger = logger,
                                    aux_loss_option = hparams['AUXILIARY_LOSS_OPTION'], 
                                    loader_output_type = hparams['LOADER_OUTPUT_TYPE'], 
                                    num_classes = model.out_features,
                                    mask_token_id = tokenizer.mask_token_id,
                                    cls_token_id = tokenizer.cls_token_id,
                                    pad_token_id = tokenizer.pad_token_id,
                                    sep_token_id = tokenizer.sep_token_id,
                                    unk_token_id = tokenizer.unk_token_id,
                                    eos_token_id = tokenizer.eos_token_id,
                                    bos_token_id = tokenizer.bos_token_id,)

    # Log hyperparams:
    _hparams = {}
    for key in hparams.keys():
        _hparams[key] = _serialize_hparam(hparams[key])
    hparams_logs = _hparams
    with open(os.path.join(logdir, 'hparams_logs.json'), 'w') as jsonfile:
        json.dump(hparams_logs, jsonfile)

    # Training:
    logger.info('Training...')
    train_writer = SummaryWriter(os.path.join(logdir, 'train'))
    valid_writer = SummaryWriter(os.path.join(logdir, 'validation'))

    # Order checkpoints:
    ckpt_filenames = os.listdir(restore_ckpts_dir)
    ckpt_epochs = [int(file.replace('state_e', '').replace('.pkl', '')) for file in ckpt_filenames]
    ordered_ckpt_idx = np.argsort(ckpt_epochs)
    ordered_ckpt_filenames = [ckpt_filenames[i] for i in ordered_ckpt_idx]

    for filename in ordered_ckpt_filenames:
        restore_file = os.path.join(restore_ckpts_dir, filename)

        # Restore params:
        if restore_file is not None:
            logger.info('Restoring parameters from {}'.format(restore_file))
            with open(restore_file, 'rb') as pklfile:
                bytes_output = pickle.load(pklfile)
            state = serialization.from_bytes(init_state, bytes_output)
            logger.info('Parameters restored...')

        epoch = state.epoch

        logger.info('Epoch:  {}'.format(epoch))

        start = time.time()
        state, train_metrics = valid_epoch(state, loader)
        end = time.time()
        logger.info('TIME: train_epoch: {}'.format(end - start))
        train_metrics_np = jax.device_get(train_metrics)
        lr = scheduler(state.step)

        summary = Summary()
        summary.scalar('learning_rate', lr)
        summary = log_metrics_from_epoch(name = 'train',
                            epoch = epoch, 
                            hparams = hparams, 
                            metrics_np = train_metrics_np, 
                            logger = logger, 
                            summary = summary)
        summary.scalar('regularization_loss', reg_loss_func(state.params))

        train_writer.write(summary, step = int(epoch))
        train_writer.writer.flush()

        # VALID:
        start = time.time()
        state, valid_metrics = valid_epoch(state, valid_loader)
        end = time.time()
        logger.info('TIME: valid_epoch: {}'.format(end - start))
        valid_metrics_np = jax.device_get(valid_metrics)

        summary = Summary()
        summary = log_metrics_from_epoch(name = 'valid',
                            epoch = epoch, 
                            hparams = hparams, 
                            metrics_np = valid_metrics_np, 
                            logger = logger, 
                            summary = summary)

        valid_writer.write(summary, step = int(epoch))
        valid_writer.writer.flush()


    train_writer.close()
    valid_writer.close()
    logger.info('Finished...')
    return None