import functools
from typing import Tuple
import jax
import jax.numpy as jnp

from common import Batch, InfoDict, Model, Params, PRNGKey


def loss(diff, expectile=0.8):
    weight = jnp.where(diff > 0, expectile, (1 - expectile))
    return weight * (diff**2)


# @functools.partial(
#     jax.jit, static_argnames=("critic_ensemble_size", "redq_subset_size")
# )
def update_v(
    key: PRNGKey,
    critic: Model,
    value: Model,
    batch: Batch,
    expectile: float,
    critic_ensemble_size: int,
    redq_subset_size: int = 2,
) -> Tuple[Model, InfoDict]:
    actions = batch.actions
    model_indices = jax.random.choice(
        key, critic_ensemble_size, [redq_subset_size], replace=False
    )
    qs = critic(batch.observations, actions)
    qs = qs[model_indices]
    q = jnp.min(qs, axis=0)
    # q1, q2 = critic(batch.observations, actions)
    # q = jnp.minimum(q1, q2)

    def value_loss_fn(value_params: Params) -> Tuple[jnp.ndarray, InfoDict]:
        v = value.apply({"params": value_params}, batch.observations)
        value_loss = loss(q - v, expectile).mean()
        return value_loss, {
            "value_loss": value_loss,
            "v": v.mean(),
        }

    new_value, info = value.apply_gradient(value_loss_fn)

    return new_value, info


def update_q(
    critic: Model, target_value: Model, batch: Batch, discount: float
) -> Tuple[Model, InfoDict]:
    next_v = target_value(batch.next_observations)

    target_q = batch.rewards + discount * batch.masks * next_v

    def critic_loss_fn(critic_params: Params) -> Tuple[jnp.ndarray, InfoDict]:
        # q1, q2 = critic.apply(
        #     {"params": critic_params}, batch.observations, batch.actions
        # )
        qs = critic.apply({"params": critic_params}, batch.observations, batch.actions)
        # critic_loss = ((q1 - target_q) ** 2 + (q2 - target_q) ** 2).mean()
        critic_loss = jnp.sum(
            jnp.stack([(q - target_q) ** 2 for q in qs]), axis=0
        ).mean()
        return critic_loss, {
            "critic_loss": critic_loss,
            # "q1": q1.mean(),
            # "q2": q2.mean(),
            "q1": qs[0].mean(),
            "q2": qs[1].mean(),
        }

    new_critic, info = critic.apply_gradient(critic_loss_fn)

    return new_critic, info
