import os

import datasets

import os
import datasets


def make_map_fn(split, source=None):
    def process_fn(example, idx):
        if source is None:
            data_source = example.pop("source")
        else:
            data_source = source

        # Construct question
        question = example.pop("prompt")
        image_path = example.pop("image_path", None)  # Allow missing image_path
        solution = example.pop("answer")

        # Only include images if image_path is present and file exists
        images = []
        if image_path and os.path.exists(image_path):
            images = [{"image": image_path}]
        else:
            images = None  # Always define images as an empty list if no image

        data = {
            "data_source": "GPQA-TTT",
            "prompt": [
                {
                    "role": "user",
                    "content": question,
                }
            ],
            "ability": "math",
            "reward_model": {"style": "rule", "ground_truth": solution},
            "extra_info": {
                "split": split,
                "index": f"{data_source}-{idx}",
            },
            "images": images,  # Always include this field
        }

        return data

    return process_fn

def make_map_fn_old(split, source=None):
        def process_fn(example, idx):
            if source is None:
                data_source = example.pop("source")
            else:
                data_source = source
            question = example.pop("prompt") + ". Give only the option as the answer"
            question_with_image = f"{question}\n<image>"
            image_path = example.pop("image_path")
            solution = example.pop("answer")

            # prompt_content = [
            # {"type": "text", "text": question + ". Give only the option as the answer"},
            # {"type": "image", "image": image_path}  # This assumes path-based access
            # ]
            

            data = {
                "data_source": "GPQA-TTT",
                "prompt": [
                    {
                        "role": "user",
                        "content": question,
                    }
                ],
                "images": [{"image": image_path}],
                "ability": "math",
                "reward_model": {"style": "rule", "ground_truth": solution},
                "extra_info": {
                    "split": split,
                    "index": f"{data_source}-{idx}",
                },
            }
            return data

        return process_fn

if __name__ == '__main__':

    # data_source = 'mathvista' # put the dataset folder name here
    # data_source = 'food_101_500_test' # put the dataset folder name here

    data_source = "imagenet_v2_20_to_imagenet_r"  # put the dataset folder name here


    train_dataset = datasets.load_dataset("json", data_files=os.path.join(data_source, 'train.json'), split='train')
    test_dataset = datasets.load_dataset("json", data_files=os.path.join(data_source, 'test.json'), split='train')

    print(train_dataset)
    print(test_dataset)


    train_dataset = train_dataset.map(function=make_map_fn("train", data_source), with_indices=True)
    test_dataset = test_dataset.map(function=make_map_fn("test", data_source), with_indices=True)

    train_dataset.to_parquet(os.path.join(data_source, 'train.parquet'))
    test_dataset.to_parquet(os.path.join(data_source, 'test.parquet'))