from typing import Tuple

import jax
import jax.numpy as jnp

from common import Batch, InfoDict, Model, Params, PRNGKey


def update(key: PRNGKey, actor: Model, critic: Model, value: Model,
           batch: Batch, temperature: float, exp_a_clip: float) -> Tuple[Model, InfoDict]:
    v = value(batch.observations)

    q1, q2 = critic(batch.observations, batch.actions)
    q = jnp.minimum(q1, q2)
    exp_a = jnp.exp((q - v) * temperature)
    exp_a = jnp.minimum(exp_a, exp_a_clip)

    def actor_loss_fn(actor_params: Params) -> Tuple[jnp.ndarray, InfoDict]:
        dist = actor.apply({'params': actor_params},
                           batch.observations,
                           training=True,
                           rngs={'dropout': key})
        log_probs = dist.log_prob(batch.actions)
        actor_loss = -(exp_a * log_probs).mean()

        return actor_loss, {'actor_loss': actor_loss, 'adv': q - v}

    new_actor, info = actor.apply_gradient(actor_loss_fn)

    return new_actor, info

def update_weighted(key: PRNGKey, actor: Model, critic: Model, value: Model,
           batch: Batch, temperature: float, weights: jnp.ndarray) -> Tuple[Model, InfoDict]:
    v = value(batch.observations)

    q1, q2 = critic(batch.observations, batch.actions)
    q = jnp.minimum(q1, q2)
    exp_a = jnp.exp((q - v) * temperature)
    # exp_a = jnp.minimum(weights * exp_a, 100.0)
    exp_a = jnp.clip(weights * exp_a, 0, exp_a_clip)

    def actor_loss_fn(actor_params: Params) -> Tuple[jnp.ndarray, InfoDict]:
        dist = actor.apply({'params': actor_params},
                           batch.observations,
                           training=True,
                           rngs={'dropout': key})
        log_probs = dist.log_prob(batch.actions)
        actor_loss = -(exp_a * log_probs).mean()

        return actor_loss, {'actor_loss': actor_loss, 'adv': q - v, "exp_a": exp_a}

    new_actor, info = actor.apply_gradient(actor_loss_fn)

    return new_actor, info

def update_weighted_sum(key: PRNGKey, actor: Model, critic: Model, value: Model,
           batch: Batch, temperature: float, exp_a_clip: float, weights: jnp.ndarray) -> Tuple[Model, InfoDict]:
    v = value(batch.observations)

    q1, q2 = critic(batch.observations, batch.actions)
    q = jnp.minimum(q1, q2)
    exp_a = jnp.exp((q - v) * temperature)
    exp_a = jnp.minimum(exp_a, exp_a_clip)

    def actor_loss_fn(actor_params: Params) -> Tuple[jnp.ndarray, InfoDict]:
        dist = actor.apply({'params': actor_params},
                           batch.observations,
                           training=True,
                           rngs={'dropout': key})
        log_probs = dist.log_prob(batch.actions)
        actor_loss = -(weights * (exp_a * log_probs)).sum()

        return actor_loss, {'actor_loss': actor_loss, 'adv': q - v, "exp_a": exp_a}

    new_actor, info = actor.apply_gradient(actor_loss_fn)

    return new_actor, info


def update_weighted_sum_before_clip(key: PRNGKey, actor: Model, critic: Model, value: Model,
           batch: Batch, temperature: float, exp_a_clip: float, weights: jnp.ndarray) -> Tuple[Model, InfoDict]:
    v = value(batch.observations)

    q1, q2 = critic(batch.observations, batch.actions)
    q = jnp.minimum(q1, q2)
    exp_a = weights * jnp.exp((q - v) * temperature)
    exp_a = jnp.minimum(exp_a, exp_a_clip)

    def actor_loss_fn(actor_params: Params) -> Tuple[jnp.ndarray, InfoDict]:
        dist = actor.apply({'params': actor_params},
                           batch.observations,
                           training=True,
                           rngs={'dropout': key})
        log_probs = dist.log_prob(batch.actions)
        actor_loss = -((exp_a * log_probs)).sum()

        return actor_loss, {'actor_loss': actor_loss, 'adv': q - v, "exp_a": exp_a}

    new_actor, info = actor.apply_gradient(actor_loss_fn)

    return new_actor, info

