# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import re
import random
from datetime import datetime
from dataclasses import dataclass, field
from typing import Optional

from datasets import load_dataset, load_from_disk
from transformers import Qwen2VLForConditionalGeneration

from trainer import Qwen2VLGRPOTrainerMaskImg, Qwen2VLGRPOVLLMTrainerModified
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config

from datasets import Dataset, DatasetDict

from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer


@dataclass
class GRPOScriptArguments(ScriptArguments):
    """
    Script arguments for the GRPO training script with multi-modal masking.

    Args:
        reward_funcs (`list[str]`):
            List of reward functions. Possible values: 'accuracy', 'format'.
        mask_multimodal (`bool`):
            Whether to mask multi-modal input tokens during final answer generation.
    """

    reward_funcs: list[str] = field(
        default_factory=lambda: ["accuracy", "format"],
        metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
    )
    max_pixels: Optional[int] = field(
        default=12845056,
        metadata={"help": "Maximum number of pixels for the image"},
    )
    min_pixels: Optional[int] = field(
        default=3136,
        metadata={"help": "Minimum number of pixels for the image"},
    )
    temporal: Optional[bool] = field(
        default=True,
        metadata={"help": "whether using temporal GRPO"},
    )
    len_control: Optional[bool] = field(
        default=True,
        metadata={"help": "whether using length reward"},
    )
    dual_reasoning: Optional[bool] = field(
        default=False,
        metadata={"help": "whether using dual reasoning with shuffled options"},
    )
    dual_reasoning_reward_list: Optional[list[float]] = field(
        default_factory=lambda: [1.0, 0.3, 0.2, 0.0],
        metadata={"help": "List of reward values for dual reasoning. Possible values: 1.0, 0.3, 0.2, 0.0"},
    )
    progressive_reward: Optional[bool] = field(
        default=False,
        metadata={"help": "whether using progressive reward strategy"},
    )
    progressive_reward_stages: Optional[list[list[float]]] = field(
        default_factory=lambda: [
            [1.0, 0.7, 0.4, 0.1],  # Stage 1: lenient (first 30% steps)
            [1.0, 0.5, 0.2, 0.0],  # Stage 2: moderate (middle 40% steps)  
            [1.0, 0.3, 0.1, 0.0]   # Stage 3: strict (last 30% steps)
        ],
        metadata={"help": "Progressive reward stages: [[stage1], [stage2], [stage3]]"},
    )
    progressive_reward_ratios: Optional[list[float]] = field(
        default_factory=lambda: [0.3, 0.4, 0.3],
        metadata={"help": "Ratio of steps for each progressive stage [stage1_ratio, stage2_ratio, stage3_ratio]"},
    )
    mask_multimodal: Optional[bool] = field(
        default=True,
        metadata={"help": "Whether to mask multi-modal input tokens during final answer generation"},
    )
    use_efficient_masking: Optional[bool] = field(
        default=False,
        metadata={"help": "Use efficient single-pass generation with dynamic attention masking (experimental)"},
    )



