from typing import Tuple

import jax.numpy as jnp

from common import Batch, InfoDict, Model, Params


def loss(diff, expectile=0.8):
    weight = jnp.where(diff > 0, expectile, (1 - expectile))
    return weight * (diff**2)


def update_v(critic: Model, value: Model, batch: Batch,
             expectile: float) -> Tuple[Model, InfoDict]:
    actions = batch.actions
    q1, q2 = critic(batch.observations, actions)
    q = jnp.minimum(q1, q2)
    
    def value_loss_fn(value_params: Params) -> Tuple[jnp.ndarray, InfoDict]:
        v = value.apply({'params': value_params}, batch.observations)
        value_loss = loss(q - v, expectile).mean()
        return value_loss, {
            'value_loss': value_loss,
            'v': v.mean(),
        }
    
    new_value, info = value.apply_gradient(value_loss_fn)

    return new_value, info


def update_q(critic: Model, target_value: Model, batch: Batch, reward_dim: int,
             discount: float, ensemble_number, delta) -> Tuple[Model, InfoDict]:
    # next_v_one = target_value(batch.next_obs_one)
    # target_q_one = batch.rew_one + discount * batch.mask_one * jnp.max(next_v_one, axis=-1)
    # next_v_zero = target_value(batch.next_obs_zero)
    # target_q_zero = batch.rew_zero + discount * batch.mask_zero * jnp.min(next_v_zero, axis=-1)
    
    next_v_one = target_value(batch.next_obs_one)
    next_v_one = jnp.mean(next_v_one, axis=-1) # + delta * jnp.var(next_v_one, axis=-1)
    target_q_one = batch.rew_one + discount * batch.mask_one * next_v_one
    
    next_v_zero = target_value(batch.next_obs_zero)
    next_v_zero = jnp.mean(next_v_zero, axis=-1)
    # target_q_zero = batch.rew_zero + jnp.clip(discount * batch.mask_zero * next_v_zero, 0, 2.0)
    target_q_zero = batch.rew_zero + discount * batch.mask_zero * next_v_zero
    
    target_q = jnp.concatenate((target_q_one, target_q_zero))
        
    ensemble = 5
    target_q = target_q.reshape(-1, 1)
    target_q = jnp.repeat(target_q, ensemble, axis=1)
    
    def critic_loss_fn(critic_params: Params) -> Tuple[jnp.ndarray, InfoDict]:
        q1, q2 = critic.apply({'params': critic_params}, batch.observations,
                              batch.actions)
        critic_loss = ((q1 - target_q)**2 + (q2 - target_q)**2).mean()
        return critic_loss, {
            'critic_loss': critic_loss,
            'q1': q1.mean(),
            'q2': q2.mean(),
            'target_mean': next_v_zero.mean(),
            'qone': next_v_one.mean(),
            'qzero': next_v_zero.mean()
        }

    new_critic, info = critic.apply_gradient(critic_loss_fn)

    return new_critic, info

