import warnings
from typing import Callable, Optional, Union, Any, List

from accelerate.utils import broadcast_object_list, gather, gather_object
from datasets import Dataset, IterableDataset
from peft import PeftConfig
import torch
from torch import nn
from transformers import (
    PreTrainedModel,
    PreTrainedTokenizerBase,
    Trainer,
    TrainerCallback,
    is_wandb_available
)

from vllm import LLM, SamplingParams

from open_r1.envs.environment import Environment
from open_r1.inference.vllm_client import VLLMClient
from open_r1.utils.logging_utils import print_prompt_completions_sample

# Monkey patch vllm client BEFORE importing trl
import trl.extras.vllm_client
trl.extras.vllm_client.VLLMClient = VLLMClient

# Now import trl after monkey patching
from trl import GRPOTrainer, GRPOConfig
from trl.data_utils import maybe_apply_chat_template, apply_chat_template, is_conversational
from trl.extras.profiling import profiling_context, profiling_decorator
from trl.import_utils import is_rich_available
from trl.trainer.utils import pad, selective_log_softmax
from trl.models import unwrap_model_for_generation

if is_wandb_available():
    import wandb


# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]


# torch.nanstd doesn't exist, so we define it here
def nanstd(tensor: torch.Tensor) -> torch.Tensor:
    """
    Compute the standard deviation of a tensor, ignoring NaNs. This function only supports 1D tensors.

    Args:
        tensor (`torch.Tensor`):
            Input tensor of shape `(N,)`.

    Returns:
        `torch.Tensor`:
            Standard deviation of the tensor, ignoring NaNs.
    """
    variance = torch.nanmean((tensor - torch.nanmean(tensor, keepdim=True)) ** 2)  # Compute variance ignoring NaNs
    count = torch.sum(~torch.isnan(tensor))  # Count of non-NaN values
    variance *= count / (count - 1)  # Bessel's correction
    return torch.sqrt(variance)


