import os
# os.environ['TF_FORCE_UNIFIED_MEMORY'] = '1'
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import sys
import copy
import functools
import pickle
import time
import datetime
import json
import numpy as np
import jax
from jax import numpy as jnp
import flax
from flax import serialization
from objax.jaxboard import SummaryWriter, Summary
import tensorflow as tf

from ProtLig_GPCRclassA.metrics import log_confusion_matrix

# 
from ProtLig_GPCRclassA.amino_GNN.concentration.dataset_conc import AminoConcentrationDatasetPrecompute, AminoConcentrationDatasetPairsSamplingPrecompute
from ProtLig_GPCRclassA.amino_GNN.collate import AminoCollatePrecomputeMasked
from ProtLig_GPCRclassA.amino_GNN.concentration.element_conc import AminoConcentrationElementPrecomputeMasked
from ProtLig_GPCRclassA.amino_GNN.concentration.sampler_conc import ConcentrationSampler, ConcentrationSamplerFixedExtremes, ConcentrationSamplerFixedExtremesAdjustWeightsExact, LabelSampler
from ProtLig_GPCRclassA.amino_GNN.loader import AminoLoader, get_tf_loader, get_tf_loader_masked

from ProtLig_GPCRclassA.utils import _serialize_hparam, prefetch_to_device, get_last_datetime, get_last_state

from ProtLig_GPCRclassA.amino_GNN.make_loss_func import make_loss_func

from ProtLig_GPCRclassA.amino_GNN.concentration.make_init_conc import make_init_model
from ProtLig_GPCRclassA.amino_GNN.make_create_optimizer import make_create_optimizer
from ProtLig_GPCRclassA.amino_GNN.make_compute_metrics import log_metrics_from_epoch, make_compute_metrics
from ProtLig_GPCRclassA.amino_GNN.concentration.train.make_train_conc_masked_pmap_epoch import make_train_conc_masked_pmap_epoch

from ProtLig_GPCRclassA.amino_GNN.make_regularization_loss import make_regularization_loss
from ProtLig_GPCRclassA.amino_GNN.concentration.make_monotonicity_loss import make_monotonicity_loss_func

from ProtLig_GPCRclassA.amino_GNN.select_model import get_model_by_name

from ProtLig_GPCRclassA.utils import tf_to_jax, tf_to_jraph_graph_reshape

from ProtLig_GPCRclassA.amino_GNN.select_model import get_model_by_name
from ProtLig_GPCRclassA.amino_GNN.select_dataset_type import get_dataset_type_by_name
from ProtLig_GPCRclassA.amino_GNN.select_element_type import get_element_type_by_name
from ProtLig_GPCRclassA.amino_GNN.concentration.select_concentration_sampler import get_concentration_sampler_by_name

import logging

def get_tf_loader_conc_data_samples_masked(jax_dataset, batch_size, use_cache, shuffle, shuffle_buffer_size, drop_last, 
              id_mapping_table = None, seq_embedding_lookup = None,
              n_partitions = 0):
    loader = jax_dataset.tf_Dataset_by_example()
    # Cache:
    if use_cache:
        loader = loader.cache()
    # Shuffle:
    if shuffle:
        loader = loader.shuffle(buffer_size = shuffle_buffer_size, reshuffle_each_iteration = True)
    # Cache sequence embedding:
    if id_mapping_table is not None and seq_embedding_lookup is not None:
        loader = loader.map(lambda s, g, l: ((id_mapping_table.lookup(s[0]), ) + s[1:], g, l))
        loader = loader.map(lambda s, g, l: (seq_embedding_lookup(s[0]) + s[1:] + (s[0], ), g, l))
    else:
        raise NotImplementedError('I did not check this.')
    # pmap:
    if n_partitions > 0:
        batch_size_pmap = batch_size // n_partitions
        loader = loader.batch(batch_size_pmap, drop_remainder = drop_last, num_parallel_calls = tf.data.AUTOTUNE)
        loader = loader.batch(n_partitions, drop_remainder = drop_last, num_parallel_calls = tf.data.AUTOTUNE)
    else:
        loader = loader.batch(batch_size, drop_remainder = drop_last, num_parallel_calls = tf.data.AUTOTUNE)
    # Prefetch
    loader = loader.prefetch(buffer_size = tf.data.AUTOTUNE)
    return loader


