from typing import Tuple
import functools
import jax
import jax.numpy as jnp
from flax.training.train_state import TrainState

from src.common import InfoDict, Params, Batch


@functools.partial(jax.jit, static_argnames=["soft_critic", "k_samples"])
def update(
    rng: jax.random.PRNGKey,
    actor: TrainState,
    critic: TrainState,
    target_critic: TrainState,
    temp: TrainState,
    batch: Batch,
    discount: float,
    soft_critic: bool,
    k_samples: int = 1,
) -> Tuple[TrainState, InfoDict]:

    rng, actor_key, q_key, next_q_key = jax.random.split(rng, 4)
    batch_size = batch.rewards.shape[0]

    # next action
    dist = actor.apply_fn(actor.params, batch.next_observations)
    next_actions = dist.sample(seed=actor_key)
    next_log_probs = dist.log_prob(next_actions)

    # next q
    next_q1s = jnp.zeros((batch_size, k_samples))
    next_q2s = jnp.zeros((batch_size, k_samples))
    for idx in range(k_samples):
        next_q_key, dropout_key_next = jax.random.split(next_q_key, 2)
        next_q1, next_q2 = target_critic.apply_fn(
            target_critic.params,
            batch.next_observations,
            next_actions,
            rngs={"dropout": dropout_key_next},
            training=True,
        )
        next_q1s = next_q1s.at[:, idx].set(next_q1)
        next_q2s = next_q2s.at[:, idx].set(next_q2)
    next_q = jnp.minimum(next_q1s.mean(axis=1), next_q2s.mean(axis=1))

    # target q
    target_q = batch.rewards + discount * batch.masks * next_q

    if soft_critic:
        target_q -= (
            discount
            * batch.masks
            * temp.apply_fn(temp.params)
            * next_log_probs
        )

    def critic_loss_fn(critic_params: Params) -> Tuple[jax.Array, InfoDict]:

        q_rng = q_key
        q1s = jnp.zeros((batch_size, k_samples))
        q2s = jnp.zeros((batch_size, k_samples))
        for idx in range(k_samples):
            q_rng, dropout_key_curr = jax.random.split(q_rng, 2)
            q1, q2 = critic.apply_fn(
                critic_params,
                batch.observations,
                batch.actions,
                rngs={"dropout": dropout_key_curr},
                training=True,
            )
            q1s = q1s.at[:, idx].set(q1)
            q2s = q2s.at[:, idx].set(q2)

        q1 = q1s.mean(axis=1)
        q2 = q2s.mean(axis=1)
        q1_var = q1s.var(axis=1)
        q2_var = q2s.var(axis=1)

        critic_loss = (q1 - target_q) ** 2 + (q2 - target_q) ** 2
        critic_loss = critic_loss.mean()
        info = {
            "critic_loss": critic_loss,
            "q1": q1.mean(),
            "q2": q2.mean(),
            "q1_var": q1_var.mean(),
            "q2_var": q2_var.mean(),
        }

        return critic_loss, info

    grads, info = jax.grad(critic_loss_fn, has_aux=True)(critic.params)
    new_critic = critic.apply_gradients(grads=grads)

    return new_critic, info


