from typing import Tuple
import functools
import jax
import jax.numpy as jnp
from flax.training.train_state import TrainState

from src.common import InfoDict, Params, Batch


@functools.partial(jax.jit, static_argnames=["k_samples", "optimistic"])
def update(
    rng: jax.random.PRNGKey,
    actor: TrainState,
    critic: TrainState,
    batch: Batch,
    k_samples: int = 1,
    optimistic: str = False,
) -> Tuple[TrainState, InfoDict]:

    rng, q_key = jax.random.split(rng, 2)
    batch_size = batch.rewards.shape[0]

    def actor_loss_fn(actor_params: Params) -> Tuple[jax.Array, InfoDict]:

        q_rng = q_key
        pi_actions = actor.apply_fn(actor_params, batch.observations)
        q1s = jnp.zeros((batch_size, k_samples))
        q2s = jnp.zeros((batch_size, k_samples))
        q3s = jnp.zeros((batch_size, k_samples))
        q4s = jnp.zeros((batch_size, k_samples))

        for idx in range(k_samples):
            q_rng, dropout_key_curr = jax.random.split(q_rng, 2)
            qs = critic.apply_fn(
                critic.params,
                batch.observations,
                pi_actions,
                rngs={"dropout": dropout_key_curr},
                training=True,
            )
            q1s = q1s.at[:, idx].set(qs[0])
            q2s = q2s.at[:, idx].set(qs[1])
            if optimistic:
                q3s = q3s.at[:, idx].set(qs[2])
                q4s = q4s.at[:, idx].set(qs[3])

        q1_var = q1s.var(axis=1)
        q2_var = q2s.var(axis=1)
        q3_var = q3s.var(axis=1)
        q4_var = q4s.var(axis=1)

        actor_loss = -q1s.mean()
        info = {
            "actor_loss": actor_loss,
            "actor_q1_var": q1_var.mean(),
            "actor_q2_var": q2_var.mean(),
            "actor_q3_var": q3_var.mean(),
            "actor_q4_var": q4_var.mean(),
        }

        return actor_loss, info

    grads, info = jax.grad(actor_loss_fn, has_aux=True)(actor.params)
    new_actor = actor.apply_gradients(grads=grads)

    return new_actor, info
