from typing import Tuple
import jax.numpy as jnp
import jax
from functools import partial

from sources.utils import Batch, InfoDict, Model, Params, PRNGKey


# Chi-square loss function for value estimation
# Combines quadratic and linear terms based on alpha parameter
def chi_square_loss(diff, alpha, args=None):
    loss = alpha*jnp.maximum(diff+diff**2/4,0) - (1-alpha)*diff
    return loss

# Total variation loss function
# Simpler variant that uses linear terms only
def total_variation_loss(diff, alpha, args=None):
    loss = alpha*jnp.maximum(diff,0) - (1-alpha)*diff
    return loss

# Reverse KL divergence loss function
# Implements a more sophisticated loss based on KL divergence
# Uses exponential terms and includes gradient detachment
def reverse_kl_loss(diff, alpha, args=None):
    z = diff/alpha
    if args.max_clip is not None:
        z = jnp.minimum(z, args.max_clip)  # Prevent numerical instability
    max_z = jnp.max(z, axis=0)
    max_z = jnp.where(max_z < -1.0, -1.0, max_z)  # Ensure stability
    max_z = jax.lax.stop_gradient(max_z)  # Stop gradients for stability
    # Scale the loss using log-sum-exp trick for numerical stability
    loss = jnp.exp(z - max_z) - z*jnp.exp(-max_z) - jnp.exp(-max_z)
    return loss

# Asymmetric loss function that weights positive and negative errors differently
# Used for quantile regression with expectile parameter
def expectile_loss(diff, expectile=0.8):
    weight = jnp.where(diff > 0, expectile, (1 - expectile))
    return weight * (diff**2)

