from typing import Tuple
import jax
import jax.numpy as jnp
from sources.utils import Batch, InfoDict, Model, Params, PRNGKey


def update_actor(key: PRNGKey, actor: Model, critic: Model, v: jnp.ndarray,
           batch: Batch, reward_weight: jnp.ndarray, is_bad: jnp.ndarray, actor_temperature: float, 
           double: bool, cal_log: bool) -> Tuple[Model, InfoDict]:
    
    # Get Q-values from critic network for the current state-action pairs
    q1, q2 = critic(batch.observations, batch.actions)
    
    # If using double Q-learning, take the minimum of both Q-values to prevent overestimation
    if double:
        q = jnp.minimum(q1, q2)
    else:
        q = q1
        
    # Calculate advantage by subtracting value estimates from Q-values
    # Clip advantages to prevent extreme values
    adv = (q - v).clip(max=5.0)
    
    # Calculate weights for advantage using exponential function and temperature
    # Higher temperature increases the impact of advantages
    adv_weight = jnp.exp(adv*actor_temperature)
    adv_weight = adv_weight.clip(min=0,max=100)  # Clip weights to prevent numerical instability
    
    # Calculate discriminator weights as inverse of negative reward weights
    disc_weight = 1/(-reward_weight)
    disc_weight = disc_weight.clip(min=0,max=100)  # Clip discriminator weights
    
    # Final weight calculation (currently only using advantage weights)
    weight = (adv_weight).clip(min=0, max=100)

    def actor_loss_fn(actor_params: Params) -> Tuple[jnp.ndarray, InfoDict]:
        # Get action distribution from actor network
        dist = actor.apply({'params': actor_params},
                           batch.observations,
                           training=True,
                           rngs={'dropout': key})
        
        # Calculate log probabilities of actions
        log_probs = dist.log_prob(batch.actions).sum(-1)    
        
        # Actor loss: negative weighted log probabilities
        # Minimizing this loss maximizes the likelihood of actions with high advantages
        actor_loss = -(weight * log_probs).mean()

        # If logging is enabled, calculate various metrics for monitoring
        info = {}
        if cal_log:
            info.update({
                # Average advantages for good and bad actions
                'actor_update/bad_adv': (adv*is_bad).sum()/is_bad.sum(),
                'actor_update/good_adv': (adv*(1-is_bad)).sum()/(1-is_bad).sum(),
                
                # Average weights for good and bad actions
                'actor_update/bad_weight_a': (weight*is_bad).sum()/is_bad.sum(),
                'actor_update/good_weight_a': (weight*(1-is_bad)).sum()/(1-is_bad).sum(),
                
                # Average advantage weights for good and bad actions
                'actor_update/bad_adv_weight': (adv_weight*is_bad).sum()/is_bad.sum(),
                'actor_update/good_adv_weight': (adv_weight*(1-is_bad)).sum()/(1-is_bad).sum(),
                
                # Average discriminator weights and log probabilities for good and bad actions
                'mixed_data/bad_disc_weight': (disc_weight*is_bad).sum()/is_bad.sum(),
                'mixed_data/good_disc_weight': (disc_weight*(1-is_bad)).sum()/(1-is_bad).sum(),
                'mixed_data/bad_logp': (log_probs*is_bad).sum()/is_bad.sum(),
                'mixed_data/good_logp': (log_probs*(1-is_bad)).sum()/(1-is_bad).sum(),
                
                'info/actor_temperature': actor_temperature,
            })

        return actor_loss, info

    # Apply gradient update to actor network
    new_actor, grad_info = actor.apply_gradient(actor_loss_fn)
    return new_actor, grad_info

def update_actor_Q(key: PRNGKey, actor: Model, critic: Model, v: jnp.ndarray,
           batch: Batch,reward_weight: jnp.ndarray, is_bad: jnp.ndarray, actor_temperature: float, 
           double: bool, cal_log: bool) -> Tuple[Model, InfoDict]:
    
    def actor_loss_fn(actor_params: Params) -> Tuple[jnp.ndarray, InfoDict]:
        # Get action distribution from current actor parameters
        dist = actor.apply({'params': actor_params},
                           batch.observations,
                           training=True,
                           rngs={'dropout': key})
        
        # Sample new actions and calculate their log probabilities
        actions = dist.sample(seed=key)
        log_probs = dist.log_prob(actions).sum(-1)
        
        # Get Q-values for the sampled actions
        q1, q2 = critic(batch.observations, actions)
        q = jnp.minimum(q1, q2)  # Use minimum Q-value for stability
        
        # Actor loss combines Q-values and entropy regularization
        # Negative Q-value encourages high-value actions
        # Small entropy term (0.01 * log_probs) encourages exploration
        actor_loss = (-q + 0.01 * log_probs).mean()

        # Log metrics if enabled
        info = {}
        if cal_log:
            info.update({
                # Average Q-values and log probabilities for good and bad actions
                'mixed_data/bad_q':(q*is_bad).sum()/is_bad.sum(),
                'mixed_data/good_q':(q*(1-is_bad)).sum()/(1-is_bad).sum(),
                'mixed_data/bad_logp': (log_probs*is_bad).sum()/is_bad.sum(),
                'mixed_data/good_logp': (log_probs*(1-is_bad)).sum()/(1-is_bad).sum(),
            })

        return actor_loss, info

    # Apply gradient update to actor network
    new_actor, grad_info = actor.apply_gradient(actor_loss_fn)
    return new_actor, grad_info