################################################################################
# spectral/modules/prefixsums/loss.py
#
# 
# 
# 
# 2024
#
# Implements a solution accuracy metric.

import torch

Tensor = torch.Tensor

def metric_solution_accuracy(
    # Arguments:
    y_hat: Tensor,
    y:     Tensor
  ) -> Tensor:
  """
  Computes the solution accuracy between the prediction and the target. The
  solution accuracy differs from the category accuracy in that an individual
  problem will only increase the accuracy if all of its classes are correct.

  Args:
    y_hat (Tensor):
      The predicted logits for each class.
    y (Tensor):
      The target classes.

  Returns:
    Tensor:
      The mean-reduced tensor of each solution accuracy.
  """
  selected_classes = y_hat.argmax(dim = 1).long()
  equal_classes = selected_classes == y
  return torch.mean(
    torch.all(equal_classes.view((y.size(0), -1)), dim = -1).float()
  )