"""
Compute loss metric supporting S-Softmax.
"""
__author__ = 'XYZ'


import torch
import torch.nn as nn


def compute_loss_and_topk(loss_function, outputs, labels, num_classes, topk_values=(1, 3, 5), score_level=10):
  """
  Computes loss and Top-K accuracies.

  Args:
    loss_function: callable loss function (standard or custom like ScoreLossPlus)
    outputs: tensor logits from model; shape [B, C] or [B*score_level, C] (reshape needed)
    labels: ground truth labels [B]
    num_classes: number of classes [C]
    topk_values: tuple of desired k-values (e.g., 1, 3, 5)
    score_level: int, used for ScoreLossPlus reshaping if required

  Returns:
    dict with:
      - loss: computed loss tensor
      - score_dis: optional diagnostic tensor (for ScoreLossPlus)
      - score: optional score tensor (for ScoreLossPlus)
      - correct_topk: dict of k: correct counts
      - total: number of samples
  """
  result = {
    'loss': None,
    'score_dis': None,
    'score': None,
    'correct_topk': {k: 0 for k in topk_values},
    'total': labels.size(0),
  }

  ## Detect if using ScoreLossPlus-like loss
  is_custom_scoreloss = hasattr(loss_function, '__class__') and loss_function.__class__.__name__ == 'ScoreLossPlus'

  if is_custom_scoreloss:
    assert score_level is not None, "Score Level must be provided for ScoreLossPlus"
    assert outputs.ndim == 2, f"Expected shape [B*score_level, C] before reshape, got {outputs.shape}"

    ## Reshape and apply Softmax over score levels
    # B = labels.size(0)
    B = -1
    C = num_classes
    outputs = outputs.view(B, score_level, C)              # [B, score_level, C]
    outputs = nn.Softmax(dim=1)(outputs)             # softmax over score_level axis

  loss_out = loss_function(outputs, labels)

  if isinstance(loss_out, tuple):
    result['loss'], result['score_dis'], result['score'] = loss_out
    assert outputs.ndim == 3, f"Expected shape [B, score_level, C], got {outputs.shape}"
    probs = outputs.sum(dim=1)  ## Aggregate over score levels -> [B, C]
  else:
    result['loss'] = loss_out
    probs = outputs  ## [B, C]

  topk = probs.topk(max(topk_values), dim=1).indices  ## [B, max_k]
  for k in topk_values:
    result['correct_topk'][k] = topk[:, :k].eq(labels.view(-1, 1)).sum().item()

  return result
