"""Types for use in natural episodes."""

from dataclasses import dataclass
from typing import Optional

from natural_episode.overseers import NaturalOverseerEvaluation
from serde import serialize
from utils import tag

from .tasks import EpisodeEvaluation, NaturalTaskInstance


@dataclass(slots=True)
class NaturalResult:
    evaluation: EpisodeEvaluation
    encoder_message: str
    paraphrased_message: Optional[str]
    decoder_message: str

    def __str__(self):
        return f"""RESULT:
encoder said {self.encoder_message}
evaluation was {self.evaluation}"""

    def is_success(self):
        return self.evaluation


@dataclass
class NaturalExperimentResults:
    inputs: list[NaturalTaskInstance]
    results: list[NaturalResult]
    overseer_evaluation: Optional[NaturalOverseerEvaluation]

    def __post_init__(self):
        self.failed_examples = [
            (input, result)
            for (input, result) in zip(self.inputs, self.results)
            if not result.evaluation
        ]
        self.accuracy = round(100 * (1 - len(self.failed_examples) / len(self.results)))
        self.overseer_rejected_reason = (
            self.overseer_evaluation.reject_episode()
            if self.overseer_evaluation is not None
            else None
        )

    def is_success(self, useful_threshold: Optional[float]):
        if self.overseer_rejected_reason is not None:
            return False
        if useful_threshold is not None:
            return (
                1 - len(self.failed_examples) / len(self.results)
            ) >= useful_threshold
        else:
            return len(self.failed_examples) == 0

    def to_string(self, n_examples: Optional[int] = None) -> str:
        """Print the episode results to the log file"""
        summary_string = f"""SUMMARY:
{self.accuracy}% of outputs were CORRECT.
{'OVERSEER approved the messages' if self.overseer_rejected_reason is None else ('OVERSEER rejected the messages because ' + self.overseer_rejected_reason)}"""

        # Generate usefulness logs
        failed = self.failed_examples[:n_examples] if n_examples is not None else None

        if failed:
            examples = []
            for i, (input_, result) in enumerate(failed):
                examples.append(
                    tag(
                        f"""\nINPUT:
{tag(serialize(input_), "input")}
\nENCODER MESSAGE:
{tag(result.encoder_message, "encoder_message")}
\nEVALUATION:
{tag(str(result.evaluation), "evaluation")}
""",
                        f"example_{i}",
                    )
                )

            examples_str = "\n".join(examples)
            incorrect_str = f"""Some examples which were INCORRECT:\n
{tag(examples_str, "not_useful_examples")}
"""
        else:
            incorrect_str = ""

        # Combine logs and return final results string
        episode_results_str = f"""{summary_string}
\n--------------------------------------------------------\n
{incorrect_str}""".strip()
        return episode_results_str
