import jax
from jax import numpy as jnp

def make_monotonicity_loss_func(is_weighted, apply_fn, slope_pos = 0.1, slope_neg = 0.1, eps_pos = 0.01, eps_neg = 0.05):
    # def func(params, S, G, c):
    #     _globals = G.globals.copy()
    #     _globals['_conc'] = c
    #     _G = G._replace(globals = _globals)
    #     return apply_fn(params, (S, _G), deterministic = True)

    # def dfunc_dc(params, S, G, c):
    #     jacobian_fn = jax.jacrev(func, argnums = 3)
    #     J_c = jacobian_fn(params, S, G, c)
    #     J_c = jnp.squeeze(J_c['_main_label']) # TODO: Check why is there an extra dimension?!
    #     return jnp.diag(J_c)

    def func(params, S, G, c):
        _globals = G.globals.copy()
        _globals['_conc'] = c
        _G = G._replace(globals = _globals)
        val = apply_fn(params, (S, _G), deterministic = True)
        return jnp.sum(val['_main_label'])

    dfunc_dc = jax.grad(func, argnums = 3)

    if slope_pos is not None or slope_neg is not None:
        assert slope_pos > 0.0 or slope_neg > 0.0
        def _main_monotonicity_loss_func(params, batch, labels):
            """
            NOTE: dfunc_dc is evaluated in batch concentration.
            """
            S, G = batch
            c = G.globals['_conc']
            g_c = dfunc_dc(params, S, G, c)
            # loss_vals = labels['_main_label'] * slope * jax.nn.relu( - g_c + eps) # NOTE: loss is applied only to positive examples due to multiplication by labels. NOTE: previous version.
            # NOTE: slope*(y*max{0, -g - eps_pos} + (1-y)*max{0, |g| - eps_neg})
            loss_vals = labels['_main_label'] * slope_pos * jax.nn.relu( - g_c - eps_pos) + (1 - labels['_main_label']) * slope_neg * jax.nn.relu( jnp.abs(g_c) - eps_neg)
            return loss_vals
    else:
        # NOTE: Ignoring monotonicity loss.
        def _main_monotonicity_loss_func(params, batch, labels):
            return 0.0

    # Weighted:
    if is_weighted:
        def monotonicity_loss_func(params, batch, labels):
            loss_vals = _main_monotonicity_loss_func(params, batch, labels)
            weighted_loss_val = labels['_main_sample_weight'] * loss_vals
            return jnp.mean(weighted_loss_val)
    else:
        def monotonicity_loss_func(params, batch, labels):
            loss_vals = _main_monotonicity_loss_func(params, batch, labels)
            return jnp.mean(loss_vals)

    return monotonicity_loss_func