import os
# os.environ['TF_FORCE_UNIFIED_MEMORY'] = '1'
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import sys
import copy
import functools
import pickle
import time
import datetime
import json
import numpy as np
import jax
from jax import numpy as jnp
import flax
from flax import serialization
from objax.jaxboard import SummaryWriter, Summary
import tensorflow as tf

from ProtLig_GPCRclassA.metrics import log_confusion_matrix

# 
from ProtLig_GPCRclassA.amino_GNN.concentration.dataset_conc import AminoConcentrationDatasetPrecompute
from ProtLig_GPCRclassA.amino_GNN.collate import AminoCollatePrecomputeMasked
from ProtLig_GPCRclassA.amino_GNN.concentration.element_conc import AminoConcentrationElementPrecomputeMasked
from ProtLig_GPCRclassA.amino_GNN.concentration.sampler_conc import ConcentrationSampler
from ProtLig_GPCRclassA.amino_GNN.loader import AminoLoader, get_tf_loader, get_tf_loader_masked

from ProtLig_GPCRclassA.utils import _serialize_hparam, prefetch_to_device, get_last_datetime, get_last_state

from ProtLig_GPCRclassA.amino_GNN.make_loss_func import make_loss_func

from ProtLig_GPCRclassA.amino_GNN.concentration.make_init_conc import make_init_model
from ProtLig_GPCRclassA.amino_GNN.make_create_optimizer import make_create_optimizer
from ProtLig_GPCRclassA.amino_GNN.concentration.make_compute_metrics import log_metrics_concentration_from_epoch
from ProtLig_GPCRclassA.amino_GNN.concentration.eval.make_valid_conc_masked_params_epoch import make_valid_conc_masked_params_epoch

from ProtLig_GPCRclassA.amino_GNN.make_regularization_loss import make_regularization_loss
from ProtLig_GPCRclassA.amino_GNN.concentration.make_monotonicity_loss import make_monotonicity_loss_func

from ProtLig_GPCRclassA.amino_GNN.select_model import get_model_by_name

import logging

# jax.config.update('jax_platform_name', 'cpu')
# print(jnp.ones(3).device_buffer.device())


