# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import re
import textwrap
import threading
from collections import defaultdict
from typing import Any, Callable, Optional, Union
import random

import torch
import torch.utils.data
import transformers
from datasets import Dataset, IterableDataset
from packaging import version
from transformers import (
    AriaForConditionalGeneration,
    AriaProcessor,
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    AutoProcessor,
    AutoTokenizer,
    GenerationConfig,
    PreTrainedModel,
    PreTrainedTokenizerBase,
    Qwen2VLForConditionalGeneration,
    Qwen2_5_VLForConditionalGeneration,
    Trainer,
    TrainerCallback,
    is_wandb_available,
)
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.utils import is_peft_available

from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from trl.models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
from trl.trainer.grpo_config import GRPOConfig
from trl.trainer.utils import generate_model_card, get_comet_experiment_url, pad

from qwen_vl_utils import process_vision_info

import copy


if is_peft_available():
    from peft import PeftConfig, get_peft_model

if is_wandb_available():
    import wandb

# Global lock for debug logging to avoid multi-threading issues
_debug_lock = threading.Lock()

# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]


class Qwen2VLGRPOTrainer(Trainer):
    """
    Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
    paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300).

    Example:

    ```python
    from datasets import load_dataset
    from trl import GRPOTrainer

    dataset = load_dataset("trl-lib/tldr", split="train")

    trainer = GRPOTrainer(
        model="Qwen/Qwen2-0.5B-Instruct",
        reward_funcs="weqweasdas/RM-Gemma-2B",
        train_dataset=dataset,
    )

    trainer.train()
    ```

    Args:
        model (`Union[str, PreTrainedModel]`):
            Model to be trained. Can be either:

            - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
              a path to a *directory* containing model weights saved using
              [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
              loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
              in `args.model_init_kwargs`.
            - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
        reward_funcs (`Union[RewardFunc, list[RewardFunc]]`):
            Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
            functions with the prompts and completions and sum the rewards. Can be either:

            - A single reward function, such as:
                - A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a
                path to a *directory* containing model weights saved using
                [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
                using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the
                keyword arguments in `args.model_init_kwargs`.
                - A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported.
                - A custom reward function: The function is provided with the prompts and the generated completions,
                  plus any additional columns in the dataset. It should return a list of rewards. For more details, see
                  [Using a custom reward function](#using-a-custom-reward-function).
            - A list of reward functions, where each item can independently be any of the above types. Mixing different
            types within the list (e.g., a string model ID and a custom reward function) is allowed.
        args ([`GRPOConfig`], *optional*, defaults to `None`):
            Configuration for this trainer. If `None`, a default configuration is used.
        train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
            Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is
            ignored. The format of the samples can be either:

            - [Standard](dataset_formats#standard): Each sample contains plain text.
            - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
              and content).
        eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
            Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
        processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
            Processing class used to process the data. The padding side must be set to "left". If `None`, the
            processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`].
        reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`):
            Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:

            - A single processing class: Used when `reward_funcs` contains only one reward function.
            - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`.
            If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is
            `None`, the tokenizer for the model is automatically loaded using [`~transformers.AutoTokenizer.from_pretrained`].
            For elements in `reward_funcs` that are custom reward functions (not [`~transformers.PreTrainedModel`]),
            the corresponding entries in `reward_processing_classes` are ignored.
        callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
            List of callbacks to customize the training loop. Will add those to the list of default callbacks
            detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).

            If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
            method.
        optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
            A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
            model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
        peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
            PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
    """

    def __init__(
        self,
        model: Union[str, PreTrainedModel],
        reward_funcs: Union[RewardFunc, list[RewardFunc]],
        args: GRPOConfig = None,
        script_args = None,
        train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
        eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
        processing_class: Optional[PreTrainedTokenizerBase] = None,
        reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
        callbacks: Optional[list[TrainerCallback]] = None,
        optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
        peft_config: Optional["PeftConfig"] = None,
        max_pixels: Optional[int] = 12845056,
        min_pixels: Optional[int] = 3136,
        attn_implementation: str = "flash_attention_2",
    ):
        # Args
        if args is None:
            model_name = model if isinstance(model, str) else model.config._name_or_path
            model_name = model_name.split("/")[-1]
            args = GRPOConfig(f"{model_name}-GRPO")
            

        # Models
        # Trained model
        model_init_kwargs = args.model_init_kwargs or {}
        model_init_kwargs["attn_implementation"] = attn_implementation
        if isinstance(model, str):
            model_id = model
            torch_dtype = model_init_kwargs.get("torch_dtype")
            if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
                pass  # torch_dtype is already a torch.dtype or "auto" or None
            elif isinstance(torch_dtype, str):  # it's a str, but not "auto"
                torch_dtype = getattr(torch, torch_dtype)
                model_init_kwargs["torch_dtype"] = torch_dtype
            else:
                raise ValueError(
                    "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
                    f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
                )
            # Disable caching if gradient checkpointing is enabled (not supported)
            model_init_kwargs["use_cache"] = (
                False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
            )
            if "Qwen2-VL" in model_id:
                model = Qwen2VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
            elif "Qwen2.5-VL" in model_id:
                model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
            elif "Aria" in model_id:
                model_init_kwargs.pop("use_cache")
                model = AriaForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
            else:
                model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
                # model = Qwen2VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
        else:
            model_id = model.config._name_or_path
            if args.model_init_kwargs is not None:
                raise ValueError(
                    "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
                    "This argument can only be used when the `model` argument is a string."
                )

        if peft_config is not None:
            model = get_peft_model(model, peft_config)

        #self.ref_model = None
        # Reference model
        if is_deepspeed_zero3_enabled():
            if "Qwen2-VL" in model_id:
                self.ref_model = Qwen2VLForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs)
            elif "Qwen2.5-VL" in model_id:
                self.ref_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs)
            elif "Aria" in model_id:
                self.ref_model = AriaForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs)
            else:
                self.ref_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs)
                # self.ref_model = Qwen2VLForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs)
        elif peft_config is None:
            # If PEFT configuration is not provided, create a reference model based on the initial model.
            self.ref_model = create_reference_model(model)
        else:
            # If PEFT is used, the reference model is not needed since the adapter can be disabled
            # to revert to the initial model.
            self.ref_model = None

        # Processing class
        if processing_class is None:
            if "Qwen2-VL" in model_id or "Qwen2.5-VL" in model_id or "Aria" in model_id or True:
                processing_class = AutoProcessor.from_pretrained(model_id)
                pad_token_id = processing_class.tokenizer.pad_token_id
                processing_class.pad_token_id = pad_token_id
                processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
                if "Qwen" in model_id or "Qwen2.5-VL" in model_id:
                    processing_class.image_processor.max_pixels = max_pixels
                    processing_class.image_processor.min_pixels = min_pixels
            else:
                processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
                pad_token_id = processing_class.pad_token_id

        # Reward functions
        if not isinstance(reward_funcs, list):
            reward_funcs = [reward_funcs]
        for i, reward_func in enumerate(reward_funcs):
            if isinstance(reward_func, str):
                reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
                    reward_func, num_labels=1, **model_init_kwargs
                )
        self.reward_funcs = reward_funcs

        # Reward processing class
        if reward_processing_classes is None:
            reward_processing_classes = [None] * len(reward_funcs)
        elif not isinstance(reward_processing_classes, list):
            reward_processing_classes = [reward_processing_classes]
        else:
            if len(reward_processing_classes) != len(reward_funcs):
                raise ValueError("The number of reward processing classes must match the number of reward functions.")

        for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
            if isinstance(reward_func, PreTrainedModel):
                if reward_processing_class is None:
                    reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
                if reward_processing_class.pad_token_id is None:
                    reward_processing_class.pad_token = reward_processing_class.eos_token
                # The reward model computes the reward for the latest non-padded token in the input sequence.
                # So it's important to set the pad token ID to the padding token ID of the processing class.
                reward_func.config.pad_token_id = reward_processing_class.pad_token_id
                reward_processing_classes[i] = reward_processing_class
        self.reward_processing_classes = reward_processing_classes

        # Data collator
        def data_collator(features):  # No data collation is needed in GRPO
            return features

        # Training arguments
        self.max_prompt_length = args.max_prompt_length
        self.max_completion_length = args.max_completion_length  # = |o_i| in the GRPO paper
        self.num_generations = args.num_generations  # = G in the GRPO paper
        self.temporal = script_args.temporal
        self.generation_config = GenerationConfig(
            max_new_tokens=self.max_completion_length,
            do_sample=True,
            top_p=0.95,  
            temperature=1, # HACK
            num_return_sequences=self.num_generations,
            pad_token_id=pad_token_id,
        )
        self.shuffled_num_generations = self.num_generations // 2
        self.shuffled_generation_config = GenerationConfig(
            max_new_tokens=self.max_completion_length,
            do_sample=True,
            top_p=0.95,  
            temperature=1, # HACK
            num_return_sequences=self.shuffled_num_generations,
            pad_token_id=pad_token_id,
        )
        
        self.dummy_generation_config = GenerationConfig(
            max_new_tokens=1,
            do_sample=True,
            top_p=0.95,  
            temperature=1, # HACK
            num_return_sequences=1,
            pad_token_id=pad_token_id,
        )
        self.len_control = script_args.len_control
        self.dual_reasoning = getattr(script_args, 'dual_reasoning', False)
        self.dual_reasoning_reward_list = getattr(script_args, 'dual_reasoning_reward_list', [1, 0.3, 0.2, 0.0])

        # KL consistency check parameters for dual reasoning
        self.use_kl_check = getattr(script_args, 'use_kl_check', False)
        self.kl_lambda = getattr(script_args, 'kl_lambda', 0.3)

        # Progressive reward strategy
        self.progressive_reward = getattr(script_args, 'progressive_reward', False)
        
        # Parse progressive_reward_stages (handle string input from command line)
        stages_raw = getattr(script_args, 'progressive_reward_stages', [
            [1.0, 0.7, 0.4, 0.1],  # Stage 1: lenient
            [1.0, 0.5, 0.2, 0.0],  # Stage 2: moderate  
            [1.0, 0.3, 0.1, 0.0]   # Stage 3: strict
        ])
        
        # Default stages for fallback
        default_stages = [
            [1.0, 0.7, 0.4, 0.1],
            [1.0, 0.5, 0.2, 0.0],  
            [1.0, 0.3, 0.1, 0.0]
        ]
        
        if isinstance(stages_raw, str):
            import ast
            try:
                parsed_stages = ast.literal_eval(stages_raw)
                # Ensure all elements are float
                self.progressive_reward_stages = [[float(x) for x in stage] for stage in parsed_stages]
                print(f"Successfully parsed progressive_reward_stages: {self.progressive_reward_stages}")
            except (ValueError, SyntaxError, TypeError) as e:
                print(f"Warning: Failed to parse progressive_reward_stages string: {e}")
                print(f"Original string: {stages_raw}")
                print("Using default stages")
                self.progressive_reward_stages = default_stages
        else:
            try:
                # Ensure all elements are float even for non-string input
                self.progressive_reward_stages = [[float(x) for x in stage] for stage in stages_raw]
            except (ValueError, TypeError) as e:
                print(f"Warning: Failed to convert progressive_reward_stages to float: {e}")
                print("Using default stages")
                self.progressive_reward_stages = default_stages
            
        # Parse progressive_reward_ratios (handle list input from command line)
        ratios_raw = getattr(script_args, 'progressive_reward_ratios', [0.3, 0.4, 0.3])
        default_ratios = [0.3, 0.4, 0.3]
        
        try:
            if isinstance(ratios_raw, list) and len(ratios_raw) > 0 and isinstance(ratios_raw[0], str):
                # Convert string list to float list
                self.progressive_reward_ratios = [float(x) for x in ratios_raw]
            else:
                # Ensure all elements are float
                self.progressive_reward_ratios = [float(x) for x in ratios_raw]
            print(f"Successfully parsed progressive_reward_ratios: {self.progressive_reward_ratios}")
        except (ValueError, TypeError) as e:
            print(f"Warning: Failed to convert progressive_reward_ratios to float: {e}")
            print(f"Original ratios: {ratios_raw}")
            print("Using default ratios")
            self.progressive_reward_ratios = default_ratios
        
        self.beta = args.beta

        # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
        # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
        # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
        # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
        # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
        # This acts as a flag to indicate that the warning has already been issued.
        model.warnings_issued["estimate_tokens"] = True

        # Initialize the metrics
        self._metrics = defaultdict(list)

        super().__init__(
            model=model,
            args=args,
            data_collator=data_collator,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            processing_class=processing_class,
            callbacks=callbacks,
            optimizers=optimizers,
        )

        # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
        # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
        # self.model_accepts_loss_kwargs to False to enable scaling.
        self.model_accepts_loss_kwargs = False

        if self.ref_model is not None:
            if self.is_deepspeed_enabled:
                self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
            else:
                self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)

        for i, reward_func in enumerate(self.reward_funcs):
            if isinstance(reward_func, PreTrainedModel):
                self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)

    def _get_current_dual_reasoning_reward_list(self):
        """Get the current dual reasoning reward list based on training progress."""
        default_reward_list = [1.0, 0.3, 0.2, 0.0]
        
        if not self.progressive_reward:
            reward_list = self.dual_reasoning_reward_list
        else:
            # Calculate current progress
            current_step = self.state.global_step
            max_steps = self.state.max_steps
            
            if max_steps == 0:
                reward_list = self.dual_reasoning_reward_list
            else:
                progress = current_step / max_steps
                
                # Determine current stage
                cumulative_ratio = 0
                reward_list = None
                for i, ratio in enumerate(self.progressive_reward_ratios):
                    cumulative_ratio += ratio
                    if progress <= cumulative_ratio:
                        reward_list = self.progressive_reward_stages[i]
                        # Log stage changes
                        if hasattr(self, '_last_stage') and self._last_stage != i:
                            print(f"Progressive reward: Stage {i+1} at step {current_step} ({progress:.2%}): {reward_list}")
                        self._last_stage = i
                        break
                
                # If we're past all stages, use the last one
                if reward_list is None:
                    reward_list = self.progressive_reward_stages[-1]
        
        # Final safety check: ensure all elements are float
        try:
            reward_list = [float(x) for x in reward_list]
        except (ValueError, TypeError) as e:
            print(f"Warning: Error converting reward list to float in _get_current_dual_reasoning_reward_list: {e}")
            print(f"Problematic reward_list: {reward_list}, type: {type(reward_list)}")
            reward_list = default_reward_list
        
        return reward_list

    def _set_signature_columns_if_needed(self):
        # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
        # By default, this method sets `self._signature_columns` to the model's expected inputs.
        # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
        # Instead, we set them to the columns expected by the `training_step` method, hence the override.
        if self._signature_columns is None:
            self._signature_columns = ["prompt"]


    # Get the per-token log probabilities for the completions for the model and the reference model
    def _get_per_token_logps(self, model, input_ids, **kwargs):
        # logits = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values, image_grid_thw=image_grid_thw).logits  # (B, L, V)
        # import pdb
        # pdb.set_trace()
        logits = model(input_ids, **kwargs).logits
        logits = logits[:, :-1, :]  # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
        input_ids = input_ids[:, 1:]  # (B, L-1), exclude the first input ID since we don't have logits for it
        # Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
        per_token_logps = []
        for logits_row, input_ids_row in zip(logits, input_ids):
            log_probs = logits_row.log_softmax(dim=-1)
            token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
            per_token_logps.append(token_log_prob)
        return torch.stack(per_token_logps)
    
    def remove_none_from_data(self, data):
        for entry in data:
            if "content" in entry and isinstance(entry["content"], list):
                for sub_entry in entry["content"]:
                    if isinstance(sub_entry, dict):
                        keys_to_remove = [k for k, v in sub_entry.items() if v is None]
                        for k in keys_to_remove:
                            del sub_entry[k]
        return data


    # Trainer "prepares" the inputs before calling `compute_loss`. It converts to tensor and move to device.
    # Since we preprocess the data in `compute_loss`, we need to override this method to skip this step.
    def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
        return inputs
    
    def _prepare_shuffled_inputs(self, inputs, think_completions):
        """Prepare inputs with shuffled options for dual reasoning."""
        import random
        
        # Create shuffled version of the prompt
        shuffled_inputs = []
        for i, example in enumerate(inputs):
            if example['problem_type'] == 'multiple choice' and 'options' in example:
                # Extract original options content from the problem text
                options = example['options'].copy()
                original_letters = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'][:len(options)]
                
                # Parse original options to get their content
                original_option_contents = []
                original_option_letters = []
                for option in options:
                    if '. ' in option:
                        letter, content = option.split('. ', 1)
                        original_option_contents.append(content)
                        original_option_letters.append(letter)
                    else:
                        original_option_contents.append(option)
                        original_option_letters.append('')
                
                # Create derangement: shuffle the content ensuring no option stays in its original position
                def is_derangement(perm, n):
                    """Check if permutation is a derangement (no element in its original position)"""
                    for i in range(n):
                        if perm[i] == i:
                            return False
                    return True

                def generate_derangement(n):
                    """Generate a random derangement of n elements"""
                    if n == 1:
                        # No derangement possible for n=1
                        return [0]

                    max_attempts = 1000  # Prevent infinite loop
                    for _ in range(max_attempts):
                        perm = list(range(n))
                        random.shuffle(perm)
                        if is_derangement(perm, n):
                            return perm

                    # Fallback: construct a derangement deterministically
                    # Simple rotation works as derangement for n>1
                    return list(range(1, n)) + [0]

                # Generate derangement indices
                num_options = len(original_option_contents)
                shuffled_content_indices = generate_derangement(num_options)

                # Create mapping from original letter to shuffled letter
                # This tells us which original option content is now at each position
                content_mapping = {}  # original_letter -> shuffled_letter
                reverse_mapping = {}  # shuffled_letter -> original_letter
                
                for new_idx, original_idx in enumerate(shuffled_content_indices):
                    if original_idx < len(original_option_letters) and new_idx < len(original_letters):
                        orig_letter = original_option_letters[original_idx]
                        new_letter = original_letters[new_idx]
                        if orig_letter:  # Skip empty letters
                            content_mapping[orig_letter] = new_letter
                            reverse_mapping[new_letter] = orig_letter
                
                # Reconstruct problem with shuffled option content
                # Handle the actual format: Question\nA. Option1\nB. Option2\n...
                import re
                
                # Split by options pattern (find where options start)
                lines = example['problem'].strip().split('\n')
                question_lines = []
                option_lines = []
                
                # Find where options start (first line that matches "A. ")
                options_start_idx = -1
                for i, line in enumerate(lines):
                    if re.match(r'^[A-H]\.\s', line.strip()):
                        options_start_idx = i
                        break
                
                if options_start_idx != -1:
                    question_lines = lines[:options_start_idx]
                    option_lines = lines[options_start_idx:]
                    
                    # Extract question part
                    question_part = '\n'.join(question_lines)
                    if question_part:
                        question_part += '\n'
                    
                    # Create new problem with shuffled options
                    new_problem = question_part
                    
                    # Create new options with shuffled content but fixed labels A, B, C, D
                    for j, original_idx in enumerate(shuffled_content_indices):
                        if j < len(original_letters) and original_idx < len(original_option_contents):
                            option_line = f"{original_letters[j]}. {original_option_contents[original_idx]}\n"
                            new_problem += option_line
                    
                    # Create new input with shuffled options and mapping information
                    new_example = example.copy()
                    new_example['problem'] = new_problem
                    new_example['shuffled_options'] = [f"{original_letters[j]}. {original_option_contents[shuffled_content_indices[j]]}" 
                                                     for j in range(min(len(original_letters), len(shuffled_content_indices)))]
                    new_example['original_options'] = options
                    new_example['option_mapping'] = content_mapping  # original -> shuffled
                    new_example['reverse_mapping'] = reverse_mapping  # shuffled -> original
                    
                    # Store shuffled problem for later logging
                    new_example['_original_problem_for_debug'] = example['problem']
                    new_example['_shuffled_problem_for_debug'] = new_problem
                    
                    shuffled_inputs.append(new_example)
                else:
                    shuffled_inputs.append(example)
            else:
                shuffled_inputs.append(example)
        
        # Convert back to the format expected by the model  
        formatted_prompts = self._format_inputs_for_generation(shuffled_inputs, think_completions)
        return shuffled_inputs, formatted_prompts
    
    def _output_complete_debug_info(self, device, dual_reward_func):
        """Output complete dual reasoning debug info for all samples with thread safety."""
        with _debug_lock:
            log_path = os.getenv("LOG_PATH", "./debug_log_dual_reasoning.txt")
            
            # Get debug info from reward function
            reward_debug_infos = getattr(dual_reward_func, '_debug_infos', [])
            
            # Calculate number of samples (num_generations per sample)
            num_generations = self.num_generations
            num_samples = len(self._dual_reasoning_debug_info.get('first_generations', [])) // num_generations
            
            with open(log_path, "a", encoding="utf-8") as f:
                for sample_idx in range(num_samples):
                    f.write(f"======== DUAL REASONING DEBUG (Sample {sample_idx + 1}) ========\n")
                    
                    # 1. Original Problem
                    f.write(f"1. ORIGINAL PROBLEM:\n{self._dual_reasoning_debug_info['original_problem']}\n\n")
                    
                    # 2. First Generation (use first generation of this sample)
                    first_gen_idx = sample_idx * num_generations
                    if first_gen_idx < len(self._dual_reasoning_debug_info['first_generations']):
                        f.write(f"2. FIRST GENERATION (Reasoning + Answer):\n{self._dual_reasoning_debug_info['first_generations'][first_gen_idx]}\n\n")
                    
                    # 3. Shuffled Problem
                    shuffled_inputs = self._dual_reasoning_debug_info.get('shuffled_inputs_with_mapping', [])
                    shuffled_input_idx = sample_idx * num_generations
                    if shuffled_input_idx < len(shuffled_inputs) and '_shuffled_problem_for_debug' in shuffled_inputs[shuffled_input_idx]:
                        f.write(f"3. SHUFFLED PROBLEM:\n{shuffled_inputs[shuffled_input_idx]['_shuffled_problem_for_debug']}\n\n")
                    
                    # 4. Second Generation (use first generation of this sample)
                    second_completions = self._dual_reasoning_debug_info.get('second_completions', [])
                    second_gen_idx = sample_idx * num_generations
                    if second_gen_idx < len(second_completions):
                        f.write(f"4. SECOND GENERATION (Shuffled Options + Reasoning + Answer):\n{second_completions[second_gen_idx]}\n\n")
                    
                    # 5. Reward Calculation (show first reward of this sample)
                    reward_idx = sample_idx * num_generations
                    if reward_idx < len(reward_debug_infos):
                        debug_info = reward_debug_infos[reward_idx]
                        f.write(f"5. REWARD CALCULATION:\n")
                        f.write(f"Original Answer: {debug_info['original_answer']}\n")
                        f.write(f"Shuffled Answer: {debug_info['shuffled_answer']} -> Mapped: {debug_info['mapped_shuffled_answer']}\n")
                        f.write(f"Ground Truth: {debug_info['gt_answer']}\n")
                        f.write(f"Consistent: {debug_info['answers_consistent']}, Original Correct: {debug_info['original_correct']}, Shuffled Correct: {debug_info['shuffled_correct']}\n")
                        f.write(f"Final Reward: {debug_info['reward']}\n")
                    
                    f.write(f"========================================\n\n")
            
            # Clear debug infos for next batch
            if hasattr(dual_reward_func, '_debug_infos'):
                dual_reward_func._debug_infos.clear()
            if hasattr(self, '_dual_reasoning_debug_info'):
                delattr(self, '_dual_reasoning_debug_info')
    
    def _format_inputs_for_generation(self, inputs, think_completions):
        """Format inputs for generation with thinking completions."""
        # This needs to return the same format as the original prompt_inputs
        # We need to reconstruct the prompts with shuffled options and add thinking part
        
        prompts = []
        for i, example in enumerate(inputs):
            # Get the thinking part
            thinking_part = think_completions[i] if i < len(think_completions) else ""
            
            # Create content list with proper media type
            content = []
            
            # Add media content (video or image)
            if example['data_type'] == 'video':
                content.append({
                    "type": "video",
                })
            elif example['data_type'] == 'image':
                content.append({
                    "type": "image",
                })
            
            # Add text content with thinking and problem  
            prompt_text = f"{example['problem']}\nThinking: <think>{thinking_part}</think>\nBased on the above reasoning, please provide your final answer as a single letter (A, B, C, or D) between <answer> and </answer> tags:"
            

            
            content.append({
                "type": "text",
                "text": prompt_text
            })
            
            # Create new conversation format with thinking part
            new_prompt = [{
                "role": "user",
                "content": content
            }]
            prompts.append(new_prompt)
        
        return prompts

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        if return_outputs:
            raise ValueError("The GRPOTrainer does not support returning outputs")
    
        # Define device early to avoid UnboundLocalError
        device = self.accelerator.device
        
        prompts = [x["prompt"] for x in inputs]
        prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]

                
        
        input_copy = copy.deepcopy(inputs[0]['prompt'])
        
        input_copy = self.remove_none_from_data(input_copy)
        
        if inputs[0]['data_type'] == 'image':
            input_copy[0]['content'][0]['image'] = os.getcwd() + "/Video-R1-data" + inputs[0]['path'][1:] 
        elif inputs[0]['data_type'] == 'video':
            input_copy[0]['content'][0]['video'] = os.getcwd() + "/Video-R1-data" + inputs[0]['path'][1:] 
            
        try:
            image_inputs, video_inputs, video_kwargs = process_vision_info(input_copy, return_video_kwargs=True)
        except Exception as e:
            print(f"process_vision_info error, using fixed data, {e}")
            if inputs[0]['data_type'] == 'image':
                input_copy[0]['content'][0]['image'] = os.getcwd() + "/Video-R1-data" + '/Math/Multimath-300k/17ff4c7d14c388134de02381b1fc2824.png'
            elif inputs[0]['data_type'] == 'video':
                input_copy[0]['content'][0]['video'] = os.getcwd() + "/Video-R1-data" + '/LLaVA-Video-178K/liwei_youtube_videos/videos/youtube_video_2024/ytb_7nRmsEw7nsE.mp4'
                
            image_inputs, video_inputs, video_kwargs = process_vision_info(input_copy, return_video_kwargs=True)
        
        
        prompt_inputs = self.processing_class(
            text=copy.deepcopy(prompts_text),
            images=image_inputs,
            videos=video_inputs,
            return_tensors="pt",
            padding=True,
            padding_side="left",
            add_special_tokens=False,
        )
        
        
        prompt_inputs = super()._prepare_inputs(prompt_inputs)


        # fix prompt_inputs["input_ids"] length issue
        if self.max_prompt_length is not None:
            prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -self.max_prompt_length :]
            prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -self.max_prompt_length :]

        prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]

        
        if self.max_prompt_length is not None:
            prompt_ids = prompt_ids[:, -self.max_prompt_length :]
            prompt_mask = prompt_mask[:, -self.max_prompt_length :]
            
        if self.temporal and video_inputs:
            indices = torch.randperm(video_inputs[0].size(0))
            shuffled_video_inputs = [video_inputs[0][indices]]
            temporal_shuffled_prompt_inputs = self.processing_class(
                text=copy.deepcopy(prompts_text),
                images=image_inputs,
                videos=shuffled_video_inputs,
                return_tensors="pt",
                padding=True,
                padding_side="left",
                add_special_tokens=False,
            )
            temporal_shuffled_prompt_inputs = super()._prepare_inputs(temporal_shuffled_prompt_inputs)
            shuffled_prompt_ids, shuffled_prompt_mask = temporal_shuffled_prompt_inputs["input_ids"], temporal_shuffled_prompt_inputs["attention_mask"]
            if self.max_prompt_length is not None:
                shuffled_prompt_ids = shuffled_prompt_ids[:, -self.max_prompt_length :]
                shuffled_prompt_mask = shuffled_prompt_mask[:, -self.max_prompt_length :]
        
        
        # Generate completions
        with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
            if self.dual_reasoning and inputs[0]['problem_type'] == 'multiple choice':
                # For dual reasoning, we generate twice:
                # 1. First generate normal completion
                # 2. Then generate with shuffled options using extracted thinking
                

                
                # First normal generation
                prompt_completion_ids = unwrapped_model.generate(**prompt_inputs, generation_config=self.generation_config)
                prompt_length = prompt_ids.size(1)
                prompt_ids = prompt_completion_ids[:, :prompt_length]
                completion_ids = prompt_completion_ids[:, prompt_length:]
                prompt_mask = prompt_mask.repeat_interleave(self.num_generations, dim=0)
                
                # Decode first completions and extract thinking parts
                first_completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
                think_parts = []
                for completion in first_completions:
                    think_match = re.search(r'<think>(.*?)</think>', completion, re.DOTALL)
                    if think_match:
                        think_parts.append(think_match.group(1).strip())
                    else:
                        think_parts.append("")
                
                # Store first generation results for later logging
                self._dual_reasoning_debug_info = {
                    'original_problem': inputs[0]['problem'],
                    'first_generations': first_completions.copy()
                }
                
                # Generate second completions with shuffled options
                # For each thinking part, we need to create a shuffled version of the original input
                shuffled_inputs_with_mapping = []
                shuffled_prompt_inputs = []
                
                # Create shuffled inputs for each thinking part
                # Fix: Each think_part corresponds to one original completion, 
                # so we need to create one shuffled input per think_part, not per input example
                for i, think_part in enumerate(think_parts):
                    # Use only the first input example since we're generating for one sample at a time
                    # In GRPO, typically batch size is 1 for generation
                    input_example = inputs[0] if len(inputs) > 0 else inputs[i % len(inputs)]
                    shuffled_input_with_mapping, shuffled_prompt_input = self._prepare_shuffled_inputs([input_example], [think_part])
                    shuffled_inputs_with_mapping.extend(shuffled_input_with_mapping)
                    shuffled_prompt_inputs.extend(shuffled_prompt_input)
                
                # Process shuffled inputs to get tensors
                shuffled_prompts_text = []
                for shuffled_input in shuffled_prompt_inputs:
                    shuffled_prompt_text = maybe_apply_chat_template({"prompt": shuffled_input}, self.processing_class)["prompt"]
                    shuffled_prompts_text.append(shuffled_prompt_text)
                
                # Repeat image/video inputs to match the number of shuffled text inputs
                num_shuffled_texts = len(shuffled_prompts_text)
                shuffled_image_inputs = None
                shuffled_video_inputs = None
                
                if image_inputs is not None:
                    # Repeat each image for all shuffled texts
                    shuffled_image_inputs = image_inputs * num_shuffled_texts
                
                if video_inputs is not None:
                    # Repeat each video for all shuffled texts  
                    shuffled_video_inputs = video_inputs * num_shuffled_texts
                
                # Create shuffled prompt tensor inputs
                shuffled_tensor_inputs = self.processing_class(
                    text=shuffled_prompts_text,
                    images=shuffled_image_inputs,
                    videos=shuffled_video_inputs,
                    return_tensors="pt",
                    padding=True,
                    padding_side="left",
                    add_special_tokens=False,
                )
                shuffled_tensor_inputs = super()._prepare_inputs(shuffled_tensor_inputs)
                
                # Create generation config for dual reasoning (only one generation per prompt)
                dual_reasoning_generation_config = GenerationConfig(
                    max_new_tokens=self.max_completion_length,
                    do_sample=True,
                    top_p=0.95,  
                    temperature=1,
                    num_return_sequences=1,  # Only one generation per prompt
                    pad_token_id=self.processing_class.pad_token_id,
                )
                

                
                # Generate second completions
                second_completion_ids = unwrapped_model.generate(**shuffled_tensor_inputs, generation_config=dual_reasoning_generation_config)
                second_completion_only = second_completion_ids[:, shuffled_tensor_inputs["input_ids"].size(1):]
                
                # Store second generation for later complete debug output
                second_completions = self.processing_class.batch_decode(second_completion_only, skip_special_tokens=True)
                self._dual_reasoning_debug_info['second_completions'] = second_completions.copy()
                self._dual_reasoning_debug_info['shuffled_inputs_with_mapping'] = shuffled_inputs_with_mapping.copy()
                
                # For second completion, we want the full sequence: 
                # prompt + thinking + answer (already generated by the model)
                # So second_completion_ids already contains what we want
                
                # Pad completions to same length before concatenating
                max_length = max(completion_ids.size(1), second_completion_only.size(1))
                
                # Pad first completions
                if completion_ids.size(1) < max_length:
                    padding_length = max_length - completion_ids.size(1)
                    completion_ids = torch.cat([
                        completion_ids, 
                        torch.full((completion_ids.size(0), padding_length), 
                                 self.processing_class.pad_token_id, 
                                 device=device, dtype=completion_ids.dtype)
                    ], dim=1)
                
                # Pad  completions  
                if second_completion_only.size(1) < max_length:
                    padding_length = max_length - second_completion_only.size(1)
                    second_completion_only = torch.cat([
                        second_completion_only,
                        torch.full((second_completion_only.size(0), padding_length),
                                 self.processing_class.pad_token_id,
                                 device=device, dtype=second_completion_only.dtype)
                    ], dim=1)
                
                # Don't concatenate shuffled completions to completion_ids for training
                # Only original completions should be used for training
                # Store shuffled completions separately for reward calculation
                self._dual_reasoning_debug_info['shuffled_completion_ids'] = second_completion_only
                
                # For dual reasoning, we need to handle prompt_completion_ids properly
                # The first generation already created prompt_completion_ids
                # The second generation created second_completion_ids (full sequence)
                # We need to ensure they have the same sequence length
                
                # Ensure both prompt_completion_ids and second_completion_ids have same length
                max_total_length = max(prompt_completion_ids.size(1), second_completion_ids.size(1))
                
                # Pad prompt_completion_ids if needed
                if prompt_completion_ids.size(1) < max_total_length:
                    padding_length = max_total_length - prompt_completion_ids.size(1)
                    prompt_completion_ids = torch.cat([
                        prompt_completion_ids,
                        torch.full((prompt_completion_ids.size(0), padding_length),
                                 self.processing_class.pad_token_id,
                                 device=device, dtype=prompt_completion_ids.dtype)
                    ], dim=1)
                
                # Pad second_completion_ids if needed
                if second_completion_ids.size(1) < max_total_length:
                    padding_length = max_total_length - second_completion_ids.size(1)
                    second_completion_ids = torch.cat([
                        second_completion_ids,
                        torch.full((second_completion_ids.size(0), padding_length),
                                 self.processing_class.pad_token_id,
                                 device=device, dtype=second_completion_ids.dtype)
                    ], dim=1)
                
                # Don't concatenate shuffled sequences to prompt_completion_ids for training
                # Only original completions should be used for training
                # Store shuffled sequences separately
                self._dual_reasoning_debug_info['shuffled_prompt_completion_ids'] = second_completion_ids
                
            else:
                # Normal generation
                prompt_completion_ids = unwrapped_model.generate(**prompt_inputs, generation_config=self.generation_config)
                prompt_length = prompt_ids.size(1)
                prompt_ids = prompt_completion_ids[:, :prompt_length]
                completion_ids = prompt_completion_ids[:, prompt_length:]
                prompt_mask = prompt_mask.repeat_interleave(self.num_generations, dim=0)
            
            if self.temporal:
                
                if video_inputs:
            
                    shuffled_prompt_completion_ids = unwrapped_model.generate(**temporal_shuffled_prompt_inputs, generation_config=self.shuffled_generation_config)
                    shuffled_prompt_length = shuffled_prompt_ids.size(1)
                    shuffled_prompt_ids = shuffled_prompt_completion_ids[:, :shuffled_prompt_length]
                    shuffled_completion_ids = shuffled_prompt_completion_ids[:, shuffled_prompt_length:]
                    shuffled_prompt_mask = prompt_mask.repeat_interleave(self.shuffled_num_generations, dim=0)
                    
                else:
                    
                    shuffled_prompt_completion_ids = unwrapped_model.generate(**prompt_inputs, generation_config=self.dummy_generation_config)

        
        print('path:', input_copy[0]['content'][0][inputs[0]['data_type']])   
        print('problem_id:', inputs[0]['problem_id'])       
        print('prompt_length:', prompt_length)
                
        
        
        
        # Mask everything after the first EOS token
        is_eos = completion_ids == self.processing_class.eos_token_id
        eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
        eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
        sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
        completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()

        # Concatenate prompt_mask with completion_mask for logit computation
        # attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)  # (B*G, P+C)
        # pixel_values = prompt_inputs["pixel_values"].repeat(self.num_generations, 1)
        # image_grid_thw = prompt_inputs["image_grid_thw"].repeat_interleave(self.num_generations, dim=0)
        

        
        prompt_inputs.pop("input_ids")
        prompt_inputs.pop("attention_mask")
        
        if inputs[0]['data_type'] == 'image':
            prompt_inputs["pixel_values"] = prompt_inputs["pixel_values"].repeat(len(prompt_completion_ids), 1)
            prompt_inputs["image_grid_thw"] = prompt_inputs["image_grid_thw"].repeat(len(prompt_completion_ids), 1)
        # import pdb; pdb.set_trace()
        

        if inputs[0]['data_type'] == 'video':
            prompt_inputs["pixel_values_videos"] = prompt_inputs["pixel_values_videos"].repeat(len(prompt_completion_ids), 1)
            prompt_inputs["video_grid_thw"] = prompt_inputs["video_grid_thw"].repeat(len(prompt_completion_ids), 1)
            if 'second_per_grid_ts' in prompt_inputs:
                del prompt_inputs["second_per_grid_ts"]
                # prompt_inputs["second_per_grid_ts"] = torch.tensor(prompt_inputs["second_per_grid_ts"]).repeat(len(prompt_completion_ids), 1)
        
        
        
        
        try:
            per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, **prompt_inputs)
            per_token_logps = per_token_logps[:, prompt_length - 1 :]
            
            # In dual reasoning, ensure completion_mask matches per_token_logps dimensions
            if self.dual_reasoning and inputs[0]['problem_type'] == 'multiple choice':
                # Extract completion part from prompt_completion_ids for mask calculation
                completion_from_prompt = prompt_completion_ids[:, prompt_length:]
                
                # Recalculate completion_mask based on the same tensor used for per_token_logps
                is_eos_prompt = completion_from_prompt == self.processing_class.eos_token_id
                eos_idx_prompt = torch.full((is_eos_prompt.size(0),), is_eos_prompt.size(1), dtype=torch.long, device=device)
                eos_idx_prompt[is_eos_prompt.any(dim=1)] = is_eos_prompt.int().argmax(dim=1)[is_eos_prompt.any(dim=1)]
                sequence_indices_prompt = torch.arange(is_eos_prompt.size(1), device=device).expand(is_eos_prompt.size(0), -1)
                completion_mask = (sequence_indices_prompt <= eos_idx_prompt.unsqueeze(1)).int()
                
        except Exception as e:
            print(f"Error computing per_token_logps: {e}. Setting output to zero.")
            # per_token_logps = torch.tensor(0.0, device=prompt_completion_ids.device, requires_grad=True)
            per_token_logps = self._get_per_token_logps(model, prompt_completion_ids)
        
        with torch.inference_mode():
            try:
                if self.ref_model is not None:
                    ref_per_token_logps = self._get_per_token_logps(self.ref_model, prompt_completion_ids, **prompt_inputs)
                else:
                    with self.accelerator.unwrap_model(model).disable_adapter():
                        ref_per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, **prompt_inputs)
                ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1 :]
            except Exception as e:
                print(f"Error computing ref_per_token_logps: {e}. Setting output to zero.")
                # ref_per_token_logps = torch.tensor(0.0, device=prompt_completion_ids.device)
                with self.accelerator.unwrap_model(model).disable_adapter():
                    ref_per_token_logps = self._get_per_token_logps(model, prompt_completion_ids)
                ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1 :]

        # Compute the KL divergence between the model and the reference model
        
        x_clamped = torch.clamp(ref_per_token_logps - per_token_logps, min=-10, max=10)  # 限制 x 的范围
        per_token_kl = torch.exp(x_clamped) - x_clamped - 1
        
        # Fix: Don't run temporal shuffled reward calculation in dual reasoning mode
        # as dual reasoning already handles its own shuffled completions
        if self.temporal and video_inputs and not (self.dual_reasoning and inputs[0]['problem_type'] == 'multiple choice'):
            shuffled_completions = self.processing_class.batch_decode(shuffled_completion_ids, skip_special_tokens=True)
            if is_conversational(inputs[0]):
                shuffled_completions = [[{"role": "assistant", "content": shuffled_completion}] for shuffled_completion in shuffled_completions]
                
            # Compute the rewards
            shuffled_prompts = [prompt for prompt in prompts for _ in range(self.shuffled_num_generations)]
            shuffled_rewards_per_func = torch.zeros(len(shuffled_prompts), len(self.reward_funcs), device=device)
            for i, (reward_func, reward_processing_class) in enumerate(
                zip(self.reward_funcs, self.reward_processing_classes)
            ):
                # Repeat all input columns (but "prompt" and "completion") to match the number of generations
                shuffled_reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
                for key in shuffled_reward_kwargs:
                    for example in inputs:
                        # Repeat each value in the column for `num_generations` times
                        shuffled_reward_kwargs[key].extend([example[key]] * self.shuffled_num_generations)
                shuffled_output_reward_func = reward_func(prompts=shuffled_prompts, completions=shuffled_completions, **shuffled_reward_kwargs)
                shuffled_rewards_per_func[:, i] = torch.tensor(shuffled_output_reward_func, dtype=torch.float32, device=device)

        
        # Decode the generated completions
        completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
        if is_conversational(inputs[0]):
            completions = [[{"role": "assistant", "content": completion}] for completion in completions]
            
        # Compute the rewards
        if self.dual_reasoning and inputs[0]['problem_type'] == 'multiple choice':
            # For dual reasoning, we need to combine original and shuffled completions for reward calculation
            # but only train on original completions
            shuffled_completion_ids = self._dual_reasoning_debug_info.get('shuffled_completion_ids', torch.tensor([]))
            shuffled_completions = self.processing_class.batch_decode(shuffled_completion_ids, skip_special_tokens=True)
            if is_conversational(inputs[0]):
                shuffled_completions = [[{"role": "assistant", "content": completion}] for completion in shuffled_completions]
            
            # Combine for reward calculation: [original_completions + shuffled_completions]
            combined_completions = completions + shuffled_completions
            
            # Find dual_reasoning reward function
            dual_reward_func = None
            for reward_func in self.reward_funcs:
                if hasattr(reward_func, '__name__') and reward_func.__name__ == 'dual_reasoning_reward':
                    dual_reward_func = reward_func
                    break
            
            if dual_reward_func is not None:
                # Use dual reasoning reward
                # We pass combined_completions (original + shuffled) but only get rewards for original
                prompts = [prompt for prompt in prompts for _ in range(self.num_generations * 2)]
                reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
                for key in reward_kwargs:
                    for example in inputs:
                        # Extend for both original and shuffled completions
                        reward_kwargs[key].extend([example[key]] * (self.num_generations * 2))
                
                # Add mapping information for dual reasoning if available
                if 'shuffled_inputs_with_mapping' in locals():
                    # Add mapping information from shuffled inputs
                    # For dual reasoning: first num_generations are original completions, 
                    # second num_generations are shuffled completions
                    reward_kwargs['shuffled_mappings'] = []
                    reward_kwargs['original_questions'] = []
                    reward_kwargs['shuffled_questions'] = []
                    
                    # First num_generations (original completions) - no mapping needed
                    for i in range(self.num_generations):
                        reward_kwargs['shuffled_mappings'].append({})
                        # Add original question (all original generations use the same question)
                        if len(inputs) > 0:
                            reward_kwargs['original_questions'].append(inputs[0]['problem'])
                        else:
                            reward_kwargs['original_questions'].append("")
                        reward_kwargs['shuffled_questions'].append("")  # No shuffled question for original completions
                    
                    # Second num_generations (shuffled completions) - add reverse mapping to convert shuffled answer back to original
                    for i in range(self.num_generations):
                        if i < len(shuffled_inputs_with_mapping) and 'reverse_mapping' in shuffled_inputs_with_mapping[i]:
                            reward_kwargs['shuffled_mappings'].append(shuffled_inputs_with_mapping[i]['reverse_mapping'])
                        else:
                            reward_kwargs['shuffled_mappings'].append({})
                        reward_kwargs['original_questions'].append("")  # No original question for shuffled completions
                        # Add shuffled question
                        if i < len(shuffled_inputs_with_mapping):
                            reward_kwargs['shuffled_questions'].append(shuffled_inputs_with_mapping[i]['problem'])
                        else:
                            reward_kwargs['shuffled_questions'].append("")
                    
                    # Add dual reasoning reward list parameter (with progressive support)
                    reward_kwargs['dual_reasoning_reward_list'] = self._get_current_dual_reasoning_reward_list()
                    

                
                # Call dual reasoning reward function with combined completions
                output_reward_func = dual_reward_func(prompts=prompts, completions=combined_completions, **reward_kwargs)
                
                # Output complete debug info immediately after reward calculation
                if os.getenv("DEBUG_MODE") == "true" and hasattr(self, '_dual_reasoning_debug_info'):
                    self._output_complete_debug_info(device, dual_reward_func)
                
                # dual_reasoning_reward function returns num_generations rewards
                # (one reward for each original completion only)
                # Shuffled completions are only used for consistency check, not training
                
                rewards_per_func = torch.tensor(output_reward_func, dtype=torch.float32, device=device).unsqueeze(1)
            else:
                # Fallback to normal reward if dual_reasoning reward not found
                prompts = [prompt for prompt in prompts for _ in range(self.num_generations)]
                rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
                for i, (reward_func, reward_processing_class) in enumerate(
                    zip(self.reward_funcs, self.reward_processing_classes)
                ):
                    reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
                    for key in reward_kwargs:
                        for example in inputs:
                            reward_kwargs[key].extend([example[key]] * self.num_generations)
                    # Add dual reasoning reward list parameter for dual_reasoning reward function
                    if hasattr(reward_func, '__name__') and reward_func.__name__ == 'dual_reasoning_reward':
                        reward_kwargs['dual_reasoning_reward_list'] = self._get_current_dual_reasoning_reward_list()
                    output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
                    rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
        else:
            # Normal reward computation
            prompts = [prompt for prompt in prompts for _ in range(self.num_generations)]
            rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
            for i, (reward_func, reward_processing_class) in enumerate(
                zip(self.reward_funcs, self.reward_processing_classes)
            ):
                # Repeat all input columns (but "prompt" and "completion") to match the number of generations
                reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
                for key in reward_kwargs:
                    for example in inputs:
                        # Repeat each value in the column for `num_generations` times
                        reward_kwargs[key].extend([example[key]] * self.num_generations)
                # Add dual reasoning reward list parameter for dual_reasoning reward function
                if hasattr(reward_func, '__name__') and reward_func.__name__ == 'dual_reasoning_reward':
                    reward_kwargs['dual_reasoning_reward_list'] = self._get_current_dual_reasoning_reward_list()
                output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
                rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
        

        
        
        if self.temporal and video_inputs:
            temporal_rewards_per_func = rewards_per_func.clone()
            
            acc_mean = temporal_rewards_per_func[:, 0].mean()
            
            # Fix: Handle temporal logic for dual reasoning mode
            if self.dual_reasoning and inputs[0]['problem_type'] == 'multiple choice':
                # In dual reasoning mode, we don't have separate shuffled rewards for temporal comparison
                # Instead, we can compare the original vs shuffled completions within the dual reasoning rewards
                # For now, use a different baseline or skip temporal enhancement in dual reasoning mode
                temporal_rewards = torch.tensor([0.5]).to('cuda')
            else:
                # Original temporal logic for non-dual reasoning mode
                shuffled_acc_mean = shuffled_rewards_per_func[:, 0].mean()
                if acc_mean >= 0.8 * shuffled_acc_mean:
                    mask = temporal_rewards_per_func[:, 0] > 0.1
                    temporal_rewards_per_func[mask, 0] = temporal_rewards_per_func[mask, 0] + 0.3
                    temporal_rewards = torch.tensor([1.0]).to('cuda')
                else:
                    temporal_rewards = torch.tensor([0.0]).to('cuda')
        else:
            temporal_rewards =  torch.tensor([0.5]).to('cuda')
        
        # Sum the rewards from all reward functions
        if self.temporal and video_inputs and not (self.dual_reasoning and inputs[0]['problem_type'] == 'multiple choice'):
            rewards = temporal_rewards_per_func.sum(dim=1)
        else:
            rewards = rewards_per_func.sum(dim=1)
    
        
        if self.len_control:
            mem_rewards = [0] * self.num_generations
            mask = rewards_per_func[:, 0] > 0.1
            lenth_list = completion_mask.sum(1)
            selected_indices = torch.nonzero(mask, as_tuple=True)[0].tolist()
            #             if len(selected_indices) > 1 and len(selected_indices) < self.num_generations:
            # if len(selected_indices) > 1:
            #     selected_items = [(i, lenth_list[i]) for i in selected_indices]
            #     sorted_items = sorted(selected_items, key=lambda x: x[1], reverse=True)
            #     N = len(sorted_items)
            #     for rank, (idx, length) in enumerate(sorted_items):
            #         reward = 0.2 - 0.2 * (rank / N)
            #         rewards[idx] += reward
            #         mem_rewards[idx] = reward
            # for idx in range(len(lenth_list)):
            #     if lenth_list[idx] >= 512:
            #         rewards[idx] -= 0.5
                    
            if len(selected_indices) > 1:     
                for idx in selected_indices:
                    if 320 <= lenth_list[idx] <= 512:
                        rewards[idx] += 0.2
        
        print(rewards)
        print(completion_mask.sum(1))

        # Compute grouped-wise rewards
        mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
        std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)

        # Normalize the rewards to compute the advantages
        mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
        std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
        advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
        
        # if self.len_control and len(selected_indices) == self.num_generations:
        #     for idx in range(len(rewards)):
        #         advantages[idx] += (mem_rewards[idx] - 0.2) * 2

        # x - x.detach() allows for preserving gradients from x
        per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
        per_token_loss = -(per_token_loss - self.beta * per_token_kl)
        # per_token_loss = -per_token_loss
        loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()

        # Add KL consistency loss if enabled and in dual reasoning mode
        if self.use_kl_check and self.dual_reasoning and inputs[0]['problem_type'] == 'multiple choice':
            # Extract the generated reasoning from completions
            # Use the first completions (before shuffling) as they contain the thinking
            if hasattr(self, '_dual_reasoning_debug_info') and 'first_generations' in self._dual_reasoning_debug_info:
                generated_reasoning = self._dual_reasoning_debug_info['first_generations']

                # Compute KL consistency loss
                kl_consistency_loss = self.compute_kl_consistency_loss(
                    model=model,
                    batch_inputs=inputs,
                    generated_reasoning=generated_reasoning
                )

                # Add to total loss with lambda coefficient
                loss = loss + self.kl_lambda * kl_consistency_loss

                # Log the KL consistency loss
                self._metrics["kl_consistency_loss"].append(kl_consistency_loss.item())


        # import pdb
        # pdb.set_trace()

        # Log the metrics
        completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
        self._metrics["completion_length"].append(completion_length)

        reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
        
        # Handle dual reasoning case where we only have one effective reward function
        if self.dual_reasoning and inputs[0]['problem_type'] == 'multiple choice':
            # In dual reasoning mode, we only use dual_reasoning reward function
            # So reward_per_func has shape (1,) but self.reward_funcs might have multiple functions
            dual_reward_name = "dual_reasoning"
            self._metrics[f"rewards/{dual_reward_name}"].append(reward_per_func[0].item())
        else:
            # Normal case: iterate through all reward functions
            for i, reward_func in enumerate(self.reward_funcs):
                if i < reward_per_func.size(0):  # Safety check to avoid index error
                    if isinstance(reward_func, PreTrainedModel):
                        reward_func_name = reward_func.config._name_or_path.split("/")[-1]
                    else:
                        reward_func_name = reward_func.__name__
                    self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
        
        gathered_rewards = self.accelerator.gather_for_metrics(rewards)
        
        num_devices = gathered_rewards.size(0) // self.num_generations 
        rewards_per_device = gathered_rewards.view(num_devices, self.num_generations)
        wrong_devices = (rewards_per_device <= 1).all(dim=1)
        wrong_ratio = wrong_devices.sum().item() / num_devices
        
        correct_devices = (rewards_per_device >= 2).all(dim=1)
        correct_ratio = correct_devices.sum().item() / num_devices
        
        self._metrics["all_wrong"].append(wrong_ratio)
        self._metrics["all_correct"].append(correct_ratio)
        
        if self.temporal:
            temporal_rewards_list = self.accelerator.gather_for_metrics(temporal_rewards)
            self._metrics["temporal_rewards"].append(self.accelerator.gather_for_metrics(temporal_rewards_list).mean().item())
        
        self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())

        self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())

        mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
        self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
        

        return loss

    def compute_kl_consistency_loss(self, model, batch_inputs, generated_reasoning, temperature=1.0):
        """
        Compute KL divergence between predictions with original and shuffled options.
        This encourages the model to develop more robust reasoning that isn't sensitive to option ordering.

        Args:
            model: The policy model
            batch_inputs: Original batch inputs containing prompts and options
            generated_reasoning: Previously generated reasoning from GRPO sampling
            temperature: Temperature for softmax normalization

        Returns:
            kl_loss: KL divergence loss between original and shuffled option distributions
        """
        import torch
        import torch.nn.functional as F
        import random
        import re

        device = self.accelerator.device
        kl_losses = []

        # Process each sample in the batch
        for idx, inputs in enumerate(batch_inputs):
            # Only apply to multiple choice problems
            if inputs['problem_type'] != 'multiple choice' or 'options' not in inputs:
                continue

            # Extract reasoning for this sample (assumes generated_reasoning is a list)
            reasoning = generated_reasoning[idx] if idx < len(generated_reasoning) else ""

            # Extract thinking part from reasoning if it contains think tags
            think_match = re.search(r'<think>(.*?)</think>', reasoning, re.DOTALL)
            thinking_part = think_match.group(1).strip() if think_match else reasoning

            # Prepare original input with reasoning
            original_prompt = self._prepare_prompt_with_reasoning(inputs, thinking_part, shuffle=False)

            # Prepare shuffled input with same reasoning
            shuffled_prompt, shuffle_mapping = self._prepare_prompt_with_reasoning(inputs, thinking_part, shuffle=True)

            # Get model predictions for original ordering
            with torch.no_grad():
                # Process original prompt
                original_inputs = self._process_single_prompt(original_prompt)
                original_outputs = model(**original_inputs)
                original_logits = original_outputs.logits

                # Extract answer token logits (last token position for options A, B, C, D)
                original_answer_logits = self._extract_answer_logits(original_logits, inputs['options'])
                original_probs = F.softmax(original_answer_logits / temperature, dim=-1)

            # Get model predictions for shuffled ordering
            with torch.no_grad():
                # Process shuffled prompt
                shuffled_inputs = self._process_single_prompt(shuffled_prompt)
                shuffled_outputs = model(**shuffled_inputs)
                shuffled_logits = shuffled_outputs.logits

                # Extract answer token logits
                shuffled_answer_logits = self._extract_answer_logits(shuffled_logits, inputs['options'])

                # Reorder according to shuffle mapping to align with original
                aligned_shuffled_logits = self._reorder_logits(shuffled_answer_logits, shuffle_mapping)
                shuffled_probs = F.softmax(aligned_shuffled_logits / temperature, dim=-1)

            # Compute KL divergence
            kl_div = F.kl_div(
                torch.log(shuffled_probs + 1e-8),
                original_probs,
                reduction='batchmean'
            )

            kl_losses.append(kl_div)

        # Average KL losses across the batch
        if kl_losses:
            return torch.stack(kl_losses).mean()
        else:
            # Return zero loss if no multiple choice problems in batch
            return torch.tensor(0.0, device=device)

    def _prepare_prompt_with_reasoning(self, inputs, reasoning, shuffle=False):
        """Prepare prompt with reasoning and optionally shuffle options."""
        import random

        prompt = inputs['prompt'].copy()

        if shuffle and 'options' in inputs:
            # Create derangement: shuffle the options ensuring no option stays in its original position
            def is_derangement(perm, n):
                """Check if permutation is a derangement (no element in its original position)"""
                for i in range(n):
                    if perm[i] == i:
                        return False
                return True

            def generate_derangement(n):
                """Generate a random derangement of n elements"""
                if n == 1:
                    # No derangement possible for n=1
                    return [0]

                max_attempts = 1000  # Prevent infinite loop
                for _ in range(max_attempts):
                    perm = list(range(n))
                    random.shuffle(perm)
                    if is_derangement(perm, n):
                        return perm

                # Fallback: construct a derangement deterministically
                # Simple rotation works as derangement for n>1
                return list(range(1, n)) + [0]

            # Shuffle options using derangement
            options = inputs['options'].copy()
            original_letters = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'][:len(options)]
            shuffled_indices = generate_derangement(len(options))

            # Create shuffle mapping
            shuffle_mapping = {i: shuffled_indices[i] for i in range(len(options))}

            # Replace options in prompt
            shuffled_options = [options[i] for i in shuffled_indices]

            # Modify prompt with shuffled options
            problem_text = inputs['problem']
            for i, (orig_letter, shuffled_option) in enumerate(zip(original_letters, shuffled_options)):
                problem_text = problem_text.replace(f"{orig_letter}. {options[i]}", f"{orig_letter}. {shuffled_option}")

            # Update prompt with shuffled problem
            prompt[0]['content'][0]['text'] = problem_text

            # Add reasoning to prompt
            if reasoning:
                prompt.append({
                    'role': 'assistant',
                    'content': [{'type': 'text', 'text': f"<think>{reasoning}</think>\n<answer>"}]
                })

            return prompt, shuffle_mapping
        else:
            # Add reasoning to original prompt
            if reasoning:
                prompt.append({
                    'role': 'assistant',
                    'content': [{'type': 'text', 'text': f"<think>{reasoning}</think>\n<answer>"}]
                })

            return prompt, None

    def _process_single_prompt(self, prompt):
        """Process a single prompt through the tokenizer."""
        prompt_text = maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"]

        # Extract image/video inputs if present
        image_inputs = None
        video_inputs = None

        if len(prompt) > 0 and 'content' in prompt[0] and len(prompt[0]['content']) > 0:
            if 'image' in prompt[0]['content'][0]:
                # Process image
                image_path = prompt[0]['content'][0]['image']
                from qwen_vl_utils import process_vision_info
                image_inputs, _, _ = process_vision_info([prompt[0]], return_video_kwargs=True)
            elif 'video' in prompt[0]['content'][0]:
                # Process video
                video_path = prompt[0]['content'][0]['video']
                from qwen_vl_utils import process_vision_info
                _, video_inputs, _ = process_vision_info([prompt[0]], return_video_kwargs=True)

        # Tokenize
        inputs = self.processing_class(
            text=[prompt_text],
            images=image_inputs,
            videos=video_inputs,
            return_tensors="pt",
            padding=True,
            padding_side="left",
            add_special_tokens=False,
        )

        return super()._prepare_inputs(inputs)

    def _extract_answer_logits(self, logits, options):
        """Extract logits for answer tokens (A, B, C, D, etc.)."""
        # Get the last token's logits (where the answer should be)
        last_logits = logits[:, -1, :]  # Shape: [batch_size, vocab_size]

        # Get token IDs for option letters
        option_letters = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'][:len(options)]
        option_token_ids = [self.processing_class.encode(letter, add_special_tokens=False)[0] for letter in option_letters]

        # Extract logits for these specific tokens
        answer_logits = last_logits[:, option_token_ids]  # Shape: [batch_size, num_options]

        return answer_logits

    def _reorder_logits(self, logits, shuffle_mapping):
        """Reorder shuffled logits according to the shuffle mapping."""
        if shuffle_mapping is None:
            return logits

        # Create inverse mapping
        inverse_mapping = {v: k for k, v in shuffle_mapping.items()}

        # Reorder logits
        reordered = torch.zeros_like(logits)
        for orig_idx, shuffled_idx in shuffle_mapping.items():
            reordered[:, orig_idx] = logits[:, shuffled_idx]

        return reordered

    def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
        metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()}  # average the metrics
        logs = {**logs, **metrics}
        if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
            super().log(logs, start_time)
        else:  # transformers<=4.46
            super().log(logs)
        self._metrics.clear()

    def create_model_card(
        self,
        model_name: Optional[str] = None,
        dataset_name: Optional[str] = None,
        tags: Union[str, list[str], None] = None,
    ):
        """
        Creates a draft of a model card using the information available to the `Trainer`.

        Args:
            model_name (`str` or `None`, *optional*, defaults to `None`):
                Name of the model.
            dataset_name (`str` or `None`, *optional*, defaults to `None`):
                Name of the dataset used for training.
            tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
                Tags to be associated with the model card.
        """
        if not self.is_world_process_zero():
            return

        if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
            base_model = self.model.config._name_or_path
        else:
            base_model = None

        tags = tags or []
        if isinstance(tags, str):
            tags = [tags]

        if hasattr(self.model.config, "unsloth_version"):
            tags.append("unsloth")

        citation = textwrap.dedent(
            """\
            @article{zhihong2024deepseekmath,
                title        = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
                author       = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
                year         = 2024,
                eprint       = {arXiv:2402.03300},
            """
        )

        model_card = generate_model_card(
            base_model=base_model,
            model_name=model_name,
            hub_model_id=self.hub_model_id,
            dataset_name=dataset_name,
            tags=tags,
            wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
            comet_url=get_comet_experiment_url(),
            trainer_name="GRPO",
            trainer_citation=citation,
            paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models",
            paper_id="2402.03300",
        )

        model_card.save(os.path.join(self.args.output_dir, "README.md"))
