import os
# os.environ['TF_FORCE_UNIFIED_MEMORY'] = '1'
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import sys
import copy
import functools
import pickle
import time
import datetime
import json
import numpy as np
import jax
from jax import numpy as jnp
import flax
from flax import serialization
from objax.jaxboard import SummaryWriter, Summary
import tensorflow as tf

from ProtLig_GPCRclassA.amino_GNN.loader import get_tf_loader_masked

from ProtLig_GPCRclassA.utils import _serialize_hparam, prefetch_to_device, get_last_datetime, get_last_state

from ProtLig_GPCRclassA.amino_GNN.concentration.make_init_conc import make_init_model
from ProtLig_GPCRclassA.amino_GNN.make_create_optimizer import make_create_optimizer
from ProtLig_GPCRclassA.amino_GNN.make_compute_metrics import log_metrics_from_epoch, make_compute_metrics
from ProtLig_GPCRclassA.amino_GNN.concentration.train.make_train_conc_masked_pmap_epoch import make_train_conc_masked_pmap_epoch

from ProtLig_GPCRclassA.amino_GNN.make_regularization_loss import make_regularization_loss
from ProtLig_GPCRclassA.amino_GNN.concentration.make_monotonicity_loss import make_monotonicity_loss_func

from ProtLig_GPCRclassA.amino_GNN.select_model import get_model_by_name
from ProtLig_GPCRclassA.amino_GNN.select_element_type import get_element_type_by_name
from ProtLig_GPCRclassA.amino_GNN.select_dataset_type import get_dataset_type_by_name
from ProtLig_GPCRclassA.amino_GNN.concentration.select_concentration_sampler import get_concentration_sampler_by_name


import logging

# jax.config.update('jax_platform_name', 'cpu')
# print(jnp.ones(3).device_buffer.device())


