from typing import Tuple
import jax
import jax.numpy as jnp
from sources.utils import Batch, InfoDict, Model, Params, PRNGKey
from jax.debug import print

def update_discriminator(key: PRNGKey, discriminator: Model,
                         high_batch: Batch, low_batch: Batch, 
                         cal_log: bool, is_bad: jnp.ndarray,
                         noise_scale: float, prefix: str) -> Tuple[Model, InfoDict, bool]:
    """
    Updates the discriminator model to better distinguish between high and low quality actions
    Args:
        key: Random key for JAX operations
        discriminator: The discriminator model to be updated
        high_batch: Batch of high-quality demonstrations
        low_batch: Batch of low-quality demonstrations
        cal_log: Flag to calculate additional logging information
        is_bad: Binary array indicating bad samples
        noise_scale: Scale of noise to add to actions
        prefix: Prefix for logging keys
    """

    # Generate random noise to add to actions for regularization
    random_action = jax.random.uniform(key,shape=high_batch.actions.shape)
    
    def disc_loss_fn(discriminator_params: Params) -> Tuple[jnp.ndarray, InfoDict]:
        """
        Defines the loss function for the discriminator
        - Aims to maximize log probability of high-quality samples being real
        - And maximize log probability of low-quality samples being fake
        """
        
        # Apply discriminator to high-quality samples with added noise
        d_high = discriminator.apply({'params': discriminator_params},
                                high_batch.observations,
                                jnp.clip(high_batch.actions + noise_scale*random_action, -1, 1),
                                rngs={'dropout': key})
        
        # Apply discriminator to low-quality samples with added noise
        d_low = discriminator.apply({'params': discriminator_params},
                                low_batch.observations,
                                jnp.clip(low_batch.actions + noise_scale*random_action, -1, 1),
                                rngs={'dropout': key})
        
        # Calculate binary cross-entropy loss
        # Maximize log(d_high) and log(1-d_low)
        loss = -jnp.log(d_high).mean() - jnp.log(1-d_low).mean()
        
        # Store basic metrics
        info = {f'{prefix}_disc/diff': (d_high.mean()-d_low.mean())}
        
        # If logging is enabled, calculate additional metrics
        if cal_log:
            # Combine predictions for both high and low quality samples
            d_merge = jnp.concatenate((d_high, d_low), axis=0)

            # Calculate and store detailed metrics:
            # - Average discriminator output for high-quality samples
            # - Average discriminator output for low-quality samples
            # - Average output for samples marked as bad
            # - Average output for samples marked as good
            info.update({f'{prefix}_disc/d_high': d_high.mean(),
                         f'{prefix}_disc/d_low': d_low.mean(),
                         f'{prefix}_disc/d_bad': (d_merge*is_bad).sum()/is_bad.sum(),
                         f'{prefix}_disc/d_good': (d_merge*(1-is_bad)).sum()/(1-is_bad).sum(),
                         })
        return loss, info
    
    # Apply gradient update to discriminator using the loss function
    new_discriminator, info = discriminator.apply_gradient(disc_loss_fn)
    
    # Return updated discriminator, info dict, and the difference between high and low predictions
    return new_discriminator, info, info[f'{prefix}_disc/diff']
