from typing import Tuple

import jax
import jax.numpy as jnp

from common import Batch, InfoDict, Model, Params, PRNGKey
import policy


def update(key: PRNGKey, actor: Model, critic: Model, value: Model,
           batch: Batch, bernoulli_p:float, temperature: float) -> Tuple[Model, InfoDict]:
    v = value(batch.observations)

    q1, q2 = critic(batch.observations, batch.actions)
    q = jnp.minimum(q1, q2)

    key, actor_rng, bernoullli_rng = jax.random.split(key, 3)

    rng, pred_actions = policy.sample_actions(actor_rng, actor.apply_fn,
                                             actor.params, batch.observations,
                                             temperature=0.01)
    qb1, qb2 = critic(batch.observations, pred_actions)
    q_bellman = jnp.minimum(qb1, qb2)

    q_use_bellman = jax.random.bernoulli(bernoullli_rng, p=bernoulli_p, shape=q_bellman.shape)

    q_final = jnp.where(q_use_bellman, q_bellman, q)

    a_use_bellman = jnp.tile(q_use_bellman[:, None], (1, pred_actions.shape[1]))
    actions = jnp.where(a_use_bellman, pred_actions, batch.actions)

    exp_a = jnp.exp((q_final - v) * temperature)
    exp_a = jnp.minimum(exp_a, 100.0)

    def actor_loss_fn(actor_params: Params) -> Tuple[jnp.ndarray, InfoDict]:
        dist = actor.apply({'params': actor_params},
                           batch.observations,
                           training=True,
                           rngs={'dropout': key})
        log_probs = dist.log_prob(actions)
        actor_loss = -(exp_a * log_probs).mean()

        return actor_loss, {'actor_loss': actor_loss, 'adv': q_final - v}

    new_actor, info = actor.apply_gradient(actor_loss_fn)

    return new_actor, info
