from typing import Tuple

import jax
import jax.numpy as jnp

from common import Batch, InfoDict, Model, Params, PRNGKey


def update(key: PRNGKey, actor: Model, critic: Model, value: Model,
           batch: Batch, temperature: float, double: bool) -> Tuple[Model, InfoDict]:
    v = value(batch.observations)

    q1, q2 = critic(batch.observations, batch.actions)
    if double:
        q = jnp.minimum(q1, q2)
    else:
        q = q1
    exp_a = jnp.exp((q - v) * temperature)
    exp_a = jnp.minimum(exp_a, 100.0)

    def actor_loss_fn(actor_params: Params) -> Tuple[jnp.ndarray, InfoDict]:
        dist = actor.apply({'params': actor_params},
                           batch.observations,
                           training=True,
                           rngs={'dropout': key})
        log_probs = dist.log_prob(batch.actions)
        print(exp_a.shape, log_probs.shape)
        actor_loss = -(exp_a * log_probs).mean()

        return actor_loss, {'actor_loss': actor_loss, 'adv': q - v}

    new_actor, info = actor.apply_gradient(actor_loss_fn)

    return new_actor, info



def update_imitate(key: PRNGKey, actor: Model, critic: Model, value: Model,
           batch: Batch, is_expert_mask, temperature: float, double: bool) -> Tuple[Model, InfoDict]:
    v = value(batch.observations)

    q1, q2 = critic(batch.observations, batch.next_observations)
    if double:
        q = jnp.minimum(q1, q2)
    else:
        q = q1
    # exp_a = jnp.exp((q - v) * temperature)
    exp_a = jnp.exp(q*temperature)
    exp_a = jnp.minimum(exp_a, 100.0)

    def actor_loss_fn(actor_params: Params) -> Tuple[jnp.ndarray, InfoDict]:
        dist = actor.apply({'params': actor_params},
                           batch.observations,
                           training=True,
                           rngs={'dropout': key})
        log_probs = dist.log_prob(batch.actions)
        actor_loss = -(exp_a * log_probs)[:exp_a.shape[0]//2].mean()
        # actor_loss = -(exp_a * log_probs).mean()
        # Calculate advantage
        adv = q - v
        
        #
        logp_expert = log_probs[log_probs.shape[0]//2:].mean()        
        
        # Calculate separate metrics for expert and suboptimal examples
        # Make sure we're only looking at the first half (current batch, not expert batch)
        first_half_adv = adv[:adv.shape[0]//2]
        first_half_exp_a = exp_a[:exp_a.shape[0]//2]
        first_half_logp = log_probs[:log_probs.shape[0]//2]
        
        # Standard advantage metrics
        unseen_adv_expert = (first_half_adv * is_expert_mask).sum() / (is_expert_mask.sum() + 1e-8)
        unseen_adv_suboptimal = (first_half_adv * (1 - is_expert_mask)).sum() / ((1 - is_expert_mask).sum() + 1e-8)
        
        # Weighted advantage metrics (using exp_a)
        unseen_weighted_adv_expert = (first_half_exp_a * is_expert_mask).sum() / (is_expert_mask.sum() + 1e-8)
        unseen_weighted_adv_suboptimal = (first_half_exp_a * (1 - is_expert_mask)).sum() / ((1 - is_expert_mask).sum() + 1e-8)

        #
        unseen_lop_expert = (first_half_logp * is_expert_mask).sum() / (is_expert_mask.sum() + 1e-8)    
        unseen_logp_sub = (first_half_logp * (1 - is_expert_mask)).sum() / ((1 - is_expert_mask).sum() + 1e-8)  
            
        loss_dict = {
            'actor_loss': actor_loss, 
            'adv': adv,
            'unseen_adv_expert': unseen_adv_expert,
            'unseen_adv_suboptimal': unseen_adv_suboptimal,
            'unseen_weighted_adv_expert': unseen_weighted_adv_expert,
            'unseen_weighted_adv_suboptimal': unseen_weighted_adv_suboptimal,
            'logp_exp': logp_expert,
            'unseen_logp_expert': unseen_lop_expert,
            'unseen_logp_suboptimal': unseen_logp_sub,
        }
        # loss_dict = {'actor_loss': actor_loss, 'adv': q - v}
        return actor_loss, loss_dict

    new_actor, info = actor.apply_gradient(actor_loss_fn)
    info['clipped_adv'] = exp_a.mean()
    return new_actor, info