def main_eval_conc_masked_params_ckpts(hparams):
    datadir = os.path.join(hparams['DATA_PARENT_DIR'], hparams['DATACASE'])
    writer_name, _ = os.path.splitext(hparams['VALID_CSV_NAME'])

    if hparams['SEQ_MODEL_NAME'] in ['esm2_t33_650M_UR50D', 'esm2_t48_15B_UR50D']:
        from transformers import EsmTokenizer
        tokenizer = EsmTokenizer.from_pretrained(hparams['SEQ_MODEL_TOKENIZER_PATH'], cache_dir = hparams['HUGGINGFACE_CACHE_DIR'])
        # tokenizer = EsmTokenizer.from_pretrained("facebook/" + hparams['SEQ_MODEL_NAME'], cache_dir = hparams['HUGGINGFACE_CACHE_DIR'])
    elif hparams['SEQ_MODEL_NAME'] == 'ProtBERT':
        from transformers import BertTokenizer
        tokenizer = BertTokenizer.from_pretrained(hparams['SEQ_MODEL_TOKENIZER_PATH'], do_lower_case=False, cache_dir = hparams['HUGGINGFACE_CACHE_DIR'])
        # tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False, cache_dir = hparams['HUGGINGFACE_CACHE_DIR'])

    logdir = hparams['RESTORE_MODEL_DIR']
    restore_ckpts_dir = os.path.join(hparams['RESTORE_MODEL_DIR'], 'ckpts')

    model_class = get_model_by_name(hparams['MODEL_NAME'])
    model = model_class(atom_features = hparams['ATOM_FEATURES'],
                        bond_features = hparams['BOND_FEATURES'],
                        out_features = hparams['OUT_FEATURES'],
                        seq_d_model = hparams['SEQ_EMBEDDING_SIZE'],
                        vocab_size = len(tokenizer))


    logger = logging.getLogger('main_eval_conc_masked_ckpts')
    # logger.setLevel(logging.DEBUG)
    logger.setLevel(logging.INFO)
    _datetime = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    logdir = os.path.join(logdir, 'eval_ckpts', writer_name + '_' + _datetime)
    os.makedirs(logdir)
    # os.mkdir(os.path.join(logdir, 'ckpts'))
    logger_file_handler = logging.FileHandler(os.path.join(logdir, 'run_eval_conc_masked_ckpts.log'))
    logger_stdout_handler = logging.StreamHandler(sys.stdout)
    logger.addHandler(logger_file_handler)
    logger.addHandler(logger_stdout_handler)

    logger.info('jax_version = {}'.format(jax.__version__))
    logger.info('flax_version = {}'.format(flax.__version__))
    logger.info('from_disk = {}'.format(hparams['PYTABLE_FROM_DISK']))
    logger.info('model_name = {}'.format(hparams['MODEL_NAME']))
    logger.info('loader_output_type = {}'.format(hparams['LOADER_OUTPUT_TYPE']))

    # ---------
    # Datasets:
    # ---------
    import tables
    h5file = tables.open_file(hparams['H5FILE'], mode = 'r', title=hparams['H5FILE_TITLE'])
    h5_table = h5file.root.amino.table # h5file.root.bert.BERTtable

    data_valid_csv = os.path.join(datadir, hparams['VALID_CSV_NAME'])
    mols_csv = os.path.join(hparams['DATA_PARENT_DIR'], hparams['MOLS_CSV'])
    seqs_csv = os.path.join(hparams['DATA_PARENT_DIR'], hparams['SEQS_CSV'])


    element = AminoConcentrationElementPrecomputeMasked(tokenizer = tokenizer,
                                 seq_max_length = hparams['SEQ_MAX_LENGTH'],
                                 bert_table = h5_table,
                                 padding_n_node = hparams['PADDING_N_NODE'], 
                                 padding_n_edge = hparams['PADDING_N_EDGE'],
                                 from_disk = hparams['PYTABLE_FROM_DISK'],
                                 seq_lookup = hparams['CACHE_SEQ_LOOKUP'],
                                 seq_col = hparams['SEQ_COL'])

    valid_dataset = AminoConcentrationDatasetPrecompute(data_csv = data_valid_csv,
                        mols_csv = mols_csv,
                        seqs_csv = seqs_csv,
                        mol_id_col = hparams['MOL_ID_COL'],
                        mol_col = hparams['MOL_COL'],
                        seq_id_col = hparams['SEQ_ID_COL'],
                        label_col = hparams['LABEL_COL'],
                        weight_col = hparams['VALID_WEIGHT_COL'],
                        atom_features = model.atom_features,
                        bond_features = model.bond_features,
                        line_graph_max_size = hparams['LINE_GRAPH_MAX_SIZE_MULTIPLIER'] * hparams['PADDING_N_NODE'],
                        self_loops = hparams['SELF_LOOPS'],
                        line_graph = hparams['LINE_GRAPH'],
                        auxiliary_label_cols = hparams['AUXILIARY_LABEL_COLS'],
                        auxiliary_weight_cols = hparams['AUXILIARY_WEIGHT_COLS'],
                        mol_global_cols = hparams['MOL_GLOBAL_COLS'],
                        seq_global_cols = hparams['SEQ_GLOBAL_COLS'],
                        parameter_col = hparams['CONC_PARAMETER_COL'],
                        n_ec50_copies = 1,
                        conc_value_col = hparams['CONC_VALUE_COL'],
                        conc_value_screen_col = hparams['CONC_VALUE_SCREEN_COL'],
                        ec50_lower_margin  = hparams['EC50_LOWER_MARGIN'],
                        ec50_upper_margin  = hparams['EC50_UPPER_MARGIN'],
                        ec50_lower_extreme = hparams['EC50_LOWER_EXTREME'],
                        ec50_upper_extreme = hparams['EC50_UPPER_EXTREME'],
                        )

    if hparams['CACHE_SEQ_LOOKUP']:
        id_mapping_table, seq_embedding_lookup = element.create_seq_embedding_lookups()
    else:
        id_mapping_table, seq_embedding_lookup = None, None
    logger.info('Valid dataset size: {}'.format(len(valid_dataset)))

    if hparams['LOADER_OUTPUT_TYPE'] == 'tf':
        valid_dataset.element_preprocess = element.make_element_preprocess()

        valid_loader = get_tf_loader_masked(valid_dataset, 
                                   batch_size = hparams['BATCH_SIZE'],
                                   use_cache = hparams['CACHE'],
                                   shuffle = False,
                                   shuffle_buffer_size = hparams['SHUFFLE_BUFFER_SIZE'],
                                   drop_last = False,
                                   id_mapping_table = id_mapping_table,
                                   seq_embedding_lookup = seq_embedding_lookup,
                                   n_partitions = hparams['N_PARTITIONS'])
    else:
        raise NotImplementedError('jax loader is not supported.')

    if not hparams['PYTABLE_FROM_DISK']:
        h5file.close()
        logger.info('Table closed...: {}'.format(hparams['H5FILE']))

    # ----------------
    # Initializations:
    # ----------------
    prng_key = jax.random.PRNGKey(int(time.time()))
    key_params, _key_num_steps, key_num_steps, key_dropout, seq_mask_key = jax.random.split(prng_key, 5)

    # Initializations:
    start = time.time()
    logger.info('Initializing...')
    init_model = make_init_model(model, 
                                batch_size = hparams['BATCH_SIZE'], 
                                n_partitions = hparams['N_PARTITIONS'],
                                seq_embedding_size = hparams['SEQ_EMBEDDING_SIZE'], 
                                num_node_features = len(hparams['ATOM_FEATURES']), 
                                num_edge_features = len(hparams['BOND_FEATURES']), 
                                self_loops = hparams['SELF_LOOPS'], 
                                line_graph = hparams['LINE_GRAPH'],
                                seq_max_length = hparams['SEQ_MAX_LENGTH'],
                                padding_n_node = hparams['PADDING_N_NODE'], 
                                padding_n_edge = hparams['PADDING_N_EDGE']) # 768)
    params = init_model(rngs = {'params' : key_params, 'dropout' : key_dropout, 'num_steps' : _key_num_steps}) # jax.random.split(key1, jax.device_count()))
    end = time.time()
    logger.info('TIME: init_model: {}'.format(end - start))

    transition_steps = hparams['OPTIMIZATION']['TRANSITION_EPOCHS']*(len(valid_dataset)/hparams['BATCH_SIZE'])
    # if hparams['SIZE_CUT_DIRNAME'] is not None:
    #     transition_steps += hparams['OPTIMIZATION']['TRANSITION_EPOCHS']*(len(_big_dataset)/hparams['BIG_BATCH_SIZE'])

    create_optimizer = make_create_optimizer(model, 
                                            option = hparams['OPTIMIZATION']['OPTION'],
                                            warmup_steps = hparams['OPTIMIZATION']['WARMUP_STEPS'],
                                            transition_steps = transition_steps)
    init_state, scheduler = create_optimizer(params, 
                                            rngs = {'dropout' : key_dropout, 'num_steps' : key_num_steps, 'seq_mask' : seq_mask_key},
                                            learning_rate = hparams['LEARNING_RATE'])

    reg_loss_func_embed = make_regularization_loss(params_path = ['params/atomic_num_embed/node_embed/embedding',
                                                            'params/chiral_tag_embed/node_embed/embedding',
                                                            'params/hybridization_embed/node_embed/embedding',
                                                            'params/X_proj_non_embeded/kernel',
                                                            'params/X_proj_non_embeded/bias',
                                                            'params/bond_type_embed/edge_embed/embedding',
                                                            'params/E_proj_non_embeded/kernel',
                                                            'params/E_proj_non_embeded/bias',
                                                            'params/stereo_embed/edge_embed/embedding',
                                                            ], alpha = 0.01, option = 'l1')
    reg_loss_func_kernel = make_regularization_loss(params_path = ['kernel',
                                                            ], alpha = 0.01, option = 'l2')
    def reg_loss_func(params):
        return reg_loss_func_embed(params) + reg_loss_func_kernel(params)

    if hparams['N_PARTITIONS'] > 0:
        raise NotImplementedError('pmap needs to be checked...')
    else:
        valid_epoch = make_valid_conc_masked_params_epoch(loss_option = hparams['LOSS_OPTION'], 
                                    logger = logger,
                                    aux_loss_option = hparams['AUXILIARY_LOSS_OPTION'],
                                    loader_output_type = hparams['LOADER_OUTPUT_TYPE'], 
                                    num_classes = model.out_features,
                                    mask_token_id = tokenizer.mask_token_id,
                                    cls_token_id = tokenizer.cls_token_id,
                                    pad_token_id = tokenizer.pad_token_id,
                                    sep_token_id = tokenizer.sep_token_id,
                                    unk_token_id = tokenizer.unk_token_id,
                                    eos_token_id = tokenizer.eos_token_id,
                                    bos_token_id = tokenizer.bos_token_id,
                                    min_conc_sample = hparams['MIN_CONC_SAMPLE'], 
                                    max_conc_sample = hparams['MAX_CONC_SAMPLE'], 
                                    step_conc_sample = hparams['STEP_CONC_SAMPLE'],
                                    conc_parameter_id_map = valid_dataset.conc_parameter_id_map)

    # Log hyperparams:
    _hparams = {}
    for key in hparams.keys():
        _hparams[key] = _serialize_hparam(hparams[key])
    hparams_logs = _hparams
    with open(os.path.join(logdir, 'hparams_logs.json'), 'w') as jsonfile:
        json.dump(hparams_logs, jsonfile)

    # Training:
    logger.info('Training...')
    valid_writer = SummaryWriter(os.path.join(logdir, writer_name))

    # Order checkpoints:
    ckpt_filenames = os.listdir(restore_ckpts_dir)
    ckpt_epochs = [int(file.replace('state_e', '').replace('.pkl', '')) for file in ckpt_filenames]
    ordered_ckpt_idx = np.argsort(ckpt_epochs)
    ordered_ckpt_filenames = [ckpt_filenames[i] for i in ordered_ckpt_idx]

    for filename in ordered_ckpt_filenames:
        restore_file = os.path.join(restore_ckpts_dir, filename)

        # Restore params:
        if restore_file is not None:
            logger.info('Restoring parameters from {}'.format(restore_file))
            with open(restore_file, 'rb') as pklfile:
                bytes_output = pickle.load(pklfile)
            state = serialization.from_bytes(init_state, bytes_output)
            logger.info('Parameters restored...')

        epoch = state.epoch

        logger.info('Epoch:  {}'.format(epoch))

        start = time.time()
        state, valid_metrics = valid_epoch(state, valid_loader)
        end = time.time()
        logger.info('TIME: valid_epoch: {}'.format(end - start))
        valid_metrics_np = jax.device_get(valid_metrics)
        lr = scheduler(state.step)

        summary = Summary()
        summary.scalar('learning_rate', lr)
        summary = log_metrics_concentration_from_epoch(name = writer_name,
                            epoch = epoch, 
                            hparams = hparams, 
                            metrics_np = valid_metrics_np, 
                            logger = logger, 
                            summary = summary)
        summary.scalar('regularization_loss', reg_loss_func(state.params))

        valid_writer.write(summary, step = int(epoch))
        valid_writer.writer.flush()

    valid_writer.close()
    logger.info('Finished...')
    return None