import jax
from jax import numpy as jnp

from ProtLig_GPCRclassA.utils import tf_to_jax, tf_to_jraph_graph_reshape

from ProtLig_GPCRclassA.amino_GNN.make_loss_func import make_loss_func, make_aux_loss_func
from ProtLig_GPCRclassA.amino_GNN.make_compute_metrics import make_compute_metrics

# from Receptor_odorant.JAX.BERT_GNN.CLS_GTransformer.profiling_configs import hparams, _datacase, _h5file, dataparams, datadir, logdir, model, get_profiler_logdir

def make_eval_step():
    def eval_step(state, batch):
        state = state.replace(rngs = jax.tree_map(lambda x: jax.random.split(x)[0], state.rngs)) # update PRNGKeys
        logits = state.apply_fn(state.params, batch[:-1], deterministic = True)
        return state, logits
    # return eval_step
    return eval_step


def make_valid_epoch(loss_option, logger, is_weighted = False, aux_loss_option = None, loader_output_type = 'jax', num_classes = 3):
    """
    Helper function to create valid_epoch function.
    """
    loss_func = make_loss_func(is_weighted = is_weighted, option = loss_option, num_classes = num_classes)
    if aux_loss_option is not None:
        aux_loss_func = make_aux_loss_func(option = aux_loss_option)
    else:
        aux_loss_func = None
    compute_metrics = make_compute_metrics(loss_func = loss_func, use_jit = False, num_classes = num_classes, aux_loss_func = aux_loss_func)
    eval_step = make_eval_step()
    # jit:
    eval_step = jax.jit(eval_step)

    # Case loader outputs jnp.DeviceArray:
    if loader_output_type == 'jax':
        def valid_epoch(state, valid_loader):
            batch_metrics = []
            for i, batch in enumerate(valid_loader):
                seq = batch[0]
                G = batch[1]
                labels = batch[2]
                if isinstance(labels, (list, tuple)):
                    labels = labels[0]
                S = seq # ['hidden_states']
                batch = (S, G, labels)
                state, logits = eval_step(state, batch)
                metrics = compute_metrics(logits, labels = labels)
                logger.debug('eval_step: {}:  eval_loss:  {}'.format(i, metrics['loss']))
                batch_metrics.append(metrics)
            valid_loader.reset()
            return state, batch_metrics
    # Case loader outputs tf.Tensor:
    elif loader_output_type == 'tf':
        def valid_epoch(state, valid_loader):
            batch_metrics = []
            for i, batch in valid_loader.enumerate():
                batch = jax.tree_map(lambda x: jax.device_put(tf_to_jax(x), device = jax.devices()[0]), batch)
                batch = batch[0], tf_to_jraph_graph_reshape(batch[1]), batch[2]               
                seq = batch[0]
                G = batch[1] 
                labels = batch[2]
                if isinstance(labels, (list, tuple)):
                    labels = labels[0]
                S = seq # ['hidden_states']
                batch = (S, G, labels)
                state, logits = eval_step(state, batch)
                metrics = compute_metrics(logits, labels = labels)
                logger.debug('eval_step: {}:  eval_loss:  {}'.format(i, metrics['loss']))
                batch_metrics.append(metrics)
            return state, batch_metrics
    return valid_epoch