from __future__ import annotations

from typing import Literal

import numpy as np


def risk_mse(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    """Population risk proxy: mean squared error."""
    return float(np.mean((y_true - y_pred) ** 2) * 0.5)


def risk_ce_logits(logits: np.ndarray, y_true: np.ndarray) -> float:
    """Cross-entropy risk from logits (softmax).

    y_true: integer labels 0..C-1
    """
    # log_softmax
    max_logit = np.max(logits, axis=1, keepdims=True)
    stabilized = logits - max_logit
    logsumexp = max_logit + np.log(np.sum(np.exp(stabilized), axis=1, keepdims=True))
    n = logits.shape[0]
    row_idx = np.arange(n)
    log_probs = logits - logsumexp
    chosen = log_probs[row_idx, y_true.astype(int)]
    return float(-np.mean(chosen))




