from __future__ import annotations

import pathlib
from typing import TYPE_CHECKING

from torch import optim

from .._ood_model import _OODModel

if TYPE_CHECKING:
    from jaxtyping import Float
    from torch import Tensor


class _TemperatureScaledModel(_OODModel):
    """Base class for models with temperature scaling."""

    @staticmethod
    def get_checkpoints(
        dataset_name: str,
        model_name: str,
        parametrization_name: str,
        root_dir: str,
        seed: int,
    ) -> list[pathlib.Path]:
        """Returns the list of checkpoints available to use."""
        try:
            checkpoints = list(
                pathlib.Path(
                    pathlib.Path(root_dir)
                    / f"training_logs/{dataset_name}/{model_name}/{parametrization_name}/seed_{seed}"
                ).rglob("*.ckpt")
            )
        except TypeError as e:
            raise ValueError(f"No checkpoint found for seed: {seed}.") from e

        return checkpoints

    def forward(
        self, inputs: Float[Tensor, "batch *in_feature"]
    ) -> Float[Tensor, "batch *out_feature"]:
        return self.model(inputs)

    def configure_optimizers(self) -> optim.Optimizer:
        pass
