import os
# os.environ['TF_FORCE_UNIFIED_MEMORY'] = '1'
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import sys
import copy
import functools
import pickle
import time
import datetime
import json
import numpy as np
import jax
from jax import numpy as jnp
import flax
from flax import serialization
from objax.jaxboard import SummaryWriter, Summary
import tensorflow as tf

from ProtLig_GPCRclassA.metrics import log_confusion_matrix

# from ProtLig_GPCRclassA.amino_GNN.loader import AminoDatasetPrecompute, AminoLoader, AminoCollatePrecompute
from ProtLig_GPCRclassA.amino_GNN.dataset import AminoDatasetPrecompute
# from ProtLig_GPCRclassA.amino_GNN.collate import AminoCollatePrecomputeMasked
from ProtLig_GPCRclassA.amino_GNN.element import AminoElementPrecomputeMasked
from ProtLig_GPCRclassA.amino_GNN.loader import AminoLoader, get_tf_loader, get_tf_loader_masked

from ProtLig_GPCRclassA.utils import _serialize_hparam, prefetch_to_device, get_last_datetime, get_last_state

from ProtLig_GPCRclassA.amino_GNN.make_loss_func import make_loss_func

from ProtLig_GPCRclassA.amino_GNN.base.make_init import make_init_model, get_tf_specs
from ProtLig_GPCRclassA.amino_GNN.make_create_optimizer import make_create_optimizer
from ProtLig_GPCRclassA.amino_GNN.make_compute_metrics import log_metrics_from_epoch, make_compute_metrics
from ProtLig_GPCRclassA.amino_GNN.base.train.make_train_masked_epoch import make_train_masked_epoch

from ProtLig_GPCRclassA.amino_GNN.make_regularization_loss import make_regularization_loss

from ProtLig_GPCRclassA.amino_GNN.select_model import get_model_by_name

import logging

# jax.config.update('jax_platform_name', 'cpu')
# print(jnp.ones(3).device_buffer.device())


