from typing import Optional

from episode.paraphrasers import Paraphraser
from logger import logger

from .decoders.natural_decoder import NaturalDecoder
from .encoders.natural_encoder import NaturalEncoder
from .natural_episode_types import NaturalResult
from .tasks.cv_task import CVTaskInstance


class NaturalEpisode:
    """
    A natural episode is an episode without explicit prompts for LLMs to do steganography.
    """

    def __init__(
        self,
        encoder: NaturalEncoder,
        decoder: NaturalDecoder,
        paraphraser: Optional[Paraphraser] = None,
    ):
        """
        Initialize the Episode class with the encoder and decoder.
        Args:
            encoder: The encoder agent.
            decoder: The decoder agent.
            paraphraser: The paraphraser.
        """
        self.encoder = encoder
        self.decoder = decoder
        self.paraphraser = paraphraser

    def run(self, task: CVTaskInstance) -> NaturalResult:
        """
        Run the episode with the given task and payload.

        Args:

            task: The task to be performed by the encoder.
        """
        logger.debug(f"Task was: {task}")
        message = self.encoder.get_message(task)

        logger.debug(f"Message was: {message}")
        if self.paraphraser is not None:
            message_paraphrased = self.paraphraser.paraphrase(message, task)
            logger.debug(f"Paraphrased message was: {message_paraphrased}")
            message_to_decode = (
                message_paraphrased if message_paraphrased is not None else message
            )
        else:
            message_paraphrased = None
            message_to_decode = message

        decoder_output = self.decoder.get_output(task, message_to_decode)

        return NaturalResult(
            task.evaluate(decoder_output),
            message,
            message_paraphrased,
            decoder_output,
        )
