# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocess the Uground20k dataset to parquet format
"""

import os
import datasets
from datasets import Image as DatasetImage
import json
import re

from verl.utils.hdfs_io import copy, makedirs
import argparse

from PIL import Image
from io import BytesIO

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--local_dir', default='./data/uground')
    parser.add_argument('--hdfs_dir', default=None)
 
    args = parser.parse_args()

    data_source = './datasets/uground_21k'

    dataset = datasets.load_from_disk(data_source)
    feature = DatasetImage()

    train_size = 20000
    test_size = 1000
    # train_size = 50
    # test_size = 50

    train_dataset = dataset.select(range(train_size))
    test_dataset = dataset.select(range(train_size, train_size + test_size))

    QUESTION_TEMPLATE = "<image>What is the coordinate of [{Question}] in the image?\nThe size of image is ({size_x},{size_y}).\nOutput the thinking process in <think> </think> and final answer (coordinate (x,y)) in <answer> </answer> tags."

    # add a row to each data item that represents a unique id
    def make_map_fn(split):

        def process_fn(example, idx):
            conversations = example.pop('conversations')
            conversations = json.loads(conversations)
            problem = conversations[0]['value'].split('Description:')[-1].split('Answer')[0].strip()
            prompt = QUESTION_TEMPLATE.format(Question=problem, size_x=example['width'], size_y=example['height'])
            answer = tuple(map(int, re.findall(r'\d+', conversations[1]['value'])))

            width = int(example['width']//28*28)
            height = int(example['height']//28*28)
            image = example.pop('image')
            image = Image.open(BytesIO(image)).convert('RGB')
            image = image.resize((width, height), Image.LANCZOS)
            image = feature.encode_example(image)

            answer = [int(answer[0]/1000*width), int(answer[1]/1000*height)]


            data = {
                "data_source": data_source,
                "prompt": [{
                    "role": "user",
                    "content": prompt,
                }],
                "images": [image],
                "ability": "grounding",
                "reward_model": {
                    "style": "rule",
                    "ground_truth": answer
                },
                "extra_info": {
                    'split': split,
                    'index': idx,
                    'answer': answer,
                    "question": problem,
                }
            }
            return data

        return process_fn
    


    train_pils = []
    test_pils = []
    train_dataset = train_dataset.filter(lambda x: x['image'] is not None)
    print(train_dataset)
    test_dataset = test_dataset.filter(lambda x: x['image'] is not None)
    print(test_dataset)
    train_dataset_mapped = train_dataset.map(function=make_map_fn('train'), with_indices=True, num_proc=64)
    test_dataset_mapped = test_dataset.map(function=make_map_fn('test'), with_indices=True, num_proc=64)
    local_dir = args.local_dir
    hdfs_dir = args.hdfs_dir

    train_dataset_mapped.to_parquet(os.path.join(local_dir, 'train.parquet'))
    test_dataset_mapped.to_parquet(os.path.join(local_dir, 'test.parquet'))

    if hdfs_dir is not None:
        makedirs(hdfs_dir)
        copy(src=local_dir, dst=hdfs_dir)