def main_train_masked(hparams):
    datadir = os.path.join(hparams['DATA_PARENT_DIR'], hparams['DATACASE'])

    if hparams['SEQ_MODEL_NAME'] in ['esm2_t33_650M_UR50D', 'esm2_t48_15B_UR50D']:
        from transformers import EsmTokenizer
        tokenizer = EsmTokenizer.from_pretrained(hparams['SEQ_MODEL_TOKENIZER_PATH'], cache_dir = hparams['HUGGINGFACE_CACHE_DIR'])
        # tokenizer = EsmTokenizer.from_pretrained("facebook/" + hparams['SEQ_MODEL_NAME'], cache_dir = hparams['HUGGINGFACE_CACHE_DIR'])
    elif hparams['SEQ_MODEL_NAME'] == 'ProtBERT':
        from transformers import BertTokenizer
        tokenizer = BertTokenizer.from_pretrained(hparams['SEQ_MODEL_TOKENIZER_PATH'], do_lower_case=False, cache_dir = hparams['HUGGINGFACE_CACHE_DIR'])
        # tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False, cache_dir = hparams['HUGGINGFACE_CACHE_DIR'])

    logdir = os.path.join(hparams['LOGGING_PARENT_DIR'], hparams['DATACASE'])
    restore_file = hparams['RESTORE_FILE']

    model_class = get_model_by_name(hparams['MODEL_NAME'])
    model = model_class(atom_features = hparams['ATOM_FEATURES'],
                        bond_features = hparams['BOND_FEATURES'],
                        out_features = hparams['OUT_FEATURES'],
                        seq_d_model = hparams['SEQ_EMBEDDING_SIZE'],
                        vocab_size = len(tokenizer))

    logdir = os.path.join(logdir, model.__class__.__name__)

    if hparams['SLURM_JOB_ARRAY']:
        logdir = os.path.join(logdir, hparams['SLURM_ARRAY_JOB_ID'])
    logger = logging.getLogger('main_train_masked')
    # logger.setLevel(logging.DEBUG)
    logger.setLevel(logging.INFO)
    _datetime = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    logdir = os.path.join(logdir, _datetime)
    os.makedirs(logdir)
    os.mkdir(os.path.join(logdir, 'ckpts'))
    logger_file_handler = logging.FileHandler(os.path.join(logdir, 'run.log'))
    logger_stdout_handler = logging.StreamHandler(sys.stdout)
    logger.addHandler(logger_file_handler)
    logger.addHandler(logger_stdout_handler)

    logger.info('jax_version = {}'.format(jax.__version__))
    logger.info('flax_version = {}'.format(flax.__version__))
    logger.info('from_disk = {}'.format(hparams['PYTABLE_FROM_DISK']))
    logger.info('model_name = {}'.format(hparams['MODEL_NAME']))
    logger.info('loader_output_type = {}'.format(hparams['LOADER_OUTPUT_TYPE']))

    # ---------
    # Datasets:
    # ---------
    import tables
    h5file = tables.open_file(hparams['H5FILE'], mode = 'r', title=hparams['H5FILE_TITLE'])
    h5_table = h5file.root.amino.table # h5file.root.bert.BERTtable

    elements = []
    datasets = []
    loaders = []
    for i in range(len(hparams['N_EPOCH'])):
        data_train_csv = os.path.join(datadir, hparams['TRAIN_CSV_NAME'][i])
        mols_csv = os.path.join(hparams['DATA_PARENT_DIR'], hparams['MOLS_CSV'][i])
        seqs_csv = os.path.join(hparams['DATA_PARENT_DIR'], hparams['SEQS_CSV'][i])

        _element = AminoElementPrecomputeMasked(tokenizer = tokenizer,
                                     seq_max_length = hparams['SEQ_MAX_LENGTH'],
                                     bert_table = h5_table,
                                     padding_n_node = hparams['PADDING_N_NODE'][i], 
                                     padding_n_edge = hparams['PADDING_N_EDGE'][i],
                                     from_disk = hparams['PYTABLE_FROM_DISK'],
                                     seq_lookup = hparams['CACHE_SEQ_LOOKUP'],
                                     seq_col = hparams['SEQ_COL'])
        
        _partial_dataset = AminoDatasetPrecompute(data_csv = data_train_csv,
                            mols_csv = mols_csv,
                            seqs_csv = seqs_csv,
                            mol_id_col = hparams['MOL_ID_COL'],
                            mol_col = hparams['MOL_COL'],
                            seq_id_col = hparams['SEQ_ID_COL'],
                            label_col = hparams['LABEL_COL'],
                            weight_col = hparams['WEIGHT_COL'],
                            atom_features = model.atom_features,
                            bond_features = model.bond_features,
                            class_alpha = hparams['CLASS_ALPHA'],
                            line_graph_max_size = hparams['LINE_GRAPH_MAX_SIZE_MULTIPLIER'] * hparams['PADDING_N_NODE'][i],
                            self_loops = hparams['SELF_LOOPS'],
                            line_graph = hparams['LINE_GRAPH'],
                            auxiliary_label_cols = hparams['AUXILIARY_LABEL_COLS'],
                            auxiliary_weight_cols = hparams['AUXILIARY_WEIGHT_COLS'],
                            mol_global_cols = hparams['MOL_GLOBAL_COLS'],
                            seq_global_cols = hparams['SEQ_GLOBAL_COLS'],
                            )
        if i == 0:
            _dataset = copy.deepcopy(_partial_dataset)
            if hparams['CACHE_SEQ_LOOKUP']:
                id_mapping_table, seq_embedding_lookup = _element.create_seq_embedding_lookups()
            else:
                id_mapping_table, seq_embedding_lookup = None, None
        else:
            _dataset = _dataset + _partial_dataset
            logger.info('Adding {} records for train dataset {}'.format(len(_partial_dataset), i))
        logger.info('Train dataset {} size: {}'.format(i, len(_dataset)))

        if hparams['LOADER_OUTPUT_TYPE'] == 'tf':
            _dataset.element_preprocess = _element.make_element_preprocess()
            if hparams['SAMPLE_BY_LABEL_DISTRIBUTION']:
                from ProtLig_GPCRclassA.amino_GNN.loader import get_tf_loader_masked_sample_by_label_distribution
                _loader = get_tf_loader_masked_sample_by_label_distribution(_dataset,
                                       batch_size = hparams['BATCH_SIZE'][i],
                                       use_cache = hparams['CACHE'],
                                       shuffle = True,
                                       shuffle_buffer_size = hparams['SHUFFLE_BUFFER_SIZE'],
                                       drop_last = True,
                                       id_mapping_table = id_mapping_table,
                                       seq_embedding_lookup = seq_embedding_lookup,
                                       n_partitions = hparams['N_PARTITIONS'])
            else:
                _loader = get_tf_loader_masked(_dataset,
                                       batch_size = hparams['BATCH_SIZE'][i],
                                       use_cache = hparams['CACHE'],
                                       shuffle = True,
                                       shuffle_buffer_size = hparams['SHUFFLE_BUFFER_SIZE'],
                                       drop_last = True,
                                       id_mapping_table = id_mapping_table,
                                       seq_embedding_lookup = seq_embedding_lookup,
                                       n_partitions = hparams['N_PARTITIONS'])
        else:
            raise NotImplementedError('jax loader is not supported for pmap.')

        elements.append(_element)
        datasets.append(_dataset)
        loaders.append(_loader)

    if not hparams['PYTABLE_FROM_DISK']:
        h5file.close()
        logger.info('Table closed...: {}'.format(hparams['H5FILE']))

    # ----------------
    # Initializations:
    # ----------------
    prng_key = jax.random.PRNGKey(int(time.time()))
    key_params, _key_num_steps, key_num_steps, key_dropout, seq_mask_key = jax.random.split(prng_key, 5)

    # Initializations:
    start = time.time()
    logger.info('Initializing...')
    init_model = make_init_model(model, 
                                batch_size = hparams['BATCH_SIZE'][0],
                                n_partitions = hparams['N_PARTITIONS'],
                                seq_embedding_size = hparams['SEQ_EMBEDDING_SIZE'], 
                                num_node_features = len(hparams['ATOM_FEATURES']), 
                                num_edge_features = len(hparams['BOND_FEATURES']), 
                                self_loops = hparams['SELF_LOOPS'], 
                                line_graph = hparams['LINE_GRAPH'],
                                seq_max_length = hparams['SEQ_MAX_LENGTH'],
                                padding_n_node = hparams['PADDING_N_NODE'][0], 
                                padding_n_edge = hparams['PADDING_N_EDGE'][0]) # 768)
    params = init_model(rngs = {'params' : key_params, 'dropout' : key_dropout, 'num_steps' : _key_num_steps}) # jax.random.split(key1, jax.device_count()))
    end = time.time()
    logger.info('TIME: init_model: {}'.format(end - start))

    transition_steps = hparams['OPTIMIZATION']['TRANSITION_EPOCHS']*(len(datasets[-1])/hparams['BATCH_SIZE'][-1])
    # if hparams['SIZE_CUT_DIRNAME'] is not None:
    #     transition_steps += hparams['OPTIMIZATION']['TRANSITION_EPOCHS']*(len(_big_dataset)/hparams['BIG_BATCH_SIZE'])

    create_optimizer = make_create_optimizer(model, 
                                            option = hparams['OPTIMIZATION']['OPTION'],
                                            warmup_steps = hparams['OPTIMIZATION']['WARMUP_STEPS'],
                                            transition_steps = transition_steps)
    init_state, scheduler = create_optimizer(params, 
                                            rngs = {'dropout' : key_dropout, 'num_steps' : key_num_steps, 'seq_mask' : seq_mask_key},
                                            learning_rate = hparams['LEARNING_RATE'])

    # Restore params:
    if restore_file is not None:
        logger.info('Restoring parameters from {}'.format(restore_file))
        with open(restore_file, 'rb') as pklfile:
            bytes_output = pickle.load(pklfile)
        state = serialization.from_bytes(init_state, bytes_output)
        logger.info('Parameters restored...')
    else:
        state = init_state    

    reg_loss_func_embed = make_regularization_loss(params_path = ['params/atomic_num_embed/node_embed/embedding',
                                                            'params/chiral_tag_embed/node_embed/embedding',
                                                            'params/hybridization_embed/node_embed/embedding',
                                                            'params/X_proj_non_embeded/kernel',
                                                            'params/X_proj_non_embeded/bias',
                                                            'params/bond_type_embed/edge_embed/embedding',
                                                            'params/E_proj_non_embeded/kernel',
                                                            'params/E_proj_non_embeded/bias',
                                                            'params/stereo_embed/edge_embed/embedding',
                                                            ], alpha = 0.01, option = 'l1')
    reg_loss_func_kernel = make_regularization_loss(params_path = ['kernel',
                                                            ], alpha = 0.01, option = 'l2')
    def reg_loss_func(params):
        return reg_loss_func_embed(params) + reg_loss_func_kernel(params)

    if hparams['N_PARTITIONS'] > 0:
        raise NotImplementedError('N_PARTITIONS > 0 for single-GPU training.')
    else:
        train_epoch = make_train_masked_epoch(is_weighted = hparams['WEIGHT_COL'] is not None, 
                                        loss_option = hparams['LOSS_OPTION'], 
                                        init_rngs = state.rngs, 
                                        logger = logger,
                                        aux_loss_option = hparams['AUXILIARY_LOSS_OPTION'],
                                        reg_loss_func = reg_loss_func, 
                                        loader_output_type = hparams['LOADER_OUTPUT_TYPE'], 
                                        num_classes = model.out_features,
                                        mask_token_id = tokenizer.mask_token_id,
                                        cls_token_id = tokenizer.cls_token_id,
                                        pad_token_id = tokenizer.pad_token_id,
                                        sep_token_id = tokenizer.sep_token_id,
                                        unk_token_id = tokenizer.unk_token_id,
                                        eos_token_id = tokenizer.eos_token_id,
                                        bos_token_id = tokenizer.bos_token_id,
                                        )

    # Log hyperparams:
    _hparams = {}
    for key in hparams.keys():
        _hparams[key] = _serialize_hparam(hparams[key])
    hparams_logs = _hparams
    with open(os.path.join(logdir, 'hparams_logs.json'), 'w') as jsonfile:
        json.dump(hparams_logs, jsonfile)

    # Training:
    logger.info('Training...')
    train_writer = SummaryWriter(os.path.join(logdir, 'train'))

    epoch = state.epoch
    # Training loop:
    while epoch <= hparams['N_EPOCH'][-1]:
        for i in range(len(hparams['N_EPOCH'])): # NOTE: This assumes that N_EPOCH are sorted.
            if epoch < hparams['N_EPOCH'][i]:
                loader_switch_i = i
                break
        logger.info('Epoch:  {}'.format(epoch))

        start = time.time()
        state, batch_losses = train_epoch(state, loaders[loader_switch_i])
        end = time.time()

        logger.info('TIME: train_epoch: {}'.format(end - start))
        loss = jnp.mean(jnp.asarray(batch_losses)) # NOTE: If N_i = N_j => mean of means equals total mean.
        lr = scheduler(state.step)
        logger.info('TIME: loss: {}'.format(loss))

        summary = Summary()
        summary.scalar('epoch_loss', loss)
        summary.scalar('learning_rate', lr)
        summary.scalar('regularization_loss', reg_loss_func(state.params))

        train_writer.write(summary, step = int(epoch))
        train_writer.writer.flush()
    
        # Save current state:
        if epoch%hparams['SAVE_FREQUENCY'] == 0:
            bytes_output = serialization.to_bytes(state)
            with open(os.path.join(logdir, 'ckpts', 'state_e' + str(epoch) + '.pkl'), 'wb') as pklfile:
                pickle.dump(bytes_output, pklfile)
            logger.info('State {} saved...'.format('state_e' + str(epoch)))

        # Update epoch number:
        state = state.replace(epoch = jax.tree_map(lambda x: x + 1, state.epoch))

        # Swith loader at the end of each N_EPOCH section:
        if epoch == hparams['N_EPOCH'][loader_switch_i]: # NOTE: loader_switch_i is chosen such that it is the index of the lowest UNPROCESSED section.
            loader_switch_i += 1
            logger.info('Switching to loader {} ...'.format(loader_switch_i))

        # Get epoch for while condition:
        epoch = state.epoch


    train_writer.close()
    logger.info('Finished...')
    return None