@functools.partial(jax.jit, static_argnames=["soft_critic", "k_samples"])
def update_max(
    rng: jax.random.PRNGKey,
    actor: TrainState,
    critic: TrainState,
    target_critic: TrainState,
    temp: TrainState,
    batch: Batch,
    discount: float,
    soft_critic: bool,
    k_samples: int = 1,
) -> Tuple[TrainState, InfoDict]:

    rng, actor_key, q_key, next_q_key = jax.random.split(rng, 4)
    batch_size = batch.rewards.shape[0]

    # next action
    dist = actor.apply_fn(actor.params, batch.next_observations)
    next_actions = dist.sample(seed=actor_key)
    next_log_probs = dist.log_prob(next_actions)

    # next q
    next_q1s = jnp.zeros((batch_size, k_samples))
    next_q2s = jnp.zeros((batch_size, k_samples))
    next_q3s = jnp.zeros((batch_size, k_samples))
    next_q4s = jnp.zeros((batch_size, k_samples))
    next_q5s = jnp.zeros((batch_size, k_samples))
    next_q6s = jnp.zeros((batch_size, k_samples))
    next_q7s = jnp.zeros((batch_size, k_samples))
    next_q8s = jnp.zeros((batch_size, k_samples))
    next_q9s = jnp.zeros((batch_size, k_samples))
    next_q10s = jnp.zeros((batch_size, k_samples))
    for idx in range(k_samples):
        next_q_key, dropout_key_next = jax.random.split(next_q_key, 2)
        next_qs = target_critic.apply_fn(
            target_critic.params,
            batch.next_observations,
            next_actions,
            rngs={"dropout": dropout_key_next},
            training=True,
        )
        next_q1s = next_q1s.at[:, idx].set(next_qs[0])
        next_q2s = next_q2s.at[:, idx].set(next_qs[1])
        next_q3s = next_q3s.at[:, idx].set(next_qs[2])
        next_q4s = next_q4s.at[:, idx].set(next_qs[3])
        next_q5s = next_q5s.at[:, idx].set(next_qs[4])
        next_q6s = next_q6s.at[:, idx].set(next_qs[5])
        next_q7s = next_q7s.at[:, idx].set(next_qs[6])
        next_q8s = next_q8s.at[:, idx].set(next_qs[7])
        next_q9s = next_q9s.at[:, idx].set(next_qs[8])
        next_q10s = next_q10s.at[:, idx].set(next_qs[9])

    # next q
    next_q = jnp.vstack(
        [
            next_q1s.mean(axis=1),
            next_q2s.mean(axis=1),
            next_q3s.mean(axis=1),
            next_q4s.mean(axis=1),
            next_q5s.mean(axis=1),
            next_q6s.mean(axis=1),
            next_q7s.mean(axis=1),
            next_q8s.mean(axis=1),
            next_q9s.mean(axis=1),
            next_q10s.mean(axis=1),
        ]
    ).max(axis=0)

    # target q
    target_q = batch.rewards + discount * batch.masks * next_q

    if soft_critic:
        target_q -= (
            discount
            * batch.masks
            * temp.apply_fn(temp.params)
            * next_log_probs
        )

    def critic_loss_fn(critic_params: Params) -> Tuple[jax.Array, InfoDict]:

        q_rng = q_key
        q1s = jnp.zeros((batch_size, k_samples))
        q2s = jnp.zeros((batch_size, k_samples))
        q3s = jnp.zeros((batch_size, k_samples))
        q4s = jnp.zeros((batch_size, k_samples))
        q5s = jnp.zeros((batch_size, k_samples))
        q6s = jnp.zeros((batch_size, k_samples))
        q7s = jnp.zeros((batch_size, k_samples))
        q8s = jnp.zeros((batch_size, k_samples))
        q9s = jnp.zeros((batch_size, k_samples))
        q10s = jnp.zeros((batch_size, k_samples))
        for idx in range(k_samples):
            q_rng, dropout_key_curr = jax.random.split(q_rng, 2)
            qs = critic.apply_fn(
                critic_params,
                batch.observations,
                batch.actions,
                rngs={"dropout": dropout_key_curr},
                training=True,
            )
            q1s = q1s.at[:, idx].set(qs[0])
            q2s = q2s.at[:, idx].set(qs[1])
            q3s = q3s.at[:, idx].set(qs[2])
            q4s = q4s.at[:, idx].set(qs[3])
            q5s = q5s.at[:, idx].set(qs[4])
            q6s = q6s.at[:, idx].set(qs[5])
            q7s = q7s.at[:, idx].set(qs[6])
            q8s = q8s.at[:, idx].set(qs[7])
            q9s = q9s.at[:, idx].set(qs[8])
            q10s = q10s.at[:, idx].set(qs[9])

        q1 = q1s.mean(axis=1)
        q2 = q2s.mean(axis=1)
        q3 = q3s.mean(axis=1)
        q4 = q4s.mean(axis=1)
        q5 = q5s.mean(axis=1)
        q6 = q6s.mean(axis=1)
        q7 = q7s.mean(axis=1)
        q8 = q8s.mean(axis=1)
        q9 = q9s.mean(axis=1)
        q10 = q10s.mean(axis=1)
        q1_var = q1s.var(axis=1)
        q2_var = q2s.var(axis=1)
        q3_var = q3s.var(axis=1)
        q4_var = q4s.var(axis=1)
        q5_var = q5s.var(axis=1)
        q6_var = q6s.var(axis=1)
        q7_var = q7s.var(axis=1)
        q8_var = q8s.var(axis=1)
        q9_var = q9s.var(axis=1)
        q10_var = q10s.var(axis=1)

        critic_loss = (
            (q1 - target_q) ** 2
            + (q2 - target_q) ** 2
            + (q3 - target_q) ** 2
            + (q4 - target_q) ** 2
            + (q5 - target_q) ** 2
            + (q6 - target_q) ** 2
            + (q7 - target_q) ** 2
            + (q8 - target_q) ** 2
            + (q9 - target_q) ** 2
            + (q10 - target_q) ** 2
        )
        critic_loss = critic_loss.mean()
        info = {
            "critic_loss_explore": critic_loss,
            "q1_explore": q1.mean(),
            "q2_explore": q2.mean(),
            "q3_explore": q3.mean(),
            "q4_explore": q4.mean(),
            "q5_explore": q5.mean(),
            "q6_explore": q6.mean(),
            "q7_explore": q7.mean(),
            "q8_explore": q8.mean(),
            "q9_explore": q9.mean(),
            "q10_explore": q10.mean(),
            "q1_var_explore": q1_var.mean(),
            "q2_var_explore": q2_var.mean(),
            "q3_var_explore": q3_var.mean(),
            "q4_var_explore": q4_var.mean(),
            "q5_var_explore": q5_var.mean(),
            "q6_var_explore": q6_var.mean(),
            "q7_var_explore": q7_var.mean(),
            "q8_var_explore": q8_var.mean(),
            "q9_var_explore": q9_var.mean(),
            "q10_var_explore": q10_var.mean(),
        }

        return critic_loss, info

    grads, info = jax.grad(critic_loss_fn, has_aux=True)(critic.params)
    new_critic = critic.apply_gradients(grads=grads)

    return new_critic, info
