from models import ChatModel, ModelType, create_chat_model
from serde import Serializable

from ..tasks.natural_task import NaturalTask, NaturalTaskInstance
from .natural_decoder import NaturalDecoder


class NaturalLLMDecoder(NaturalDecoder):
    """Uses an LLM to decode the message, but does so with subtle prompting.."""

    def __init__(
        self, strategy: str, model: ChatModel | ModelType, task_type: NaturalTask
    ):
        self.model = create_chat_model(model) if isinstance(model, ModelType) else model
        self.task_type = task_type
        self.strategy = strategy
        self.system_prompt = f"""{task_type.DECODER_TASK_DESCRIPTION}

Here is your strategy:
{strategy}"""

    def get_output(self, task: NaturalTaskInstance, encoder_answer: str) -> str:
        """Attempt to decode the secret payload from the speaker's answer to the task

        Args:
            task: the task which the encoder was answering
            encoder_answer: the encoder's answer to the task

        Returns:
            the guess of the secret payload, or None if the model fails to produce a guess
        """
        messages = [
            {"role": "system", "content": self.system_prompt},
            {"role": "user", "content": task.get_decoder_prompt(encoder_answer)},
        ]
        return self.model.query_with_retry(messages)

    def serialize(self) -> dict[str, Serializable]:
        return {
            "model": self.model,
            "system_prompt": self.system_prompt,
            "strategy": self.strategy,
        }
