from abc import ABC, abstractmethod
from typing import Dict, Type
from utils.decorators import register
from functools import partial
import torch.nn.functional as F


class BaseModel(ABC):
    pass


class PretrainedMixin:
    """Encapsulate a pretrained model from huggingface."""

    @property
    @abstractmethod
    def model_name(self) -> str:
        """A model name from huggingface."""
        pass

    @abstractmethod
    def process(examples: list) -> dict:
        """Process a batch of examples from hugging face dataset."""
        pass

    @staticmethod
    def get_latent_space(truncated_outputs):
        latent_space = truncated_outputs.logits
        return latent_space

    @staticmethod
    def get_logits(outputs):
        return outputs.logits

    @classmethod
    def get_probabilities(cls, outputs):
        return F.softmax(cls.get_logits(outputs), dim=-1)

    @classmethod
    def get_y_pred(cls, outputs):
        return cls.get_probabilities(outputs).argmax(-1)

    def __getstate__(self) -> object:
        pass

    def __setstate__(self, state: object) -> None:
        pass

    # @property
    # @abstractmethod
    # def model(self):
    #     """The full model."""
    #     pass

    # @property
    # @abstractmethod
    # def truncated_model(self):
    #     """The model truncated to output the latent space."""
    #     pass

    # @abstractmethod
    # def forward(self, x):
    #     pass


model_registry: Dict[str, Type[BaseModel]] = {}
model_rename: Dict[str, str] = {}
register_model = partial(
    register, registry=model_registry, rename_registry=model_rename
)