def accuracy_reward(completions, solution, **kwargs):
    
    def extract_answer(text):
        pattern = r'<answer>\s*(.*?)\s*</answer>'
        match = re.search(pattern, text, re.DOTALL)
        if match:
            return match.group(1).strip()
        return ""

    def normalize_number(num_str):
        try:
            num_str = num_str.replace(',', '')
            return float(num_str)
        except Exception as e:
            print(f"Error converting '{num_str}' to float: {e}")
            return None

    def wer(reference, hypothesis):
        ref_words = reference.split()
        hyp_words = hypothesis.split()
        m = len(ref_words)
        n = len(hyp_words)
        d = [[0]*(n+1) for _ in range(m+1)]
        for i in range(m+1):
            d[i][0] = i
        for j in range(n+1):
            d[0][j] = j
        for i in range(1, m+1):
            for j in range(1, n+1):
                if ref_words[i-1] == hyp_words[j-1]:
                    d[i][j] = d[i-1][j-1]
                else:
                    d[i][j] = 1 + min(d[i-1][j], d[i][j-1], d[i-1][j-1])
        return d[m][n] / max(1, m)


    def compute_rouge_score(reference, hypothesis, use_stemmer=True):
        scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=use_stemmer)
        scores = scorer.score(reference, hypothesis)
        average_fmeasure = (scores['rouge1'].fmeasure + scores['rouge2'].fmeasure + scores['rougeL'].fmeasure) / 3
        return average_fmeasure
    

    question_type = kwargs['problem_type'][0]
    
    contents = [completion[0]["content"] for completion in completions]
    current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
    rewards = []

    for content, sol in zip(contents, solution):
    
        try:
            output_ans = extract_answer(content)
            gt_ans = extract_answer(sol)
            if question_type == "multiple choice":
                reward = 1.0 if output_ans.strip() == gt_ans.strip() else 0.0
            elif question_type == "numerical":
                gt_has_decimal = ("." in gt_ans) or ("," in gt_ans)
                out_has_decimal = ("." in output_ans) or ("," in output_ans)
                if gt_has_decimal != out_has_decimal:
                    reward = 0.0
                else:
                    gt_number = normalize_number(gt_ans)
                    out_number = normalize_number(output_ans)
                    if gt_number is None or out_number is None:
                        reward = 0.0
                    else:
                        reward = 1.0 if round(gt_number, 2) == round(out_number, 2) else 0.0
            elif question_type == "OCR":
                error_rate = wer(gt_ans, output_ans)
                reward = 1 - error_rate
                reward = max(0.0, min(1.0, reward))
            elif question_type == "free-form":
                score = compute_rouge_score(gt_ans, output_ans)
                reward = max(0.0, min(1.0, score))
            elif question_type == "regression":
                gt_number = normalize_number(gt_ans)
                out_number = normalize_number(output_ans)
                if gt_number is None or out_number is None:
                    reward = 0.0
                rel_diff = (abs(out_number - gt_number) + 1e-9) / (abs(gt_number) + 1e-9)
                rel_diff = min(1.0, max(0.0, rel_diff))
                reward = 1 - rel_diff
            else:
                reward = 0.0
        except Exception as e:
            print(f"Error in reward_fn for question_type '{question_type}': {e}")
            reward = 0.0
    
        rewards.append(reward)
        
        if os.getenv("DEBUG_MODE") == "true":
            log_path = os.getenv("LOG_PATH")
            # local_rank = int(os.getenv("LOCAL_RANK", 0))
            with open(log_path, "a", encoding="utf-8") as f:
                f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
                f.write(f"Content: {content}\n")
                f.write(f"Solution: {sol}\n")
            
    return rewards


def format_reward(completions, **kwargs):
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
    completion_contents = [completion[0]["content"] for completion in completions]
    matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
    return [1.0 if match else 0.0 for match in matches]


