from typing import Optional

from models import ModelType, create_classifier
from models.hf_classifier import HFClassifier
from serde import Serializable

from ..payloads import Payload, PayloadInstance
from ..tasks import TaskInstance
from .decoder import Decoder


class ClassifierDecoder(Decoder):
    """Use a pretrained classifier to get the password."""

    def __init__(self, model_type: ModelType, payload_type: Payload):
        self.payload_type = payload_type
        self.classifier = create_classifier(model_type)
        assert isinstance(self.classifier, HFClassifier)

    def get_secret_payload(
        self, task: TaskInstance, encoder_answer: str
    ) -> Optional[PayloadInstance]:
        """Attempt to decode the secret payload from the speaker's answer to the task

        Args:
            task: the task which the encoder was answering
            encoder_answer: the encoder's answer to the task

        Returns:
            the guess of the secret payload, or None if the model fails to produce a guess
        """

        return self.payload_type.create_payload_safe(
            self.classifier.classify(encoder_answer)
        )

    def serialize(self) -> Serializable:
        return {"classifier": self.classifier}
