from typing import Any, Callable, Optional, Union
import logging
import json
import os
import re
import math
import random
import copy

from collections import defaultdict, deque

from sympy import Q
from transformers.utils import (
    ModelOutput,
    is_peft_available,
    is_rich_available,
    is_torch_mlu_available,
    is_torch_npu_available,
    is_torch_xpu_available,
)

if is_rich_available():
    from rich.console import Console
    from rich.panel import Panel
    from rich.table import Table
    from rich.text import Text

from datasets import Dataset, IterableDataset

from contextlib import nullcontext

import torch
import torch.distributed as dist  # Import PyTorch distributed module
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP


from transformers import (
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    GenerationConfig,
    PreTrainedModel,
    PreTrainedTokenizerBase,
    Trainer,
    TrainerCallback,
    is_wandb_available,
)

from transformers.utils import is_datasets_available, is_flash_attn_2_available, is_peft_available, is_rich_available


from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed

from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from trl.extras.profiling import profiling_context, profiling_decorator
from trl.extras.vllm_client import VLLMClient
from trl.import_utils import is_liger_kernel_available, is_vllm_available
from trl.models import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation
from trl.models.utils import _ForwardRedirection
from trl.trainer.callbacks import SyncRefModelCallback
from trl.trainer.utils import (
    empty_cache,
    disable_dropout_in_model,
    generate_model_card,
    get_comet_experiment_url,
    pad,
    print_prompt_completions_sample,
    selective_log_softmax,
)


if is_peft_available():
    from peft import PeftConfig, get_peft_model

if is_liger_kernel_available():
    from liger_kernel.chunked_loss import LigerFusedLinearGRPOLoss

if is_vllm_available():
    from vllm import LLM, SamplingParams
    from vllm.sampling_params import GuidedDecodingParams

if is_wandb_available():
    import wandb

import wandb


from trl import GRPOConfig, GRPOTrainer
from trl.trainer.grpo_trainer import nanmax, nanmin, nanstd, truncate_with_protected_tokens

from .ref_config import RefGuidedVIConfig
from .utils import get_rank, is_main_process
from .templates import (
    batch_modify_user_messages,
    modify_user_message_for_reasoning,
    modify_system_message
)

# --- Custom logging filter ---
class RankFilter(logging.Filter):
    def filter(self, record):
        record.rank = get_rank()
        return True

# --- Get the Logger instance in all processes and disable propagation ---
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)  # You can set the minimum log level for the logger in all processes
logger.propagate = False  # Prevent logs from propagating to the root logger

# --- Configure and add handlers only on the main process (Rank 0) ---
# if is_main_process():
# Clear any existing handlers (for the Rank 0 logger instance)
for handler in logger.handlers[:]:
    logger.removeHandler(handler)

formatter = logging.Formatter(
    '%(asctime)s - Rank %(rank)s - %(filename)s:%(lineno)d - %(levelname)s - %(message)s'
)

# File handler
file_handler = logging.FileHandler(f"trainer_rank_{get_rank()}_only.log")
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)
file_handler.addFilter(RankFilter())
logger.addHandler(file_handler)

# Console (stream) handler
stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.INFO)
stream_handler.setFormatter(formatter)
stream_handler.addFilter(RankFilter())
logger.addHandler(stream_handler)

# --- Print-like helper (optional; recommended for convenience) ---
def rank0_print(*args, **kwargs):
    """
    A print-like function that logs via logger.info() only on the Rank 0 process.
    Preserves logger file/line metadata and Rank information.
    """
    # While an if is_main_process() check isn't strictly required (non-Rank-0 loggers have no handlers),
    # it avoids performing string formatting and other work on non-main processes, slightly improving performance.
    # if is_main_process():
    message = "📝 " + " ".join(map(str, args))
    logger.info(message, stacklevel=2)
        

if is_peft_available():
    from peft import PeftConfig, get_peft_model


# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]

