data_source = "ViRL39K"

data_path = "ViRL39K/39Krelease.parquet"
img_main_path = "ViRL39K"
save_path = "ViRL39K/data_forVerl"

dataset = datasets.load_dataset("parquet", data_files=data_path, split="train")
print(len(dataset))

instruction_following = (
    r"You FIRST think about the reasoning process as an internal monologue and then provide the final answer. "
    r"The reasoning process MUST BE enclosed within <think> </think> tags. "
    r"The final answer MUST BE put in \boxed{}."
)

def resize_min_side(image, min_size=28):
    width, height = image.size
    # Check if both sides already >= min_size
    if min(width, height) >= min_size:
        return image  # No resizing needed

    # Compute scale factor to make the smaller side equal to min_size
    scale = max(min_size / width, min_size / height)

    # Calculate new dimensions
    new_width = int(round(width * scale))
    new_height = int(round(height * scale))

    # Resize while preserving aspect ratio
    return image.resize((new_width, new_height), Image.LANCZOS)


# add a row to each data item that represents a unique id
def make_map_fn(split):
    def process_fn(example, idx):
        problem = example.pop("question")
        
        if "<image>" not in problem:
            print(problem)
            problem = "<image>\n" + problem
            print("############")
            print(problem)
            print("!!!!!!!!!!!!!!!!!!!!!!!!!!!")
        prompt = problem + " " + instruction_following
        answer = extract_boxed_content(example.pop("answer"))

        img_path = [os.path.join(img_main_path, img) for img in example.pop("image")]
        

        images = [resize_min_side(Image.open(path)) for path in img_path]
            
        data = {
            "data_source": data_source,
            "prompt": [
                {
                    "role": "user",
                    "content": prompt,
                }
            ],
            "images": images,
            "ability": "math",
            "reward_model": {"style": "rule", "ground_truth": answer},
            "extra_info": {
                "split": split,
                "index": idx,
                "answer": answer,
                "question": problem,
            },
        }
        return data

    return process_fn

train_dataset_virl = dataset.map(function=make_map_fn("train"), with_indices=True, num_proc=8)

train_dataset_virl.to_parquet(os.path.join(save_path, "train.parquet"))