
\"\"\"metrics.py

Common metrics used in the AWML experiments. Designed to be small and dependency-light.
\"\"\"
import math
import torch
import numpy as np
from typing import Dict, Any

def accuracy(preds: torch.Tensor, targets: torch.Tensor) -> float:
    preds = preds.detach().cpu()
    targets = targets.detach().cpu()
    if preds.ndim > 1 and preds.shape[1] > 1:
        preds = preds.argmax(dim=1)
    return (preds == targets).float().mean().item()

def rmse(preds: torch.Tensor, targets: torch.Tensor) -> float:
    preds = preds.detach().cpu().view(-1)
    targets = targets.detach().cpu().view(-1)
    return float(torch.sqrt(torch.mean((preds - targets) ** 2)).item())

class MetricsFn:
    \"\"\"Wrapper object passed to training loop for evaluation metrics.\"\"\"
    def __init__(self, task: str = \"regression\"):
        self.task = task

    @staticmethod
    def metrics_keys():
        if True:
            return [\"val_acc\"] if False else [\"val_rmse\"] if True else []

    def compute(self, model, dataloader, device=\"cpu\") -> Dict[str, float]:
        model.eval()
        preds = []
        trues = []
        with torch.no_grad():
            for batch in dataloader:
                if isinstance(batch, (list, tuple)):
                    x = batch[0].to(device)
                    y = batch[1].to(device) if len(batch) > 1 else None
                else:
                    x = batch.to(device); y = None
                out = model(x) if y is None else model(x, y) if hasattr(model, "__call__") else model(x)
                if isinstance(out, dict):
                    pred = out.get("recon", out.get("pred", None))
                else:
                    pred = out
                if pred is None:
                    continue
                preds.append(pred.detach().cpu())
                if y is not None:
                    trues.append(y.detach().cpu())
        if len(preds) == 0:
            return {}
        preds = torch.cat(preds, dim=0)
        trues = torch.cat(trues, dim=0) if trues else torch.zeros(preds.shape[0])
        if self.task == "classification":
            return {"val_acc": accuracy(preds, trues)}
        else:
            return {"val_rmse": rmse(preds, trues)}