class MultiTurnGRPOTrainer(GRPOTrainer):
    """
    Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
    paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300).
    This class extends the GRPOTrainer to support multi-turn conversations, which is crucial for our use case to enable the model
    to first reason on reverse-augmented questions before answering the original question.

    Example:
    """
    def __init__(
            self,
            model: Union[str, PreTrainedModel],
            env: Environment,
            reward_funcs: Union[RewardFunc, list[RewardFunc]],
            scale_rewards: bool = False,
            args: Optional[GRPOConfig] = None,
            train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
            eval_dataset: Optional[Union[Dataset, IterableDataset]] = None,
            processing_class: Optional[PreTrainedTokenizerBase] = None,
            callbacks: Optional[list[TrainerCallback]] = None,
            optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
            peft_config: Optional["PeftConfig"] = None,
            **kwargs,
    ):
        self.vllm_client = None
        if not args.use_vllm:
            raise ValueError("vLLM must be enabled for MultiTurnGRPOTrainer!")
        
        super().__init__(
            model=model,
            reward_funcs=reward_funcs,
            args=args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            processing_class=processing_class,
            callbacks=callbacks,
            optimizers=optimizers,
            peft_config=peft_config,
            **kwargs,
        )
        self.env = env
        self.scale_rewards = scale_rewards
        self.max_completion_length_each_turn = args.max_completion_length_each_turn
        if self.max_completion_length_each_turn == -1:
            self.max_completion_length_each_turn = self.max_completion_length
        self.sampling_params = SamplingParams(
            max_tokens=self.max_completion_length_each_turn,
            temperature=self.temperature,
            top_p=self.top_p,
            top_k=-1 if self.top_k is None else self.top_k,
            min_p=0.0 if self.min_p is None else self.min_p,
            repetition_penalty=self.repetition_penalty
        )
        self.verbose = kwargs.get("verbose", False)
        self.sft_weight = args.sft_weight

    @profiling_decorator
    def _generate_and_score_completions(
        self, inputs: dict[str, Union[torch.Tensor, Any]]
    ) -> dict[str, Union[torch.Tensor, Any]]:
        """
        Generate and score completions for a given input.
        [Data Format]
        For each input, we have a dictionary with the following keys:
        - "prompt": The prompt to generate completions for.
        - "completion": The completion to score.
        - "answer": The answer to score the completion against.
        - "question": The original question.
        - "augmented_questions": A list of augmented questions.
        - "augmented_answers": A list of augmented answers.
        
        Something not currently considered (for MCQA problems):
        - "options": The options to score the completion against.
        - "answer_type": The type of answer to score the completion against.
        - "answer_options": The options to score the completion against.
        """
        device = self.accelerator.device
        # Get the prompts, 
        prompts = [x["prompt"] for x in inputs]
        metadata = [x["metadata"] for x in inputs]
        rule_indices = [x["rule_indices"] for x in inputs]
        prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
        prompt_inputs = self.processing_class(
            text=prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
        )
        prompt_inputs = Trainer._prepare_inputs(self, prompt_inputs)
        prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]

        if self.max_prompt_length is not None:
            prompt_ids = prompt_ids[:, -self.max_prompt_length :]
            prompt_mask = prompt_mask[:, -self.max_prompt_length :]

        # Generate completions using either vLLM or regular generation
        if self.args.use_vllm:
            # First, have main process load weights if needed
            if self.state.global_step != self._last_loaded_step:
                self._move_model_to_vllm()
                self._last_loaded_step = self.state.global_step

            # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
            all_prompts = gather_object(prompts)
            all_prompts_text = gather_object(prompts_text)
            all_metadata = gather_object(metadata)
            all_rule_indices = gather_object(rule_indices)
            self.accelerator.wait_for_everyone()
            if self.accelerator.is_main_process:
                """ Single-turn Generation
                # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
                # num_generations outputs for each one. This is faster than generating outputs for each duplicate
                # prompt individually.
                ordered_set_of_prompts = all_prompts_text[:: self.num_generations]
                with profiling_context(self, "vLLM.generate"):
                    completion_ids = self.vllm_client.generate(
                        prompts=ordered_set_of_prompts,
                        n=self.num_generations,
                        repetition_penalty=self.repetition_penalty,
                        temperature=self.temperature,
                        top_p=self.top_p,
                        top_k=-1 if self.top_k is None else self.top_k,
                        min_p=0.0 if self.min_p is None else self.min_p,
                        max_tokens=self.max_completion_length,
                        guided_decoding_regex=self.guided_decoding_regex,
                    )
                """
                # Multi-turn Generation. This is crucial, we will use vLLM to generate multi-turn completions.
                # Always remember that the `prompts` contains `num_generations` duplicates, which is used to compute the group-wise advantage.
                env_result = self.env.generate(
                    prompts=all_prompts,
                    metadata=all_metadata,
                    rule_indices=all_rule_indices,
                    llm=self.vllm_client,
                    sampling_params=self.sampling_params,
                )
                # Construct the off-policy multi-turn completions.
                # Since `prompts` contains `num_generations` duplicates, we first take unique prompts, metadata, and rule indices,
                # and construct the off-policy multi-turn completions for each one.
                # Since we are doing this for SFT, we don't need to prepare `num_generations` completions for each prompt.
                ordered_set_of_prompts = all_prompts[:: self.num_generations]
                ordered_set_of_metadata = all_metadata[:: self.num_generations]
                ordered_set_of_rule_indices = all_rule_indices[:: self.num_generations]
                off_policy_results = self.env.generate_off_policy(
                    prompts=ordered_set_of_prompts,
                    metadata=ordered_set_of_metadata,
                    rule_indices=ordered_set_of_rule_indices,
                )
                completion_ids = env_result['ids']
                completion_messages = env_result['messages']
                completion_mask = env_result['mask']
            else:
                """ Single-turn Generation
                completion_ids = [None] * len(all_prompts_text)
                """
                # Multi-turn Generation
                ## On-policy generation, n_generations for each prompt
                completion_ids = [None] * len(all_prompts)
                completion_messages = [None] * len(all_prompts)
                completion_mask = [None] * len(all_prompts)
                ## Off-policy generation, 1 completion (multi-turn) for each prompt
                ordered_set_of_prompts = all_prompts[:: self.num_generations]
                off_policy_results = [None] * len(ordered_set_of_prompts)

            # Broadcast the completions from the main process to all processes, ensuring each process receives its
            # corresponding slice.
            self.accelerator.wait_for_everyone()
            completion_ids = broadcast_object_list(completion_ids, from_process=0)
            completion_messages = broadcast_object_list(completion_messages, from_process=0)
            completion_mask = broadcast_object_list(completion_mask, from_process=0)
            off_policy_results = broadcast_object_list(off_policy_results, from_process=0) # Broadcast ground truth turns

            process_slice = slice(
                self.accelerator.process_index * len(prompts),
                (self.accelerator.process_index + 1) * len(prompts),
            )
            # Slice the off-policy results, this requires that the duplicated prompts are placed on the same process/gpu
            # IMPORTANT: This is crucial, otherwise the off-policy's process_slice will be incorrect for the current process
            assert len(prompts) % self.num_generations == 0
            process_slice_off_policy = slice(
                self.accelerator.process_index * len(prompts) // self.num_generations,
                (self.accelerator.process_index + 1) * len(prompts) // self.num_generations,
            )
            
            completion_ids = completion_ids[process_slice]
            completion_messages = completion_messages[process_slice]
            completion_mask = completion_mask[process_slice]
            off_policy_results = off_policy_results[process_slice_off_policy]

            # Pad the completions, and concatenate them with the prompts
            completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
            completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)

            completion_mask = [torch.tensor(mask, device=device) for mask in completion_mask]
            completion_mask = pad(completion_mask, padding_value=0)
            
            prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
        
        else:
            # FIXME: This is not being handled correctly for multi-turn conversations as we by default will use vLLM
            # Regular generation path
            with unwrap_model_for_generation(
                self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
            ) as unwrapped_model:
                prompt_completion_ids = unwrapped_model.generate(
                    prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config
                )

            # Compute prompt length and extract completion ids
            prompt_length = prompt_ids.size(1)
            prompt_ids = prompt_completion_ids[:, :prompt_length]
            completion_ids = prompt_completion_ids[:, prompt_length:]

        """ Single-turn Generation
        # Mask everything after the first EOS token
        is_eos = completion_ids == self.processing_class.eos_token_id
        eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
        eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
        sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
        completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
        """

        # Concatenate prompt_mask with completion_mask for logit computation
        attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)  # (B, P+C)

        logits_to_keep = completion_ids.size(1)  # we only need to compute the logits for the completion tokens

        with torch.no_grad():
            # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's
            # computation here, and use per_token_logps.detach() instead.
            if self.num_iterations > 1:
                old_per_token_logps = self._get_per_token_logps(
                    self.model, prompt_completion_ids, attention_mask, logits_to_keep
                )
            else:
                old_per_token_logps = None

            if self.beta == 0.0:
                ref_per_token_logps = None
            elif self.ref_model is not None:
                ref_per_token_logps = self._get_per_token_logps(
                    self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep
                )
            else:
                with self.accelerator.unwrap_model(self.model).disable_adapter():
                    ref_per_token_logps = self._get_per_token_logps(
                        self.model, prompt_completion_ids, attention_mask, logits_to_keep
                    )

        """ Single-turn Generation
        # Decode the generated completions
        completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
        if is_conversational(inputs[0]):
            completions = []
            for prompt, completion in zip(prompts, completions_text):
                bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
                completions.append([{"role": "assistant", "content": bootstrap + completion}])
        else:
            completions = completions_text
        """
        # use message dicts for reward function inputs
        # completions = completion_messages
        # we should get the solution to the original question from the LAST turn in the completion_messages
        first_turn_prompts = [x[-1]["content"] for x in prompts]
        first_turn_completions = [x[0]["content"] for x in completion_messages]
        last_turn_prompts = [x[-2]["content"] for x in completion_messages]
        completions = [x[-1]["content"] for x in completion_messages]

        rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
        
        for i, reward_func in enumerate(self.reward_funcs):
            # Repeat all input columns (but "prompt" and "completion") to match the number of generations
            keys = [key for key in inputs[0] if key not in ["prompt", "completion", "metadata", "rule_indices"]]
            reward_kwargs = {key: [example[key] for example in inputs] for key in keys}       # Should contain solution and answer
            output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
            # Convert None values to NaN
            output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func]
            rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)

        # If all reward functions return None for a given row, issue a detailed warning
        if torch.isnan(rewards_per_func).all(dim=1).any():
            nan_row_idx = torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0]
            row_reward_kwargs = {key: value[nan_row_idx] for key, value in reward_kwargs.items()}
            row_reward_kwargs["prompt"] = prompts[nan_row_idx]
            row_reward_kwargs["completion"] = completions[nan_row_idx]
            warnings.warn(
                f"All reward functions returned None for the following kwargs: {row_reward_kwargs}. "
                "Please ensure that at least one reward function returns a valid reward."
            )

        # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
        # completions may be distributed across processes
        rewards_per_func = gather(rewards_per_func)

        # Apply weights to each reward function's output and sum
        rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)

        # Compute grouped-wise rewards
        mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
        std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)

        # Normalize the rewards to compute the advantages
        mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
        std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
        advantages = rewards - mean_grouped_rewards
        
        if self.args.scale_rewards:
            # Scale the rewards to be between 0 and 1
            advantages = advantages / (std_grouped_rewards + 1e-4)

        # Slice to keep only the local part of the data
        process_slice = slice(
            self.accelerator.process_index * len(prompts),
            (self.accelerator.process_index + 1) * len(prompts),
        )
        advantages = advantages[process_slice]

        # Log the metrics
        mode = "eval" if self.control.should_evaluate else "train"

        if mode == "train":
            self._total_train_tokens += self.accelerator.gather_for_metrics(attention_mask.sum()).sum().item()
        self._metrics[mode]["num_tokens"] = [self._total_train_tokens]

        completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
        self._metrics[mode]["completion_length"].append(completion_length)

        # Calculate mean reward per function, but only for samples where the function was applied
        for i, reward_func in enumerate(self.reward_funcs):
            if isinstance(reward_func, nn.Module):  # Module instead of PretrainedModel for compat with compiled models
                reward_func_name = reward_func.config._name_or_path.split("/")[-1]
            else:
                reward_func_name = reward_func.__name__
            # Only calculate mean for samples where this reward function was applied (non-NaN values)
            mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()
            self._metrics[mode][f"rewards/{reward_func_name}"].append(mean_rewards)
            std_rewards = nanstd(rewards_per_func[:, i]).item()
            self._metrics[mode][f"rewards/{reward_func_name}/std"].append(std_rewards)
        self._metrics[mode]["reward"].append(rewards.mean().item())
        self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item())

        if self.log_completions and self.state.global_step % self.args.logging_steps == 0:
            prompts_to_log = gather_object(prompts)
            completions_to_log = gather_object(completions)
            rewards_to_log = rewards.tolist()

            if self.accelerator.is_main_process:
                if is_rich_available():
                    # print_prompt_completions_sample(
                    #     [str(prompts_to_log[0][-1]["content"])],
                    #     [completions_to_log[0]],
                    #     [rewards_to_log[0]],
                    #     self.state.global_step,
                    # )
                    print_prompt_completions_sample(
                        [str(first_turn_prompts[0])],
                        [str(first_turn_completions[0])],
                        [rewards_to_log[0]],
                        self.state.global_step,
                    )
                    print_prompt_completions_sample(
                        [str(last_turn_prompts[0])],
                        [str(completions_to_log[0])],
                        [rewards_to_log[0]],
                        self.state.global_step,
                    )
                if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None:
                    import pandas as pd

                    # For logging
                    table = {
                        "step": [str(self.state.global_step)] * len(rewards),
                        "prompt": prompts_to_log,
                        "completion": completions_to_log,
                        "reward": rewards.tolist(),
                    }
                    df = pd.DataFrame(table)
                    wandb.log({"completions": wandb.Table(dataframe=df)})

        return {
            "prompt_ids": prompt_ids,
            "prompt_mask": prompt_mask,
            "completion_ids": completion_ids,
            "completion_mask": completion_mask,
            "old_per_token_logps": old_per_token_logps,
            "ref_per_token_logps": ref_per_token_logps,
            "advantages": advantages,
            "metadata": metadata,
            "off_policy_results": off_policy_results
        }
    
    @profiling_decorator
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        if return_outputs:
            raise ValueError("The GRPOTrainer does not support returning outputs")
        
        # Compute the per-token log probabilities for the model
        prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
        completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
        input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
        attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
        logits_to_keep = completion_ids.size(1)  # we only need to compute the logits for the completion tokens

        per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)

        # Compute the KL divergence between the model and the reference model
        if self.beta != 0.0:
            ref_per_token_logps = inputs["ref_per_token_logps"]
            per_token_kl = (
                torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
            )

        # Compute the GRPO loss
        advantages = inputs["advantages"]
        # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's computation (see
        # _generate_and_score_completions) and use per_token_logps.detach() instead.
        old_per_token_logps = inputs["old_per_token_logps"] if self.num_iterations > 1 else per_token_logps.detach()
        coef_1 = torch.exp(per_token_logps - old_per_token_logps)
        coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
        per_token_loss1 = coef_1 * advantages.unsqueeze(1)
        per_token_loss2 = coef_2 * advantages.unsqueeze(1)
        per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
        if self.beta != 0.0:
            per_token_loss = per_token_loss + self.beta * per_token_kl
        grpo_loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()

        # --- SFT Loss Calculation (Turn-by-Turn) ---
        sft_weight = self.sft_weight
        total_sft_loss = torch.tensor(0.0, device=grpo_loss.device)
        
        # print("SFT Weight: ", sft_weight)
        if sft_weight > 0.0:
            off_policy_results_batch = inputs["off_policy_results"]
            num_samples = len(off_policy_results_batch)
            processed_turns = 0

            sample_sft_losses = []
            # print("Number of samples: ", num_samples)
            # print("Off-policy results batch: ", off_policy_results_batch)
            for sample_turns in off_policy_results_batch:
                sample_total_turn_loss = torch.tensor(0.0, device=grpo_loss.device)
                num_valid_turns = 0
                for turn_data in sample_turns:
                    history = turn_data["history"]
                    user_content = turn_data["user"]
                    assistant_content = turn_data["assistant"]

                    # Construct messages for this turn
                    turn_messages = history + [
                        {"role": "user", "content": user_content},
                        {"role": "assistant", "content": assistant_content}
                    ]

                    try:
                        # 1. Format the turn using the chat template to get the string
                        formatted_turn_string = self.processing_class.apply_chat_template(
                            turn_messages,
                            add_generation_prompt=False,
                            tokenize=False # Get the formatted string only
                        )

                        # 2. Tokenize the formatted string to get input_ids and attention_mask
                        # Save the current truncation side
                        original_truncation_side = self.processing_class.truncation_side
                        # Temporarily set truncation side to left
                        self.processing_class.truncation_side = "left"
                        tokenized_output = self.processing_class(
                            formatted_turn_string,
                            return_tensors="pt",
                            max_length=self.max_prompt_length+self.max_completion_length,
                            add_special_tokens=False,  # Don't add special tokens again - chat template already has them
                            truncation=True
                        )
                        # Restore original truncation side
                        self.processing_class.truncation_side = original_truncation_side
                        turn_input_ids = tokenized_output["input_ids"].to(model.device)
                        turn_attention_mask = tokenized_output["attention_mask"].to(model.device)

                        # --- Create labels by finding the assistant marker and masking everything before it ---
                        # This is similar to how DataCollatorForCompletionOnlyLM works
                        
                        # Find the assistant role marker in the tokenizer's vocabulary
                        # Common markers - pick the ones appropriate for your tokenizer
                        assistant_markers = []
                        for marker in ["<start_of_turn>model\n"]:
                            try:
                                marker_ids = self.processing_class.encode(marker, add_special_tokens=False)
                                if marker_ids:  # Only add if the tokenizer recognizes this marker
                                    assistant_markers.append(marker_ids)
                            except:
                                pass
                        
                        # Initialize labels with ignore_index
                        labels = torch.full_like(turn_input_ids, -100)
                        
                        # Find the last occurrence of an assistant marker in the input_ids
                        assistant_start_idx = None
                        for marker_ids in assistant_markers:
                            # The marker could be multiple tokens
                            marker_len = len(marker_ids)
                            if marker_len == 0:
                                continue
                                
                            # Search for the marker in the input_ids
                            # Start from the end and work backwards to find the last occurrence
                            for i in range(turn_input_ids.size(1) - marker_len, -1, -1):
                                if turn_input_ids[0, i:i+marker_len].tolist() == marker_ids:
                                    # Found the marker - this is where assistant text begins
                                    assistant_start_idx = i + marker_len
                                    break
                            
                            if assistant_start_idx is not None:
                                break
                                
                        # If we found the assistant marker, set all tokens after it as labels
                        if assistant_start_idx is not None:
                            # Set label values for all tokens after the assistant marker
                            for i in range(assistant_start_idx, turn_input_ids.size(1)):
                                if turn_attention_mask[0, i] == 1:  # Only for non-padding tokens
                                    labels[0, i] = turn_input_ids[0, i]
                            
                        # Check if any labels remain after masking (possible if assistant response is truncated away or not found)
                        if (labels != -100).sum() == 0:
                            continue # Skip this turn if no assistant tokens are left or found

                        # Perform forward pass
                        outputs = model(input_ids=turn_input_ids, attention_mask=turn_attention_mask)
                        logits = outputs.logits

                        # Calculate loss (CrossEntropyLoss handles -100 labels)
                        loss_fct = torch.nn.CrossEntropyLoss()
                        # Shift logits and labels
                        shift_logits = logits[..., :-1, :].contiguous()
                        shift_labels = labels[..., 1:].contiguous()
                        
                        # Flatten the tokens
                        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
                        
                        if not torch.isnan(loss): # Avoid NaN loss propagation
                            sample_total_turn_loss += loss
                            num_valid_turns += 1
                            processed_turns += 1
                            
                    except Exception as e:
                        print(f"Warning: Error during SFT loss calculation for one turn: {e}")
                        # print(f"Problematic assistant content: {assistant_content}")
                        continue # Skip this turn

                if num_valid_turns > 0:
                    sample_avg_sft_loss = sample_total_turn_loss / num_valid_turns
                    sample_sft_losses.append(sample_avg_sft_loss)
            
            if sample_sft_losses:
                total_sft_loss = torch.stack(sample_sft_losses).mean() # Average across batch samples
        # --- End SFT Loss Calculation ---

        # Combine GRPO loss and SFT loss
        loss = grpo_loss + sft_weight * total_sft_loss

        # Log the metrics
        mode = "eval" if self.control.should_evaluate else "train"

        if self.beta != 0.0:
            mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum()
            self._metrics[mode]["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())

        is_clipped = (per_token_loss1 < per_token_loss2).float()
        clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum()
        self._metrics[mode]["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item())
        
        # Log individual loss components
        self._metrics[mode]["grpo_loss"].append(grpo_loss.item())
        self._metrics[mode]["sft_loss"].append(total_sft_loss.item())
        self._metrics[mode]["total_loss"].append(loss.item())
        
        return loss

    @profiling_decorator
    def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
        """
        NOTE: There might be a bug in the current framework, and here is the bypass solution:
        Sometimes, the logits_to_keep argument isn't supported, even though we have updated the transformers package to the latest version.
        So here we omit the logits_to_keep argument and compute all logits in the forward pass, and then slice the logits to keep only the last logits_to_keep + 1 tokens.

        For example, here is the forward function of the mistral model (in the class of `MistralForCausalLM`):
            https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L773
            # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
            outputs: BaseModelOutputWithPast = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_values=past_key_values,
                inputs_embeds=inputs_embeds,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                cache_position=cache_position,
                **kwargs,
            )

            hidden_states = outputs.last_hidden_state
            # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
            slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
            logits = self.lm_head(hidden_states[:, slice_indices, :])

        The sub-module `self.model` is a `MistralModel` which only contains the decoder layers, and the `logits` is computed by the `self.lm_head` module.
            self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        So if we first slice the logits within the forward function, we can save computation to only compute the logits for the last logits_to_keep tokens.
        Since we cannot modify the `MistralCausalLM` class, we need to implement this bypass solution, which slices the logits after the forward calling of `lm_head`.
        """
        # try:
        #     # Try with logits_to_keep first
        #     logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
        # except TypeError:
        
        if self.accelerator.is_main_process and self.verbose:
            # Inside MultiTurnGRPOTrainer._get_per_token_logps, before the model call
            print("--- Debug Info ---")
            print(f"input_ids shape: {input_ids.shape}, dtype: {input_ids.dtype}")
            print(f"attention_mask shape: {attention_mask.shape}, dtype: {attention_mask.dtype}")
            print(f"input_ids min/max: {input_ids.min().item()}, {input_ids.max().item()}")
            # Check attention_mask values - should ideally be 0 or 1
            unique_vals = torch.unique(attention_mask)
            print(f"attention_mask unique values: {unique_vals}")
            print("--- End Debug Info ---")
        
        # If logits_to_keep is not supported, compute all logits and slice after
        logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
        # Slice the logits to keep only the last logits_to_keep + 1 tokens
        logits = logits[:, -logits_to_keep-1:, :]
        
        logits = logits[:, :-1, :]  # (B, L-1, V), exclude the last logit: it corresponds to the next token pred

        input_ids = input_ids[:, -logits_to_keep:]
        # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
        # See https://github.com/huggingface/trl/issues/2770
        logits = logits[:, -logits_to_keep:]
        # Divide logits by sampling temperature.
        # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
        logits = logits / self.temperature
        return selective_log_softmax(logits, input_ids)  # compute logprobs for the input tokens
