from flax import nnx
import jax
import jax.numpy as jnp
import optax
from functools import partial


@partial(jax.jit, static_argnames=("n_quantiles3",))
def jax_projection(quantiles1: jax.Array, weight1: jax.Array, quantiles2: jax.Array, weight2: jax.Array,
                   n_quantiles3: int):
    """
    :param quantiles1: quantiles shape (Batch, N_quantiles_1)
    :param weight1: weight of 1, shape (Batch, )
    :param quantiles2: quantiles shape (Batch, N_quantiles_2)
    :param weight2: weight of 2, shape (Batch, )
    :param n_quantiles3: number of output quantile, must be int, enforcing same shape of output
    :return:
    """
    n_quantiles1 = quantiles1.shape[-1]
    n_quantiles2 = quantiles2.shape[-1]
    integ_pmf1 = weight1 / (n_quantiles1 * (weight1 + weight2)) * jnp.ones_like(quantiles1)
    integ_pmf2 = weight2 / (n_quantiles2 * (weight1 + weight2)) * jnp.ones_like(quantiles2)

    merged_quantiles = jnp.concatenate([quantiles1, quantiles2], axis=-1)
    integ_pmf = jnp.concatenate([integ_pmf1, integ_pmf2], axis=-1)

    index = jnp.argsort(merged_quantiles, axis=-1, )
    sorted_quantiles = jnp.take_along_axis(merged_quantiles, index, axis=-1)
    integ_pmf = jnp.take_along_axis(integ_pmf, index, axis=-1)

    integ_cmf = jnp.cumsum(integ_pmf, axis=-1)
    cum_prob = integ_cmf - integ_pmf / 2.0
    new_quantiles = jax.vmap(jnp.interp, in_axes=(None, 0, 0,),
                             out_axes=0)((jnp.arange(n_quantiles3) + 0.5) / n_quantiles3,
                                         cum_prob, sorted_quantiles)

    return new_quantiles


def copy_param(model) -> nnx.Param:
    graph_def, params, _ = nnx.split(model, nnx.Param, ...)
    copied_param = jax.tree.map(lambda x: x.copy(), params)
    return copied_param


@jax.jit
def quanitle_regression_loss(target, predict, taus):
    pairwise_delta = target[..., None, :] - predict[..., None]

    abs_pairwise_delta = jnp.abs(pairwise_delta)
    taus = taus[..., None]
    loss = jnp.where(pairwise_delta < 0, (1 - taus) * abs_pairwise_delta, taus * abs_pairwise_delta)
    return loss


@partial(jax.jit, static_argnames=('soft_update_ratio',))
def polyak_update(graph_def, state, target_params, soft_update_ratio: float):
    model, _, _ = nnx.merge(graph_def, state)
    _, current_param, _ = nnx.split(model, nnx.Param, ...)
    new_params = optax.incremental_update(current_param, target_params, soft_update_ratio)
    return new_params


@partial(jax.jit, static_argnames=("n_target_quantiles",))
def distributional_gae(rewards, dones, next_quantiles,
                       gamma, gae_lambda, n_target_quantiles):
    targets = rewards.reshape(-1, 1) + gamma * (1 - dones.reshape(-1, 1)) * next_quantiles

    def scan_body(carry, t):
        gae_weight = carry['gae_weight']
        gae_target = carry['gae_target']
        target = rewards[t] + gamma * (1.0 - dones[t]) * next_quantiles[t]

        target = jax_projection(target[None], gae_weight, gae_target[None],  1.0 - gae_lambda,
                                n_target_quantiles).squeeze()

        gae_target = rewards[t - 1] + gamma * (1.0 - dones[t - 1]) * target
        gae_weight = gae_lambda * (1.0 - dones[t - 1]) * (1.0 - gae_lambda + gae_weight)
        new_carry = { "gae_weight": gae_weight, "gae_target": gae_target,
                      }

        return new_carry, target

    gae_target = rewards[-1] + gamma * (1.0 - dones[-1]) * next_quantiles[-1]

    init_carry = { "gae_weight": gae_lambda, "gae_target": gae_target,
                   }
    xs = jnp.arange(0, len(targets), dtype=jnp.int32)[::-1]
    _, target_quantiles_jax = jax.lax.scan(scan_body, init_carry, xs=xs)

    return target_quantiles_jax[::-1]


"""


@partial(jax.jit, static_argnames=("n_target_quantiles",))
def distributional_gae(rewards, dones, next_quantiles,
                       gamma, gae_lambda, n_target_quantiles):
    init = rewards[-1] + gamma * (1.0 - dones[-1]) * next_quantiles[-1]

    def body(carry, t):

        one_step = rewards[t] + gamma * (1.0 - dones[t]) * next_quantiles[t]
        # λ-return: (1-λ)·one_step ⊕ λ·carry
        target = jax_projection(one_step[None], 1.0 - gae_lambda,
                                carry[None], gae_lambda * (1.0 - dones[t]),
                                n_target_quantiles).squeeze()
        return target, target

    _, targets_rev = jax.lax.scan(body,
                                  init,
                                  jnp.arange(len(rewards) - 1)[::-1],
                                  reverse=False)
    return jnp.concatenate([targets_rev[::-1], init[None]], axis=0)
"""
