import os
# os.environ['TF_FORCE_UNIFIED_MEMORY'] = '1'
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import sys
import functools
import pickle
import time
import datetime
import json
import numpy as np
import jax
from jax import numpy as jnp
import flax
from flax import serialization
from objax.jaxboard import SummaryWriter, Summary
import tensorflow as tf

from ProtLig_GPCRclassA.metrics import log_confusion_matrix

# from ProtLig_GPCRclassA.amino_GNN.loader import AminoDatasetPrecompute, AminoLoader, AminoCollatePrecompute
from ProtLig_GPCRclassA.amino_GNN.dataset import AminoDatasetPrecompute
# from ProtLig_GPCRclassA.amino_GNN.collate import AminoCollatePrecompute
from ProtLig_GPCRclassA.amino_GNN.element import AminoElementPrecompute
from ProtLig_GPCRclassA.amino_GNN.loader import AminoLoader, get_tf_loader

from ProtLig_GPCRclassA.utils import _serialize_hparam, prefetch_to_device, get_last_datetime, get_last_state

from ProtLig_GPCRclassA.amino_GNN.make_loss_func import make_loss_func

from ProtLig_GPCRclassA.amino_GNN.base.make_init import make_init_model, get_tf_specs
from ProtLig_GPCRclassA.amino_GNN.make_create_optimizer import make_create_optimizer
from ProtLig_GPCRclassA.amino_GNN.make_compute_metrics import log_metrics_from_epoch, make_compute_metrics
from ProtLig_GPCRclassA.amino_GNN.base.train.make_train_pmap_epoch import make_train_pmap_epoch

from ProtLig_GPCRclassA.amino_GNN.make_regularization_loss import make_regularization_loss

from ProtLig_GPCRclassA.amino_GNN.select_model import get_model_by_name

import logging

# jax.config.update('jax_platform_name', 'cpu')
# print(jnp.ones(3).device_buffer.device())


