# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# import debugpy
# try:
#     # 5678 is the default attach port in the VS Code debug configurations. Unless a host and port are specified, host defaults to 127.0.0.1
#     debugpy.listen(("localhost", 9501))
#     print("Waiting for debugger attach")
#     debugpy.wait_for_client()
# except Exception as e:
#     pass

import os
import re
import sys
from dataclasses import dataclass, field
from datetime import datetime
from typing import Optional

from PIL import Image
from torch.utils.data import Dataset
from transformers import Qwen2VLForConditionalGeneration

root_dir = os.path.dirname(os.path.abspath(__file__))  # Get the directory of the current script's absolute path
sys.path.append(root_dir)  # Add this directory to Python path
import json
import math
import random
from dataclasses import asdict

import numpy as np
import yaml
from trainer.grpo_ma_trainer import GRPOMATrainer
from trl import TrlParser, get_peft_config

from config.model_config import GRPOModelConfig
from config.training_config import GRPOConfig, GRPOScriptArguments

# Import our own modules
from dataset.dataset_grpo import MultiMediaGRPODataset
from model import get_supported_task_types, list_registered_functions

# Use our own monkey patch
from model.qwen2_5vl_monkey_patch import (
    monkey_patch_qwen2_5vl_flash_attn,
    monkey_patch_qwen2_5vl_forward,
    monkey_patch_torch_load,
)
from model.qwen_module import Qwen2VLModule
from model.vlm_module import *

SYSTEM_PROMPT = ("A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
                 "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
                 "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
                 "<think> reasoning process here </think><answer> answer here </answer>")


def get_vlm_module(model_name_or_path):
    """Get the corresponding VLM module based on model name"""
    if "qwen" in model_name_or_path.lower():
        return Qwen2VLModule
    else:
        print(f"Warning: Unsupported model {model_name_or_path}, using Qwen2VL as default")
        return Qwen2VLModule


def validate_reward_functions(vlm_module_cls, detected_task_type, sample_dataset_item):
    """Validate if reward functions work properly"""
    print(f"\n=== Validating reward functions for task type: {detected_task_type} ===")

    try:
        # Create test data
        test_completion = [{"content": "<think>test reasoning</think><answer>test answer</answer>"}]
        test_solution = "test answer"
        test_question_types = [detected_task_type]

        # Get additional parameters from sample
        test_kwargs = {}
        for key in ["affordance_label_id", "image_path", "problem"]:
            if key in sample_dataset_item:
                test_kwargs[key] = [sample_dataset_item[key]]

        # Test accuracy reward function
        try:
            accuracy_reward = vlm_module_cls.accuracy_reward(test_completion, [test_solution], test_question_types, **test_kwargs)
            print(f"✓ Accuracy reward function working: {accuracy_reward}")
        except Exception as e:
            print(f"✗ Accuracy reward function error: {e}")

        # Test format reward function
        try:
            format_reward = vlm_module_cls.format_reward(test_completion, [test_solution], test_question_types, **test_kwargs)
            print(f"✓ Format reward function working: {format_reward}")
        except Exception as e:
            print(f"✗ Format reward function error: {e}")

    except Exception as e:
        print(f"✗ Reward function validation failed: {e}")

    print("=== Validation complete ===\n")


def main(script_args, training_args, model_args):
    # Set seed before any data loading
    if training_args.seed is not None:
        random.seed(training_args.seed)
        np.random.seed(training_args.seed)
        torch.manual_seed(training_args.seed)

    # Create output directory if it doesn't exist
    os.makedirs(training_args.output_dir, exist_ok=True)

    # Save all arguments to JSON file
    args_dict = {
        "script_args": asdict(script_args),
        "training_args": asdict(training_args),
        "model_args": asdict(model_args),
        "timestamp": datetime.now().isoformat(),
    }

    # Save to JSON file
    args_file_path = os.path.join(training_args.output_dir, "training_args.json")
    with open(args_file_path, "w", encoding="utf-8") as f:
        json.dump(args_dict, f, indent=2, ensure_ascii=False, default=str)

    print(f"Training arguments saved to: {args_file_path}")

    # Load the VLM module
    vlm_module_cls = get_vlm_module(model_args.model_name_or_path)
    print(f"Using VLM module: {vlm_module_cls.__name__}")
    # Load the reward functions - use improved reward function system
    reward_funcs_registry = {
        "accuracy": vlm_module_cls.accuracy_reward,
        "format": vlm_module_cls.format_reward,
    }
    reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
    print(f"Loaded reward functions: {script_args.reward_funcs}")

    # Validate if reward functions are available
    for func_name in script_args.reward_funcs:
        if func_name in reward_funcs_registry:
            print(f"✓ Reward function '{func_name}' is available")
        else:
            print(f"✗ Warning: Reward function '{func_name}' not found")
            print(f"Available functions: {list(reward_funcs_registry.keys())}")

    # Set learning rate scheduler parameters
    training_args.lr_scheduler_type = "cosine"  # Use cosine annealing
    print(f"Learning rate scheduler: {training_args.lr_scheduler_type}")
    print(f"Warmup ratio: {training_args.warmup_ratio}")
    print(f"Learning rate: {training_args.learning_rate}")

    # Load the dataset
    print(f"Loading dataset from: {script_args.dataset_name}")

    dataset = MultiMediaGRPODataset(script_args.dataset_name, script_args)
    print(f"Loaded {len(dataset)} samples")

    if training_args.train_cls == "GRPO-MA":
        trainer_cls = GRPOMATrainer
    print(f"Using trainer: {training_args.train_cls}")

    # Initialize the GRPO trainer
    trainer = trainer_cls(
        model=model_args.model_name_or_path,
        reward_funcs=reward_funcs,
        args=training_args,
        vlm_module=vlm_module_cls(),
        train_dataset=dataset,
        eval_dataset=None,
        peft_config=get_peft_config(model_args),
        freeze_vision_modules=model_args.freeze_vision_modules,
        attn_implementation=model_args.attn_implementation,
        max_pixels=script_args.max_pixels,
        min_pixels=script_args.min_pixels,
        max_anyres_num=script_args.max_anyres_num,
        torch_dtype=model_args.torch_dtype,
    )

    trainer.train()

    # Save and push to hub
    trainer.save_model(training_args.output_dir)
    if training_args.push_to_hub:
        trainer.push_to_hub(dataset_name=script_args.dataset_name)


if __name__ == "__main__":
    parser = TrlParser((GRPOScriptArguments, GRPOConfig, GRPOModelConfig))
    script_args, training_args, model_args = parser.parse_args_and_config()

    # Apply monkey patches
    monkey_patch_qwen2_5vl_flash_attn()

    if training_args.deepspeed and "zero3" in training_args.deepspeed:
        print("Zero3 is used, applying qwen2_5vl forward and torch load monkey patches")
        monkey_patch_qwen2_5vl_forward()
        monkey_patch_torch_load()

    main(script_args, training_args, model_args)
