import jax
from jax import numpy as jnp
import jraph

from ProtLig_GPCRclassA.utils import tf_to_jax, tf_to_jraph_graph_reshape

def make_predict_step(model, return_intermediates = False, num_classes = 2):
    if num_classes == 2:
        if return_intermediates:
            def predict_step(params, batch):
                preds, intermediates = model.apply(params, batch, deterministic = True, mutable=['intermediates'])
                preds['_main_label'] = jax.nn.sigmoid(preds['_main_label'])
                return preds, intermediates
        else:
            def predict_step(params, batch):
                preds = model.apply(params, batch, deterministic = True)
                preds['_main_label'] = jax.nn.sigmoid(preds['_main_label'])
                return preds
        # return predict_step
    else:
        if return_intermediates:
            def predict_step(params, batch):
                preds, intermediates = model.apply(params, batch, deterministic = True, mutable=['intermediates'])
                preds['_main_label'] = jax.nn.softmax(preds['_main_label'])
                return preds, intermediates
        else:
            def predict_step(params, batch):
                preds = model.apply(params, batch, deterministic = True)
                preds['_main_label'] = jax.nn.softmax(preds['_main_label'])
                return preds
        # return predict_step
    return jax.jit(predict_step)




def make_predict_conc_epoch(model, return_intermediates = False, num_classes = 2, loader_output_type = 'jax'):
    """
    Helper function to create valid_epoch function.
    """
    predict_step = make_predict_step(model = model, return_intermediates = return_intermediates, num_classes = num_classes)
    # Case loader outputs jnp.DeviceArray:
    if loader_output_type == 'jax':
        raise NotImplementedError('Logic chaged. See tf version of epoch...')
        def predict_epoch(params, predict_loader):
            batch_outputs = []
            for i, batch in enumerate(predict_loader):
                print('---- {}'.format(i))
                S = batch[0]
                G = batch[1]
                batch = (S, G)
                outputs = predict_step(params, batch)
                batch_outputs.append(outputs)
            predict_loader.reset()
            return batch_outputs
    # Case loader outputs tf.Tensor:
    elif loader_output_type == 'tf':
        def predict_epoch(params, predict_loader):
            batch_outputs = []
            for i, batch in predict_loader.enumerate():
                batch = jax.tree_map(lambda x: jax.device_put(tf_to_jax(x), device = jax.devices()[0]), batch)
                S = batch[0]
                G = batch[1]
                G =  tf_to_jraph_graph_reshape(G)
                batch = (S, G)
                outputs = predict_step(params, batch)
                batch_outputs.append(outputs)
            return batch_outputs
    return predict_epoch