from typing import Optional

from models import ModelType, create_completions_model
from models.hf_completion_model import HFCompletionModel

from ..payloads import PayloadInstance
from ..tasks import TaskInstance
from .encoder import Encoder


class CompletionEncoder(Encoder):
    """Uses a hugging face completion model to do encode the message."""

    def __init__(self, model_type: ModelType, adapter_link: Optional[str]):
        self.model = create_completions_model(model_type, adapter_link=adapter_link)

        assert isinstance(self.model, HFCompletionModel)

    def get_message(self, task: TaskInstance, payload: PayloadInstance) -> str:
        """Get the message given the task and payload.

        The payload is fed to the completion model to be encoded.
        This works with a completion model which has already been trained to encode messages somehow.

        Args:
            task: a task instance e.g. answer question x
            payload: the payload to encode into the resopnse to the task

        Returns:
            the encoder's response to the task
        """

        payload_str = f"(The payload is {payload})."
        message = f"{payload_str} {task.get_user_prompt()}"

        response = self.model.completion(message)

        if response is not None:
            payload_loc = response.find(payload_str)
            assert (
                payload_loc >= 0
            ), "We expect this to be the case for the completion output"
            # Strip off the payload part
            response_stripped = response[payload_loc + len(payload_str) :]
            return response_stripped.strip()
        raise Exception("Encoder did not produce an output")