def main_train_conc_masked_pmap(hparams):
    datadir = os.path.join(hparams['DATA_PARENT_DIR'], hparams['DATACASE'])

    if hparams['SEQ_MODEL_NAME'] in ['esm2_t33_650M_UR50D', 'esm2_t48_15B_UR50D']:
        from transformers import EsmTokenizer
        tokenizer = EsmTokenizer.from_pretrained(hparams['SEQ_MODEL_TOKENIZER_PATH'], cache_dir = hparams['HUGGINGFACE_CACHE_DIR'])
    elif hparams['SEQ_MODEL_NAME'] == 'ProtBERT':
        from transformers import BertTokenizer
        tokenizer = BertTokenizer.from_pretrained(hparams['SEQ_MODEL_TOKENIZER_PATH'], do_lower_case=False, cache_dir = hparams['HUGGINGFACE_CACHE_DIR'])

    logdir = os.path.join(hparams['LOGGING_PARENT_DIR'], hparams['DATACASE'])
    restore_file = hparams['RESTORE_FILE']

    model_class = get_model_by_name(hparams['MODEL_NAME'])
    model = model_class(atom_features = hparams['ATOM_FEATURES'],
                        bond_features = hparams['BOND_FEATURES'],
                        out_features = hparams['OUT_FEATURES'],
                        seq_d_model = hparams['SEQ_EMBEDDING_SIZE'],
                        vocab_size = len(tokenizer))

    logdir = os.path.join(logdir, model.__class__.__name__)

    if hparams['SLURM_JOB_ARRAY']:
        logdir = os.path.join(logdir, hparams['SLURM_ARRAY_JOB_ID'])
    logger = logging.getLogger('main_train_conc_masked_pmap')
    # logger.setLevel(logging.DEBUG)
    logger.setLevel(logging.INFO)
    _datetime = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    logdir = os.path.join(logdir, _datetime)
    os.makedirs(logdir)
    os.mkdir(os.path.join(logdir, 'ckpts'))
    logger_file_handler = logging.FileHandler(os.path.join(logdir, 'run.log'))
    logger_stdout_handler = logging.StreamHandler(sys.stdout)
    logger.addHandler(logger_file_handler)
    logger.addHandler(logger_stdout_handler)

    logger.info('jax_version = {}'.format(jax.__version__))
    logger.info('flax_version = {}'.format(flax.__version__))
    logger.info('from_disk = {}'.format(hparams['PYTABLE_FROM_DISK']))
    logger.info('model_name = {}'.format(hparams['MODEL_NAME']))
    logger.info('loader_output_type = {}'.format(hparams['LOADER_OUTPUT_TYPE']))

    # ---------
    # Datasets:
    # ---------
    import tables
    h5file = tables.open_file(hparams['H5FILE'], mode = 'r', title=hparams['H5FILE_TITLE'])
    h5_table = h5file.root.amino.table # h5file.root.bert.BERTtable

    element_class = get_element_type_by_name(name = hparams['ELEMENT_TYPE'])
    dataset_class = get_dataset_type_by_name(name = hparams['DATASET_TYPE'])

    elements = []
    datasets = []
    loaders = []
    for i in range(len(hparams['N_EPOCH'])):
        data_train_csv = os.path.join(datadir, hparams['TRAIN_CSV_NAME'][i])
        mols_csv = os.path.join(hparams['DATA_PARENT_DIR'], hparams['MOLS_CSV'][i])
        seqs_csv = os.path.join(hparams['DATA_PARENT_DIR'], hparams['SEQS_CSV'][i])

        _element = element_class(tokenizer = tokenizer,
                                     seq_max_length = hparams['SEQ_MAX_LENGTH'],
                                     bert_table = h5_table,
                                     padding_n_node = hparams['PADDING_N_NODE'][i], 
                                     padding_n_edge = hparams['PADDING_N_EDGE'][i],
                                     from_disk = hparams['PYTABLE_FROM_DISK'],
                                     seq_lookup = hparams['CACHE_SEQ_LOOKUP'],
                                     seq_col = hparams['SEQ_COL'])
        
        _partial_dataset = dataset_class(data_csv = data_train_csv,
                            mols_csv = mols_csv,
                            seqs_csv = seqs_csv,
                            mol_id_col = hparams['MOL_ID_COL'],
                            mol_col = hparams['MOL_COL'],
                            seq_id_col = hparams['SEQ_ID_COL'],
                            label_col = hparams['LABEL_COL'],
                            weight_col = hparams['WEIGHT_COL'],
                            atom_features = model.atom_features,
                            bond_features = model.bond_features,
                            class_alpha = hparams['CLASS_ALPHA'],
                            line_graph_max_size = hparams['LINE_GRAPH_MAX_SIZE_MULTIPLIER'] * hparams['PADDING_N_NODE'][i],
                            self_loops = hparams['SELF_LOOPS'],
                            line_graph = hparams['LINE_GRAPH'],
                            auxiliary_label_cols = hparams['AUXILIARY_LABEL_COLS'],
                            auxiliary_weight_cols = hparams['AUXILIARY_WEIGHT_COLS'],
                            mol_global_cols = hparams['MOL_GLOBAL_COLS'],
                            seq_global_cols = hparams['SEQ_GLOBAL_COLS'],
                            parameter_col = hparams['CONC_PARAMETER_COL'],
                            n_ec50_copies = hparams['N_EC50_COPIES'],
                            conc_value_col = hparams['CONC_VALUE_COL'],
                            conc_value_screen_col = hparams['CONC_VALUE_SCREEN_COL'],
                            screening_lower_margin = hparams['SCREENING_LOWER_MARGIN'],
                            screening_upper_margin = hparams['SCREENING_UPPER_MARGIN'],
                            ec50_lower_margin = hparams['EC50_LOWER_MARGIN'],
                            ec50_upper_margin = hparams['EC50_UPPER_MARGIN'],
                            ec50_greater_than_lower_margin = hparams['EC50_GREATER_THAN_LOWER_MARGIN'],
                            sampling_region_lower_bound = hparams['SAMPLING_REGION_LOWER_BOUND'],
                            sampling_region_upper_bound = hparams['SAMPLING_REGION_UPPER_BOUND'],
                            include_conc_parameter_list = hparams['INCLUDE_CONC_PARAMETER_LIST'],
                            )
        if i == 0:
            _dataset = copy.deepcopy(_partial_dataset)
            if hparams['CACHE_SEQ_LOOKUP']:
                id_mapping_table, seq_embedding_lookup = _element.create_seq_embedding_lookups()
            else:
                id_mapping_table, seq_embedding_lookup = None, None
        else:
            _dataset = _dataset + _partial_dataset
            logger.info('Adding {} records for train dataset {}'.format(len(_partial_dataset), i))

        if hparams['WEIGHT_COL'] == '_adjusted_class_weight':
            _dataset.adjusted_class_weight_col(ec50_lower_margin = hparams['EC50_LOWER_MARGIN'], 
                                                ec50_upper_margin = hparams['EC50_UPPER_MARGIN'], 
                                                ec50_lower_extreme = hparams['EC50_LOWER_EXTREME'], 
                                                ec50_upper_extreme = hparams['EC50_UPPER_EXTREME'])
        logger.info('Train dataset {} size: {}'.format(i, len(_dataset)))

        if hparams['LOADER_OUTPUT_TYPE'] == 'tf':
            _dataset.element_preprocess = _element.make_element_preprocess()
            _loader = get_tf_loader_masked(_dataset,
                                   batch_size = hparams['BATCH_SIZE'][i],
                                   use_cache = hparams['CACHE'],
                                   shuffle = True,
                                   shuffle_buffer_size = hparams['SHUFFLE_BUFFER_SIZE'],
                                   drop_last = True,
                                   id_mapping_table = id_mapping_table,
                                   seq_embedding_lookup = seq_embedding_lookup,
                                   n_partitions = hparams['N_PARTITIONS'])
        else:
            raise NotImplementedError('jax loader is not supported for pmap.')

        elements.append(_element)
        datasets.append(_dataset)
        loaders.append(_loader)

    concetration_sampler_class = get_concentration_sampler_by_name(name = hparams['CONCENTRATION_SAMPLER_TYPE'])

    if hparams['CONCENTRATION_SAMPLER_TYPE'] == 'ConcentrationSamplerFixedExtremesAdjustWeightsExact':
        augmented_data = datasets[-1].data
    else:
        augmented_data = None

    conc_sampler = concetration_sampler_class(ec50_seed = None,         
                    screening_lower_margin = hparams['SCREENING_LOWER_MARGIN'],
                    screening_upper_margin = hparams['SCREENING_UPPER_MARGIN'],
                    screening_lower_extreme = hparams['SCREENING_LOWER_EXTREME'],
                    screening_upper_extreme = hparams['SCREENING_UPPER_EXTREME'],
                    ec50_nd_lower_extreme = hparams['EC50_ND_LOWER_EXTREME'],
                    ec50_nd_upper_extreme = hparams['EC50_ND_UPPER_EXTREME'],
                    ec50_std_multiplier = hparams['EC50_STD_MULTIPLIER'],
                    ec50_lower_margin = hparams['EC50_LOWER_MARGIN'],
                    ec50_upper_margin = hparams['EC50_UPPER_MARGIN'],
                    ec50_lower_extreme = hparams['EC50_LOWER_EXTREME'],
                    ec50_upper_extreme = hparams['EC50_UPPER_EXTREME'],
                    ec50_greater_than_lower_margin = hparams['EC50_GREATER_THAN_LOWER_MARGIN'],
                    ec50_greater_than_lower_extreme = hparams['EC50_GREATER_THAN_LOWER_EXTREME'],
                    mean_ec50 = datasets[0].mean_ec50, # NOTE: Taking statistics from only the first dataset.
                    std_ec50 = datasets[0].std_ec50, # NOTE: Taking statistics from only the first dataset.
                    screening_lower_perturbation_prob = hparams['SCREENING_LOWER_PERTURBATION_PROB'],
                    screening_lower_perturbation_shift = hparams['SCREENING_LOWER_PERTURBATION_SHIFT'],
                    conc_parameter_id_map = datasets[0].conc_parameter_id_map,
                    augmented_data = augmented_data,
                    unknown_case_sample_weight_scale = hparams['UNKNOWN_CASE_SAMPLE_WEIGHT_SCALE'],
                    sampling_region_lower_bound = hparams['SAMPLING_REGION_LOWER_BOUND'],
                    sampling_region_upper_bound = hparams['SAMPLING_REGION_UPPER_BOUND'])

    if not hparams['PYTABLE_FROM_DISK']:
        h5file.close()
        logger.info('Table closed...: {}'.format(hparams['H5FILE']))

    # ----------------
    # Initializations:
    # ----------------
    prng_key = jax.random.PRNGKey(int(time.time()))
    key_params, _key_num_steps, key_num_steps, key_dropout, seq_mask_key, conc_sampler_key = jax.random.split(prng_key, 6)

    # Initializations:
    start = time.time()
    logger.info('Initializing...')
    init_model = make_init_model(model, 
                                batch_size = hparams['BATCH_SIZE'][0], 
                                n_partitions = hparams['N_PARTITIONS'],
                                seq_embedding_size = hparams['SEQ_EMBEDDING_SIZE'], 
                                num_node_features = len(hparams['ATOM_FEATURES']), 
                                num_edge_features = len(hparams['BOND_FEATURES']), 
                                self_loops = hparams['SELF_LOOPS'], 
                                line_graph = hparams['LINE_GRAPH'],
                                seq_max_length = hparams['SEQ_MAX_LENGTH'],
                                padding_n_node = hparams['PADDING_N_NODE'][0], 
                                padding_n_edge = hparams['PADDING_N_EDGE'][0]) # 768)
    params = init_model(rngs = {'params' : key_params, 'dropout' : key_dropout, 'num_steps' : _key_num_steps}) # jax.random.split(key1, jax.device_count()))
    end = time.time()
    logger.info('TIME: init_model: {}'.format(end - start))

    transition_steps = hparams['OPTIMIZATION']['TRANSITION_EPOCHS']*(len(datasets[-1])/hparams['BATCH_SIZE'][-1])
    # if hparams['SIZE_CUT_DIRNAME'] is not None:
    #     transition_steps += hparams['OPTIMIZATION']['TRANSITION_EPOCHS']*(len(_big_dataset)/hparams['BIG_BATCH_SIZE'])

    create_optimizer = make_create_optimizer(model, 
                                            option = hparams['OPTIMIZATION']['OPTION'],
                                            warmup_steps = hparams['OPTIMIZATION']['WARMUP_STEPS'],
                                            transition_steps = transition_steps)
    init_state, scheduler = create_optimizer(params, 
                                            rngs = {'dropout' : key_dropout, 'num_steps' : key_num_steps, 'seq_mask' : seq_mask_key, 'conc_sampler' : conc_sampler_key},
                                            learning_rate = hparams['LEARNING_RATE'])

    # Restore params:
    if restore_file is not None:
        logger.info('Restoring parameters from {}'.format(restore_file))
        with open(restore_file, 'rb') as pklfile:
            bytes_output = pickle.load(pklfile)
        state = serialization.from_bytes(init_state, bytes_output)
        logger.info('Parameters restored...')
    else:
        state = init_state    

    reg_loss_func_embed = make_regularization_loss(params_path = ['params/atomic_num_embed/node_embed/embedding',
                                                            'params/chiral_tag_embed/node_embed/embedding',
                                                            'params/hybridization_embed/node_embed/embedding',
                                                            'params/X_proj_non_embeded/kernel',
                                                            'params/X_proj_non_embeded/bias',
                                                            'params/bond_type_embed/edge_embed/embedding',
                                                            'params/E_proj_non_embeded/kernel',
                                                            'params/E_proj_non_embeded/bias',
                                                            'params/stereo_embed/edge_embed/embedding',
                                                            ], alpha = 0.01, option = 'l1')
    reg_loss_func_kernel = make_regularization_loss(params_path = ['kernel',
                                                            ], alpha = 0.01, option = 'l2')
    def reg_loss_func(params):
        return reg_loss_func_embed(params) + reg_loss_func_kernel(params)
    monotonicity_loss_func = make_monotonicity_loss_func(is_weighted = hparams['WEIGHT_COL'] is not None,
                                                        apply_fn = state.apply_fn,
                                                        slope_pos = hparams['MONOTONICITY_SLOPE_POS'],
                                                        slope_neg = hparams['MONOTONICITY_SLOPE_NEG'],
                                                        eps_pos = hparams['MONOTONICITY_EPS_POS'], 
                                                        eps_neg = hparams['MONOTONICITY_EPS_NEG']) # NOTE: monotonicity loss.

    if hparams['N_PARTITIONS'] > 0:
        train_epoch = make_train_conc_masked_pmap_epoch(is_weighted = hparams['WEIGHT_COL'] is not None, 
                                        loss_option = hparams['LOSS_OPTION'], 
                                        init_rngs = state.rngs, 
                                        logger = logger,
                                        aux_loss_option = hparams['AUXILIARY_LOSS_OPTION'],
                                        reg_loss_func = reg_loss_func, 
                                        monotonicity_loss_func = monotonicity_loss_func,
                                        loader_output_type = hparams['LOADER_OUTPUT_TYPE'], 
                                        num_classes = model.out_features,
                                        mask_token_id = tokenizer.mask_token_id,
                                        cls_token_id = tokenizer.cls_token_id,
                                        pad_token_id = tokenizer.pad_token_id,
                                        sep_token_id = tokenizer.sep_token_id,
                                        unk_token_id = tokenizer.unk_token_id,
                                        eos_token_id = tokenizer.eos_token_id,
                                        bos_token_id = tokenizer.bos_token_id,
                                        conc_sampler = conc_sampler,
                                        )
        pstate = flax.jax_utils.replicate(state)
    else:
        raise ValueError('N_PARTITIONS = 0 for pmap training.')

    # Log hyperparams:
    _hparams = {}
    for key in hparams.keys():
        _hparams[key] = _serialize_hparam(hparams[key])
    hparams_logs = _hparams
    with open(os.path.join(logdir, 'hparams_logs.json'), 'w') as jsonfile:
        json.dump(hparams_logs, jsonfile)

    # Training:
    logger.info('Training...')
    train_writer = SummaryWriter(os.path.join(logdir, 'train'))

    epoch = pstate.epoch[0]
    # Training loop:
    while epoch <= hparams['N_EPOCH'][-1]:
        for i in range(len(hparams['N_EPOCH'])): # NOTE: This assumes that N_EPOCH are sorted.
            if epoch < hparams['N_EPOCH'][i]:
                loader_switch_i = i
                break
        logger.info('Epoch:  {}'.format(epoch))

        start = time.time()
        pstate, batch_losses = train_epoch(pstate, loaders[loader_switch_i])
        end = time.time()

        logger.info('TIME: train_epoch: {}'.format(end - start))
        loss = jnp.mean(jnp.asarray(batch_losses)) # NOTE: If N_i = N_j => mean of means equals total mean.
        lr = scheduler(pstate.step[0])
        logger.info('TIME: loss: {}'.format(loss))

        summary = Summary()
        summary.scalar('epoch_loss', loss)
        summary.scalar('learning_rate', lr)
        state_params = flax.jax_utils.unreplicate(pstate.params)
        summary.scalar('regularization_loss', reg_loss_func(state_params))

        train_writer.write(summary, step = int(epoch))
        train_writer.writer.flush()
    
        # Save current state:
        if epoch%hparams['SAVE_FREQUENCY'] == 0:
            state = flax.jax_utils.unreplicate(pstate)
            bytes_output = serialization.to_bytes(state)
            with open(os.path.join(logdir, 'ckpts', 'state_e' + str(epoch) + '.pkl'), 'wb') as pklfile:
                pickle.dump(bytes_output, pklfile)
            logger.info('State {} saved...'.format('state_e' + str(epoch)))

        # Update epoch number:
        pstate = pstate.replace(epoch = jax.tree_map(lambda x: x + 1, pstate.epoch))

        # Swith loader at the end of each N_EPOCH section:
        if epoch == hparams['N_EPOCH'][loader_switch_i]: # NOTE: loader_switch_i is chosen such that it is the index of the lowest UNPROCESSED section.
            loader_switch_i += 1
            logger.info('Switching to loader {} ...'.format(loader_switch_i))

        # Get epoch for while condition:
        epoch = pstate.epoch[0]


    train_writer.close()
    logger.info('Finished...')
    return None