"""Utilities for computing metrics."""

import numpy as np
import torch
from torch import Tensor
from scipy.stats import kendalltau
import numpy as np
from utils.data import tnp
from typing import Dict


def reduce(
    tensor: Tensor, dim: int | tuple = None, reduction: str = "nanmean"
) -> Tensor:
    """Reduce a tensor along the specified dimension.

    Args:
        tensor: Can be of any shape.
        dim: Dimension(s) to reduce. If None, reduces all.
        reduction: ["mean", "sum", "nanmean"].

    Returns:
        Reduced tensor.
    """
    if reduction == "nanmean":
        return torch.nanmean(tensor, dim=dim)
    elif reduction == "mean":
        return torch.mean(tensor, dim=dim)
    elif reduction == "sum":
        return torch.sum(tensor, dim=dim)
    else:
        raise ValueError(
            f"Invalid reduction type: {reduction}. Must be one of ['mean', 'sum', 'nanmean']."
        )


def kendalltau_correlation(
    input: Tensor, target: Tensor, reduction: str = "mean"
) -> Tensor:
    """Compute kendall-tau statistic between input and target.
    - Closer to 1 for samples with strongly positive ordinal correlation
    - Closer to -1 for samples with strongly negative ordinal correlation
    - Closer to 0 indicates a weak correlation

    Args:
        input: predicted values, [B, N, d]
        target: ground truth values, [B, N, d]
        reduction: how to reduce the output, in ["mean", "none"]

    Returns:
        tau: [d] if reduction is "mean", or [B, d] if reduction is "none".
    """
    assert reduction in ["mean", "none"], "reduction must be either 'mean' or 'none'"
    input = tnp(input)
    target = tnp(target)

    B, _, D = input.shape

    # Initialize a list to hold tau values for each dataset's dimensions: B x [D]
    tau_list = []
    for b in range(B):
        # For each dataset
        tau_d = []  # D x [1]
        for d in range(D):
            # Compute Kendall's tau for each dimension
            t = kendalltau(
                x=input[b, :, d],
                y=target[b, :, d],
            ).statistic
            tau_d.append(t)

        tau_list.append(np.stack(tau_d))  # [D]

    tau = np.stack(tau_list, axis=0)  # [B, D]

    # Average valid tau values across all datasets
    if reduction == "mean":
        return np.nanmean(tau, axis=0)
    else:
        return tau


def performance_profile(
    trajectories: Dict,
    taus=None,
) -> Dict:
    """Compute performance profile.

    Args:
        trajectories: dict {algorithm_name: [hv1, hv2, ..., hvT]}, hvi is of shape [B, T]
            Each list is the hypervolume values at iterations.
        taus: array-like, range of tau values for the profile. Default = np.linspace(1, 5, 100).

    Returns:
        profile: dict {algorithm_name: (taus, rho)}
    """
    if taus is None:
        taus = np.linspace(1, 5, 100)

    algs = list(trajectories.keys())
    T = len(next(iter(trajectories.values())))

    # Performance ratio per iteration
    ratios = np.zeros((len(algs), T))
    for t in range(T):
        best = max(trajectories[a][t] for a in algs)
        for i, a in enumerate(algs):
            ratios[i, t] = (
                best / trajectories[a][t] if trajectories[a][t] > 0 else np.inf
            )

    # Build profile
    profile = {}
    for i, a in enumerate(algs):
        rho = [np.mean(ratios[i, :] <= tau) for tau in taus]
        profile[a] = (taus, np.array(rho))

    return profile
