import re
import ast

from tqdm import tqdm
from datasets import load_dataset

from dataset.base import BaseDataset


PROMPT = {
    "task_instructions": [
        "请回答以下选择题：",
        "请回答以下判断题：",
        "请回答以下填空题：",
    ],
    "multi_choice_example_format": ["问题：{}\n选项：\n{}\n正确答案：\n"],
    "T/F_example_format": ["问题：{}\n正确答案：\n"],
    "short_ans_example_format": ["问题：{}\n正确答案：\n"],
}

TYPE_MAP = {
    "选择": "multi_choice",
    "判断": "true_false_cn",
    "填空": "fill_blank_cn",
}


class CMMMU(BaseDataset):
    def __init__(self):
        super(CMMMU, self).__init__()
        self.data = load_dataset("lmms-lab/CMMMU")['val']
         
    def __len__(self):
        return len(self.data)

    def get_data(self, chunk_idx):
        data = []
        for idx in tqdm(chunk_idx):
            ins = self.data[idx]
            prompt = self.construct_prompt(ins)
            data.append({
                "images": self.create_image_list(ins, prompt),
                "question": prompt,
                'category': ins['category'],
                "label": ins['answer'],
                "options": [ins["option1"], ins["option2"], ins["option3"], ins["option4"]],
                "eval_method": TYPE_MAP[ins["type"]],
            })

        return data, ['category']
    
    def construct_prompt(self, ins):
        question = ins["question"]
        task_instructions = PROMPT["task_instructions"]

        if ins["type"] == "选择":
            formatted_options = ""
            start_chr = "A"
            for i in range(1, 5):
                formatted_options += f"({start_chr}) {ins[f'option{i}']}\n"
                start_chr = chr(ord(start_chr) + 1)

            current_example_template = PROMPT["multi_choice_example_format"][0]
            current_example = current_example_template.format(question, formatted_options)
            final_input_prompt = task_instructions[0] + "\n\n" + current_example + "\n\n请直接使用所提供的选项字母作为答案回答。"

        elif ins["type"] == "判断":
            current_example_template = PROMPT["T/F_example_format"][0]
            current_example = current_example_template.format(question)
            final_input_prompt = task_instructions[1] + "\n\n" + current_example + "\n\n请直接回答正确或错误。"

        else:  # For fill in the blanks questions.
            current_example_template = PROMPT["short_ans_example_format"][0]
            current_example = current_example_template.format(question)
            final_input_prompt = task_instructions[2] + "\n\n" + current_example + "\n\n请直接回答所填内容。"

        for i in range(1, 6):
            final_input_prompt = final_input_prompt.replace(f'<img="{ins[f"image_{i}_filename"]}">', f"<图片 {i}>")

        return final_input_prompt

    def create_image_list(self, ins, prompt):
        image_tokens = re.findall(r"<图片 \d+>", prompt)
        # Remove <> and  swap space as _
        image_tokens = [image_token.strip("<>").replace(" ", "_").replace("图片", "image") for image_token in image_tokens]
        images = [ins[image_token].convert("RGB") for image_token in image_tokens]
        return images
    
