from typing import Tuple
import functools
import jax
import jax.numpy as jnp
from flax.training.train_state import TrainState

from src.common import InfoDict, Params, Batch


@functools.partial(jax.jit, static_argnames=["soft_critic"])
def update(
    rng: jax.random.PRNGKey,
    actor: TrainState,
    critic: TrainState,
    target_critic: TrainState,
    temp: TrainState,
    batch: Batch,
    discount: float,
    soft_critic: bool,
) -> Tuple[TrainState, InfoDict]:

    rng, actor_key = jax.random.split(rng, 2)

    # next action
    dist = actor.apply_fn(actor.params, batch.next_observations)
    next_actions = dist.sample(seed=actor_key)
    next_log_probs = dist.log_prob(next_actions)

    # next q
    next_q1, next_q2 = target_critic.apply_fn(
        target_critic.params,
        batch.next_observations,
        next_actions,
    )
    next_q = jnp.minimum(next_q1, next_q2)

    # target q
    target_q = batch.rewards + discount * batch.masks * next_q

    if soft_critic:
        target_q -= (
            discount
            * batch.masks
            * temp.apply_fn(temp.params)
            * next_log_probs
        )

    def critic_loss_fn(critic_params: Params) -> Tuple[jax.Array, InfoDict]:

        q1, q2 = critic.apply_fn(
            critic_params,
            batch.observations,
            batch.actions,
        )
        critic_loss = (q1 - target_q) ** 2 + (q2 - target_q) ** 2
        critic_loss = critic_loss.mean()
        info = {
            "critic_loss": critic_loss,
            "q1": q1.mean(),
            "q2": q2.mean(),
        }

        return critic_loss, info

    grads, info = jax.grad(critic_loss_fn, has_aux=True)(critic.params)
    new_critic = critic.apply_gradients(grads=grads)

    return new_critic, info