# Updates the value network using the critic's estimates
def update_value(critic: Model, value: Model, batch: Batch, is_bad: jnp.ndarray,
             expectile: float, double: bool, key: PRNGKey, 
             args, cal_log: bool) -> Tuple[Model, InfoDict]:
    actions = batch.actions
    batch_size = batch.observations.shape[0]
    rng1, rng2 = jax.random.split(key)
    if args.sample_random_times > 0:
        # Augment batch with random actions to improve loss estimation
        # Uses mixture of actual and uniform random actions
        times = args.sample_random_times
        random_action = jax.random.uniform(
            rng1, shape=(times * actions.shape[0],
                         actions.shape[1]),
            minval=-1.0, maxval=1.0)
        obs = jnp.concatenate([batch.observations, jnp.repeat(
            batch.observations, times, axis=0)], axis=0)
        acts = jnp.concatenate([batch.actions, random_action], axis=0)
    else:
        obs = batch.observations
        acts = batch.actions

    if args.noise:
        # Add noise to actions for better exploration
        # Clips noise to maintain action bounds
        std = args.noise_std
        noise = jax.random.normal(rng2, shape=(acts.shape[0], acts.shape[1]))
        noise = jnp.clip(noise * std, -0.5, 0.5)
        acts = (batch.actions + noise)
        acts = jnp.clip(acts, -1, 1)
        
    q1, q2 = critic(obs, acts)
    if double:
        q = jnp.minimum(q1, q2)
    else:
        q = q1

    # Value network loss function
    def value_loss_fn(value_params: Params) -> Tuple[jnp.ndarray, InfoDict]:
        # Compute value estimates
        v = value.apply({'params': value_params}, obs)
        
        # Choose between expectile and reverse KL loss
        if (args.v_update == 'expectile_loss'):
            value_loss = expectile_loss(q - v, expectile=expectile).mean()
        elif (args.v_update == 'rkl_loss'):
            value_loss = reverse_kl_loss(q - v, alpha=args.alpha, args=args).mean()
            
        # Log metrics if requested
        info = {}
        if cal_log:
            # Track various metrics including values for good/bad states
            info.update({
                'info/expectile': expectile,
                'value_update/loss': value_loss,
                'value_update/value': v.mean(),
                'value_update/q': q.mean(),
                'value_update/bad_value': (v*is_bad).sum()/is_bad.sum(),
                'value_update/good_value': (v*(1-is_bad)).sum()/(1-is_bad).sum(),
                'hidden/bad_V': (v*is_bad)[batch_size//2:].sum()/(is_bad[batch_size//2:].sum()),
                'hidden/good_V': (v*(1-is_bad))[batch_size//2:].sum()/((1-is_bad)[batch_size//2:].sum()),
            })
        
        return value_loss, info

    new_value, info = value.apply_gradient(value_loss_fn)
    return new_value, info



# Updates the critic network using temporal difference learning
def update_critic(critic: Model, target_value: Model, batch: Batch, is_bad: jnp.ndarray,
                reward_weight: jnp.ndarray, discount: float, double: bool, key: PRNGKey, 
                reward_gap: float, args, cal_log: bool) -> Tuple[Model, InfoDict]:
    # Calculate bounds for Q-values based on reward gap
    batch_size = batch.observations.shape[0]
    regularizer_weight = 1/(2*reward_gap)
    max_q = reward_gap/(1-discount)
    min_q = -max_q
    
    # Get target values for TD learning
    next_v = target_value(batch.next_observations)
    next_v = discount * jnp.maximum(jnp.minimum(next_v, max_q), min_q)

    def critic_loss_fn(critic_params: Params) -> Tuple[jnp.ndarray, InfoDict]:
        q1, q2 = critic.apply({'params': critic_params}, batch.observations, batch.actions)
        info = {}
        
        def iq_loss(q, next_v):
            # Compute reward as difference between Q and next state value
            reward = q - next_v
            # Weighted reward loss
            reward_loss = -(reward_weight * reward).mean()
            # L2 regularization on rewards
            regularizer_loss = regularizer_weight * (reward**2).mean()
            loss = reward_loss + regularizer_loss
            
            # Log detailed metrics if requested
            loss_dict = {}
            if cal_log:
                # Track separate metrics for good/bad states
                loss_dict.update({
                    'mixed_data/bad_reward': (reward*is_bad).sum()/is_bad.sum(),
                    'mixed_data/good_reward': (reward*(1-is_bad)).sum()/(1-is_bad).sum(),
                    'mixed_data/bad_reward_weight': (reward_weight*is_bad).sum()/is_bad.sum(),
                    'mixed_data/good_reward_weight': (reward_weight*(1-is_bad)).sum()/(1-is_bad).sum(),
                    'mixed_data/bad_q': (q*is_bad).sum()/is_bad.sum(),
                    'mixed_data/good_q': (q*(1-is_bad)).sum()/(1-is_bad).sum(),
                    'info/reward_gap': reward_gap,
                    'info/regularizer_weight': regularizer_weight,
                    'info/discount': discount,
                    'info/max_q': max_q,
                    'info/min_q': min_q,
                    'hidden/bad_Q': (q*is_bad)[batch_size//2:].sum()/(is_bad[batch_size//2:].sum()),
                    'hidden/good_Q': (q*(1-is_bad))[batch_size//2:].sum()/((1-is_bad)[batch_size//2:].sum()),
                })
            return loss, loss_dict
            
        # Handle single or double Q-learning
        if double:
            # Average losses from two Q-networks for double Q-learning
            loss1, loss_dict1 = iq_loss(q1, next_v)
            loss2, loss_dict2 = iq_loss(q2, next_v)
            critic_loss = (loss1 + loss2).mean()
            if(cal_log):
                for k, v in loss_dict1.items():
                    info[k] = (loss_dict1[k] + loss_dict2[k])/2
        else:
            # Single Q-network update
            critic_loss, loss_dict = iq_loss(q1, next_v)
            if(cal_log):
                for k, v in loss_dict.items():
                    info[k] = v

        if args.grad_pen:
            raise NotImplementedError("Gradient penalty not implemented")
        return critic_loss, info

    new_critic, info = critic.apply_gradient(critic_loss_fn)

    return new_critic, info



# Huber loss implementation for robust regression
# Combines benefits of L1 and L2 losses
def huber_loss(x, delta: float = 1.):
    """Huber loss, similar to L2 loss close to zero, L1 loss away from zero.
    See "Robust Estimation of a Location Parameter" by Huber.
    (https://projecteuclid.org/download/pdf_1/euclid.aoms/1177703732).
    Args:
    x: a vector of arbitrary shape.
    delta: the bounds for the huber loss transformation, defaults at 1.
    Note `grad(huber_loss(x))` is equivalent to `grad(0.5 * clip_gradient(x)**2)`.
    Returns:
    a vector of same shape of `x`.
    """
    # 0.5 * x^2                  if |x| <= d
    # 0.5 * d^2 + d * (|x| - d)  if |x| > d
    abs_x = jnp.abs(x)
    quadratic = jnp.minimum(abs_x, delta)
    # Same as max(abs_x - delta, 0) but avoids potentially doubling gradient.
    linear = abs_x - quadratic
    return 0.5 * quadratic**2 + delta * linear