class RefGuidedVITrainer(GRPOTrainer):
    """
    Trainer class for Variational Chain-of-Thought Training with reference-guided reasoning.
    
    Inherits from GRPOTrainer and implements the specific training logic for the
    variational inference framework with reference guidance.
    
    This version includes detailed logging to track the complete flow of samples.
    """
    
    def __init__(
        self,
        model: Union[str, PreTrainedModel],
        reward_funcs: Union[RewardFunc, list[RewardFunc]],
        args: Optional[RefGuidedVIConfig] = None,
        train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
        eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
        processing_class: Optional[PreTrainedTokenizerBase] = None,
        reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
        callbacks: Optional[list[TrainerCallback]] = None,
        optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
        peft_config: Optional["PeftConfig"] = None,
    ):
        rank0_print("🔧 Initializing RefGuidedVITrainer with detailed logging")
        
        if args.prob_model not in ["self", "ref"]:
            raise ValueError("prob_model must be either 'self' or 'ref'.")
        
        self.real_beta = None
        if args.prob_model == "ref" and args.beta == 0.0:
            self.real_beta = args.beta # i.e., 0.0
            args.beta = 1e-6 # use beta to create ref_model in GRPOTrainer, after that, it will be set to 0.0 in the training loop

        super().__init__(
            model=model,
            reward_funcs=reward_funcs,
            args=args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            processing_class=processing_class,
            reward_processing_classes=reward_processing_classes,
            callbacks=callbacks,
            optimizers=optimizers,
            peft_config=peft_config,
        )

        if self.real_beta is not None:
            self.beta = self.real_beta # reset beta to the original value, in case it was set to non-zero value for ref_model
        rank0_print(f"🔧 RefGuidedVITrainer initialized with prob_model={self.args.prob_model} and beta={self.beta}")

        maxlen = self._logs["prompt"].maxlen
        self._logs.update({
            # P part
            "p(y,z|x)_prompt": deque(maxlen=maxlen),
            "p(y,z|x)_completion": deque(maxlen=maxlen),
            "p(y,z|x)_rewards": defaultdict(lambda: deque(maxlen=maxlen)),
            "p(y,z|x)_advantages": deque(maxlen=maxlen),
            # Q part
            "q(z|x,y)_prompt": deque(maxlen=maxlen),
            "q(z|x,y)_completion": deque(maxlen=maxlen),
            "q(z|x,y)_rewards": defaultdict(lambda: deque(maxlen=maxlen)),
            "q(z|x,y)_advantages": deque(maxlen=maxlen),
        })

        
        rank0_print(f"✅ RefGuidedVITrainer initialized successfully")
        rank0_print(f"📊 Dataset info: {len(train_dataset) if train_dataset else 0} training samples")
        rank0_print(f"🎯 Reward functions: {self.reward_func_names}")
        rank0_print(f"🔧 Processing class: {type(processing_class).__name__}")


    ###################################################
    # Additional required variables
    ###################################################
        # rank0_print("🏷️ Initializing special token IDs...")
        
        self.think_start_tag = "<think>"
        think_start_tokens = self.processing_class(self.think_start_tag, add_special_tokens=False)["input_ids"]
        self.think_tag_start_ids = torch.tensor(think_start_tokens, dtype=torch.long, device=self.accelerator.device)
        rank0_print(f"🏷️{self.think_start_tag=} -> {self.think_tag_start_ids=}")

        self.think_end_tag = "</think>"
        think_end_tokens = self.processing_class(self.think_end_tag, add_special_tokens=False)["input_ids"]
        self.think_tag_end_ids = torch.tensor(think_end_tokens, dtype=torch.long, device=self.accelerator.device)
        rank0_print(f"🏷️{self.think_end_tag=} -> {self.think_tag_end_ids=}")

        self.args.answer_prefix = "The answer is" if self.args.answer_prefix == "simple_prefix" else ""
        self.synthetic_answer_prefix = f"</think>\n\n{self.args.answer_prefix}"
        synthetic_answer_prefix_tokens = self.processing_class(self.synthetic_answer_prefix, add_special_tokens=False)["input_ids"]
        self.synthetic_answer_prefix_ids = torch.tensor(synthetic_answer_prefix_tokens, dtype=torch.long, device=self.accelerator.device)
        rank0_print(f"🏷️{self.synthetic_answer_prefix=} -> {self.synthetic_answer_prefix_ids=}")
    
    def _prepare_q_messages(self, inputs: list[dict[str, Union[torch.Tensor, Any]]]):
        '''
        change system prompt and user message (Q and A)
        '''
        inputs = [modify_system_message(msg) for msg in inputs]
        inputs = [modify_user_message_for_reasoning(msg) for msg in inputs]
        return inputs
    
    def _prepare_p_messages(self, inputs: list[dict[str, Union[torch.Tensor, Any]]]):
        '''
            change system prompt only
        '''
        # if isinstance(inputs[0], dict):
        #     inputs = [modify_system_message(msg) for msg in inputs]
        # else:
        #     exit(f"inputs must be a list of dicts, but got {inputs=}")
        inputs = [modify_system_message(msg) for msg in inputs]
        return inputs

    def _generate_completions(
        self, 
        inputs,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Generate completions using either vLLM or regular generation.
        
        Args:
            prompt_ids: Tokenized prompt IDs, shape (batch_size, prompt_length)
            prompt_mask: Attention mask for prompts, shape (batch_size, prompt_length)  
            prompts_text: List of prompt texts for vLLM generation
            
        Returns:
            tuple containing:
            - prompt_completion_ids: Full sequence (prompt + completion)
            - prompt_ids: Original prompt IDs (may be modified)
            - completion_ids: Generated completion IDs only
        """
        mode = "train" if self.model.training else "eval"
        device = self.accelerator.device

        # rank0_print("🚀 Prepare prompt text to vllm...")
        prompts = [x["prompt"] for x in inputs]

        prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]

        # add <think> tag to the prompt
        prompts_text = [f"{prompt}{self.think_start_tag}\n" for prompt in prompts_text]

        prompt_inputs = self.processing_class(
            text=prompts_text,
            return_tensors="pt",
            padding=True,
            padding_side="left",
            add_special_tokens=False,
            # **kwargs,
        )
        prompt_inputs = super(GRPOTrainer,self)._prepare_inputs(prompt_inputs)
        prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]

        if self.max_prompt_length is not None:
            # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
            # Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text,
            # because we can't use `skip_special_tokens=True` (some special tokens are still needed for generation).
            protected = [self.image_token_id, self.vision_start_token_id, self.vision_end_token_id]
            protected = [token for token in protected if token is not None]
            prompt_ids, prompt_mask = truncate_with_protected_tokens(
                prompt_ids, prompt_mask, self.max_prompt_length, protected
            )

            prompts_text = self.processing_class.batch_decode(
                prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
            )
            prompts_text = [re.sub(rf"^({re.escape(self.pad_token)})+", "", text) for text in prompts_text]

            # The chat template inserts a single image token into the prompt text. However, when this text is later
            # tokenized, the single image token string is expanded into multiple image token IDs, depending on the
            # image size. Since we're detokenizing here, we may see repeated image tokens in the decoded text. We
            # collapse them back into a single token string to match the original template.
            # if self.image_token is not None:
            #     prompts_text = [
            #         re.sub(rf"({re.escape(self.image_token)})+", self.image_token, text) for text in prompts_text
            #     ]

        # rank0_print("🚀 Generating completions...")
        # Generate completions using either vLLM or regular generation
        if self.use_vllm:
            # First, update the vLLM weights if needed
            if self.state.global_step != self._last_loaded_step:
                self._move_model_to_vllm()
                self._last_loaded_step = self.state.global_step

            # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
            if self.vllm_mode == "server":
                all_prompts_text = gather_object(prompts_text)
                
                if self.accelerator.is_main_process:
                    # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
                    # num_generations outputs for each one. This is faster than generating outputs for each duplicate
                    # prompt individually.
                    ordered_set_of_prompts = all_prompts_text[:: self.num_generations]

                  
                    ordered_set_of_images = None

                    with profiling_context(self, "vLLM.generate"):
                        completion_ids = self.vllm_client.generate(
                            prompts=ordered_set_of_prompts,
                            images=ordered_set_of_images,
                            n=self.num_generations,
                            repetition_penalty=self.repetition_penalty,
                            temperature=self.temperature,
                            top_p=self.top_p,
                            top_k=-1 if self.top_k is None else self.top_k,
                            min_p=0.0 if self.min_p is None else self.min_p,
                            max_tokens=self.max_completion_length,
                            guided_decoding_regex=self.guided_decoding_regex,
                            generation_kwargs=self.args.generation_kwargs,
                        )
                else:
                    completion_ids = [None] * len(all_prompts_text)
                # Broadcast the completions from the main process to all processes, ensuring each process receives its
                # corresponding slice.
                completion_ids = broadcast_object_list(completion_ids, from_process=0)
                process_slice = slice(
                    self.accelerator.process_index * len(prompts),
                    (self.accelerator.process_index + 1) * len(prompts),
                )
                completion_ids = completion_ids[process_slice]

            # Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts
            elif self.vllm_mode == "colocate":
                if self.guided_decoding_regex:
                    guided_decoding = GuidedDecodingParams(regex=self.guided_decoding_regex)
                else:
                    guided_decoding = None

                # rank0_print("🚀 DEBUG: reach generation_kwargs...")
                generation_kwargs = {
                    "n": 1,  # vLLM on each GPU generates only 1 in colocate mode
                    "repetition_penalty": self.repetition_penalty,
                    "temperature": self.temperature,
                    "top_p": self.top_p,
                    "top_k": -1 if self.top_k is None else self.top_k,
                    "min_p": 0.0 if self.min_p is None else self.min_p,
                    "max_tokens": self.max_completion_length,
                    "guided_decoding": guided_decoding,
                }
                if self.args.generation_kwargs is not None:
                    generation_kwargs.update(self.args.generation_kwargs)
                sampling_params = SamplingParams(**generation_kwargs)

                if self.vllm_tensor_parallel_size > 1:
                    # rank0_print(f"🚀 DEBUG: {self.vllm_tensor_parallel_size=} > 1 ... ")
                    # Gather prompts from all ranks in the TP group and flatten.
                    # Each rank starts with its own prompts; after gathering, all ranks see the full group set.
                    orig_size = len(prompts_text)
                    # rank0_print(f"🚀 DEBUG: {orig_size=} = len(prompts_text)")
                    gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)]
                    # rank0_print(f"🚀 DEBUG: {gathered_prompts=}")
                    # rank0_print(f"🚀 DEBUG: {prompts_text=}")
                    # rank0_print(f"🚀 DEBUG: {self.tp_group=}")
                    # print(f"Rank {torch.distributed.get_rank()}: Testing barrier...")
                    # torch.distributed.barrier(group=self.tp_group)
                    # print(f"Rank {torch.distributed.get_rank()}: Barrier test succeeded!")
                    torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group)
                    # rank0_print("🚀 DEBUG: torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group)")
                    all_prompts_text = [p for sublist in gathered_prompts for p in sublist]
                    # rank0_print("🚀 DEBUG: all_prompts_text = [p for sublist in gathered_prompts for p in sublist]")

                else:
                    # rank0_print("🚀 DEBUG: self.vllm_tensor_parallel_size == 1 ...")
                    all_prompts_text = prompts_text
                    # all_images = images if has_images else None

                vllm_inputs = all_prompts_text
                # rank0_print(f"🚀 vLLM generating {len(vllm_inputs)} completions Start")
                with profiling_context(self, "vLLM.generate"):
                    all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=True)

                # rank0_print(f"🚀 vLLM generating {len(vllm_inputs)} completions END")
                completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs]

                del all_outputs

                if self.vllm_tensor_parallel_size > 1:
                    # Slice completions for this rank within its TP group.
                    # Each rank generates all outputs — we keep only our share.
                    local_rank_in_group = torch.distributed.get_rank(group=self.tp_group)
                    tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size)
                    completion_ids = completion_ids[tp_slice]

            # Pad the completions, and concatenate them with the prompts
            completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
            completion_ids = pad(completion_ids, padding_value=self.pad_token_id)
            prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)

        elif self.use_transformers_paged:
            # Re-process inputs for paged generation if needed
            # Note: images are already validated and preprocessed above
            paged_prompt_inputs = self.processing_class(text=prompts_text) # , **kwargs
            previous_attn = self.model_wrapped.config._attn_implementation

            if is_flash_attn_2_available():
                self.model_wrapped.config._attn_implementation = "paged_attention"
            else:
                self.model_wrapped.config._attn_implementation = "sdpa_paged"
            with (
                profiling_context(self, "transformers.generate_batch"),
                unwrap_model_for_generation(
                    self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
                ) as unwrapped_model,
                torch.no_grad(),
                FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),
            ):
                # Cast to the appropriate dtype based on training configuration
                if self.args.bf16:
                    unwrapped_model.to(torch.bfloat16)
                elif self.args.fp16:
                    unwrapped_model.to(torch.float16)
                with torch.inference_mode():
                    all_outputs = unwrapped_model.generate_batch(
                        paged_prompt_inputs.input_ids, generation_config=self.generation_config, progress_bar=False
                    )
            completion_ids = [output.generated_tokens for output in all_outputs.values()]
            completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
            completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right")
            prompt_ids = [torch.tensor(ids, device=device) for ids in paged_prompt_inputs.input_ids]
            prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left")
            prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
            # Restore the original attention implementation, training mode
            self.model_wrapped.config._attn_implementation = previous_attn
        else:
            # Regular generation path
            with (
                profiling_context(self, "transformers.generate"),
                unwrap_model_for_generation(
                    self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
                ) as unwrapped_model,
                torch.no_grad(),
                FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),
            ):
                prompt_inputs["input_ids"], prompt_inputs["attention_mask"] = prompt_ids, prompt_mask
                prompt_completion_ids = unwrapped_model.generate(
                    **prompt_inputs, generation_config=self.generation_config, disable_compile=True
                )
            # Compute prompt length and extract completion ids
            prompt_length = prompt_ids.size(1)
            prompt_ids = prompt_completion_ids[:, :prompt_length]
            completion_ids = prompt_completion_ids[:, prompt_length:]

        # Mask everything after the first EOS token
        is_eos = completion_ids == self.eos_token_id
        eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
        eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
        sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
        completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()

        # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need
        # to re-tokenize completions if the reward is computed from tokens.
        completion_ids_list = [
            [id.item() for id, m in zip(row, mask_row) if m] for row, mask_row in zip(completion_ids, completion_mask)
        ]

        # Sum along sequence dimension (dim=1) to get completion length per sequence, used for logging
        completion_lengths = completion_mask.sum(1)

        # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
        if self.mask_truncated_completions:
            truncated_completions = ~is_eos.any(dim=1)
            completion_mask = completion_mask * (~truncated_completions).unsqueeze(1).int()

        # Concatenate prompt_mask with completion_mask for logit computation
        attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)  # (B, P+C)

        logits_to_keep = completion_ids.size(1)  # we only need to compute the logits for the completion tokens
        batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size

        # rank0_print("🚀 Computing per_token_logps...")
        with torch.no_grad():
            # If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of
            # a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the
            # samples may come from an earlier version of the model. In that case, we need to track old_per_token_logps
            # for importance sampling. If the steps are aligned, importance sampling isn't necessary and we set
            # old_per_token_logps to None.
            generate_every = self.args.steps_per_generation * self.num_iterations  # generation frequency
            if self.args.gradient_accumulation_steps % generate_every != 0:
                old_per_token_logps, _ = self._get_per_token_logps_and_entropies(
                    self.model,
                    prompt_completion_ids,
                    attention_mask,
                    logits_to_keep,
                    batch_size,
                    # pixel_values=prompt_inputs.get("pixel_values"),
                    # image_grid_thw=prompt_inputs.get("image_grid_thw"),
                    # pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
                    # image_sizes=prompt_inputs.get("image_sizes"),
                )
            else:
                old_per_token_logps = None

            # Compute the per-token log probabilities for the reference model
            if self.beta != 0.0:
                if self.ref_model is not None:
                    ref_per_token_logps, _ = self._get_per_token_logps_and_entropies(
                        self.ref_model,
                        prompt_completion_ids,
                        attention_mask,
                        logits_to_keep,
                        batch_size=batch_size,
                    )
                else:
                    with self.accelerator.unwrap_model(self.model).disable_adapter():
                        ref_per_token_logps, _ = self._get_per_token_logps_and_entropies(
                            self.model,
                            prompt_completion_ids,
                            attention_mask,
                            logits_to_keep,
                            batch_size=batch_size,
                        )
            else:
                ref_per_token_logps = None

        # Decode the generated completions
        completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
        if is_conversational(inputs[0]):
            completions = []
            for prompt, completion in zip(prompts, completions_text):
                bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
                completions.append([{"role": "assistant", "content": bootstrap + completion}])
        else:
            completions = completions_text


        return {
            "prompts": prompts,
            "prompt_ids": prompt_ids,
            "prompt_mask": prompt_mask,
            "prompts_text": prompts_text,
            "completions_text": completions_text,
            "completions": completions,
            "completion_ids": completion_ids,
            "completion_mask": completion_mask,
            "completion_ids_list": completion_ids_list,
            "completion_lengths": completion_lengths,
            "attention_mask": attention_mask,
            "is_eos": is_eos,
            "old_per_token_logps": old_per_token_logps,
            "ref_per_token_logps": ref_per_token_logps
        }
    
    def _build_synthetic_enhanced_reasoning_trajectory(self, q_staffs: dict, raw_inputs: list[dict]):
        '''
        From the Q rollout completions, extract tokens after the first </think> tag
        (up to the last valid token) as reasoning_tokens, and build a new synthetic
        completion:
            {reasoning_tokens}</think> The answer is \\boxed{gt}.
        If </think> is not found, reasoning_tokens is empty.
        Returns
            {"input_ids": Tensor[B, L], "attention_mask": Tensor[B, L] (bool)}
        Args:
            q_staffs: Dict containing Q rollout results; required fields: completion_ids (B, L), completion_mask (B, L)
            raw_inputs: List of raw samples (length B), each must contain reward_model.ground_truth
        '''
        device = self.accelerator.device
        completion_ids: torch.Tensor = q_staffs["completion_ids"].to(device)          # (B, L)
        completion_mask: torch.Tensor = q_staffs["completion_mask"].to(device).bool() # (B, L)
        batch_size, seq_len = completion_ids.shape

    # Locate the start position of the </think> tag in each sample (i.e., the tag's first token)
        think_end_start_pos = self._find_first_subsequence(completion_ids, self.think_tag_end_ids)  # (B,)
        think_end_len = self.think_tag_end_ids.shape[0]

        synthesized_list: list[torch.Tensor] = []
        for i in range(batch_size):
            ids_row = completion_ids[i]
            mask_row = completion_mask[i]  # mask can be all zero if this completion is truncated
            # Upper bound of valid tokens (excluding padding)
            valid_len = mask_row.int().sum().item()  # valid_len can be 0 if no eos in completion
            # Remove trailing eos if present
            if valid_len > 0 and ids_row[valid_len - 1].item() == self.eos_token_id:
                valid_len -= 1
            # reasoning start: if </think> found, take the position after it; otherwise use valid_len (i.e., empty)
            end_pos = think_end_start_pos[i].item()
            if end_pos != -1:
                reasoning_start = end_pos + think_end_len
                if reasoning_start > valid_len:
                    reasoning_start = valid_len
            else:
                reasoning_start = valid_len  # Not found: empty reasoning
            # valid_len is the end; extract everything after </think>
            # If reasoning_start == valid_len, there is no content after </think> and we get an empty list
            reasoning_tokens = ids_row[reasoning_start:valid_len].tolist()
            # If there is no reasoning, an empty list is fine
            # Concatenate: reasoning_tokens + </think> + " The answer is \\boxed{gt}." + eos
            gt = raw_inputs[i]['reward_model']['ground_truth']
            answer_text = f"\n\n{self.args.answer_prefix} \\boxed{{{gt}}}."  # according to user-specified format
            answer_ids = self.processing_class(answer_text, add_special_tokens=False, return_tensors='pt')['input_ids'][0].tolist()
            new_ids = reasoning_tokens + self.think_tag_end_ids.tolist() + answer_ids + [self.processing_class.eos_token_id]
            synthesized_list.append(torch.tensor(new_ids, dtype=torch.long, device=device))

        padded_ids = pad(
            tensors=synthesized_list,
            padding_value=self.processing_class.pad_token_id,
            padding_side="right",
        )
        attention_mask = padded_ids != self.processing_class.pad_token_id

        del synthesized_list, think_end_start_pos

        return {
            "input_ids": padded_ids, 
            "attention_mask": attention_mask
        }

    def _build_synthetic_naive_reasoning_trajectory(self, p_staffs: dict, raw_inputs: list[dict]):
        '''
          For P rollout completions, truncate the part after </think> and replace with the ground-truth answer.
          Rules:
             1. If </think> is found: keep up to and including </think>, discard tokens after it,
                 then append \n\nThe answer is \\boxed{gt}. and <eos>
             2. If </think> is not found: append </think> + \n\nThe answer is \\boxed{gt}. + <eos> to the end of valid tokens
                 (we remove trailing eos if present). Length is not forced to match original; we pad afterwards.
          Return: {"input_ids": Tensor[B, L], "attention_mask": Tensor[B, L]}
        '''
        device = self.accelerator.device
        completion_ids: torch.Tensor = p_staffs["completion_ids"].to(device)          # (B, L)
        completion_mask: torch.Tensor = p_staffs["completion_mask"].to(device).bool() # (B, L)
        batch_size, seq_len = completion_ids.shape

    # Locate </think> (first occurrence)
        think_end_pos = self._find_first_subsequence(completion_ids, self.think_tag_end_ids)  # (B,)
        think_end_len = self.think_tag_end_ids.shape[0]

        new_sequences: list[torch.Tensor] = []
        for i in range(batch_size):
            row_ids = completion_ids[i]
            row_mask = completion_mask[i]
            valid_len = row_mask.int().sum().item()  # valid_len can be 0 if completion is truncated
            # Remove trailing eos if present
            if valid_len > 0 and row_ids[valid_len - 1].item() == self.eos_token_id:
                valid_len -= 1
            pos = think_end_pos[i].item()
            gt = raw_inputs[i]['reward_model']['ground_truth']
            answer_text = f"\n\n{self.args.answer_prefix} \\boxed{{{gt}}}."
            answer_ids = self.processing_class(answer_text, add_special_tokens=False, return_tensors='pt')['input_ids'][0].tolist()

            if pos != -1 and pos + think_end_len <= valid_len:
                # Case 1: </think> tag found
                prefix = row_ids[: pos + think_end_len].tolist()  # keep the part including </think>
                new_ids = prefix + answer_ids + [self.processing_class.eos_token_id]
            else:
                # Case 2: </think> not found, replace tail with </think>+answer+eos
                gt_ids = self.think_tag_end_ids.tolist() + answer_ids + [self.processing_class.eos_token_id]
                base = row_ids[:valid_len-len(gt_ids)].tolist()
                new_ids = base + gt_ids

            new_sequences.append(torch.tensor(new_ids, dtype=torch.long, device=device))

        padded_ids = pad(
            tensors=new_sequences,
            padding_value=self.processing_class.pad_token_id,
            padding_side="right",
        )
        attention_mask = padded_ids != self.processing_class.pad_token_id

        del new_sequences, think_end_pos

        return {"input_ids": padded_ids, "attention_mask": attention_mask}

    def _calculate_advantages(
        self,
        rewards: torch.Tensor,
        prompts: list[dict[str, Any]],
    ):
        '''
        Compute advantages.
        '''
        # Compute grouped-wise rewards
        mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
        std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
        is_std_zero = torch.isclose(std_grouped_rewards, torch.zeros_like(std_grouped_rewards))

        # Normalize the rewards to compute the advantages
        mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
        std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
        advantages = rewards - mean_grouped_rewards
        if self.scale_rewards:
            advantages = advantages / (std_grouped_rewards + 1e-4)

        # Slice to keep only the local part of the data
        process_slice = slice(
            self.accelerator.process_index * len(prompts),
            (self.accelerator.process_index + 1) * len(prompts),
        )
        all_process_advantages = advantages.clone()  # keep the aggregated advantages for logging
        advantages = advantages[process_slice]

        return {
            "mean_grouped_rewards": mean_grouped_rewards,
            "std_grouped_rewards": std_grouped_rewards,
            "is_std_zero": is_std_zero,
            "all_process_advantages": all_process_advantages,
            "advantages": advantages,
            "process_slice": process_slice,
        }

    def _log_base_metrics(
        self,
        variables: dict,
        prefix: str,
    ):
        '''
        Log training metrics.
        '''
        mode = "train" if self.model.training else "eval"
        device = self.accelerator.device

        attention_mask = variables["attention_mask"]
        completion_lengths = variables["completion_lengths"]
        is_eos = variables["is_eos"]
        rewards_per_func = variables["rewards_per_func"]
        prompts_text = variables["prompts_text"]
        completions_text = variables["completions_text"]
        all_process_advantages = variables["all_process_advantages"]
        mean_grouped_rewards = variables["mean_grouped_rewards"]
        std_grouped_rewards = variables["std_grouped_rewards"]
        is_std_zero = variables["is_std_zero"]

        # Log the metrics
        if mode == "train":
            self.state.num_input_tokens_seen += self.accelerator.gather(attention_mask.sum()).sum().item()
        self._metrics[mode][f"{prefix}_num_tokens"] = [self.state.num_input_tokens_seen]

        # Log completion lengths, mean, min, max
        agg_completion_lengths = self.accelerator.gather(completion_lengths)
        self._metrics[mode][f"{prefix}_completions/mean_length"] = [agg_completion_lengths.float().mean().item()]
        self._metrics[mode][f"{prefix}_completions/min_length"] = [agg_completion_lengths.float().min().item()]
        self._metrics[mode][f"{prefix}_completions/max_length"] = [agg_completion_lengths.float().max().item()]

        # Identify sequences that terminated with EOS and log their lengths
        agg_terminated_with_eos = self.accelerator.gather(is_eos.any(dim=1))
        term_completion_lengths = agg_completion_lengths[agg_terminated_with_eos]
        clipped_completions_ratio = 1 - len(term_completion_lengths) / len(agg_completion_lengths)
        self._metrics[mode][f"{prefix}_completions/clipped_ratio"].append(clipped_completions_ratio)
        if len(term_completion_lengths) == 0:  # edge case where no terminated sequences are found
            term_completion_lengths = torch.zeros(1, device=device)
        self._metrics[mode][f"{prefix}_completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item())
        self._metrics[mode][f"{prefix}_completions/min_terminated_length"].append(term_completion_lengths.float().min().item())
        self._metrics[mode][f"{prefix}_completions/max_terminated_length"].append(term_completion_lengths.float().max().item())

        # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)
        for i, reward_func_name in enumerate(self.reward_func_names):
            mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()
            self._metrics[mode][f"{prefix}_rewards/{reward_func_name}/mean"].append(mean_rewards)
            std_rewards = nanstd(rewards_per_func[:, i]).item()
            self._metrics[mode][f"{prefix}_rewards/{reward_func_name}/std"].append(std_rewards)
        self._metrics[mode][f"{prefix}_reward"].append(mean_grouped_rewards.mean().item())
        self._metrics[mode][f"{prefix}_reward_std"].append(std_grouped_rewards.mean().item())
        self._metrics[mode][f"{prefix}_frac_reward_zero_std"].append(is_std_zero.float().mean().item())

        # Log prompt and completion texts
        self._logs[f"{prefix}_prompt"].extend(gather_object(prompts_text))
        self._logs[f"{prefix}_completion"].extend(gather_object(completions_text))
        for i, name in enumerate(self.reward_func_names):
            self._logs[f"{prefix}_rewards"][name].extend(rewards_per_func[:, i].tolist())
        self._logs[f"{prefix}_advantages"].extend(all_process_advantages.tolist())
    # override
    def _generate_and_score_completions(
        self, inputs: list[dict[str, Union[torch.Tensor, Any]]]
    ) -> dict[str, Union[torch.Tensor, Any]]:
        device = self.accelerator.device
        mode = "train" if self.model.training else "eval"

        # rank0_print(f"Initial sample: \n{inputs[0]=}\n")
        # 1. prepare two types of prompts
        p_inputs = self._prepare_p_messages(inputs)
        # rank0_print(f"P prompt after adjust system msg: {p_inputs[0]=}")
        q_inputs = self._prepare_q_messages(inputs)
        # rank0_print(f"Q prompt after adjust system msg and user msg: {q_inputs[0]=}")

        
        # 2. generate p and q completions
        # rank0_print(f"Starting Q & P completions via vLLM generation...")
        p_staffs = self._generate_completions(p_inputs)
        # rank0_print(f"P completions via vLLM generation ENDED...")
        q_staffs = self._generate_completions(q_inputs)
        # rank0_print(f"Q completions via vLLM generation ENDED...")

        # rank0_print(f"Logging Q & P prompts and completions started...")
        # log p prompts and completions
        for i in range(min(4, len(p_staffs["prompts_text"]))):
            # check attention mask, log first and last 5 non-padding tokens
            first_few_non_padding_tokens = p_staffs["prompt_mask"][i].nonzero(as_tuple=True)[0][:3]
            last_few_non_padding_tokens = p_staffs["completion_mask"][i].nonzero(as_tuple=True)[0][-3:]
            # decode the first and last few tokens
            first_few_tokens_text = self.processing_class.batch_decode(
                p_staffs["prompt_ids"][i][first_few_non_padding_tokens], skip_special_tokens=False
            )
            last_few_tokens_text = self.processing_class.batch_decode(
                p_staffs["completion_ids"][i][last_few_non_padding_tokens], skip_special_tokens=False
            )

            rank0_print(
                f"p(y,z|x) prompt {i}: \n"
                f"{p_staffs['prompts_text'][i]}\n\n"
                f"p(y,z|x) completion {i}: \n"
                f"{p_staffs['completions_text'][i]}\n\n"
                f"p(y,z|x) first and last few tokens: \n"
                f"{first_few_tokens_text}...{last_few_tokens_text}"
            )
            

        # # log q prompts and completions
        for i in range(min(4, len(q_staffs["prompts_text"]))):
            # check attention mask, log first and last 5 non-padding tokens
            first_few_non_padding_tokens = q_staffs["prompt_mask"][i].nonzero(as_tuple=True)[0][:3]
            last_few_non_padding_tokens = q_staffs["completion_mask"][i].nonzero(as_tuple=True)[0][-3:]
            # decode the first and last few tokens
            first_few_tokens_text = self.processing_class.batch_decode(
                q_staffs["prompt_ids"][i][first_few_non_padding_tokens], skip_special_tokens=False
            )
            last_few_tokens_text = self.processing_class.batch_decode(
                q_staffs["completion_ids"][i][last_few_non_padding_tokens], skip_special_tokens=False
            )
            rank0_print(
                f"q(z|x,y) prompt {i}: \n"
                f"{q_staffs['prompts_text'][i]}\n\n"
                f"q(z|x,y) completion {i}: \n"
                f"{q_staffs['completions_text'][i]}\n\n"
                f"q(z|x,y) first and last few tokens: \n"
                f"{first_few_tokens_text}...{last_few_tokens_text}"
            )
        # 3. calculate p rollout reward. Note that this reward is gathered across all processes
        p_rewards_per_func = self._calculate_rewards(p_inputs, p_staffs["prompts"], p_staffs["completions"], p_staffs["completion_ids_list"])
        # there is three reward_funcs, e.g. thinking_format_correct, answer_accuracy, no_reference_leakage
        # p_rewards don't need the no_reference_leakage rewards
        # if not thinking_format_correct, use self.args.format_wrong_reward
        # if thinking_format_correct, use answer_accuracy
        p_format_correct_mask = (p_rewards_per_func[:, 0] == 1.0) # (B,)
        p_rewards = torch.where(
            p_format_correct_mask, 
            p_rewards_per_func[:,1], # answer_accuracy reward
            torch.tensor(self.args.format_wrong_reward, device=device) # format_wrong_reward
        )
        p_adv_staffs = self._calculate_advantages(p_rewards, p_staffs["prompts"])
        p_staffs.update(p_adv_staffs)
        p_staffs["rewards_per_func"] = p_rewards_per_func

        # logging the p_rewards
        rank0_print(
            f"p_rewards_per_func[:, 0] format reward: \n"
            f"{p_rewards_per_func[:, 0][p_staffs['process_slice']]=}"
        )

        rank0_print(
            f"p_rewards_per_func[:, 1] accuracy reward: \n"
            f"{p_rewards_per_func[:, 1][p_staffs['process_slice']]=}"
        )

        rank0_print(
            f"p_rewards final reward: \n"
            f"{p_rewards[p_staffs['process_slice']]=}"
        )

        # 4. build synthetic reasoning (z|q,a) + gt trajectory
        synthetic_enhanced_reasoning_trajectory = self._build_synthetic_enhanced_reasoning_trajectory(q_staffs, inputs)
        
        # check synthetic enhanced reasoning trajectory
        for i in range(min(4, len(synthetic_enhanced_reasoning_trajectory["input_ids"]))):
            # check attention mask, log first and last 5 non-padding tokens
            first_few_non_padding_tokens = synthetic_enhanced_reasoning_trajectory["attention_mask"][i].nonzero(as_tuple=True)[0][:3]
            last_few_non_padding_tokens = synthetic_enhanced_reasoning_trajectory["attention_mask"][i].nonzero(as_tuple=True)[0][-3:]
            # decode the first and last few tokens
            first_few_tokens_text = self.processing_class.batch_decode(
                synthetic_enhanced_reasoning_trajectory["input_ids"][i][first_few_non_padding_tokens], skip_special_tokens=False
            )
            last_few_tokens_text = self.processing_class.batch_decode(
                synthetic_enhanced_reasoning_trajectory["input_ids"][i][last_few_non_padding_tokens], skip_special_tokens=False
            )
            synthetic_enhanced_reasoning_text = self.processing_class.decode(synthetic_enhanced_reasoning_trajectory['input_ids'][i], skip_special_tokens=True)
            rank0_print(
                f"synthetic enhanced reasoning trajectory {i}: \n"
                f"{synthetic_enhanced_reasoning_text}"
                f"\n\nsynthetic_enhanced_reasoning_text(completion) first and last few tokens: \n"
                f"{first_few_tokens_text}...{last_few_tokens_text}"
            )

        # 5. build synthetic reasoning (z|q) + gt trajectory
        synthetic_naive_reasoning_trajectory = self._build_synthetic_naive_reasoning_trajectory(p_staffs, inputs)

        # check synthetic naive reasoning trajectory
        for i in range(min(4, len(synthetic_naive_reasoning_trajectory["input_ids"]))):
            # check attention mask, log first and last 5 non-padding tokens
            first_few_non_padding_tokens = synthetic_naive_reasoning_trajectory["attention_mask"][i].nonzero(as_tuple=True)[0][:3]
            last_few_non_padding_tokens = synthetic_naive_reasoning_trajectory["attention_mask"][i].nonzero(as_tuple=True)[0][-3:]
            # decode the first and last few tokens
            first_few_tokens_text = self.processing_class.batch_decode(
                synthetic_naive_reasoning_trajectory["input_ids"][i][first_few_non_padding_tokens], skip_special_tokens=False
            )
            last_few_tokens_text = self.processing_class.batch_decode(
                synthetic_naive_reasoning_trajectory["input_ids"][i][last_few_non_padding_tokens], skip_special_tokens=False
            )
            synthetic_naive_reasoning_text = self.processing_class.decode(synthetic_naive_reasoning_trajectory['input_ids'][i], skip_special_tokens=True)
            rank0_print(
                f"synthetic naive reasoning trajectory {i}: \n"
                f"{synthetic_naive_reasoning_text}"
                f"\n\nsynthetic_naive_reasoning_text(completion) first and last few tokens: \n"
                f"{first_few_tokens_text}...{last_few_tokens_text}"
            )

        # 6. calculate response probability gains
        enhanced_reasoning_response_likelihood = self._calculate_response_likelihood(
            prompt_ids=p_staffs["prompt_ids"], # only question 
            prompt_mask=p_staffs["prompt_mask"],
            completion_ids=synthetic_enhanced_reasoning_trajectory["input_ids"],
            completion_mask=synthetic_enhanced_reasoning_trajectory["attention_mask"],
            prefix="Enhanced Reasoning Trajectory"
        ) # (B,)
        enhanced_reasoning_response_likelihood = gather(enhanced_reasoning_response_likelihood)
        self._metrics[mode][f"q(z|x,y)_enhanced_reasoning_response_likelihood"].append(enhanced_reasoning_response_likelihood.mean().item())
        self._metrics[mode][f"q(z|x,y)_enhanced_reasoning_response_likelihood_std"].append(enhanced_reasoning_response_likelihood.std().item())

        # logging the enhanced reasoning response likelihood
        rank0_print(
            f"enhanced reasoning response likelihood: \n"
            f"{enhanced_reasoning_response_likelihood}"
        )

        naive_reasoning_response_likelihood = self._calculate_response_likelihood(
            prompt_ids=p_staffs["prompt_ids"],
            prompt_mask=p_staffs["prompt_mask"],
            completion_ids=synthetic_naive_reasoning_trajectory["input_ids"],
            completion_mask=synthetic_naive_reasoning_trajectory["attention_mask"],
            prefix="Naive Reasoning Trajectory"
        ) # (B,)
        

        naive_reasoning_response_likelihood = gather(naive_reasoning_response_likelihood)
        self._metrics[mode][f"p(z|x)_naive_reasoning_response_likelihood"].append(naive_reasoning_response_likelihood.mean().item())
        self._metrics[mode][f"p(z|x)_naive_reasoning_response_likelihood_std"].append(naive_reasoning_response_likelihood.std().item())
        # logging the naive reasoning response likelihood
        rank0_print(
            f"naive reasoning response likelihood: \n"
            f"{naive_reasoning_response_likelihood}"
        )
        
    # average naive_reasoning_response_likelihood and repeat
    # Compute the mean of naive_reasoning_response_likelihood within each group
        naive_group_mean_response_likelihood = naive_reasoning_response_likelihood.view(-1, self.num_generations).mean(dim=1)
    # torch.repeat_interleave repeats each element of naive_group_mean_response_likelihood by the group size
        naive_group_mean_response_likelihood = torch.repeat_interleave(naive_group_mean_response_likelihood, self.num_generations)

        # logging the naive reasoning response likelihood
        rank0_print(
            f"group mean naive reasoning response likelihood: \n"
            f"{naive_group_mean_response_likelihood}"
        )

        # prob gain relative to naive group mean
        if self.args.prob_reward_baseline == "naive_group_mean":
            enhanced_reasoning_prob_gain_reward = enhanced_reasoning_response_likelihood - naive_group_mean_response_likelihood
        elif self.args.prob_reward_baseline == "none":
            enhanced_reasoning_prob_gain_reward = enhanced_reasoning_response_likelihood
        self._metrics[mode][f"q(z|x,y)_enhanced_reasoning_prob_gain"].append(enhanced_reasoning_prob_gain_reward.mean().item())
        self._metrics[mode][f"q(z|x,y)_enhanced_reasoning_prob_gain_std"].append(enhanced_reasoning_prob_gain_reward.std().item())
        # logging the reasoning response likelihood gain
        rank0_print(
            f"reasoning response likelihood gain: \n"
            f"{enhanced_reasoning_prob_gain_reward}"
        )

        # clip prob gain reward to [log(min_prob_reward_ratio), log(max_prob_reward_ratio)]
        cliped_enhanced_reasoning_prob_gain_reward = enhanced_reasoning_prob_gain_reward.clamp(math.log(self.args.min_prob_reward_ratio), math.log(self.args.max_prob_reward_ratio))
        # avg
        self._metrics[mode][f"q(z|x,y)_cliped_enhanced_reasoning_prob_gain"].append(cliped_enhanced_reasoning_prob_gain_reward.mean().item())
        # max
        self._metrics[mode][f"q(z|x,y)_cliped_enhanced_reasoning_prob_gain_max"].append(cliped_enhanced_reasoning_prob_gain_reward.max().item())
        # std
        self._metrics[mode][f"q(z|x,y)_cliped_enhanced_reasoning_prob_gain_std"].append(cliped_enhanced_reasoning_prob_gain_reward.std().item())
        # logging the clipped reasoning response likelihood gain
        rank0_print(
            f"clipped reasoning response likelihood gain: \n"
            f"{cliped_enhanced_reasoning_prob_gain_reward}"
        )

        # 7. calculate q rollout reward. Note that this reward is gathered across all processes
        q_rewards_per_func = self._calculate_rewards(q_inputs, q_staffs["prompts"], q_staffs["completions"], q_staffs["completion_ids_list"])
        # there is three reward_funcs, e.g. thinking_format_correct, answer_accuracy, no_reference_leakage, valid_reasoning
        # q_rewards don't need the answer_accuracy reward
        # if not thinking_format_correct, use self.args.format_wrong_reward
        # if thinking_format_correct but no_reference_leakage only, use self.reference_leakage_reward
        # if thinking_format_correct but not_valid_reasoning only, use self.invalid_reasoning_in_response_reward
        # if thinking_format_correct but not no_reference_leakage and not valid_reasoning, use half of self.reference_leakage_reward + self.invalid_reasoning_in_response_reward
        # if thinking_format_correct and no_reference_leakage and valid_reasoning, use cliped_enhanced_reasoning_prob_gain_reward
        q_format_correct_mask = (q_rewards_per_func[:, 0] == 1.0) # (B,)
        q_no_reference_leakage_mask = (q_rewards_per_func[:, 2] == 1.0) # (B,)
        # q_valid_reasoning_mask = (q_rewards_per_func[:, 3] == 1.0) # (B,)
        q_rewards = torch.where(
            q_format_correct_mask & q_no_reference_leakage_mask, # & q_valid_reasoning_mask
            cliped_enhanced_reasoning_prob_gain_reward, # enhanced reasoning prob gain reward
            torch.where(
                q_format_correct_mask, 
                torch.tensor(self.args.format_wrong_reward, device=device) / 2, # half of format_wrong_reward
                torch.tensor(self.args.format_wrong_reward, device=device) # format_wrong_reward
            )
        )
        
        q_adv_staffs = self._calculate_advantages(q_rewards, q_staffs["prompts"])
        q_staffs.update(q_adv_staffs)
        q_staffs["rewards_per_func"] = q_rewards_per_func

        # logging the p_rewards
        rank0_print(
            f"q_rewards_per_func[:, 0] format reward: \n"
            f"{q_rewards_per_func[:, 0][q_staffs['process_slice']]=}"
        )

        rank0_print(
            f"q_rewards_per_func[:, 2] no reference leakage reward: \n"
            f"{q_rewards_per_func[:, 2][q_staffs['process_slice']]=}"
        )

        rank0_print(
            f"q_rewards final reward: \n"
            f"{q_rewards[q_staffs['process_slice']]=}"
        )

        # 8. log all the metrics
        self._log_base_metrics(
            variables=p_staffs,
            prefix="p(y,z|x)"
        )

        self._log_base_metrics(
            variables=q_staffs,
            prefix="q(z|x,y)"
        )

        # return everything needed in loss computation
   
        needed_keys = [
            "prompt_ids",
            "prompt_mask",
            "completion_ids",
            "completion_mask",
            "advantages",
            "old_per_token_logps",
            "ref_per_token_logps",
        ]

        outputs = {}
        p_staffs_for_loss = {
            f"p(y,z|x)_{key}": p_staffs[key] for key in needed_keys if key in p_staffs
        }
        
        q_staffs_for_loss = {
            f"q(z|x,y)_{key}": q_staffs[key] for key in needed_keys if key in q_staffs
        }
        q_staffs_for_loss["q(z|x,y)_format_correct_mask"] = q_format_correct_mask[q_staffs["process_slice"]]

        synthetic_trajectory_based_on_q = {
            f"x_(z|x,y)_y_{key}": synthetic_enhanced_reasoning_trajectory[key] for key in synthetic_enhanced_reasoning_trajectory
        }
        outputs.update(p_staffs_for_loss)
        outputs.update(q_staffs_for_loss)
        outputs.update(synthetic_trajectory_based_on_q)

        # USE Reward to weight Z_KL
        # 1. Only samples with positive reward participate (ensures format correctness, no reference answer leakage, valid reasoning)
        # 2. Among those, only samples with probability gain are used, since cliped_enhanced_reasoning_prob_gain_reward has a minimum of 0
        # 3. Larger probability gain yields larger sample weight
        if self.args.z_kl_sample_weight == "clipped_prob_gain":
            z_kl_sample_weight = cliped_enhanced_reasoning_prob_gain_reward[q_staffs["process_slice"]].clamp(min=0)
        elif self.args.z_kl_sample_weight == "none":
            z_kl_sample_weight = torch.ones_like(cliped_enhanced_reasoning_prob_gain_reward[q_staffs["process_slice"]])

        outputs["z_kl_sample_weight"] = z_kl_sample_weight

        del p_staffs, q_staffs, p_adv_staffs, q_adv_staffs
        del synthetic_enhanced_reasoning_trajectory, synthetic_naive_reasoning_trajectory
        del p_rewards_per_func, q_rewards_per_func, p_rewards, q_rewards
        del enhanced_reasoning_response_likelihood, naive_reasoning_response_likelihood
        del enhanced_reasoning_prob_gain_reward, cliped_enhanced_reasoning_prob_gain_reward
        del naive_group_mean_response_likelihood

        return outputs
    def _calculate_response_likelihood(
        self,
        prompt_ids,
        prompt_mask,
        completion_ids,
        completion_mask,
        prefix
    ):
        '''
        Calculate the likelihood of the response part (tokens after synthetic_answer_prefix_ids)
        '''
        mode = "train" if self.model.training else "eval"
        device = self.accelerator.device
        full_input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
        full_attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
        completion_seq_len = completion_ids.size(1)

        # 1. get per token logps
        # calculate log-probabilities for all completion ids
        infer_batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size

        if self.args.prob_model == "self":
            infer_model = self.model
        elif self.args.prob_model == "ref":
            infer_model = self.ref_model
        else:
            raise ValueError(f"Unknown prob_model: {self.args.prob_model}")
        
        with torch.no_grad():
            all_logps, _ = self._get_per_token_logps_and_entropies(
                model=infer_model,
                input_ids=full_input_ids,
                attention_mask=full_attention_mask,
                logits_to_keep=completion_seq_len,
                batch_size=infer_batch_size
            ) # Shape: (batch_size, completion_seq_len)

        # 2. identify response part start end indices (token after think tag and before padding)
        # get </think> token idx
        think_end_start_pos = self._find_first_subsequence(completion_ids, self.synthetic_answer_prefix_ids)  # (B,)
        think_end_len = self.synthetic_answer_prefix_ids.shape[0]
        response_start_pos = (think_end_start_pos + think_end_len).unsqueeze(1)  # (B,)
        # get non-padding token idx
        response_end_pos = completion_mask.sum(dim=1, keepdim=True)  # (B, 1)
        response_end_pos -= 1 # last one is '<|im_end|>'

        # 3. construct response token mask 
        sequence_indices = torch.arange(completion_seq_len, device=device).unsqueeze(0)
    # [start, end) - half-open interval (inclusive start, exclusive end)
        response_mask = (sequence_indices >= response_start_pos) & (sequence_indices < response_end_pos)

        # logging response span based on this response_mask:
        for i in range(response_mask.size(0)):
            response_tokens = completion_ids[i][response_mask[i]]
            # decode the response tokens
            response_text = self.processing_class.batch_decode(response_tokens, skip_special_tokens=False)
            rank0_print(f"📊 {prefix} Response tokens {i=}: {response_text}")


        # 4. calculate response likelihood
         # Apply the mask to zero out probabilities of padding tokens
        masked_logps = all_logps * response_mask # masked part will be zero, non-masked part will be 0-1 if prob, -infinity if log-prob

        # Sum the probabilities of the actual response tokens
        sum_logps = masked_logps.sum(dim=1)

        # The reward is the mean probability of the response tokens (avg(p))
        avg_logps = sum_logps / response_mask.sum(dim=1).clamp(min=1) # (B,)

        # del everything
        del full_input_ids, full_attention_mask, response_mask, all_logps, masked_logps
        del think_end_start_pos, response_start_pos, response_end_pos
        # avg_logps = torch.exp(avg_logps)
        # # compute minimum probability of the response tokens (min(p))
        # min_prob_rewards = all_probs.masked_fill_(~mask, -1e8).min(dim=1)
        
        return avg_logps

    def _find_first_subsequence(
        self, 
        sequence_tensor: torch.Tensor, 
        subsequence_tensor: torch.Tensor
    ) -> torch.Tensor:
        """
        Find the first occurrence position of a subsequence within each sequence in a batch.

        Args:
            sequence_tensor (torch.Tensor): Main sequence with shape (batch_size, seq_len).
            subsequence_tensor (torch.Tensor): Subsequence to search for with shape (sub_seq_len,).

        Returns:
            torch.Tensor: Shape (batch_size,). Each element is the starting position of the subsequence,
                          or -1 if not found.
        """
        batch_size, seq_len = sequence_tensor.shape
        sub_seq_len = subsequence_tensor.shape[0]

        # If subsequence is longer than the main sequence, no match is possible
        if sub_seq_len > seq_len:
            return torch.full((batch_size,), -1, dtype=torch.long, device=sequence_tensor.device)
        
        # Create a boolean tensor that marks all possible starting points using a sliding window
        windows = sequence_tensor.unfold(dimension=1, size=sub_seq_len, step=1)
        matches = (windows == subsequence_tensor).all(dim=2)
        
        # Use argmax to find the first True (match) position
        # argmax returns 0 when there is no True in a row, so add a post-check
        first_match_indices = torch.argmax(matches.int(), dim=1)
        
        # Check which rows really have a match
        has_match = matches.any(dim=1)
        
        # Set to -1 if no match
        first_match_indices[~has_match] = -1
        
        return first_match_indices

    def _calculate_q_z_kl_divergence(
        self,
        q_per_token_logps: torch.Tensor,
        p_per_token_logps: torch.Tensor,
        q_z_start_and_end_idx: torch.Tensor,
        p_z_start_and_end_idx: torch.Tensor,
        constraint_coef: float = None,
        learning_coef: float = None,
        sample_weight: torch.Tensor = None
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Compute the KL divergence on the token spans specified by start_and_end_idx in q and p.
        This implementation is fully vectorized and suitable for DDP environments.

        Args:
            q_per_token_logps (torch.Tensor): shape [batch_size, max_q_seq_len]
            p_per_token_logps (torch.Tensor): shape [batch_size, max_p_seq_len]
            q_z_start_and_end_idx (torch.Tensor): shape [batch_size, 2]
            p_z_start_and_end_idx (torch.Tensor): shape [batch_size, 2]

        Returns:
            torch.Tensor: Mean KL loss over valid samples (scalar).
        """
        # Get batch_size and device info
        batch_size = q_per_token_logps.shape[0]
        device = self.accelerator.device

        # --- 1. Create mask for q ---
        # Get q max sequence length
        max_q_seq_len = q_per_token_logps.shape[1]
        
        # Extract q start and end indices; keep shape [batch_size, 1] for broadcasting
        q_starts = q_z_start_and_end_idx[:, 0:1]
        q_ends = q_z_start_and_end_idx[:, 1:2]
        
        # Create q token index grid, shape: [1, max_q_seq_len]
        q_arange = torch.arange(max_q_seq_len, device=device).unsqueeze(0)
        
        # Broadcast to generate q mask: compare q_arange (1, max_q_seq_len) with q_starts/q_ends (batch_size, 1)
        # result has shape [batch_size, max_q_seq_len]
        q_mask = (q_arange >= q_starts) & (q_arange < q_ends)

        # --- 2. Create mask for p (analogous) ---
        max_p_seq_len = p_per_token_logps.shape[1]
        p_starts = p_z_start_and_end_idx[:, 0:1]
        p_ends = p_z_start_and_end_idx[:, 1:2]
        p_arange = torch.arange(max_p_seq_len, device=device).unsqueeze(0)
        # if start_idx and end_idx are -1, the mask will be all False
        p_mask = (p_arange >= p_starts) & (p_arange < p_ends) 
        # --- 3. Apply masks and select logp values ---
        # q_mask and p_mask select True positions and flatten into 1D tensors.
        # Since lengths match by construction, the two 1D tensors are aligned.
        selected_q_logps = q_per_token_logps[q_mask]
        selected_p_logps = p_per_token_logps[p_mask]
        
        # --- 4. Compute logp differences ---
        if constraint_coef is None and learning_coef is None:
            logr = selected_p_logps - selected_q_logps
            logr = logr.clamp(min=math.log(self.args.min_r), max=math.log(self.args.max_r)) # 
            kl = -logr if self.args.kl_estimator == "k1" else (logr.exp() - 1) - logr  # Else statement is k3
        else:
            assert constraint_coef is not None and learning_coef is not None, \
                "Both constraint_coef and learning_coef must be provided for the constrained KL divergence."
            # Constrained KL divergence
            logr_constraint = selected_p_logps.detach() - selected_q_logps
            logr_constraint = logr_constraint.clamp(min=math.log(self.args.min_r), max=math.log(self.args.max_r))
            kl_constraint = -logr_constraint if self.args.kl_estimator == "k1" else (logr_constraint.exp() - 1) - logr_constraint

            logr_learning = selected_p_logps - selected_q_logps.detach()
            logr_learning = logr_learning.clamp(min=math.log(self.args.min_r), max=math.log(self.args.max_r))
            kl_learning = -logr_learning if self.args.kl_estimator == "k1" else (logr_learning.exp() - 1) - logr_learning

            kl = constraint_coef * kl_constraint + learning_coef * kl_learning
        
        # # --- 5. Compute KL loss ---
        # # Option A: total loss over the batch (simplest)
        # total_loss = kl.sum()

        # Option B: per-sample loss (more flexible)
        # First compute z segment length for each sample
        z_lengths = q_z_start_and_end_idx[:, 1] - q_z_start_and_end_idx[:, 0]
        
        # Create indices mapping each KL element to its sample
        # e.g., for lengths [2, 3], batch_indices is [0, 0, 1, 1, 1]
        batch_indices = torch.arange(batch_size, device=device).repeat_interleave(z_lengths)
        
        # Accumulate per-sample sums with scatter_add_
        # If sample i has z length 0, per_sample_kl[i] will be 0
        per_sample_kl_sum = kl.new_zeros(batch_size).scatter_add_(0, batch_indices, kl)
        per_sample_kl_mean = per_sample_kl_sum / z_lengths.clamp(min=1.)  # per-sample mean KL loss

        rank0_print(f"sample_weight:\n{sample_weight}\n")
        # Apply sample weights or masks
        kl_loss = (per_sample_kl_mean * sample_weight).sum() / ((z_lengths > 0) * sample_weight).sum().clamp(min=1.)

        del q_mask, p_mask, selected_q_logps, selected_p_logps, kl, batch_indices, z_lengths, per_sample_kl_sum, per_sample_kl_mean

        return kl_loss

    def _compute_grpo_loss(self, model, inputs, metric_prefix):
        # Compute the per-token log probabilities for the model
        prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
        completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
        input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
        attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
        logits_to_keep = completion_ids.size(1)  # we only need to compute the logits for the completion tokens

        # Compute the per_token_logps and the entropy at each position in the completion
        per_token_logps, entropies = self._get_per_token_logps_and_entropies(
            model,
            input_ids,
            attention_mask,
            logits_to_keep,
            compute_entropy=True,
        )

        if self.top_entropy_quantile < 1.0:
            entropy_mask = self.get_high_entropy_mask(entropies, completion_mask, 1 - self.top_entropy_quantile)
        else:
            entropy_mask = None

        # Compute the KL divergence between the model and the reference model
        if self.beta != 0.0:
            ref_per_token_logps = inputs["ref_per_token_logps"]
            per_token_kl = (
                torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
            )

        # Compute the loss
        advantages = inputs["advantages"]
        # When using num_iterations == 1 and steps_per_generation <= gradient_accumulation_steps
        # old_per_token_logps == per_token_logps, so we can skip it's computation
        # (see _generate_and_score_completions) and use per_token_logps.detach() instead.
        old_per_token_logps = inputs.get("old_per_token_logps")
        old_per_token_logps = per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps

        log_ratio = per_token_logps - old_per_token_logps
        if self.importance_sampling_level == "token":
            log_importance_weights = log_ratio
        elif self.importance_sampling_level == "sequence":
            log_importance_weights = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)
            log_importance_weights = log_importance_weights.unsqueeze(-1)
        else:
            raise ValueError(
                f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' "
                "and 'sequence'."
            )
        # From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on
        # importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1)

        coef_1 = torch.exp(log_importance_weights)
        coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)

        # Two-sided clipping
        if self.args.delta is not None:
            coef_1 = torch.clamp(coef_1, max=self.args.delta)

        per_token_loss1 = coef_1 * advantages.unsqueeze(1)
        per_token_loss2 = coef_2 * advantages.unsqueeze(1)
        per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
        if entropy_mask is not None:
            per_token_loss = per_token_loss * entropy_mask
        if self.beta != 0.0:
            per_token_loss = per_token_loss + self.beta * per_token_kl

        if self.loss_type == "grpo":
            loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean()
        elif self.loss_type == "bnpo":
            loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0)
        elif self.loss_type == "dr_grpo":
            loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length)
        else:
            raise ValueError(f"Unknown loss type: {self.loss_type}")

        # Log the metrics
        mode = "train" if self.model.training else "eval"

        completion_token_count = completion_mask.sum().clamp(min=1.0)

        def masked_batch_mean(x):
            if x.shape[1] == 1:  # when importance_sampling_level == "sequence"
                return x.mean()
            else:
                return (x * completion_mask).sum() / completion_token_count

        if self.beta != 0.0:
            mean_kl = masked_batch_mean(per_token_kl)
            self._metrics[mode][f"{metric_prefix}_kl"].append(self.accelerator.gather(mean_kl).nanmean().item())

        mean_entropy = masked_batch_mean(entropies)
        self._metrics[mode][f"{metric_prefix}_entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item())

        # Compute the clipped probability ratios
        is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0)
        is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0)
        is_region_clipped = is_low_clipped | is_high_clipped

        low_clip = masked_batch_mean(is_low_clipped.float())
        high_clip = masked_batch_mean(is_high_clipped.float())
        clip_ratio = masked_batch_mean(is_region_clipped.float())

        gathered_low_clip = self.accelerator.gather(low_clip)
        self._metrics[mode][f"{metric_prefix}_clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item())
        self._metrics[mode][f"{metric_prefix}_clip_ratio/low_min"].append(nanmin(gathered_low_clip).item())
        gathered_high_clip = self.accelerator.gather(high_clip)
        self._metrics[mode][f"{metric_prefix}_clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item())
        self._metrics[mode][f"{metric_prefix}_clip_ratio/high_max"].append(nanmax(gathered_high_clip).item())
        gathered_clip_ratio = self.accelerator.gather(clip_ratio)
        self._metrics[mode][f"{metric_prefix}_clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item())

        if self.beta != 0.0:
            del ref_per_token_logps, per_token_kl
        del entropies
        del input_ids, attention_mask, log_ratio, coef_1, coef_2, per_token_loss1, per_token_loss2, per_token_loss
        del is_low_clipped, is_high_clipped, is_region_clipped, low_clip, high_clip, clip_ratio
        del gathered_low_clip, gathered_high_clip, gathered_clip_ratio
        
        return loss, per_token_logps

    def _compute_loss(self, model, inputs):
        '''
        loss of q & loss of p & kl between q and p
        1. First compute the standard GRPO loss separately for p and q using their rewards and corresponding trajectories
        '''
        p_staffs_for_loss = {
            key.removeprefix("p(y,z|x)_"): inputs[key] for key in inputs if key.startswith("p(y,z|x)_")
        }
        q_staffs_for_loss = {
            key.removeprefix("q(z|x,y)_"): inputs[key] for key in inputs if key.startswith("q(z|x,y)_")
        }
        synthetic_enhanced_reasoning_trajectory = {
            key.removeprefix("x_(z|x,y)_y_"): inputs[key] for key in inputs if key.startswith("x_(z|x,y)_y_")
        }

        mode = "train" if model.training else "eval"
        device = self.accelerator.device
    # 1. First compute the standard GRPO loss separately for p and q using their rewards and corresponding trajectories
        # rank0_print(f"{list(p_staffs_for_loss.keys())=}")
        p_grpo_loss, p_per_token_logps = self._compute_grpo_loss(model, p_staffs_for_loss, metric_prefix="p(y,z|x)")
        q_grpo_loss, q_per_token_logps = self._compute_grpo_loss(model, q_staffs_for_loss, metric_prefix="q(z|x,y)")

    # 2. Compute the KL divergence between p(z|x) and q(z|x,y)
        # Check if all GPUs have valid samples
        # rank0_print(f"{q_staffs_for_loss['format_correct_mask']=}")
        # valid sample: 1. format_correct 2. not truncated. which means q(z|x,y) has intact reasoning process.
        # truncated mask can be get by q_completion_mask.sum()==0 (this is done via self.mask_truncated_completions)
        rank0_print(f"{q_staffs_for_loss['format_correct_mask']=}")
        rank0_print(f"{(q_staffs_for_loss['completion_mask'].sum(dim=1) > 0 )=}")
        valid_sample_mask = torch.logical_and(
            q_staffs_for_loss["format_correct_mask"], 
            q_staffs_for_loss["completion_mask"].sum(dim=1) > 0
        )
            
        local_valid_count = valid_sample_mask.sum().item()
        # local_valid_count = data_for_z_kl['valid_rollout_mask'].sum().item()
        local_valid_tensor = torch.tensor(local_valid_count, device=self.accelerator.device, dtype=torch.long)
        # Gather valid sample counts from all GPUs
        all_valid_counts = self.accelerator.gather(local_valid_tensor)
        # If any GPU has no valid samples, skip z_kl_loss and sft_loss computation
        all_gpu_have_valid_sample = (all_valid_counts >= 1).all().item()

        # logging whether all GPUs have valid samples
        self._metrics[mode]["all_gpu_have_valid_q_sample"].append(int(all_gpu_have_valid_sample))

        rank0_print(f"🤖 Valid sample counts across all GPUs: {all_valid_counts.tolist()}")
        rank0_print(f"🤖All GPU have valid samples: {all_gpu_have_valid_sample}")

        if all_gpu_have_valid_sample:
            p_prompt_ids, p_prompt_mask = p_staffs_for_loss["prompt_ids"], p_staffs_for_loss["prompt_mask"]

            p_completion_ids, p_completion_mask = synthetic_enhanced_reasoning_trajectory["input_ids"], synthetic_enhanced_reasoning_trajectory["attention_mask"]
            p_input_ids = torch.cat([p_prompt_ids, p_completion_ids], dim=1)
            p_attention_mask = torch.cat([p_prompt_mask, p_completion_mask], dim=1)
            p_logits_to_keep = p_completion_ids.size(1)  # we only need to compute the logits for the completion tokens
            # per token log probs in P(y,z|x)
            valid_p_input_ids = p_input_ids[valid_sample_mask]
            valid_p_attention_mask = p_attention_mask[valid_sample_mask]
            valid_p_per_token_logps, _ = self._get_per_token_logps_and_entropies(
                model=model,
                input_ids=valid_p_input_ids,
                attention_mask=valid_p_attention_mask,
                logits_to_keep=p_logits_to_keep
            ) # Shape: (batch_size, p_completion_seq_len)


            # Prepare data for KL divergence computation: identify valid reasoning tokens
            valid_q_completion_ids = q_staffs_for_loss["completion_ids"][valid_sample_mask]
            valid_p_completion_ids = p_completion_ids[valid_sample_mask] # q format_correct_mask decides whether a sample is valid

            # Compute q_z_start_and_end_idx (B,2); start is </think> tag pos, end is the last non-pad token pos (e.g., completion_mask.sum())
            q_z_start = self._find_first_subsequence(
                valid_q_completion_ids,
                self.think_tag_end_ids
            )
            q_z_end = self._find_first_subsequence(
                valid_q_completion_ids,
                torch.tensor([self.eos_token_id], device=device)
            )
            q_z_end[q_z_end==-1] = valid_q_completion_ids.shape[1]  # if not found, set to the max length

            q_z_start_and_end_idx = torch.stack([q_z_start + 1, q_z_end], dim=1) # exclude <think> tag; half-open interval
            
            # Compute p_z_start_and_end_idx (B,2); start is the first token, end is </think> tag pos
            p_z_start = valid_p_completion_ids.new_zeros(valid_p_completion_ids.shape[0])
            p_z_end = self._find_first_subsequence(
                valid_p_completion_ids,
                self.think_tag_end_ids
            )
            p_z_start_and_end_idx = torch.stack([p_z_start, p_z_end], dim=1)  # exclude <think> tag; half-open interval

            rank0_print(
                f"q_z_start_and_end_idx: \n"
                f"{q_z_start_and_end_idx}\n",
                f"{q_z_start_and_end_idx[:,1] - q_z_start_and_end_idx[:,0]}\n",
                f"p_z_start_and_end_idx: \n"
                f"{p_z_start_and_end_idx}\n"
                f"{p_z_start_and_end_idx[:,1] - p_z_start_and_end_idx[:,0]}\n",
            )

            # logging z tokens for debug
            # iterate token from start to end
            q_z_texts = []
            for i in range(valid_q_completion_ids.shape[0]):
                start, end = q_z_start_and_end_idx[i]
                z_tokens = valid_q_completion_ids[i, start:end]
                z_tokens_text = self.processing_class.decode(z_tokens, skip_special_tokens=False)
                q_z_texts.append(z_tokens_text)
            
            p_z_texts = []
            for i in range(valid_p_completion_ids.shape[0]):
                start, end = p_z_start_and_end_idx[i]
                z_tokens = valid_p_completion_ids[i, start:end]
                z_tokens_text = self.processing_class.decode(z_tokens, skip_special_tokens=False)
                p_z_texts.append(z_tokens_text)

            for q_z_text, p_z_text in zip(q_z_texts, p_z_texts):
                rank0_print(f"\nq_z_text:\n{q_z_text}\n\np_z_text:\n{p_z_text}")

            valid_q_per_token_logps = q_per_token_logps[valid_sample_mask]
            
            z_kl_loss = self._calculate_q_z_kl_divergence(
                q_per_token_logps= valid_q_per_token_logps,
                p_per_token_logps= valid_p_per_token_logps,
                q_z_start_and_end_idx= q_z_start_and_end_idx,
                p_z_start_and_end_idx= p_z_start_and_end_idx,
                constraint_coef=self.args.z_kl_constraint_coef,
                learning_coef=self.args.z_kl_learning_coef,
                sample_weight = inputs["z_kl_sample_weight"][valid_sample_mask]
            ) # mean over valid samples

            self._metrics[mode][f"z_kl"].append(self.accelerator.gather(z_kl_loss).nanmean().item())
            
            del valid_p_per_token_logps, valid_q_per_token_logps
            del valid_p_input_ids, valid_p_attention_mask
            del valid_q_completion_ids, valid_p_completion_ids
            del q_z_start_and_end_idx, p_z_start_and_end_idx
        
        # put loss together
        total_loss = self.args.p_grpo_loss_coef * p_grpo_loss + self.args.q_grpo_loss_coef * q_grpo_loss
        if all_gpu_have_valid_sample:
            total_loss += self.args.z_kl_beta * z_kl_loss
        
        # empty cache for loss backward computation
        del p_per_token_logps, q_per_token_logps
        empty_cache()

        return total_loss
    # override
    def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
        mode = "train" if self.model.training else "eval"
        metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()}  # average the metrics

        # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
        # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
        if mode == "eval":
            metrics = {f"eval_{key}": val for key, val in metrics.items()}

        logs = {**logs, **metrics}
        super(GRPOTrainer, self).log(logs, start_time)
        self._metrics[mode].clear()

        if self.accelerator.is_main_process and self.log_completions:
            if is_rich_available():
                print_prompt_completions_sample(
                    # P
                    self._logs["p(y,z|x)_prompt"],
                    self._logs["p(y,z|x)_completion"],
                    self._logs["p(y,z|x)_rewards"],
                    self._logs["p(y,z|x)_advantages"],
                    self.state.global_step,
                    self.num_completions_to_print,
                )

                # Q
                print_prompt_completions_sample(
                    self._logs["q(z|x,y)_prompt"],
                    self._logs["q(z|x,y)_completion"],
                    self._logs["q(z|x,y)_rewards"],
                    self._logs["q(z|x,y)_advantages"],
                    self.state.global_step,
                    self.num_completions_to_print,
                )

            if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None:
                import pandas as pd

                table = {
                    "step": [str(self.state.global_step)] * len(self._logs["p(y,z|x)_prompt"]),
                    # P
                    "p(y,z|x)_prompt": self._logs["p(y,z|x)_prompt"],
                    "p(y,z|x)_completion": self._logs["p(y,z|x)_completion"],
                    **{f"p(y,z|x)_rewards_{key}": val for key, val in self._logs["p(y,z|x)_rewards"].items()},
                    "p(y,z|x)_advantage": self._logs["p(y,z|x)_advantages"],
                    # Q
                    "q(z|x,y)_prompt": self._logs["q(z|x,y)_prompt"],
                    "q(z|x,y)_completion": self._logs["q(z|x,y)_completion"],
                    **{f"q(z|x,y)_rewards_{key}": val for key, val in self._logs["q(z|x,y)_rewards"].items()},
                    "q(z|x,y)_advantage": self._logs["q(z|x,y)_advantages"],
                }

                df = pd.DataFrame(table)
                if self.wandb_log_unique_prompts:
                    df = df.drop_duplicates(subset=["prompt"])
                wandb.log({"completions": wandb.Table(dataframe=df)}, step=self.state.global_step)
