from typing import Tuple
import functools
import jax
import jax.numpy as jnp
from flax.training.train_state import TrainState

from src.common import InfoDict, Params, Batch
from src.agents.td3_derl.rnd_states import RNDTrainState


def rnd_bonus(
    rnd: TrainState, state: jax.Array, action: jax.Array
) -> jax.Array:
    pred, target = rnd.apply_fn(rnd.params, state, action)
    bonus = jnp.sum((pred - target) ** 2, axis=1)
    return bonus


@functools.partial(jax.jit)
def update_rnd(
    rng: jax.random.PRNGKey,
    rnd: RNDTrainState,
    batch: Batch,
) -> Tuple[TrainState, InfoDict]:

    def rnd_loss_fn(params: Params) -> Tuple[jax.Array, InfoDict]:
        pred, target = rnd.apply_fn(params, batch.observations, batch.actions)
        rnd_loss = jnp.sum((pred - target) ** 2, axis=1)
        new_rms = rnd.rms.update(rnd_loss)
        rnd_loss = rnd_loss.mean()
        return rnd_loss, new_rms

    (rnd_loss, new_rms), grads = jax.value_and_grad(rnd_loss_fn, has_aux=True)(
        rnd.params
    )
    new_rnd = rnd.apply_gradients(grads=grads).replace(rms=new_rms)

    # log rnd bonus for random actions
    rng, random_key = jax.random.split(rng, 2)
    random_actions = jax.random.uniform(
        random_key, shape=batch.actions.shape, minval=-1.0, maxval=1.0
    )
    pred, target = rnd.apply_fn(rnd.params, batch.observations, random_actions)
    random_bonus = jnp.sum((pred - target) ** 2, axis=1).mean()

    info = {
        "rnd_loss": rnd_loss,
        "rnd_rms": new_rnd.rms.std,
        "rnd_data": rnd_loss / rnd.rms.std,
        "rnd_random": random_bonus,
    }
    return new_rnd, info
