# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# import debugpy
# try:
#     # 5678 is the default attach port in the VS Code debug configurations. Unless a host and port are specified, host defaults to 127.0.0.1
#     debugpy.listen(("localhost", 9501))
#     print("Waiting for debugger attach")
#     debugpy.wait_for_client()
# except Exception as e:
#     pass

import os
import re
from datetime import datetime
from dataclasses import dataclass, field
from typing import Optional

from PIL import Image
from torch.utils.data import Dataset
from transformers import Qwen2VLForConditionalGeneration

from math_verify import parse, verify
from open_r1.trainer import VLMGRPOTrainer, GRPOConfig
from open_r1.vlm_modules import *
from trl import ModelConfig, ScriptArguments, TrlParser, get_peft_config
from transformers import TrainingArguments
import yaml
import json
import random
import math
import PIL.Image

# Qwen2_5_VLVisionFlashAttention2.forward = custom_forward
from open_r1.qwen2_5vl_monkey_patch import monkey_patch_qwen2_5vl_flash_attn, monkey_patch_qwen2_5vl_forward, monkey_patch_torch_load
monkey_patch_qwen2_5vl_flash_attn()  
monkey_patch_qwen2_5vl_forward()  
monkey_patch_torch_load()


# ----------------------- Main Script -----------------------
@dataclass
class GRPOScriptArguments(ScriptArguments):
    """
    Script arguments for the GRPO training script.

    Args:
        reward_funcs (`list[str]`):
            List of reward functions. Possible values: 'accuracy', 'format'.
    """

    reward_funcs: list[str] = field(
        default_factory=lambda: ["format"],
        metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
    )
    max_pixels: Optional[int] = field(
        default=12845056,
        metadata={"help": "Maximum number of pixels for the image (for QwenVL)"},
    )
    min_pixels: Optional[int] = field(
        default=3136,
        metadata={"help": "Minimum number of pixels for the image (for QwenVL)"},
    )
    max_anyres_num: Optional[int] = field(
        default=12,
        metadata={"help": "Maximum number of anyres blocks for the image (for InternVL)"},
    )
    image_root: Optional[str] = field(
        default=None,
        metadata={"help": "Root directory of the image"},
    )

@dataclass
class GRPOModelConfig(ModelConfig):
    freeze_vision_modules: bool = False


SYSTEM_PROMPT = (
    "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
    "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
    "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
    "<think> reasoning process here </think><answer> answer here </answer>"
)

