from typing import Tuple

import jax.numpy as jnp
import jax

from common import Batch, InfoDict, Model, Params
from vae_utils import compute_vae
import policy


def loss(diff, use_bellman, expectile=0.8, bellman_expectile=0.7):
    # weight = jnp.where(diff > 0, expectile, (1 - expectile))
    # return weight * (diff**2)

    expectile_weight = jnp.where(diff > 0, expectile, (1 - expectile))
    expectile_loss = expectile_weight * (diff ** 2)
    weight = jnp.where(diff > 0, bellman_expectile, (1 - bellman_expectile))
    mse_loss = weight * (diff ** 2)
    loss = jnp.where(use_bellman, mse_loss, expectile_loss)
    return loss


def update_v(key, actor:Model, critic: Model, value: Model, vae_params: Tuple,
             batch: Batch,
             expectile: float, bernoulli_p: float, bellman_expectile: float,
             discount: float, temperature: float) -> Tuple[Model, InfoDict]:
    key, actor_rng, bernoullli_rng = jax.random.split(key, 3)
    actions = batch.actions
    observations = batch.observations
    q1, q2 = critic(observations, actions)
    q = jnp.minimum(q1, q2)

    rng, pred_actions = policy.sample_actions(actor_rng, actor.apply_fn,
                                             actor.params, observations,
                                             temperature=0.01)
    qb1, qb2 = critic(observations, pred_actions)
    q_bellman = jnp.minimum(qb1, qb2)
    use_bellman = jax.random.bernoulli(bernoullli_rng, p=bernoulli_p, shape=q_bellman.shape)

    q_final = jnp.where(use_bellman, q_bellman, q)
    bellman_cnt = jnp.count_nonzero(use_bellman)

    # q_final = (q + q_bellman)/2

    def value_loss_fn(value_params: Params) -> Tuple[jnp.ndarray, InfoDict]:
        v = value.apply({'params': value_params}, batch.observations)
        value_loss = loss(q_final - v, use_bellman, expectile, bellman_expectile).mean()
        return value_loss, {
            'value_loss': value_loss,
            'v': v.mean(),
        }

    new_value, info = value.apply_gradient(value_loss_fn, True)

    return new_value, info, bellman_cnt


def update_q(key, actor: Model, critic: Model, target_value: Model, batch: Batch,
             discount: float) -> Tuple[Model, InfoDict]:
    key, rand_rng = jax.random.split(key, 2)
    next_v = target_value(batch.next_observations)
    v = target_value(batch.observations)
    target_q = batch.rewards + discount * batch.masks * next_v

    random_actions = jax.random.uniform(rand_rng, shape=batch.actions.shape, minval=-1, maxval=1)

    def critic_loss_fn(critic_params: Params) -> Tuple[jnp.ndarray, InfoDict]:
        q1, q2 = critic.apply({'params': critic_params}, batch.observations,
                              batch.actions)
        critic_loss = ((q1 - target_q)**2 + (q2 - target_q)**2).mean()
        return critic_loss, {
            'critic_loss': critic_loss,
            'q1': q1.mean(),
            'q2': q2.mean()
        }

    new_critic, info = critic.apply_gradient(critic_loss_fn)

    return new_critic, info
