"""
Evaluation metrics tools
Define common regression and classification evaluation metrics
"""

import numpy as np
from typing import Tuple


def mse(pred: np.ndarray, target: np.ndarray) -> float:
    """Mean squared error"""
    pred = np.asarray(pred)
    target = np.asarray(target)
    return float(np.mean((pred - target) ** 2))


def accuracy(pred: np.ndarray, target: np.ndarray) -> float:
    """Accuracy metric"""
    pred = np.asarray(pred)
    target = np.asarray(target)
    return float((pred == target).mean())


def top_k_accuracy(logits: np.ndarray, target: np.ndarray, k: int = 5) -> float:
    """Top-k accuracy metric"""
    logits = np.asarray(logits)
    target = np.asarray(target)
    topk = np.argsort(-logits, axis=-1)[:, :k]
    return float((topk == target[:, None]).any(axis=1).mean())


def absolute_error(pred: np.ndarray, target: np.ndarray) -> float:
    """Absolute error metric"""
    pred = np.asarray(pred)
    target = np.asarray(target)
    return float(np.mean(np.abs(pred - target)))


def brier_score(probs: np.ndarray, targets: np.ndarray) -> float:
    """
    Calculate Brier Score (MSE for classification tasks)
    probs: (N, C) predicted probabilities
    targets: (N,) true labels
    """
    probs = np.asarray(probs)
    targets = np.asarray(targets)
    
    # Convert targets to one-hot
    num_classes = probs.shape[1]
    targets_onehot = np.eye(num_classes)[targets]
    return float(np.mean(np.sum((probs - targets_onehot) ** 2, axis=1)))