import jax
from jax import numpy as jnp
import jraph

from ProtLig_GPCRclassA.utils import tf_to_jax, tf_to_jraph_graph_reshape

def make_predict_step(model, return_intermediates = False, num_classes = 2):
    if num_classes == 2:
        if return_intermediates:
            def predict_step(params, batch):
                preds, intermediates = model.apply(params, batch, deterministic = True, mutable=['intermediates'])
                preds['_main_label'] = jax.nn.sigmoid(preds['_main_label'])
                return preds, intermediates
        else:
            def predict_step(params, batch):
                preds = model.apply(params, batch, deterministic = True)
                preds['_main_label'] = jax.nn.sigmoid(preds['_main_label'])
                return preds
        # return predict_step
    else:
        if return_intermediates:
            def predict_step(params, batch):
                preds, intermediates = model.apply(params, batch, deterministic = True, mutable=['intermediates'])
                preds['_main_label'] = jax.nn.softmax(preds['_main_label'])
                return preds, intermediates
        else:
            def predict_step(params, batch):
                preds = model.apply(params, batch, deterministic = True)
                preds['_main_label'] = jax.nn.softmax(preds['_main_label'])
                return preds
        # return predict_step
    return jax.jit(predict_step)




def make_predict_conc_range_epoch(model, min_conc_sample, max_conc_sample, step_conc_sample, return_intermediates = False, num_classes = 2, loader_output_type = 'jax'):
    """
    Helper function to create valid_epoch function.
    """
    concentrations = jnp.arange(min_conc_sample, max_conc_sample + step_conc_sample, step_conc_sample)

    predict_step = make_predict_step(model = model, return_intermediates = return_intermediates, num_classes = num_classes)
    # Case loader outputs jnp.DeviceArray:
    if loader_output_type == 'jax':
        raise NotImplementedError('Logic cahnged. See tf version of epoch...')
    # Case loader outputs tf.Tensor:
    elif loader_output_type == 'tf':
        def predict_epoch(params, predict_loader):
            batch_outputs = []
            for i, batch in predict_loader.enumerate():
                batch = jax.tree_map(lambda x: jax.device_put(tf_to_jax(x), device = jax.devices()[0]), batch)
                S = batch[0]
                G = batch[1]
                G =  tf_to_jraph_graph_reshape(G)

                batch_outputs_per_concentraion = []
                # Create batches with a given concentrations:
                for conc in concentrations:
                    _globals = G.globals.copy()
                    new_conc = conc * jnp.ones(shape = (S[0].shape[0],), dtype = S[0].dtype) # TODO: This forces to create a single dummy concentration value per pair.
                    _globals.update({'_conc' : new_conc})
                    new_G = G._replace(globals = _globals)
                    batch = (S, new_G)
                    outputs = predict_step(params, batch)
                    outputs.update({'_conc' : jnp.expand_dims(new_conc, axis = -1)})
                    batch_outputs_per_concentraion.append(outputs)
                
                outputs_per_concentraion = jax.tree.map(lambda *x: jnp.stack(x, axis = 1), *batch_outputs_per_concentraion)
                batch_outputs.append(outputs_per_concentraion)
            return batch_outputs
    return predict_epoch