"""
Interface of the MMMU dataset.
"""

import os
import ast


from trlm.identifier import OPTION_SOLUTION_PROMPT
from trlm.dataset.base import VisualTextSample, VisualTextBase


class MMMUDataset(VisualTextBase):
    """A consistent interface for the MMMU dataset."""

    def __init__(
        self, split: str = "train", hf_dataname: str = None, config: dict = None
    ):
        self.data_path = config["data_path"]
        self.image_path = f"{self.data_path}/images"
        os.makedirs(self.image_path, exist_ok=True)

        super().__init__(split=split, hf_dataname=hf_dataname, config=config)

    def to_format(self, sample: dict):
        """Get the sample from the given idx."""
        sample_id = sample["id"]
        # Create the sample
        question = sample["question"]
        question = f"{question} {OPTION_SOLUTION_PROMPT}."
        options = sample["options"]

        question_images = []
        for i in range(1, 8):
            image_name = f"image_{i}"
            q_image_token = f"image {i}"
            filename = f"Image-ID{sample_id}-{image_name}"
            filepath = f"{self.image_path}/{filename}.png"
            if os.path.exists(filepath):
                question_images.append((q_image_token, filepath))
                continue

            image_data = sample[image_name]
            save_path = self.save_pil_image(image_data, self.image_path, filename)
            if save_path is not None:
                question_images.append((q_image_token, save_path))

        if options is None or len(options) == 0:
            question = f"{question}\n"
        else:
            options = ast.literal_eval(options)
            option_letters = [chr(ord("A") + num) for num in range(len(options))]
            options_str = [
                f"({letter}): {choice}"
                for choice, letter in zip(options, option_letters)
            ]
            options_str = "\n".join(options_str)
            question = f"{question}\nOptions:\n{options_str}"

        cot_answer = sample["explanation"]

        return VisualTextSample(
            main_id=sample_id,
            question=question,
            cot_answer=cot_answer,
            groundtruth=sample["answer"],
            question_images=question_images,
            sample_info={
                "dataset": self.hf_dataname,
                "question_type": sample["question_type"],
                "subfield": sample["subfield"],
                "topic_difficulty": sample["topic_difficulty"],
                "img_type": sample["img_type"],
            },
        )
