import torch
import utils


@torch.no_grad()
def compute_metrics(
    samples_dict: dict,
    target,
    n_proj: int = 100,
    N_ref: int = 10_000,
):
    """
    Compute quantitative metrics for all samplers.

    Args:
        samples_dict : dict[str, Tensor]
            Mapping {method_name -> samples [N, d]}
        target : object
            Target distribution with target.sample()
        n_proj : int
            Number of projections for sliced Wasserstein
        N_ref : int
            Number of reference samples from target

    Returns:
        metrics : dict[str, dict]
            metrics[method][metric_name] = value
    """

    # Reference samples
    x_ref = target.sample(N_ref)

    metrics = {}

    for name, x in samples_dict.items():

        # Safety
        if x is None:
            continue

        sw = sliced_wasserstein(
            x, x_ref, n_proj=n_proj
        ).item()

        ed = energy_distance(
            x, x_ref
        ).item()

        metrics[name] = {
            "SW": sw,
            "ED": ed,
        }

    return metrics


#SW distance 
def sliced_wasserstein(x, y, n_proj=100):
    """
    Sliced Wasserstein-2 distance between two point clouds.

    x, y : tensors [N, d]
    """
    d = x.shape[1]
    sw = 0.0

    for _ in range(n_proj):
        v = torch.randn(d, device=x.device)
        v = v / v.norm()

        proj_x = torch.sort(x @ v).values
        proj_y = torch.sort(y @ v).values

        sw += torch.mean((proj_x - proj_y) ** 2)

    return torch.sqrt(sw / n_proj)

def energy_distance(x, y):
    """
    Energy distance between two empirical distributions.

    Let X ~ P and Y ~ Q be random variables in R^d.
    The energy distance is defined as
        E(P, Q) = 2 E[||X - Y||] - E[||X - X'||] - E[||Y - Y'||],

    where X' is an independent copy of X and Y' an independent copy of Y.

    Empirical estimator:
    --------------------
    Given samples {x_i}_{i=1}^N ~ P and {y_j}_{j=1}^M ~ Q, we estimate
        E(P, Q) ≈ 2 * mean_{i,j} ||x_i - y_j||
                  - mean_{i,i'} ||x_i - x_{i'}||
                  - mean_{j,j'} ||y_j - y_{j'}||.

    """
    # implementation below

    xx = torch.cdist(x, x).mean()
    yy = torch.cdist(y, y).mean()
    xy = torch.cdist(x, y).mean()

    return 2 * xy - xx - yy