"""
Metric utilities for fields and scalars.
"""

from typing import List
import torch
from torch import linalg
import numpy as np


def step_function_field(y: torch.Tensor, threshold: float=0.0) -> torch.Tensor:
    """
    Compute a per-component normalization factor for a field tensor.

    The normalization is the maximum absolute value of the ground truth field
    per component. If `threshold > 0`, values smaller than the
    threshold are replaced by 1.0 to prevent division by very small numbers.

    Args:
        y: Tensor of shape (N, C) (typically N nodes, C components/features).
        threshold: If > 0, clamp components with max(|y|) < threshold to 1.0.

    Returns:
        Tensor of shape (C,) containing per-component normalization factors.
    """
    max_y = torch.max(torch.abs(y), axis=0)[0]
    if threshold > 0.0:
        return torch.where(max_y < threshold, 1.0, max_y) # 1.0 if max_y < 1e-6 else max_y
    return max_y


def relative_rmse_field(
        y_true: List[torch.Tensor], y_pred: List[torch.Tensor], threshold: float=0.0
) -> torch.Tensor:
    """
    Compute relative RMSE for a list of field tensors.

    For each sample i:
        rmse_i(c) = ||y_i(:,c) - yhat_i(:,c)||_2 / (sqrt(N_i) * norm_i(c))
    where norm_i(c) = step_function_field(y_i, threshold).

    Then returns sqrt(mean_i(rmse_i(c)^2)) which is equivalent to RMSE aggregated
    across samples, per component.

    Args:
        y_true: List of tensors, each of shape (N_i, C).
        y_pred: List of tensors, each of shape (N_i, C), aligned with y_true.
        threshold: Thresholding used in normalization to avoid divide-by-small.

    Returns:
        Tensor of shape (C,) with relative RMSE per component.
    """
    return torch.sqrt(
        torch.mean(
            torch.stack(
                [
                    (
                        linalg.norm(y - y_hat, dim=0) ** 2                       # pylint: disable=not-callable
                        / (y.shape[0] * step_function_field(y, threshold) ** 2)
                    )
                    for y, y_hat in zip(y_true, y_pred)
                ],
                dim=0,
            ),
            dim=0,
        )
    )


def relative_rmse_scalar(
        y_true: np.ndarray, y_pred: np.ndarray, threshold: float=0.0
) -> np.ndarray:
    """
    Compute relative RMSE for scalar outputs stored in numpy arrays.

    Normalizes each component by y_true (or 1.0 if y_true < threshold).

    Args:
        y_true: Array of shape (N, C) or (N,) for ground-truth scalars.
        y_pred: Array with the same shape as y_true.
        threshold: If > 0, clamp y_true < threshold to 1.0 in normalization.

    Returns:
        Array of shape (C,) (or scalar) with relative RMSE.
    """
    if threshold > 0.0:
        normalization = np.where(y_true < threshold, 1.0, y_true)**2
    else:
        normalization = y_true**2
    return np.sqrt(
        np.mean(
            (y_true - y_pred)**2/normalization, axis=0
        )
    )
