from typing import Tuple
import functools
import jax
import jax.numpy as jnp
from flax.training.train_state import TrainState

from src.common import InfoDict, Params, Batch
from src.agents.sac_derl.rnd_states import RNDTrainState


@functools.partial(jax.jit)
def update(
    rng: jax.random.PRNGKey,
    actor: TrainState,
    critic: TrainState,
    temp: TrainState,
    rnd: RNDTrainState,
    batch: Batch,
    rnd_coeff: float = 0.0,
) -> Tuple[TrainState, InfoDict]:

    rng, actor_key = jax.random.split(rng, 2)

    def actor_loss_fn(actor_params: Params) -> Tuple[jax.Array, InfoDict]:

        dist = actor.apply_fn(actor_params, batch.observations)
        pi_actions = dist.sample(seed=actor_key)
        log_probs = dist.log_prob(pi_actions)

        q1, q2 = critic.apply_fn(critic.params, batch.observations, pi_actions)
        q = jnp.minimum(q1, q2)
        # pred, target = rnd.apply_fn(rnd.params, batch.observations, pi_actions)
        # rnd_bonus = jnp.sum((pred - target) ** 2, axis=1)

        actor_loss = (
            # log_probs * temp.apply_fn(temp.params) - q - rnd_coeff * rnd_bonus
            log_probs * temp.apply_fn(temp.params)
            - q
        ).mean()
        info = {
            "actor_loss": actor_loss,
            "actor_entropy": -log_probs.mean(),
            # "actor_rnd_bonus": rnd_bonus.mean(),
        }

        return actor_loss, info

    grads, info = jax.grad(actor_loss_fn, has_aux=True)(actor.params)
    new_actor = actor.apply_gradients(grads=grads)

    return new_actor, info
