# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import re
from datetime import datetime
from dataclasses import dataclass, field
from typing import Optional

from datasets import load_dataset, load_from_disk
from transformers import Qwen2VLForConditionalGeneration

from trainer import Qwen2VLGRPOTrainer, Qwen2VLGRPOVLLMTrainerModified
# from trainer_drive import Qwen2VLGRPOTrainer, Qwen2VLGRPOVLLMTrainerModified
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config

from datasets import Dataset, DatasetDict

from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer
import numpy as np
from typing import List, Dict, Any


@dataclass
class GRPOScriptArguments(ScriptArguments):
    """
    Script arguments for the GRPO training script.

    Args:
        reward_funcs (`list[str]`):
            List of reward functions. Possible values: 'accuracy', 'format'.
    """

    reward_funcs: list[str] = field(
        default_factory=lambda: ["accuracy", "format"],
        metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
    )
    max_pixels: Optional[int] = field(
        default=12845056,
        metadata={"help": "Maximum number of pixels for the image"},
    )
    min_pixels: Optional[int] = field(
        default=3136,
        metadata={"help": "Minimum number of pixels for the image"},
    )
    temporal: Optional[bool] = field(
        default=True,
        metadata={"help": "whether using temporal GRPO"},
    )
    len_control: Optional[bool] = field(
        default=True,
        metadata={"help": "whether using length reward"},
    )



def accuracy_reward(completions, solution, **kwargs):
    
    def extract_answer(text):
        pattern = r'<answer>\s*(.*?)\s*</answer>'
        match = re.search(pattern, text, re.DOTALL)
        if match:
            return match.group(1).strip()
        return ""

    def normalize_number(num_str):
        try:
            num_str = num_str.replace(',', '')
            return float(num_str)
        except Exception as e:
            print(f"Error converting '{num_str}' to float: {e}")
            return None
    def normalize_points(num_str):
        try:
            # num_str = num_str.replace(',', '')
            # return float(num_str)
            pattern = r'\[(-?\d+\.\d+),\s*(-?\d+\.\d+)\]'
            matches = re.findall(pattern, num_str)
            x_y_num =np.array( [(float(x), float(y)) for x, y in matches])
            return x_y_num
        except Exception as e:
            print(f"Error converting points '{num_str}' to float: {e}")
            return None

    
    def wer(reference, hypothesis):
        ref_words = reference.split()
        hyp_words = hypothesis.split()
        m = len(ref_words)
        n = len(hyp_words)
        d = [[0]*(n+1) for _ in range(m+1)]
        for i in range(m+1):
            d[i][0] = i
        for j in range(n+1):
            d[0][j] = j
        for i in range(1, m+1):
            for j in range(1, n+1):
                if ref_words[i-1] == hyp_words[j-1]:
                    d[i][j] = d[i-1][j-1]
                else:
                    d[i][j] = 1 + min(d[i-1][j], d[i][j-1], d[i-1][j-1])
        return d[m][n] / max(1, m)


    def compute_rouge_score(reference, hypothesis, use_stemmer=True):
        scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=use_stemmer)
        scores = scorer.score(reference, hypothesis)
        average_fmeasure = (scores['rouge1'].fmeasure + scores['rouge2'].fmeasure + scores['rougeL'].fmeasure) / 3
        return average_fmeasure
    

    question_type = kwargs['problem_type'][0]
    
    contents = [completion[0]["content"] for completion in completions]
    current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
    rewards = []
    # import pdb; pdb.set_trace()

    for content, sol in zip(contents, solution):
    
        try:
            output_ans = extract_answer(content)
            gt_ans = extract_answer(sol)
            if question_type == "multiple choice":
                reward = 1.0 if output_ans.strip() == gt_ans.strip() else 0.0
            elif question_type == "numerical":
                gt_has_decimal = ("." in gt_ans) or ("," in gt_ans)
                out_has_decimal = ("." in output_ans) or ("," in output_ans)
                if gt_has_decimal != out_has_decimal:
                    reward = 0.0
                else:
                    # gt_number = normalize_number(gt_ans)
                    # out_number = normalize_number(output_ans)
                    gt_number = normalize_points(gt_ans)
                    out_number = normalize_points(output_ans)
                    
                    if gt_number is None or out_number is None:
                        reward = 0.0
                    else:
                        ### reward = 1.0 if round(gt_number, 2) == round(out_number, 2) else 0.0
                        ### diff = sum(out_number - gt_number)
                        # for numerical
                        diff = max(abs((out_number - gt_number).max()), abs((out_number - gt_number).min()))
                        reward = 1.0 if diff <= 2.0 else 0.0 # todo 
                        ### adapt func
                        # r = 0.9
                        # import pdb; pdb.set_trace()
                        r_alpha_beta = np.array([[0.5,  0.5],
                                        [ 0.65, 0.55],
                                        [ 0.7, 0.6],
                                        [ 0.75, 0.65],
                                        [ 0.8,  0.7],
                                        [ 0.9,  0.8],
                                        [ 1,  1],
                                        [ 1,  1],
                                        [ 1,  1],
                                        [ 1,  1]])
                        traj_r = ((out_number - gt_number)*(out_number - gt_number)*r_alpha_beta) / gt_number.shape[0]
                        traj_r = 1- min(1.0, sum(sum(traj_r)))
                        # traj_r = 1- sum(sum(traj_r))

                        steer_r = (out_number[:,1][1:] - out_number[:,1][:-1]) /(out_number[:,0][1:] - out_number[:,0][:-1])
                        steer_r = sum(steer_r)/(out_number.shape[0]-1)
                        # steer_r = min(0.84, steer_r)
                        steer_r = 1.0 if steer_r < 0.84 else 0.0

                        acc_r = sum((out_number[2:,:] - out_number[1:9,:])*(out_number[2:,:] - out_number[1:9,:])) - \
                              sum((out_number[1:9,:] - out_number[:-2,:])*(out_number[1:9,:] - out_number[:-2,:])) 
                        T = 0.5
                        acc_r = np.sqrt(sum(acc_r))
                        acc_r = acc_r/T*T/(out_number.shape[0]-2)
                        acc_r = 1.0 if steer_r < 6 else 0.0
                        reward = traj_r + steer_r + acc_r
                        ###

                        # for regression
                        # ### rel_diff = (abs(out_number - gt_number)) / (abs(gt_number) + 1e-9)
                        # rel_diff = abs(out_number - gt_number)
                        # rel_diff = min(1.0, max(rel_diff))
                        # reward = 1 - rel_diff
            elif question_type == "OCR":
                error_rate = wer(gt_ans, output_ans)
                reward = 1 - error_rate
                reward = max(0.0, min(1.0, reward))
            elif question_type == "free-form":
                score = compute_rouge_score(gt_ans, output_ans)
                reward = max(0.0, min(1.0, score))
            elif question_type == "regression":
                gt_number = normalize_number(gt_ans)
                out_number = normalize_number(output_ans)
                if gt_number is None or out_number is None:
                    reward = 0.0
                rel_diff = (abs(out_number - gt_number) + 1e-9) / (abs(gt_number) + 1e-9)
                rel_diff = min(1.0, max(0.0, rel_diff))
                reward = 1 - rel_diff
            else:
                reward = 0.0
        except Exception as e:
            print(f"Error in reward_fn for question_type '{question_type}': {e}")
            reward = 0.0
    
        rewards.append(reward)
        
        if os.getenv("DEBUG_MODE") == "true":
            log_path = os.getenv("LOG_PATH")
            # local_rank = int(os.getenv("LOCAL_RANK", 0))
            with open(log_path, "a", encoding="utf-8") as f:
                f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
                f.write(f"Content: {content}\n")
                f.write(f"Solution: {sol}\n")
            
    return rewards


