from typing import Optional

from logger import logger

from .decoders import Decoder
from .encoders import Encoder
from .episode_types import DecoderResult, Result
from .paraphrasers import Paraphraser
from .payloads import PayloadInstance
from .tasks import TaskInstance


class Episode:
    """
    Episode class is an instance of the environment where the agents interact.
    It takes an encoder, a decoder, and a paraphraser as input and runs an episode.
    """

    def __init__(
        self,
        encoder: Encoder,
        decoder: Decoder,
        paraphraser: Optional[Paraphraser] = None,
    ):
        """
        Initialize the Episode class with the encoder and decoder.
        Args:
            encoder: The encoder agent.
            decoder: The decoder agent.
            paraphraser: The paraphraser.
        """
        self.encoder = encoder
        self.decoder = decoder
        self.paraphraser = paraphraser

    def run(self, task: TaskInstance, payload: PayloadInstance) -> Result:
        """
        Run the episode with the given task and payload.

        Args:

            task: The task to be performed by the encoder.
                Example: QATaskSimple("What colour is Santa Claus?", "red")
            payload: The payload to be encoded in the message.
        """
        logger.debug(f"Task was: {task}")
        message = self.encoder.get_message(task, payload)

        logger.debug(f"Message was: {message}")
        if self.paraphraser is not None:
            message_paraphrased = self.paraphraser.paraphrase(message, task)
            logger.debug(f"Paraphrased message was: {message_paraphrased}")
            message_to_decode = (
                message_paraphrased if message_paraphrased is not None else message
            )
        else:
            message_paraphrased = None
            message_to_decode = message

        decoder_output = self.decoder.get_secret_payload(task, message_to_decode)

        return Result(
            task.evaluate(message),
            task.evaluate(message_to_decode),
            DecoderResult(payload, decoder_output),
            message,
            message_paraphrased,
        )