def main_train_pmap(hparams):
    if hparams['SIZE_CUT_DIRNAME'] is None:
        hparams.update({'BIG_SWITCH_EPOCH' : hparams['N_EPOCH'] + 1})

    datadir = os.path.join(hparams['DATA_PARENT_DIR'], hparams['DATACASE'])
    if hparams['SIZE_CUT_DIRNAME'] is not None:
        data_train_csv = os.path.join(datadir, hparams['SIZE_CUT_DIRNAME'], hparams['TRAIN_CSV_NAME'])
        big_data_train_csv = os.path.join(datadir, hparams['SIZE_CUT_DIRNAME'], hparams['BIG_TRAIN_CSV_NAME'])
    else:
        data_train_csv = os.path.join(datadir, hparams['TRAIN_CSV_NAME'])
        big_data_train_csv = os.path.join(datadir, hparams['BIG_TRAIN_CSV_NAME'])


    logdir = os.path.join(hparams['LOGGING_PARENT_DIR'], hparams['DATACASE'])
    restore_file = hparams['RESTORE_FILE']

    model_class = get_model_by_name(hparams['MODEL_NAME'])
    model = model_class(atom_features = hparams['ATOM_FEATURES'], 
                        bond_features = hparams['BOND_FEATURES'], 
                        out_features = hparams['OUT_FEATURES'],
                        )

    logdir = os.path.join(logdir, model.__class__.__name__)

    if hparams['SLURM_JOB_ARRAY']:
        logdir = os.path.join(logdir, hparams['SLURM_ARRAY_JOB_ID'])

    if hparams['SELF_LOOPS']:
    # if isinstance(model, Simple_GAT_model) or isinstance(model, Transformer_GAT_model):
        hparams['PADDING_N_EDGE'] = hparams['PADDING_N_EDGE'] + hparams['PADDING_N_NODE'] # NOTE: Because of self_loops
        if hparams['SIZE_CUT_DIRNAME'] is not None:
            hparams['BIG_PADDING_N_EDGE'] = hparams['BIG_PADDING_N_EDGE'] + hparams['BIG_PADDING_N_NODE'] # NOTE: Because of self_loops
        if len(hparams['BOND_FEATURES']) > 0:
            raise ValueError('Can not have both bond features and self_loops.')

    logger = logging.getLogger('main_train_pmap')
    # logger.setLevel(logging.DEBUG)
    logger.setLevel(logging.INFO)
    _datetime = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    logdir = os.path.join(logdir, _datetime)
    os.makedirs(logdir)
    os.mkdir(os.path.join(logdir, 'ckpts'))
    logger_file_handler = logging.FileHandler(os.path.join(logdir, 'run.log'))
    logger_stdout_handler = logging.StreamHandler(sys.stdout)
    logger.addHandler(logger_file_handler)
    logger.addHandler(logger_stdout_handler)

    logger.info('jax_version = {}'.format(jax.__version__))
    logger.info('flax_version = {}'.format(flax.__version__))
    logger.info('from_disk = {}'.format(hparams['PYTABLE_FROM_DISK']))
    logger.info('model_name = {}'.format(hparams['MODEL_NAME']))
    logger.info('loader_output_type = {}'.format(hparams['LOADER_OUTPUT_TYPE']))

    # ---------
    # Datasets:
    # ---------
    import tables
    # h5file = tables.open_file(hparams['BERT_H5FILE'], mode = 'r', title="ProtBERT")
    # bert_table = h5file.root.bert.BERTtable

    h5file = tables.open_file(hparams['H5FILE'], mode = 'r', title=hparams['H5FILE_TITLE'])
    h5_table = h5file.root.amino.table # h5file.root.bert.BERTtable

    element = AminoElementPrecompute(bert_table = h5_table,
                                    padding_n_node = hparams['PADDING_N_NODE'],
                                    padding_n_edge = hparams['PADDING_N_EDGE'],
                                    from_disk = hparams['PYTABLE_FROM_DISK'],
                                    seq_lookup = hparams['CACHE_SEQ_LOOKUP'])

    if hparams['SIZE_CUT_DIRNAME'] is not None:
        big_element = AminoElementPrecompute(bert_table = h5_table,
                                    padding_n_node = hparams['BIG_PADDING_N_NODE'],
                                    padding_n_edge = hparams['BIG_PADDING_N_EDGE'],
                                    from_disk = hparams['PYTABLE_FROM_DISK'],
                                    seq_lookup = hparams['CACHE_SEQ_LOOKUP'])
    if hparams['CACHE_SEQ_LOOKUP']:
        id_mapping_table, seq_embedding_lookup = element.create_seq_embedding_lookups()
    else:
        id_mapping_table, seq_embedding_lookup = None, None

    if not hparams['PYTABLE_FROM_DISK']:
        h5file.close()
        logger.info('Table closed...: {}'.format(hparams['H5FILE']))

    dataset = AminoDatasetPrecompute(data_csv = data_train_csv,
                        mols_csv = os.path.join(hparams['DATA_PARENT_DIR'], hparams['MOLS_CSV']),
                        mol_id_col = hparams['MOL_ID_COL'],
                        mol_col = hparams['MOL_COL'],
                        seq_id_col = hparams['SEQ_ID_COL'],
                        label_col = hparams['LABEL_COL'],
                        weight_col = hparams['WEIGHT_COL'],
                        atom_features = model.atom_features,
                        bond_features = model.bond_features,
                        class_alpha = hparams['CLASS_ALPHA'],
                        line_graph_max_size = hparams['LINE_GRAPH_MAX_SIZE_MULTIPLIER'] * hparams['PADDING_N_NODE'],
                        self_loops = hparams['SELF_LOOPS'],
                        line_graph = hparams['LINE_GRAPH'],
                        auxiliary_label_cols = hparams['AUXILIARY_LABEL_COLS'],
                        auxiliary_weight_cols = hparams['AUXILIARY_WEIGHT_COLS'],
                        mol_global_cols = hparams['MOL_GLOBAL_COLS'],
                        seq_global_cols = hparams['SEQ_GLOBAL_COLS'],
                        )
    logger.info('Train dataset size: {}'.format(len(dataset)))

    if hparams['SIZE_CUT_DIRNAME'] is not None:
        _big_dataset = AminoDatasetPrecompute(data_csv = big_data_train_csv,
                        mols_csv = os.path.join(hparams['DATA_PARENT_DIR'], hparams['MOLS_CSV']),
                        mol_id_col = hparams['MOL_ID_COL'],
                        mol_col = hparams['MOL_COL'],
                        seq_id_col = hparams['SEQ_ID_COL'],
                        label_col = hparams['LABEL_COL'],
                        weight_col = hparams['WEIGHT_COL'],
                        atom_features = model.atom_features,
                        bond_features = model.bond_features,
                        class_alpha = hparams['CLASS_ALPHA'],
                        line_graph_max_size = hparams['LINE_GRAPH_MAX_SIZE_MULTIPLIER'] * hparams['BIG_PADDING_N_NODE'],
                        self_loops = hparams['SELF_LOOPS'],
                        line_graph = hparams['LINE_GRAPH'],
                        auxiliary_label_cols = hparams['AUXILIARY_LABEL_COLS'],
                        auxiliary_weight_cols = hparams['AUXILIARY_WEIGHT_COLS'],
                        mol_global_cols = hparams['MOL_GLOBAL_COLS'],
                        seq_global_cols = hparams['SEQ_GLOBAL_COLS'],
                        )
        big_dataset = dataset + _big_dataset
        logger.info('Big train dataset size: {}'.format(len(big_dataset)))

    if hparams['LOADER_OUTPUT_TYPE'] == 'tf':
        dataset.element_preprocess = element.make_element_preprocess()

        loader = get_tf_loader(dataset,
                               batch_size = hparams['BATCH_SIZE'],
                               use_cache = hparams['CACHE'],
                               shuffle = True,
                               shuffle_buffer_size = hparams['SHUFFLE_BUFFER_SIZE'],
                               drop_last = True,
                               id_mapping_table = id_mapping_table,
                               seq_embedding_lookup = seq_embedding_lookup,
                               n_partitions = hparams['N_PARTITIONS'])

        if hparams['SIZE_CUT_DIRNAME'] is not None:
            big_dataset.element_preprocess = big_element.make_element_preprocess()

            big_loader = get_tf_loader(big_dataset,
                               batch_size = hparams['BIG_BATCH_SIZE'],
                               use_cache = hparams['CACHE'],
                               shuffle = True,
                               shuffle_buffer_size = hparams['SHUFFLE_BUFFER_SIZE'],
                               drop_last = True,
                               id_mapping_table = id_mapping_table,
                               seq_embedding_lookup = seq_embedding_lookup,
                               n_partitions = hparams['N_PARTITIONS'])
    else:
        raise NotImplementedError('jax loader is not supported for pmap.')

    # ----------------
    # Initializations:
    # ----------------
    prng_key = jax.random.PRNGKey(int(time.time()))
    key_params, _key_num_steps, key_num_steps, key_dropout = jax.random.split(prng_key, 4)

    # Initializations:
    start = time.time()
    logger.info('Initializing...')
    init_model = make_init_model(model, 
                                batch_size = hparams['BATCH_SIZE'],
                                n_partitions = hparams['N_PARTITIONS'], 
                                seq_embedding_size = hparams['SEQ_EMBEDDING_SIZE'], 
                                num_node_features = len(hparams['ATOM_FEATURES']), 
                                num_edge_features = len(hparams['BOND_FEATURES']), 
                                self_loops = hparams['SELF_LOOPS'], 
                                line_graph = hparams['LINE_GRAPH'],
                                seq_max_length = hparams['SEQ_MAX_LENGTH'],
                                padding_n_node = hparams['PADDING_N_NODE'], 
                                padding_n_edge = hparams['PADDING_N_EDGE']) # 768)
    params = init_model(rngs = {'params' : key_params, 'dropout' : key_dropout, 'num_steps' : _key_num_steps}) # jax.random.split(key1, jax.device_count()))
    end = time.time()
    logger.info('TIME: init_model: {}'.format(end - start))

    transition_steps = hparams['OPTIMIZATION']['TRANSITION_EPOCHS']*(len(dataset)/hparams['BATCH_SIZE'])
    if hparams['SIZE_CUT_DIRNAME'] is not None:
        transition_steps += hparams['OPTIMIZATION']['TRANSITION_EPOCHS']*(len(_big_dataset)/hparams['BIG_BATCH_SIZE'])

    create_optimizer = make_create_optimizer(model, 
                                            option = hparams['OPTIMIZATION']['OPTION'],
                                            warmup_steps = hparams['OPTIMIZATION']['WARMUP_STEPS'],
                                            transition_steps = transition_steps)
    init_state, scheduler = create_optimizer(params, 
                                            rngs = {'dropout' : key_dropout, 'num_steps' : key_num_steps}, 
                                            learning_rate = hparams['LEARNING_RATE'])

    # Restore params:
    if restore_file is not None:
        logger.info('Restoring parameters from {}'.format(restore_file))
        with open(restore_file, 'rb') as pklfile:
            bytes_output = pickle.load(pklfile)
        state = serialization.from_bytes(init_state, bytes_output)
        logger.info('Parameters restored...')
    else:
        state = init_state    

    reg_loss_func_embed = make_regularization_loss(params_path = ['params/atomic_num_embed/node_embed/embedding',
                                                            'params/chiral_tag_embed/node_embed/embedding',
                                                            'params/hybridization_embed/node_embed/embedding',
                                                            'params/X_proj_non_embeded/kernel',
                                                            'params/X_proj_non_embeded/bias',
                                                            'params/bond_type_embed/edge_embed/embedding',
                                                            'params/E_proj_non_embeded/kernel',
                                                            'params/E_proj_non_embeded/bias',
                                                            'params/stereo_embed/edge_embed/embedding',
                                                            ], alpha = 0.01, option = 'l1')
    reg_loss_func_kernel = make_regularization_loss(params_path = ['kernel',
                                                            ], alpha = 0.01, option = 'l2')
    def reg_loss_func(params):
        return reg_loss_func_embed(params) + reg_loss_func_kernel(params)

    if hparams['N_PARTITIONS'] > 0:
        train_epoch = make_train_pmap_epoch(is_weighted = hparams['WEIGHT_COL'] is not None, 
                                        loss_option = hparams['LOSS_OPTION'], 
                                        init_rngs = state.rngs, 
                                        logger = logger,
                                        aux_loss_option = hparams['AUXILIARY_LOSS_OPTION'],
                                        reg_loss_func = reg_loss_func, 
                                        loader_output_type = hparams['LOADER_OUTPUT_TYPE'], 
                                        num_classes = model.out_features)
        pstate = flax.jax_utils.replicate(state) 
    else:
        raise ValueError('N_PARTITIONS = 0 for pmap training.')

    # Log hyperparams:
    _hparams = {}
    for key in hparams.keys():
        _hparams[key] = _serialize_hparam(hparams[key])
    hparams_logs = _hparams
    with open(os.path.join(logdir, 'hparams_logs.json'), 'w') as jsonfile:
        json.dump(hparams_logs, jsonfile)

    # Training:
    logger.info('Training...')
    train_writer = SummaryWriter(os.path.join(logdir, 'train_pmap'))

    epoch = pstate.epoch[0]
    # Training loop:
    while epoch <= hparams['N_EPOCH']:
        logger.info('Epoch:  {}'.format(epoch))
        if epoch == hparams['BIG_SWITCH_EPOCH']:
            logger.info('Switching to Big loader...')

        start = time.time()
        if epoch >= hparams['BIG_SWITCH_EPOCH']: # TODO: This doesn't make sense if there is no BIG!
            pstate, batch_losses = train_epoch(pstate, big_loader)
        else:
            pstate, batch_losses = train_epoch(pstate, loader)
        end = time.time()
        logger.info('TIME: train_epoch: {}'.format(end - start))
        loss = jnp.mean(jnp.asarray(batch_losses)) # NOTE: If N_i = N_j => mean of means equals total mean.
        lr = scheduler(pstate.step[0])
        logger.info('TIME: loss: {}'.format(loss))

        summary = Summary()
        summary.scalar('epoch_loss', loss)
        summary.scalar('learning_rate', lr)
        state_params = flax.jax_utils.unreplicate(pstate.params)
        summary.scalar('regularization_loss', reg_loss_func(state_params))

        train_writer.write(summary, step = int(epoch))
        train_writer.writer.flush()
    
        # Save current state:
        if epoch%hparams['SAVE_FREQUENCY'] == 0:
            state = flax.jax_utils.unreplicate(pstate)
            bytes_output = serialization.to_bytes(state)
            with open(os.path.join(logdir, 'ckpts', 'state_e' + str(epoch) + '.pkl'), 'wb') as pklfile:
                pickle.dump(bytes_output, pklfile)
            logger.info('State {} saved...'.format('state_e' + str(epoch)))

        # Update epoch number:
        pstate = pstate.replace(epoch = jax.tree_map(lambda x: x + 1, pstate.epoch))
        
        # Get epoch for while condition:
        epoch = pstate.epoch[0]


    train_writer.close()
    logger.info('Finished...')
    return None