def format_reward(completions, **kwargs):
    """Reward function that checks if the completion has a specific format."""
    # import pdb; pdb.set_trace()
    pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
    completion_contents = [completion[0]["content"] for completion in completions]
    matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
    return [1.0 if match else 0.0 for match in matches]


reward_funcs_registry = {
    "accuracy": accuracy_reward,
    "format": format_reward,
}

SYSTEM_PROMPT = (
    "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
    "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
    "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
    "<think> reasoning process here </think><answer> answer here </answer>"
)


def main(script_args, training_args, model_args):
    # Get reward functions
    reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
    # import pdb; pdb.set_trace()
    if script_args.dataset_name.endswith('.json') or script_args.dataset_name.endswith('.jsonl'):
        dataset =  DatasetDict({"train": Dataset.from_json(script_args.dataset_name)})
    else:
        # Load the dataset
        dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)


    # Format into conversation
    def make_conversation(example):
        return {
            "prompt": [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": example["problem"]},
            ],
        }

    
    QUESTION_TEMPLATE = (
        "{Question}\n"
        "Please think deeply. "
        "Engage in an internal dialogue other natural language thought expressions "
        "It's a reasoning process. "
        "Provide your reasoning between the <think> </think> tags, and then give your answer between the <answer> </answer> tags."
    )

    TYPE_TEMPLATE = {
        "multiple choice": " Please provide only the single option letter (e.g., A, B, C, D, etc.) within the <answer> </answer> tags.",
        "numerical": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
        "OCR": " Please transcribe text from the image/video clearly and provide your text answer within the <answer> </answer> tags.",
        "free-form": " Please provide your text answer within the <answer> </answer> tags.",
        "action-detection": " Please provide your text answer within the <answer> </answer> tags.",
        "regression": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
        "driving":" Predicted future movement details for the next 5 seconds (sampled at 0.5-second intervals), including BEV location in x and y directions (in meters). Positive x means forward direction while positive y means leftwards. The output is formatted as [x, y]."
    }

    def make_conversation_image(example):
        
        return {
            "prompt": [
                {
                    "role": "user",
                    "content": [
                        {"type": "image"},
                        {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
                    ],
                },
            ],
        }
    
        
    def make_conversation_video(example):
        return {
            "prompt": [
                {
                    "role": "user",
                    "content": [
                        {"type": "video"},
                        {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
                    ],
                },
            ],
    }
        
    def make_conversation_image_and_video(example):
        # import pdb; pdb.set_trace()
        # if example["problem_type"] == 'multiple choice':
        #     question = example['problem'] + "Options:\n"
        #     for op in example["options"]:
        #         question += op + "\n"
        # else:
        #     question = example['problem']
        # import pdb; pdb.set_trace()
        question = example["messages"][0]["content"]

        # process_driving should diff in each frame
       
        process_driving = str("<think>\nLet me think. To rephrase the question in a way that requires Chain-of-Thought reasoning with numerical or mathematical expressions, we should break down the prediction of future waypoints into smaller steps, starting from understanding the provided data and applying relevant physics equations. \n\nThe original question asks for predicting the future waypoints directly from the given vehicle status, but let's derive the waypoints through intermediate calculations. \n \nOh, I see. The question now needs to be framed in such a way that the responder understands they need.\n</think>")
        
    
        solution_slice = example["messages"][1]["content"].split("The output is formatted as [x, y]: ")[1]
        solution_slice_points = solution_slice.split("</PLANNING>")[0]
        # for check
        if solution_slice_points.startswith("[") == False:
            print("************  the solution_slice_points not start with [ .*********************")
            return
        if solution_slice_points.endswith("]") == False:
            print("************  the solution_slice_points not end with ] .*********************")
            return
        pattern = r'\[(-?\d+\.\d+),\s*(-?\d+\.\d+)\]'
        matches = re.findall(pattern, solution_slice_points)
        x_y_num =np.array( [(float(x), float(y)) for x, y in matches])
        # print("str = ",solution_slice_points)
        # print("num = ",x_y_num)
        if x_y_num.shape != (10,2) :
            print("************  the x_y_num  shape is incorrect.*********************")
            return

        
        msg ={
            "prompt": 
               [
            
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",  #example['data_type'],
                    
                        },
                        {
                            "type": "text",
                            "text": QUESTION_TEMPLATE.format(Question=question) +TYPE_TEMPLATE["driving"]  #QUESTION_TEMPLATE.format(Question=question) + TYPE_TEMPLATE[example['problem_type']]
                        }
                        ]
                },
           
                ]
            }
        # example = prepare_dataset(example)
        example["problem_id"] = example["id"]
        example["information"] = example["messages"][0]["content"]
        example["problem"] = example["messages"][1]["content"].split("<PLANNING>")[1].split("The output is formatted as [x, y]: ")[0] + "The output is formatted as [x, y]: "
        example["data_type"] = "image"
        example["problem_type"] = "numerical" #driving
        example["solution"] = "<answer>" + example["messages"][1]["content"].split("</PLANNING>")[0].split("The output is formatted as [x, y]: ")[1]+"</answer>"
        example["path"] = "./data/nuccenes_vla/" + example['images'][0]

        # pop unused
        example.pop("id")
        example.pop("messages")

        return msg
    
    dataset = dataset.map(make_conversation_image_and_video)

    # import pdb; pdb.set_trace()
    trainer_cls = Qwen2VLGRPOTrainer if not training_args.use_vllm else Qwen2VLGRPOVLLMTrainerModified
    print("using: ", trainer_cls)

    # Initialize the GRPO trainer
    trainer = trainer_cls(
        model=model_args.model_name_or_path,
        reward_funcs=reward_funcs,
        args=training_args,
        script_args=script_args,
        train_dataset=dataset[script_args.dataset_train_split],
        eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
        peft_config=get_peft_config(model_args),
        attn_implementation=model_args.attn_implementation,
        max_pixels=script_args.max_pixels,
        min_pixels=script_args.min_pixels,
    )
    
    if training_args.resume_from_checkpoint is not None:
        checkpoint = training_args.resume_from_checkpoint
        trainer.train(resume_from_checkpoint=checkpoint)
    else:
        trainer.train()

    # Save and push to hub
    trainer.save_model(training_args.output_dir)
    if training_args.push_to_hub:
        trainer.push_to_hub(dataset_name=script_args.dataset_name)


if __name__ == "__main__":
    parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
    script_args, training_args, model_args = parser.parse_args_and_config()
    # import pdb; pdb.set_trace()
    main(script_args, training_args, model_args)
