import logging
from dataclasses import dataclass
from typing import Any, Literal

import torch
from torchmetrics import Metric

from lib_llm.inference import iter_masked


logger = logging.getLogger(__name__)


@dataclass
class TokenValues:
    values: torch.Tensor
    mask: torch.Tensor

    @property
    def masked(self):
        return list(iter_masked(self.values, self.mask))

    def to(self, device: Literal["cpu", "cuda"]) -> "TokenValues":
        self.values = self.values.to(device)
        self.mask = self.mask.to(device)
        return self


MetricArg = Literal[
    "input",
    "model",
    "output",
    "logits",
    "token_probs",
    "sequence_probs",
]


class SequenceMetric(Metric):
    """Base class for metrics that compute sequence properties."""

    def __init__(self, required_args: list[MetricArg]):
        # Synchronization triggers errors since we're only computing
        # metrics on the CPU (for now)
        super().__init__(sync_on_compute=False)
        self.required_args: list[MetricArg] = required_args

    def update(self, **_: Any) -> None:
        raise NotImplementedError

    def compute(self) -> torch.Tensor:
        raise NotImplementedError


class TokenMetric(SequenceMetric):
    """Base class for metrics that compute token-level properties of
    sequences."""

    def __init__(self, required_args: list[MetricArg]):
        super().__init__(required_args)

    def update(self, **_: Any) -> None:
        raise NotImplementedError

    def compute(self) -> torch.Tensor:
        raise NotImplementedError
