from typing import Tuple
import functools
import jax
import jax.numpy as jnp
from flax.training.train_state import TrainState

from src.common import InfoDict, Params, Batch


@functools.partial(jax.jit, static_argnames=["soft_critic", "k_samples"])
def update(
    rng: jax.random.PRNGKey,
    actor: TrainState,
    critic: TrainState,
    target_critic: TrainState,
    temp: TrainState,
    batch: Batch,
    discount: float,
    soft_critic: bool,
    k_samples: int = 1,
) -> Tuple[TrainState, InfoDict]:

    rng, actor_key, q_key, next_q_key = jax.random.split(rng, 4)
    batch_size = batch.rewards.shape[0]

    # next action
    dist = actor.apply_fn(actor.params, batch.next_observations)
    next_actions = dist.sample(seed=actor_key)
    next_log_probs = dist.log_prob(next_actions)

    # next q
    next_q1s = jnp.zeros((batch_size, k_samples))
    next_q2s = jnp.zeros((batch_size, k_samples))
    for idx in range(k_samples):
        next_q_key, dropout_key_next = jax.random.split(next_q_key, 2)
        next_q1, next_q2 = target_critic.apply_fn(
            target_critic.params,
            batch.next_observations,
            next_actions,
            rngs={"dropout": dropout_key_next},
            training=True,
        )
        next_q1s = next_q1s.at[:, idx].set(next_q1)
        next_q2s = next_q2s.at[:, idx].set(next_q2)
    next_q = jnp.minimum(next_q1s.mean(axis=1), next_q2s.mean(axis=1))

    # target q
    target_q = batch.rewards + discount * batch.masks * next_q

    if soft_critic:
        target_q -= (
            discount
            * batch.masks
            * temp.apply_fn(temp.params)
            * next_log_probs
        )

    def critic_loss_fn(critic_params: Params) -> Tuple[jax.Array, InfoDict]:

        q_rng = q_key
        q1s = jnp.zeros((batch_size, k_samples))
        q2s = jnp.zeros((batch_size, k_samples))
        for idx in range(k_samples):
            q_rng, dropout_key_curr = jax.random.split(q_rng, 2)
            q1, q2 = critic.apply_fn(
                critic_params,
                batch.observations,
                batch.actions,
                rngs={"dropout": dropout_key_curr},
                training=True,
            )
            q1s = q1s.at[:, idx].set(q1)
            q2s = q2s.at[:, idx].set(q2)

        q1 = q1s.mean(axis=1)
        q2 = q2s.mean(axis=1)
        q1_var = q1s.var(axis=1)
        q2_var = q2s.var(axis=1)

        critic_loss = (q1 - target_q) ** 2 + (q2 - target_q) ** 2
        critic_loss = critic_loss.mean()
        info = {
            "critic_loss": critic_loss,
            "q1": q1.mean(),
            "q2": q2.mean(),
            "q1_var": q1_var.mean(),
            "q2_var": q2_var.mean(),
        }

        return critic_loss, info

    grads, info = jax.grad(critic_loss_fn, has_aux=True)(critic.params)
    new_critic = critic.apply_gradients(grads=grads)

    return new_critic, info


@functools.partial(jax.jit, static_argnames=["soft_critic", "k_samples"])
def update_max(
    rng: jax.random.PRNGKey,
    actor: TrainState,
    critic: TrainState,
    target_critic: TrainState,
    temp: TrainState,
    batch: Batch,
    discount: float,
    soft_critic: bool,
    k_samples: int = 1,
) -> Tuple[TrainState, InfoDict]:

    rng, actor_key, q_key, next_q_key = jax.random.split(rng, 4)
    batch_size = batch.rewards.shape[0]

    # next action
    dist = actor.apply_fn(actor.params, batch.next_observations)
    next_actions = dist.sample(seed=actor_key)
    next_log_probs = dist.log_prob(next_actions)

    # next q
    next_q1s = jnp.zeros((batch_size, k_samples))
    next_q2s = jnp.zeros((batch_size, k_samples))
    next_q3s = jnp.zeros((batch_size, k_samples))
    for idx in range(k_samples):
        next_q_key, dropout_key_next = jax.random.split(next_q_key, 2)
        next_q1, next_q2, next_q3 = target_critic.apply_fn(
            target_critic.params,
            batch.next_observations,
            next_actions,
            rngs={"dropout": dropout_key_next},
            training=True,
        )
        next_q1s = next_q1s.at[:, idx].set(next_q1)
        next_q2s = next_q2s.at[:, idx].set(next_q2)
        next_q3s = next_q3s.at[:, idx].set(next_q3)

    # next q
    next_q = jnp.vstack(
        [
            next_q1s.mean(axis=1),
            next_q2s.mean(axis=1),
            next_q3s.mean(axis=1),
        ]
    ).max(axis=0)

    # target q
    target_q = batch.rewards + discount * batch.masks * next_q

    if soft_critic:
        target_q -= (
            discount
            * batch.masks
            * temp.apply_fn(temp.params)
            * next_log_probs
        )

    def critic_loss_fn(critic_params: Params) -> Tuple[jax.Array, InfoDict]:

        q_rng = q_key
        q1s = jnp.zeros((batch_size, k_samples))
        q2s = jnp.zeros((batch_size, k_samples))
        q3s = jnp.zeros((batch_size, k_samples))
        for idx in range(k_samples):
            q_rng, dropout_key_curr = jax.random.split(q_rng, 2)
            q1, q2, q3 = critic.apply_fn(
                critic_params,
                batch.observations,
                batch.actions,
                rngs={"dropout": dropout_key_curr},
                training=True,
            )
            q1s = q1s.at[:, idx].set(q1)
            q2s = q2s.at[:, idx].set(q2)
            q3s = q3s.at[:, idx].set(q3)

        q1 = q1s.mean(axis=1)
        q2 = q2s.mean(axis=1)
        q3 = q3s.mean(axis=1)
        q1_var = q1s.var(axis=1)
        q2_var = q2s.var(axis=1)
        q3_var = q3s.var(axis=1)

        critic_loss = (
            (q1 - target_q) ** 2 + (q2 - target_q) ** 2 + (q3 - target_q) ** 2
        )
        critic_loss = critic_loss.mean()
        info = {
            "critic_loss_explore": critic_loss,
            "q1_explore": q1.mean(),
            "q2_explore": q2.mean(),
            "q3_explore": q3.mean(),
            "q1_var_explore": q1_var.mean(),
            "q2_var_explore": q2_var.mean(),
            "q3_var_explore": q3_var.mean(),
        }

        return critic_loss, info

    grads, info = jax.grad(critic_loss_fn, has_aux=True)(critic.params)
    new_critic = critic.apply_gradients(grads=grads)

    return new_critic, info