def dual_reasoning_reward(completions, solution, **kwargs):
    """
    Reward function for dual reasoning algorithm.
    Expects completions to contain both original and shuffled predictions.
    The algorithm works as follows:
    1. Generate normal completion (with original options)
    2. Extract <think> part from the completion
    3. Generate second completion with shuffled options and the extracted <think> part
    4. Compare answers and assign rewards based on consistency and correctness
    """
    def extract_answer(text):
        pattern = r'<answer>\s*(.*?)\s*</answer>'
        match = re.search(pattern, text, re.DOTALL)
        if match:
            return match.group(1).strip()
        return ""
    
    # Get metadata
    question_type = kwargs.get('problem_type', [''])[0]
    shuffled_mappings = kwargs.get('shuffled_mappings', [])
    original_questions = kwargs.get('original_questions', [])
    shuffled_questions = kwargs.get('shuffled_questions', [])
    dual_reasoning_reward_list = kwargs.get('dual_reasoning_reward_list', [1, 0.3, 0.2, 0.0])
    
    if question_type != "multiple choice":
        # For non-multiple choice, fall back to normal accuracy reward
        return accuracy_reward(completions, solution, **kwargs)
    
    contents = [completion[0]["content"] for completion in completions]
    current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
    rewards = []
    
    # Get the number of original generations
    num_original_generations = len(contents) // 2
    
    # Calculate generations per sample dynamically
    total_samples = len(solution)
    generations_per_sample = len(contents) // total_samples
    
    # Process in pairs: original and shuffled
    for i in range(num_original_generations):
        original_content = contents[i]
        shuffled_content = contents[i + num_original_generations]
        gt_solution = solution[i // generations_per_sample]  # Dynamic calculation based on actual generations
        
        try:
            # Extract answers
            original_answer = extract_answer(original_content)
            shuffled_answer = extract_answer(shuffled_content)
            gt_answer = extract_answer(gt_solution)
            
            # Get the mapping for this shuffled completion
            shuffled_mapping_idx = i + num_original_generations
            reverse_mapping = shuffled_mappings[shuffled_mapping_idx] if shuffled_mapping_idx < len(shuffled_mappings) else {}
            
            # Convert shuffled answer back to original option letter
            mapped_shuffled_answer = shuffled_answer
            if reverse_mapping and shuffled_answer.strip() in reverse_mapping:
                mapped_shuffled_answer = reverse_mapping[shuffled_answer.strip()]
            
            # Check if answers are consistent and correct
            # Now we compare the original answer with the mapped shuffled answer
            answers_consistent = (original_answer.strip() == mapped_shuffled_answer.strip())
            original_correct = (original_answer.strip() == gt_answer.strip())
            shuffled_correct = (mapped_shuffled_answer.strip() == gt_answer.strip())
            
            # Apply reward rules according to user specification:
            # - Both consistent and correct: dual_reasoning_reward_list[0]
            # - Inconsistent but one correct: dual_reasoning_reward_list[1]
            # - Consistent but wrong: dual_reasoning_reward_list[2]
            # - Inconsistent and both wrong: dual_reasoning_reward_list[3]
            
            if answers_consistent and original_correct:
                # Both consistent and correct
                reward = dual_reasoning_reward_list[0]
            elif not answers_consistent and (original_correct or shuffled_correct):
                # Inconsistent but one correct
                reward = dual_reasoning_reward_list[1]
            elif answers_consistent and not original_correct:
                # Consistent but wrong
                reward = dual_reasoning_reward_list[2]
            else:
                # Inconsistent and both wrong
                reward = dual_reasoning_reward_list[3]
                
            rewards.append(reward)
            
        except Exception as e:
            print(f"Error in dual_reasoning_reward: {e}")
            rewards.append(0.0)
            
        # Store debug info for later output (don't write to file here to avoid multi-threading issues)
        debug_info = {
            'original_answer': original_answer,
            'shuffled_answer': shuffled_answer,
            'mapped_shuffled_answer': mapped_shuffled_answer,
            'gt_answer': gt_answer,
            'answers_consistent': answers_consistent,
            'original_correct': original_correct,
            'shuffled_correct': shuffled_correct,
            'reward': reward
        }
        
        # Store debug info in global variable - we'll access it from trainer
        if not hasattr(dual_reasoning_reward, '_debug_infos'):
            dual_reasoning_reward._debug_infos = []
        dual_reasoning_reward._debug_infos.append(debug_info)
    
    # Only return rewards for original completions
    # Shuffled completions are only used for consistency check, not training
    return rewards


reward_funcs_registry = {
    "accuracy": accuracy_reward,
    "format": format_reward,
    "dual_reasoning": dual_reasoning_reward,
}

SYSTEM_PROMPT = (
    "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
    "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
    "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
    "<think> reasoning process here </think><answer> answer here </answer>"
)


def main(script_args, training_args, model_args):
    # Get reward functions
    reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]

    if script_args.dataset_name.endswith('.json') or script_args.dataset_name.endswith('.jsonl'):
        dataset =  DatasetDict({"train": Dataset.from_json(script_args.dataset_name)})
    else:
        # Load the dataset
        dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)


    # Format into conversation
    def make_conversation(example):
        return {
            "prompt": [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": example["problem"]},
            ],
        }

    
    QUESTION_TEMPLATE = (
        "{Question}\n"
        "Please think about this question as if you were a human pondering deeply. "
        "Engage in an internal dialogue using expressions such as 'let me think', 'wait', 'Hmm', 'oh, I see', 'let's break it down', etc, or other natural language thought expressions "
        "It's encouraged to include self-reflection or verification in the reasoning process. "
        "Provide your detailed reasoning between the <think> </think> tags, and then give your final answer between the <answer> </answer> tags."
    )

    TYPE_TEMPLATE = {
        "multiple choice": " Please provide only the single option letter (e.g., A, B, C, D, etc.) within the <answer> </answer> tags.",
        "numerical": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
        "OCR": " Please transcribe text from the image/video clearly and provide your text answer within the <answer> </answer> tags.",
        "free-form": " Please provide your text answer within the <answer> </answer> tags.",
        "regression": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags."
    }

    def make_conversation_image(example):
        
        return {
            "prompt": [
                {
                    "role": "user",
                    "content": [
                        {"type": "image"},
                        {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
                    ],
                },
            ],
        }
    
        
    def make_conversation_video(example):
        return {
            "prompt": [
                {
                    "role": "user",
                    "content": [
                        {"type": "video"},
                        {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
                    ],
                },
            ],
    }
        
    def make_conversation_image_and_video(example):
        if example["problem_type"] == 'multiple choice':
            question = example['problem'] + "Options:\n"
            for op in example["options"]:
                question += op + "\n"
        else:
            question = example['problem']
        
        msg ={
            "prompt": 
               [{
                    "role": "user",
                    "content": [
                        {
                            "type": example['data_type'],
                            # example['data_type']: os.getcwd() + "/Video-R1-data" + example['path'][1:]
                        },
                        {
                            "type": "text",
                            "text": QUESTION_TEMPLATE.format(Question=question) + TYPE_TEMPLATE[example['problem_type']]
                        }
                        ]
                }]
            }
        
        return msg

    
    dataset = dataset.map(make_conversation_image_and_video)

    
    trainer_cls = Qwen2VLGRPOTrainerMaskImg if not training_args.use_vllm else Qwen2VLGRPOVLLMTrainerModified
    print("using: ", trainer_cls)

    # Initialize the GRPO trainer with mask_multimodal and use_efficient_masking arguments
    trainer = trainer_cls(
        model=model_args.model_name_or_path,
        reward_funcs=reward_funcs,
        args=training_args,
        script_args=script_args,
        train_dataset=dataset[script_args.dataset_train_split],
        eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
        peft_config=get_peft_config(model_args),
        attn_implementation=model_args.attn_implementation,
        max_pixels=script_args.max_pixels,
        min_pixels=script_args.min_pixels,
        mask_multimodal=script_args.mask_multimodal,  # Pass the new parameter
        use_efficient_masking=script_args.use_efficient_masking,  # Pass the efficient masking flag
    )
    
    if training_args.resume_from_checkpoint is not None:
        checkpoint = training_args.resume_from_checkpoint
        trainer.train(resume_from_checkpoint=checkpoint)
    else:
        trainer.train()

    # Save and push to hub
    trainer.save_model(training_args.output_dir)
    if training_args.push_to_hub:
        trainer.push_to_hub(dataset_name=script_args.dataset_name)


if __name__ == "__main__":
    parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
    script_args, training_args, model_args = parser.parse_args_and_config()
    main(script_args, training_args, model_args)