def make_conc_data_samples_masked_epoch(is_weighted, loss_option, init_rngs, logger, 
                        aux_loss_option = None, reg_loss_func = None, monotonicity_loss_func = None,
                        loader_output_type = 'jax', num_classes = 3, mask_token_id = None, 
                        cls_token_id = None, pad_token_id = None, sep_token_id = None, 
                        unk_token_id = None, eos_token_id = None, bos_token_id = None,
                        conc_sampler = None):
    """
    Helper function to create train_epoch function.
    """
    def get_masked_tokens_and_labels(input_ids, attn_mask, seq_mask_rng):
        seq_label = input_ids.copy()
        seq_mask = jax.random.uniform(seq_mask_rng, shape = input_ids.shape) < 0.15
        cls_mask = input_ids != cls_token_id
        pad_mask = attn_mask.astype(bool)
        sep_mask = input_ids != sep_token_id
        unk_mask = input_ids != unk_token_id
        eos_mask = input_ids != eos_token_id
        bos_mask = input_ids != bos_token_id
        input_ids_mask = seq_mask * cls_mask * pad_mask * sep_mask * unk_mask * eos_mask * bos_mask
        return seq_label, input_ids_mask

    # conc_sampler = jax.vmap(conc_sampler_class.call)
    conc_sampler_call = conc_sampler.call

    # Case loader outputs jnp.DeviceArray:
    if loader_output_type == 'jax':
        raise NotImplementedError('jax loader not supported with pmap.')
    # Case loader outputs tf.Tensor:
    elif loader_output_type == 'tf':
        def conc_data_samples_masked_epoch(conc_sampler_rng, seq_mask_rng, loader):
            batches = []
            for i, batch in loader.enumerate():
                batch = jax.tree_map(lambda x: tf_to_jax(x), batch)
                S, G, labels = batch
                G =  tf_to_jraph_graph_reshape(G)
                # Sample concentrations and labels:
                conc_original_C0 = G.globals['_conc']['C0']
                conc_original_C1 = G.globals['_conc']['C1']
                original_labels = labels
                _globals = G.globals.copy()
                sampeled_conc, new_labels = conc_sampler_call(inputs = ((G.globals['_conc'], labels), conc_sampler_rng)) # NOTE: conc_sampler_rng is updated in train_step.
                _globals.update({'_conc' : sampeled_conc})
                G = G._replace(globals = _globals)
                labels = new_labels
                # Get masked tokens:
                input_ids_label, input_ids_mask = get_masked_tokens_and_labels(input_ids = S[2], 
                                                                        attn_mask = S[1], 
                                                                        seq_mask_rng = seq_mask_rng)
                seq_ids = S[3]
                S = S[:-1] # Discard input_ids
                labels.update({'_input_ids_label' : input_ids_label})
                # Combine back to batch:
                # batch = (seq_ids, S, G, labels)
                batch = (seq_ids, G.globals['id'], original_labels['_main_label'], G.globals['_conc'], labels['_main_label'], labels['_main_sample_weight'], conc_original_C0, conc_original_C1)
                conc_sampler_rng = jax.random.split(conc_sampler_rng)[0] # update PRNGKeys as done in train_conc_masked_pmap_step(state, batch, input_ids_mask)
                seq_mask_rng = jax.random.split(seq_mask_rng)[0] # update PRNGKeys as done in train_conc_masked_pmap_step(state, batch, input_ids_mask)
                batches.append(batch)
            return conc_sampler_rng, seq_mask_rng, batches
    return conc_data_samples_masked_epoch


