import abc
import typing

import torch

TAux = typing.TypeVar("TAux")  # type variable for auxiliary data
TModel = typing.TypeVar("TModel")  # type variable for model


class ModelTrainer(typing.Generic[TAux, TModel], metaclass=abc.ABCMeta):
    images_mean_std: tuple[torch.Tensor, torch.Tensor] | None = None

    @abc.abstractmethod
    def train(
        self,
        images: torch.Tensor,
        targets: torch.Tensor,
        seed: int,
        device: torch.device,
    ) -> tuple[TModel, TAux | None]:
        pass

    @abc.abstractmethod
    def predict(self, images: torch.Tensor, model: TModel, aux: TAux | None) -> torch.Tensor:
        pass
