
"""
This module provides utilities to compute:
    * Classification reward (R_cls)
    * Format match reward (R_fmt)
    * Error span reward (R_span)
    * Claim‑level reward R(i)  = R_cls + R_fmt + R_span
    * Answer‑level reward     = mean of claim‑level rewards
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import List, Sequence, Optional, Set


@dataclass
class ClaimPrediction:
    """
    Container holding the information needed to compute the reward for a single claim.

    Args:
        pred_label:  Model‑predicted classification label for the claim
        true_label:  Gold classification label
        format_match:
            'hard'     – exact match to the reference format
            'soft'     – minor deviations tolerated
            'invalid'  – fails required structure (default for unmatched formats)
        pred_spans:  List of text spans (or token spans) predicted as evidence/errors
        true_spans:  Gold spans to compare against

    Notes
    -----
    For labels we expect strings such as 'Supported', 'Refuted', 'Not Supported', …
    Only the label 'Not Supported' has special logic in the span reward.
    """
    pred_label: str
    true_label: str
    format_match: str  # 'hard' | 'soft' | 'invalid'
    pred_spans: Optional[Sequence[str]] = None
    true_spans: Optional[Sequence[str]] = None


# --------------------------------------------------------------------------- #
# Helper utilities
# --------------------------------------------------------------------------- #
def _f1(pred: Sequence[str] | None, gold: Sequence[str] | None) -> float:
    """
    Compute token/segment F1 between two span collections.

    Parameters
    ----------
    pred : list-like or None
        Predicted span identifiers (strings, tuples, or any hashable type).
    gold : list-like or None
        Ground‑truth span identifiers.

    Returns
    -------
    float
        Standard F1 score in [0, 1].
    """
    pred_set: Set[str] = set(pred or [])
    gold_set: Set[str] = set(gold or [])

    # Perfect alignment when both are empty
    if not pred_set and not gold_set:
        return 1.0
    if not pred_set or not gold_set:
        return 0.0

    tp = len(pred_set & gold_set)
    precision = tp / len(pred_set)
    recall = tp / len(gold_set)
    if precision + recall == 0:
        return 0.0
    return (2 * precision * recall) / (precision + recall)


# --------------------------------------------------------------------------- #
# Individual reward components
# --------------------------------------------------------------------------- #
def classification_reward(pred_label: str, true_label: str) -> float:
    """
    R_cls : ±2 depending on match between predicted and gold label.
    """
    return 2.0 if pred_label == true_label else -2.0


def format_reward(match_type: str) -> float:
    """
    R_fmt : +1 (hard), +0.5 (soft), −0.5 (invalid).
    """
    match_type = match_type.lower()
    if match_type == "hard":
        return 1.0
    if match_type == "soft":
        return 0.5
    # Treat everything else as invalid
    return -0.5


def span_reward(
    pred_label: str,
    true_label: str,
    pred_spans: Sequence[str] | None,
    true_spans: Sequence[str] | None,
) -> float:
    """
    R_span according to the spec in the paper section.

    • If both gold and prediction are 'Not Supported':
        R_span = 2·F1(pred_spans, true_spans) − 1  ∈ [‑1, +1]
    • If gold is *not* 'Not Supported' but prediction is:
        R_span = −1
    • Otherwise:
        R_span = 0
    """
    ns = "not supported"
    pred_is_ns = pred_label.strip().lower() == ns
    true_is_ns = true_label.strip().lower() == ns

    if true_is_ns and pred_is_ns:
        return 2.0 * _f1(pred_spans, true_spans) - 1.0
    if not true_is_ns and pred_is_ns:
        return -1.0
    return 0.0


# --------------------------------------------------------------------------- #
# Aggregated reward calculations
# --------------------------------------------------------------------------- #
def claim_reward(claim: ClaimPrediction) -> float:
    """
    Compute total reward for a single claim.
    """
    r_cls = classification_reward(claim.pred_label, claim.true_label)
    r_fmt = format_reward(claim.format_match)
    r_span = span_reward(
        claim.pred_label,
        claim.true_label,
        claim.pred_spans,
        claim.true_spans,
    )
    return r_cls + r_fmt + r_span


def answer_reward(claims: List[ClaimPrediction]) -> float:
    """
    Average reward across all claims belonging to the same answer/sample.

    Raises
    ------
    ValueError
        If the input list is empty.
    """
    if not claims:
        raise ValueError("`claims` must contain at least one element.")
    return sum(claim_reward(c) for c in claims) / len(claims)


__all__ = [
    "ClaimPrediction",
    "classification_reward",
    "format_reward",
    "span_reward",
    "claim_reward",
    "answer_reward",
]
