from pathlib import Path
import json
from tqdm import tqdm

import PIL.Image
from datasets import load_dataset

from src.path import test_intermediate_data_dir, test_intermediate_dir


if __name__ == "__main__":
    dataset_name = "MathVerse"
    figure_category = "geometry_diagram"
    
    dataset = load_dataset("AI4Math/MathVerse", name="testmini", split="testmini")
    
    # don't shuffle the dataset for MathVerse because there are multiple questions for the same group of images
    
    # save images
    image_save_dir = Path("images") / dataset_name
    intermediate_image_dir = test_intermediate_dir / image_save_dir
    intermediate_image_dir.mkdir(parents=True, exist_ok=True)
    
    text_data = []
    for d in tqdm(dataset):
        if "Vision" not in d["problem_version"]:
            continue
        
        image_name = f"{int(d['problem_index']):04d}_{d['problem_version'].lower().replace(' ', '-')}.png"
        image_path: Path = intermediate_image_dir / image_name
        image_path.parent.mkdir(parents=True, exist_ok=True)
        text_data.append(
            {
                "image": str(image_save_dir / image_name),
                "original_question": d["question_for_eval"],
                "original_answer": d["answer"],
                "problem_version": d["problem_version"],
            }
        )
        
        image: PIL.Image.Image = d["image"]
        image.convert("RGB").save(image_path)
    
    # save
    intermediate_output_dir = test_intermediate_data_dir / figure_category
    intermediate_output_dir.mkdir(parents=True, exist_ok=True)
    
    with open(intermediate_output_dir / f"{dataset_name}.jsonl", "w") as f:
        for data in text_data:
            f.write(json.dumps(data) + "\n")