class LazySupervisedDataset(Dataset):
    def __init__(self, data_path: str, script_args: GRPOScriptArguments, question_template: str):
        super(LazySupervisedDataset, self).__init__()
        self.script_args = script_args
        self.list_data_dict = []
        self.question_template = question_template

        if data_path.endswith(".yaml"):
            with open(data_path, "r") as file:
                yaml_data = yaml.safe_load(file)
                datasets = yaml_data.get("datasets")
                # file should be in the format of:
                # datasets:
                #   - json_path: xxxx1.json
                #     sampling_strategy: first:1000
                #   - json_path: xxxx2.json
                #     sampling_strategy: end:3000
                #   - json_path: xxxx3.json
                #     sampling_strategy: random:999

                for data in datasets:
                    json_path = data.get("json_path")
                    sampling_strategy = data.get("sampling_strategy", "all")
                    sampling_number = None

                    if json_path.endswith(".jsonl"):
                        cur_data_dict = []
                        with open(json_path, "r") as json_file:
                            for line in json_file:
                                cur_data_dict.append(json.loads(line.strip()))
                    elif json_path.endswith(".json"):
                        with open(json_path, "r") as json_file:
                            cur_data_dict = json.load(json_file)
                    else:
                        raise ValueError(f"Unsupported file type: {json_path}")

                    if ":" in sampling_strategy:
                        sampling_strategy, sampling_number = sampling_strategy.split(":")
                        if "%" in sampling_number:
                            sampling_number = math.ceil(int(sampling_number.split("%")[0]) * len(cur_data_dict) / 100)
                        else:
                            sampling_number = int(sampling_number)

                    # Apply the sampling strategy
                    if sampling_strategy == "first" and sampling_number is not None:
                        cur_data_dict = cur_data_dict[:sampling_number]
                    elif sampling_strategy == "end" and sampling_number is not None:
                        cur_data_dict = cur_data_dict[-sampling_number:]
                    elif sampling_strategy == "random" and sampling_number is not None:
                        random.shuffle(cur_data_dict)
                        cur_data_dict = cur_data_dict[:sampling_number]
                    print(f"Loaded {len(cur_data_dict)} samples from {json_path}")
                    self.list_data_dict.extend(cur_data_dict)
        else:
            raise ValueError(f"Unsupported file type: {data_path}")

    def __len__(self):
        return len(self.list_data_dict)

    def __getitem__(self, i):

        QUESTION_TEMPLATE = self.question_template
        def make_conversation(text):
            return {
                "prompt": [
                    {
                        "role": "user",
                        "content": [
                            {"type": "text", "text": QUESTION_TEMPLATE.format(Question=text)},
                        ],
                    },
                ],
            }

        def make_conversation_image(text, image_path):
            return {
                "prompt": [
                    # {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
                    {
                        "role": "user",
                        "content": [
                            {"type": "image", "image": image_path},
                            {"type": "text", "text": QUESTION_TEMPLATE.format(Question=text)},
                        ],
                    },
                ],
            }

        def make_conversation_video(text, video_path):
            return {
                "prompt": [
                    # {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
                    {
                        "role": "user",
                        "content": [
                            {"type": "video", "video": video_path},
                            {"type": "text", "text": QUESTION_TEMPLATE.format(Question=text)},
                        ],
                    },
                ],
            }

        def make_conversation_multimodal(text, image_path=None, video_path=None):
            return {
                "prompt": [
                    # {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
                    {
                        "role": "user",
                        "content": [
                            {"type": "image", "image": image_path},
                            {"type": "video", "video": video_path},
                            {"type": "text", "text": QUESTION_TEMPLATE.format(Question=text)},
                        ],
                    },
                ],
            }

        example = self.list_data_dict[i]
        image_root = self.script_args.image_root

        qry_video_path = None
        qry_image_path = None
        if 'qry_image' in example and example['qry_image'] != "":
            qry_image_path = os.path.join(image_root, example['qry_image'])

        elif 'qry_video' in example and example['qry_video'][0] != "":
            qry_video_path = [os.path.join(image_root, video_path) for video_path in example['qry_video']]
        
        pos_video_path = None
        pos_image_path = None
        if 'pos_image' in example and example['pos_image'] != "":
            pos_image_path = os.path.join(image_root, example['pos_image'])
        elif 'pos_video' in example and example['pos_video'][0] != "":
            pos_video_path = [os.path.join(image_root, video_path) for video_path in example['pos_video']]

        placeholder_image = PIL.Image.new("RGB", (28, 28), (255, 255, 255))

        has_image = qry_image_path is not None or pos_image_path is not None
        has_video = qry_video_path is not None or pos_video_path is not None

        if qry_image_path is not None:
            qry = {
                'qry_text': example['qry_text'],
                'solution': example['solution'],
                'prompt': make_conversation_image(example['qry_text'], qry_image_path)['prompt'],
            }
        elif qry_video_path is not None:
            qry = {
                'qry_text': example['qry_text'],
                'solution': example['solution'],
                'prompt': make_conversation_video(example['qry_text'], qry_video_path)['prompt'],
            }
        else:
            if has_image:
                qry = {
                    'qry_text': example['qry_text'],
                    'solution': example['solution'],
                    'prompt': make_conversation_image(example['qry_text'], placeholder_image)['prompt'],
                }
            elif has_video:
                qry = {
                    'qry_text': example['qry_text'],
                    'solution': example['solution'],
                    'prompt': make_conversation_video(example['qry_text'], [placeholder_image])['prompt'],
                }
        
        if pos_image_path is not None:
            pos = {
                'pos_text': example['pos_text'],
                'solution': example['solution'],
                'prompt': make_conversation_image(example['pos_text'], pos_image_path)['prompt'],
            }
        elif pos_video_path is not None:
            pos = {
                'pos_text': example['pos_text'],
                'solution': example['solution'],
                'prompt': make_conversation_video(example['pos_text'], pos_video_path)['prompt'],
            }
        else:
            if has_image:
                pos = {
                    'pos_text': example['pos_text'],
                    'solution': example['solution'],
                    'prompt': make_conversation_image(example['pos_text'], placeholder_image)['prompt'],
                }
            elif has_video:
                pos = {
                    'pos_text': example['pos_text'],
                    'solution': example['solution'],
                    'prompt': make_conversation_video(example['pos_text'], [placeholder_image])['prompt'],
                }
        # print("data processs end!")
        # print("qry:", qry)
        # print("pos:", pos)
        return {
            'qry': qry,
            'pos': pos,
        }


def get_vlm_module(model_name_or_path):
    if "qwen" in model_name_or_path.lower():
        return Qwen2VLModule
    elif "internvl" in model_name_or_path.lower():
        return InvernVLModule
    else:
        raise ValueError(f"Unsupported model: {model_name_or_path}")

def main(script_args, training_args, model_args):
    # Load the VLM module
    vlm_module_cls = get_vlm_module(model_args.model_name_or_path)
    print("using vlm module:", vlm_module_cls.__name__)

    # Load the reward functions
    reward_funcs_registry = {
        "format": vlm_module_cls.format_reward_embed,
    }
    reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
    print("reward_funcs:", reward_funcs)

    # Load the dataset
    dataset = LazySupervisedDataset(script_args.dataset_name, script_args, question_template=vlm_module_cls.get_question_template(task_type="embed"))

    trainer_cls = VLMGRPOTrainer
    # Initialize the GRPO trainer
    trainer = trainer_cls(
        model=model_args.model_name_or_path,
        reward_funcs=reward_funcs,
        args=training_args,
        vlm_module=vlm_module_cls(),
        train_dataset=dataset,
        eval_dataset=None,
        peft_config=get_peft_config(model_args),
        freeze_vision_modules=model_args.freeze_vision_modules,
        attn_implementation=model_args.attn_implementation,
        max_pixels=script_args.max_pixels,
        min_pixels=script_args.min_pixels,
        max_anyres_num=script_args.max_anyres_num,
        torch_dtype=model_args.torch_dtype,
    )

    # Train and push the model to the Hub
    trainer.train()

    # Save and push to hub
    trainer.save_model(training_args.output_dir)
    if training_args.push_to_hub:
        trainer.push_to_hub(dataset_name=script_args.dataset_name)


if __name__ == "__main__":
    parser = TrlParser((GRPOScriptArguments, GRPOConfig, GRPOModelConfig))
    script_args, training_args, model_args = parser.parse_args_and_config()
    main(script_args, training_args, model_args)
