"""
Effective Sample Size (ESS) computation.

This module provides a function to compute the effective sample size
of importance-weighted samples, which is a key metric for assessing
the quality of normalizing flow training.
"""

import torch


def ess(log_probs: torch.Tensor, action_samples: torch.Tensor) -> float:
    """
    Compute the effective sample size (ESS) using importance reweighting.
    
    The ESS measures how many independent samples the weighted samples
    are equivalent to. An ESS close to 1.0 indicates the flow closely
    matches the target distribution.

    Parameters
    ----------
    log_probs : torch.Tensor
        Log probabilities of samples under the flow. Shape: (N,).
    action_samples : torch.Tensor
        Action values S(x) of the samples. Shape: (N,).

    Returns
    -------
    float
        Effective sample size, normalized to [0, 1].
        
    Raises
    ------
    ValueError
        If action_samples and log_probs have different shapes.
    """
    if action_samples.shape != log_probs.shape:
        raise ValueError("action_samples and log_probs must have the same shape.")
    
    log_weight = -action_samples - log_probs
    log_weight -= log_weight.mean()
    weights = log_weight.exp()

    return (weights.mean() ** 2 / (weights ** 2).mean()).item()