from dataclasses import dataclass
from typing import Any, Optional

from serde import deserialize_from_json

from .model import Classifier, ModelType


class HFClassifierType(ModelType):
    DISTILBERT_CLASSIFIER = "distilbert"


@dataclass
class HFClassifierConfig:
    model_link: str
    num_labels: Optional[int] = None


class HFClassifier(Classifier):
    def __init__(self, model_type: HFClassifierType, num_labels: Optional[int] = None):
        super().__init__(model_type)
        self.num_labels = num_labels
        self._model = None

    @property
    def model(self):
        if self._model is None:
            self.initialize_model()
        assert self._model is not None
        return self._model

    def initialize_model(self):
        import torch as t
        from transformers import AutoModelForSequenceClassification, AutoTokenizer

        self.device = "cuda" if t.cuda.is_available() else "cpu"

        self.config = deserialize_from_json(HFClassifierConfig, self.config_dict)
        if self.config.num_labels is None:
            assert (
                self.num_labels is not None
            ), "num_labels must be provided if not in config"
            num_labels = self.num_labels
        else:
            assert self.num_labels is None, "num_labels cannot be provided if in config"
            num_labels = self.config.num_labels
        self._model = AutoModelForSequenceClassification.from_pretrained(
            self.config.model_link, num_labels=num_labels
        ).to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_link)

    def __call__(self, text: str) -> Any:
        """Get the logits for the given input.

        Args:
            text: the string to run the model on

        Returns:
            logits
        """
        tokenized = self.tokenizer(text, return_tensors="pt").to(self.device)
        return self.model(tokenized["input_ids"], tokenized["attention_mask"]).logits

    def classify(self, text: str) -> int:
        """Classify the text.

        Args:
            text: the text to classify

        Returns:
            the class of the text
        """
        import torch as t

        return int(t.argmax(self(text)).item())