# --------------
# Main function:
# --------------
def main_conc_data_samples_masked(hparams):
    datadir = os.path.join(hparams['DATA_PARENT_DIR'], hparams['DATACASE'])

    if hparams['SEQ_MODEL_NAME'] in ['esm2_t33_650M_UR50D', 'esm2_t48_15B_UR50D']:
        from transformers import EsmTokenizer
        tokenizer = EsmTokenizer.from_pretrained(hparams['SEQ_MODEL_TOKENIZER_PATH'], cache_dir = hparams['HUGGINGFACE_CACHE_DIR'])
    elif hparams['SEQ_MODEL_NAME'] == 'ProtBERT':
        from transformers import BertTokenizer
        tokenizer = BertTokenizer.from_pretrained(hparams['SEQ_MODEL_TOKENIZER_PATH'], do_lower_case=False, cache_dir = hparams['HUGGINGFACE_CACHE_DIR'])

    logdir = os.path.join(hparams['LOGGING_PARENT_DIR'], hparams['DATACASE'])
    restore_file = hparams['RESTORE_FILE']

    model_class = get_model_by_name(hparams['MODEL_NAME'])
    model = model_class(atom_features = hparams['ATOM_FEATURES'],
                        bond_features = hparams['BOND_FEATURES'],
                        out_features = hparams['OUT_FEATURES'],
                        seq_d_model = hparams['SEQ_EMBEDDING_SIZE'],
                        vocab_size = len(tokenizer))

    logdir = os.path.join(logdir, model.__class__.__name__)

    if hparams['SLURM_JOB_ARRAY']:
        logdir = os.path.join(logdir, hparams['SLURM_ARRAY_JOB_ID'])
    logger = logging.getLogger('main_train_conc_masked_pmap')
    # logger.setLevel(logging.DEBUG)
    logger.setLevel(logging.INFO)
    _datetime = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    logdir = os.path.join(logdir, _datetime)
    os.makedirs(logdir)
    os.mkdir(os.path.join(logdir, 'ckpts'))
    logger_file_handler = logging.FileHandler(os.path.join(logdir, 'run.log'))
    logger_stdout_handler = logging.StreamHandler(sys.stdout)
    logger.addHandler(logger_file_handler)
    logger.addHandler(logger_stdout_handler)

    logger.info('jax_version = {}'.format(jax.__version__))
    logger.info('flax_version = {}'.format(flax.__version__))
    logger.info('from_disk = {}'.format(hparams['PYTABLE_FROM_DISK']))
    logger.info('model_name = {}'.format(hparams['MODEL_NAME']))
    logger.info('loader_output_type = {}'.format(hparams['LOADER_OUTPUT_TYPE']))

    # ---------
    # Datasets:
    # ---------
    import tables
    h5file = tables.open_file(hparams['H5FILE'], mode = 'r', title=hparams['H5FILE_TITLE'])
    h5_table = h5file.root.amino.table # h5file.root.bert.BERTtable

    element_class = get_element_type_by_name(name = hparams['ELEMENT_TYPE'])
    dataset_class = get_dataset_type_by_name(name = hparams['DATASET_TYPE'])

    elements = []
    datasets = []
    loaders = []
    for i in range(len(hparams['N_EPOCH'])):
        data_train_csv = os.path.join(datadir, hparams['TRAIN_CSV_NAME'][i])
        mols_csv = os.path.join(hparams['DATA_PARENT_DIR'], hparams['MOLS_CSV'][i])
        seqs_csv = os.path.join(hparams['DATA_PARENT_DIR'], hparams['SEQS_CSV'][i])

        _element = element_class(tokenizer = tokenizer,
                                     seq_max_length = hparams['SEQ_MAX_LENGTH'],
                                     bert_table = h5_table,
                                     padding_n_node = hparams['PADDING_N_NODE'][i], 
                                     padding_n_edge = hparams['PADDING_N_EDGE'][i],
                                     from_disk = hparams['PYTABLE_FROM_DISK'],
                                     seq_lookup = hparams['CACHE_SEQ_LOOKUP'],
                                     seq_col = hparams['SEQ_COL'])
        
        _partial_dataset = dataset_class(data_csv = data_train_csv,
                            mols_csv = mols_csv,
                            seqs_csv = seqs_csv,
                            mol_id_col = hparams['MOL_ID_COL'],
                            mol_col = hparams['MOL_COL'],
                            seq_id_col = hparams['SEQ_ID_COL'],
                            label_col = hparams['LABEL_COL'],
                            weight_col = hparams['WEIGHT_COL'],
                            atom_features = model.atom_features,
                            bond_features = model.bond_features,
                            class_alpha = hparams['CLASS_ALPHA'],
                            line_graph_max_size = hparams['LINE_GRAPH_MAX_SIZE_MULTIPLIER'] * hparams['PADDING_N_NODE'][i],
                            self_loops = hparams['SELF_LOOPS'],
                            line_graph = hparams['LINE_GRAPH'],
                            auxiliary_label_cols = hparams['AUXILIARY_LABEL_COLS'],
                            auxiliary_weight_cols = hparams['AUXILIARY_WEIGHT_COLS'],
                            mol_global_cols = hparams['MOL_GLOBAL_COLS'],
                            seq_global_cols = hparams['SEQ_GLOBAL_COLS'],
                            parameter_col = hparams['CONC_PARAMETER_COL'],
                            n_ec50_copies = hparams['N_EC50_COPIES'],
                            conc_value_col = hparams['CONC_VALUE_COL'],
                            conc_value_screen_col = hparams['CONC_VALUE_SCREEN_COL'],
                            screening_lower_margin = hparams['SCREENING_LOWER_MARGIN'],
                            screening_upper_margin = hparams['SCREENING_UPPER_MARGIN'],
                            ec50_lower_margin = hparams['EC50_LOWER_MARGIN'],
                            ec50_upper_margin = hparams['EC50_UPPER_MARGIN'],
                            ec50_greater_than_lower_margin = hparams['EC50_GREATER_THAN_LOWER_MARGIN'],
                            sampling_region_lower_bound = hparams['SAMPLING_REGION_LOWER_BOUND'],
                            sampling_region_upper_bound = hparams['SAMPLING_REGION_UPPER_BOUND'],
                            include_conc_parameter_list = hparams['INCLUDE_CONC_PARAMETER_LIST'],
                            )
        if i == 0:
            _dataset = copy.deepcopy(_partial_dataset)
            if hparams['CACHE_SEQ_LOOKUP']:
                id_mapping_table, seq_embedding_lookup = _element.create_seq_embedding_lookups()
            else:
                id_mapping_table, seq_embedding_lookup = None, None
        else:
            _dataset = _dataset + _partial_dataset
            logger.info('Adding {} records for train dataset {}'.format(len(_partial_dataset), i))

        if hparams['WEIGHT_COL'] == '_adjusted_class_weight':
            _dataset.adjusted_class_weight_col(ec50_lower_margin = hparams['EC50_LOWER_MARGIN'], 
                                                ec50_upper_margin = hparams['EC50_UPPER_MARGIN'], 
                                                ec50_lower_extreme = hparams['EC50_LOWER_EXTREME'], 
                                                ec50_upper_extreme = hparams['EC50_UPPER_EXTREME'])
        logger.info('Train dataset {} size: {}'.format(i, len(_dataset)))

        if hparams['LOADER_OUTPUT_TYPE'] == 'tf':
            _dataset.element_preprocess = _element.make_element_preprocess()
            _loader = get_tf_loader_conc_data_samples_masked(_dataset,
                                   batch_size = hparams['BATCH_SIZE'][i],
                                   use_cache = hparams['CACHE'],
                                   shuffle = True,
                                   shuffle_buffer_size = hparams['SHUFFLE_BUFFER_SIZE'],
                                   drop_last = False,
                                   id_mapping_table = id_mapping_table,
                                   seq_embedding_lookup = seq_embedding_lookup,
                                   n_partitions = hparams['N_PARTITIONS'])
        else:
            raise NotImplementedError('jax loader is not supported for pmap.')

        elements.append(_element)
        datasets.append(_dataset)
        loaders.append(_loader)

    concetration_sampler_class = get_concentration_sampler_by_name(name = hparams['CONCENTRATION_SAMPLER_TYPE'])

    if hparams['CONCENTRATION_SAMPLER_TYPE'] == 'ConcentrationSamplerFixedExtremesAdjustWeightsExact':
        augmented_data = datasets[-1].data
    else:
        augmented_data = None

    conc_sampler = concetration_sampler_class(ec50_seed = None,         
                    screening_lower_margin = hparams['SCREENING_LOWER_MARGIN'],
                    screening_upper_margin = hparams['SCREENING_UPPER_MARGIN'],
                    screening_lower_extreme = hparams['SCREENING_LOWER_EXTREME'],
                    screening_upper_extreme = hparams['SCREENING_UPPER_EXTREME'],
                    ec50_nd_lower_extreme = hparams['EC50_ND_LOWER_EXTREME'],
                    ec50_nd_upper_extreme = hparams['EC50_ND_UPPER_EXTREME'],
                    ec50_std_multiplier = hparams['EC50_STD_MULTIPLIER'],
                    ec50_lower_margin = hparams['EC50_LOWER_MARGIN'],
                    ec50_upper_margin = hparams['EC50_UPPER_MARGIN'],
                    ec50_lower_extreme = hparams['EC50_LOWER_EXTREME'],
                    ec50_upper_extreme = hparams['EC50_UPPER_EXTREME'],
                    ec50_greater_than_lower_margin = hparams['EC50_GREATER_THAN_LOWER_MARGIN'],
                    ec50_greater_than_lower_extreme = hparams['EC50_GREATER_THAN_LOWER_EXTREME'],
                    mean_ec50 = datasets[0].mean_ec50, # NOTE: Taking statistics from only the first dataset.
                    std_ec50 = datasets[0].std_ec50, # NOTE: Taking statistics from only the first dataset.
                    screening_lower_perturbation_prob = hparams['SCREENING_LOWER_PERTURBATION_PROB'],
                    screening_lower_perturbation_shift = hparams['SCREENING_LOWER_PERTURBATION_SHIFT'],
                    conc_parameter_id_map = datasets[0].conc_parameter_id_map,
                    augmented_data = augmented_data,
                    unknown_case_sample_weight_scale = hparams['UNKNOWN_CASE_SAMPLE_WEIGHT_SCALE'],
                    sampling_region_lower_bound = hparams['SAMPLING_REGION_LOWER_BOUND'],
                    sampling_region_upper_bound = hparams['SAMPLING_REGION_UPPER_BOUND'])

    if not hparams['PYTABLE_FROM_DISK']:
        h5file.close()
        logger.info('Table closed...: {}'.format(hparams['H5FILE']))

    # ----------------
    # Initializations:
    # ----------------
    prng_key = jax.random.PRNGKey(int(time.time()))
    key_params, _key_num_steps, key_num_steps, key_dropout, seq_mask_key, conc_sampler_key = jax.random.split(prng_key, 6)

    conc_data_sample_masked = make_conc_data_samples_masked_epoch(is_weighted = hparams['WEIGHT_COL'] is not None, 
                                        loss_option = hparams['LOSS_OPTION'], 
                                        init_rngs = None,
                                        logger = logger,
                                        aux_loss_option = hparams['AUXILIARY_LOSS_OPTION'],
                                        reg_loss_func = None,
                                        monotonicity_loss_func = None,
                                        loader_output_type = hparams['LOADER_OUTPUT_TYPE'], 
                                        num_classes = model.out_features,
                                        mask_token_id = tokenizer.mask_token_id,
                                        cls_token_id = tokenizer.cls_token_id,
                                        pad_token_id = tokenizer.pad_token_id,
                                        sep_token_id = tokenizer.sep_token_id,
                                        unk_token_id = tokenizer.unk_token_id,
                                        eos_token_id = tokenizer.eos_token_id,
                                        bos_token_id = tokenizer.bos_token_id,
                                        conc_sampler = conc_sampler,
                                        )
    
    
    
    epoch = 1
    # Training loop:
    
    import numpy
    keys_tensor, vals_tensor = id_mapping_table.export()
    keys_tensor = numpy.array(keys_tensor)
    vals_tensor = numpy.array(vals_tensor)
    id_mapping_dict = {}
    for i in range(len(keys_tensor)):
        key = str(keys_tensor[i].decode('utf-8'))
        val = int(vals_tensor[i])
        id_mapping_dict[key] = val

    data_samples = {}
    while epoch <= hparams['N_EPOCH'][-1]:
        for i in range(len(hparams['N_EPOCH'])): # NOTE: This assumes that N_EPOCH are sorted.
            if epoch < hparams['N_EPOCH'][i]:
                loader_switch_i = i
                break
        conc_sampler_rng, seq_mask_rng, batches = conc_data_sample_masked(conc_sampler_rng = conc_sampler_key, 
                                                                        seq_mask_rng = seq_mask_key, 
                                                                        loader = loaders[loader_switch_i])
        data = jax.tree_map(lambda *x: jnp.concatenate(x), batches)
        data_samples[epoch] = data
        epoch += 1
        
    data_samples = jax.tree_map(lambda x: numpy.array(x), data_samples)
    import pickle
    import json
    with open('/data_mount/d/Experiments/concentration_insights/data_samples_v4_1.pkl', 'wb') as pklfile:
        pickle.dump(data_samples, pklfile)
    with open('/data_mount/d/Experiments/concentration_insights/id_mapping_dict_v4_1.json', 'w') as jsonfile:
        json.dump(id_mapping_dict, jsonfile)