# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import inspect
import os
import re
import json
import textwrap
import warnings
import numpy as np
from collections import defaultdict, deque
from collections.abc import Sequence, Sized
from contextlib import nullcontext
from functools import partial
from pathlib import Path
from typing import Any, Callable, Optional, Union, List, Dict, Tuple
import time

import datasets
from datetime import datetime
import torch
import torch.utils.data
import transformers
from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed
from datasets import Dataset, IterableDataset
from loguru import logger
from packaging import version
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.utils.data import DataLoader, Sampler
from transformers import (
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoModelForCausalLM,
    AutoProcessor,
    AutoTokenizer,
    GenerationConfig,
    PreTrainedModel,
    PreTrainedTokenizerBase,
    ProcessorMixin,
    Trainer,
    TrainerCallback,
    is_wandb_available,
)
from transformers.trainer_utils import seed_worker
from transformers.utils import is_datasets_available, is_flash_attn_2_available, is_peft_available, is_rich_available

from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from trl.extras.profiling import profiling_context, profiling_decorator
from trl.extras.vllm_client import VLLMClient
from trl.import_utils import is_liger_kernel_available, is_vllm_available
from trl.models import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation
from trl.models.utils import _ForwardRedirection
from trl.trainer.callbacks import SyncRefModelCallback
from open_r1.trainer.ts_config import GRPOConfig
from trl.trainer.utils import (
    disable_dropout_in_model,
    entropy_from_logits,
    generate_model_card,
    get_comet_experiment_url,
    pad,
    print_prompt_completions_sample,
    selective_log_softmax,
)


if is_peft_available():
    from peft import PeftConfig, get_peft_model

if is_liger_kernel_available():
    from liger_kernel.chunked_loss import LigerFusedLinearGRPOLoss

if is_vllm_available():
    from vllm import LLM, SamplingParams
    from vllm.sampling_params import GuidedDecodingParams

if is_wandb_available():
    import wandb

# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]


class RepeatSampler(Sampler):
    """
    Sampler that repeats the indices of a dataset in a structured manner.

    Args:
        data_source (`Sized`):
            Dataset to sample from.
        mini_repeat_count (`int`):
            Number of times to repeat each index per batch.
        batch_size (`int`, *optional*, defaults to `1`):
            Number of unique indices per batch.
        repeat_count (`int`, *optional*, defaults to `1`):
            Number of times to repeat the full sampling process.
        shuffle (`bool`, *optional*, defaults to `True`):
            Whether to shuffle the dataset.
        seed (`int` or `None`, *optional*, defaults to `None`):
            Random seed for reproducibility (only affects this sampler).

    Example:
    ```python
    >>> sampler = RepeatSampler(
    ...     ["a", "b", "c", "d", "e", "f", "g"], mini_repeat_count=2, batch_size=3, repeat_count=4
    ... )
    >>> list(sampler)
    [4, 4, 3, 3, 0, 0,
     4, 4, 3, 3, 0, 0,
     4, 4, 3, 3, 0, 0,
     4, 4, 3, 3, 0, 0,
     1, 1, 2, 2, 6, 6,
     1, 1, 2, 2, 6, 6,
     1, 1, 2, 2, 6, 6,
     1, 1, 2, 2, 6, 6]
    ```

    ```txt
    mini_repeat_count = 3
          -   -   -
         [0,  0,  0,  1,  1,  1,  2,  2,  2,  3,  3,  3,      |
          4,  4,  4,  5,  5,  5,  6,  6,  6,  7,  7,  7,      |
          8,  8,  8,  9,  9,  9, 10, 10, 10, 11, 11, 11,      |
                                                                repeat_count = 2
          0,  0,  0,  1,  1,  1,  2,  2,  2,  3,  3,  3,      |
          4,  4,  4,  5,  5,  5,  6,  6,  6,  7,  7,  7,      |
          8,  8,  8,  9,  9,  9, 10, 10, 10, 11, 11, 11, ...] |
          ---------   ---------   ---------   ---------
           ---------   ---------   ---------   ---------
            ---------   ---------   ---------   ---------
                         batch_size = 12
    ```
    """

    def __init__(
        self,
        data_source: Sized,
        mini_repeat_count: int,
        batch_size: int = 1,
        repeat_count: int = 1,
        shuffle: bool = True,
        seed: Optional[int] = None,
    ):
        self.data_source = data_source
        self.mini_repeat_count = mini_repeat_count
        self.batch_size = batch_size
        self.repeat_count = repeat_count
        self.num_samples = len(data_source)
        self.shuffle = shuffle
        self.seed = seed

        if shuffle:
            self.generator = torch.Generator()  # Create a local random generator
            if seed is not None:
                self.generator.manual_seed(seed)

    def __iter__(self):
        if self.shuffle:
            # E.g., [2, 4, 3, 1, 0, 6, 5] (num_samples = 7)
            indexes = torch.randperm(self.num_samples, generator=self.generator).tolist()
        else:
            indexes = list(range(self.num_samples))

        #    [2, 4, 3, 1, 0, 6, 5]
        # -> [[2, 4, 3], [1, 0, 6], [5]]  (batch_size = 3)
        indexes = [indexes[i : i + self.batch_size] for i in range(0, len(indexes), self.batch_size)]

        #    [[2, 4, 3], [1, 0, 6], [5]]
        # -> [[2, 4, 3], [1, 0, 6]]
        indexes = [chunk for chunk in indexes if len(chunk) == self.batch_size]

        for chunk in indexes:
            for _ in range(self.repeat_count):
                for index in chunk:
                    for _ in range(self.mini_repeat_count):
                        yield index

    def __len__(self) -> int:
        return (self.num_samples // self.batch_size) * self.batch_size * self.mini_repeat_count * self.repeat_count


# torch.nanstd doesn't exist, so we define it here
def nanstd(tensor: torch.Tensor) -> torch.Tensor:
    """
    Compute the standard deviation of a tensor, ignoring NaNs. This function only supports 1D tensors.

    Args:
        tensor (`torch.Tensor`):
            Input tensor of shape `(N,)`.

    Returns:
        `torch.Tensor`:
            Standard deviation of the tensor, ignoring NaNs.
    """
    variance = torch.nanmean((tensor - torch.nanmean(tensor, keepdim=True)) ** 2)  # Compute variance ignoring NaNs
    count = torch.sum(~torch.isnan(tensor))  # Count of non-NaN values
    variance *= count / (count - 1)  # Bessel's correction
    return torch.sqrt(variance)


def split_tensor_dict(
    tensor_dict: dict[str, Optional[torch.Tensor]], num_chunks: int
) -> list[dict[str, Optional[torch.Tensor]]]:
    """
    Splits a dictionary of tensors along the first dimension into `num_chunks` equal parts.

    Example:
    ```python
    >>> x = torch.arange(12).reshape(6, 2)
    >>> y = torch.arange(6).reshape(6, 1)
    >>> tensor_dict = {"x": x, "y": y}
    >>> split_tensor_dict(tensor_dict, 3)
    [
        {"x": tensor([[0, 1], [2, 3]]), "y": tensor([[0], [1]])},
        {"x": tensor([[4, 5], [6, 7]]), "y": tensor([[2], [3]])},
        {"x": tensor([[ 8,  9], [10, 11]]), "y": tensor([[4], [5]])}
    ]
    ```
    """
    first_tensor = next(tensor for tensor in tensor_dict.values() if tensor is not None)
    chunk_size = first_tensor.shape[0] // num_chunks
    return [
        {
            key: tensor[i * chunk_size : (i + 1) * chunk_size] if tensor is not None else None
            for key, tensor in tensor_dict.items()
        }
        for i in range(num_chunks)
    ]


def shuffle_sequence_dict(seq_dict: dict[str, Optional[Sequence]]) -> dict[str, Optional[Sequence]]:
    """
    Shuffles all sequence-like values in a dictionary along the first dimension in unison.

    Example:
    ```python
    >>> x = torch.arange(6).reshape(3, 2)
    >>> y = ["a", "b", "c"]
    >>> seq_dict = {"x": x, "y": y}
    >>> shuffle_sequence_dict(seq_dict)
    {'x': tensor([[2, 3],
                  [0, 1],
                  [4, 5]]),
     'y': ['b', 'a', 'c']}
    ```
    """
    # Determine batch size from the first non-None sequence
    batch_size = len(next(v for v in seq_dict.values() if v is not None))
    permutation = torch.randperm(batch_size)

    def permute(v: Optional[Sequence]) -> Optional[Sequence]:
        if v is None:
            return None
        if isinstance(v, torch.Tensor):
            return v[permutation]
        return [v[i] for i in permutation]

    return {key: permute(val) for key, val in seq_dict.items()}


def nanmin(tensor: torch.Tensor) -> torch.Tensor:
    """
    Compute the minimum value of a tensor, ignoring NaNs. This function only supports 1D tensors.

    Args:
        tensor (`torch.Tensor`): Input tensor of shape `(N,)`.

    Returns:
        `torch.Tensor`: Minimum value of the tensor, ignoring NaNs. Returns NaN if all values are NaN.
    """
    if torch.isnan(tensor).all():
        return torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device)
    return torch.min(tensor[~torch.isnan(tensor)])


def nanmax(tensor: torch.Tensor) -> torch.Tensor:
    """
    Compute the maximum value of a tensor, ignoring NaNs. This function only supports 1D tensors.

    Args:
        tensor (`torch.Tensor`): Input tensor of shape `(N,)`.

    Returns:
        `torch.Tensor`: Maximum value of the tensor, ignoring NaNs. Returns NaN if all values are NaN.
    """
    if torch.isnan(tensor).all():
        return torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device)
    return torch.max(tensor[~torch.isnan(tensor)])


def identity(x):
    """Do we really need docs for this?"""
    return x


def split_pixel_values_by_grid(batch: dict[str, torch.Tensor]) -> dict[str, Union[torch.Tensor, list[torch.Tensor]]]:
    """
    Splits `batch["pixel_values"]` into a list of tensors based on the product of each row in
    `batch["image_grid_thw"]`, while keeping other entries unchanged.
    """
    if "image_grid_thw" not in batch or "pixel_values" not in batch:
        return batch

    lengths = batch["image_grid_thw"].prod(dim=1).tolist()  # [batch_size]
    pixel_values = batch["pixel_values"]  # [total, feature_dim]

    if sum(lengths) != pixel_values.size(0):
        raise ValueError(f"Mismatch: sum(lengths) = {sum(lengths)} != pixel_values.size(0) = {pixel_values.size(0)}")

    split_values = list(torch.split(batch["pixel_values"], lengths, dim=0))
    return {**batch, "pixel_values": split_values}


def unsplit_pixel_values_by_grid(batch: dict[str, Union[torch.Tensor, list[torch.Tensor]]]) -> dict[str, torch.Tensor]:
    """
    Opposite of `split_pixel_values_by_grid`. Merges a list of tensors in `batch["pixel_values"]`
    back into a single tensor along the first dimension.
    """
    pixel_values = batch.get("pixel_values")

    if isinstance(pixel_values, list):
        merged = torch.cat(pixel_values, dim=0)
        return {**batch, "pixel_values": merged}
    else:
        return batch


def get_high_entropy_mask(entropies: torch.Tensor, mask: torch.Tensor, threshold: float) -> torch.Tensor:
    """
    Returns a binary mask identifying tokens whose entropy exceeds a given quantile threshold.

    Args:
        entropies (`torch.Tensor`):
            Tensor of shape (batch_size, seq_len) with per-token entropy values.
        mask (`torch.Tensor`):
            Binary mask of the same shape as `entropies`, where `1` indicates valid tokens and `0` padding.
        threshold (`float`):
            Quantile threshold between `0.0` and `1.0` to select high-entropy tokens.

    Returns:
        `torch.Tensor`:
            Boolean mask of shape (batch_size, seq_len), where `True` indicates tokens with entropy >= threshold and
            `False` otherwise.
    """
    non_pad_entropies = entropies[mask.bool()].float()
    if non_pad_entropies.numel() == 0:
        return torch.zeros_like(entropies, dtype=torch.bool)
    entropy_threshold = torch.quantile(non_pad_entropies, threshold)
    masked_entropies = entropies * mask.float()
    entropy_mask = masked_entropies >= entropy_threshold
    return entropy_mask & mask.bool()  # ensure padding tokens are always masked out


def truncate_with_protected_tokens(
    ids: torch.Tensor, mask: torch.Tensor, target_length: int, protected_tokens: list[int]
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Truncate tensors to target length while preserving protected tokens.

    Args:
        ids (`torch.Tensor`):
            Input tensor of token IDs, shape (batch_size, sequence_length).
        mask (`torch.Tensor`):
            Input tensor of attention masks, shape (batch_size, sequence_length).
        target_length (`int`):
            Desired length of the output sequences.
        protected_tokens (`list[int]`):
            List of token IDs that should be preserved in the output.
    """
    protected_set = set(protected_tokens)

    def process_sequence(ids, mask):
        # Create boolean masks
        is_protected = torch.tensor([x.item() in protected_set for x in ids])
        is_non_protected = ~is_protected

        # Count tokens
        num_protected = is_protected.sum().item()
        num_non_protected_needed = target_length - num_protected

        if num_non_protected_needed < 0:
            raise ValueError(
                f"target_length ({target_length}) is too small for the protected tokens ({num_protected} tokens). "
                f"Please increase target length to at least {num_protected} or disable truncation."
            )

        # Select which non-protected tokens to keep (rightmost ones)
        non_protected_indices = torch.where(is_non_protected)[0]
        keep_non_protected = torch.zeros_like(is_non_protected)
        if num_non_protected_needed > 0:
            keep_indices = non_protected_indices[-num_non_protected_needed:]
            keep_non_protected[keep_indices] = True

        # Final mask: protected OR selected non-protected
        keep_mask = is_protected | keep_non_protected

        return ids[keep_mask], mask[keep_mask]

    # Process each sequence in the batch
    truncated_seq = []
    truncated_mask = []

    for i in range(ids.shape[0]):
        new_ids, new_mask = process_sequence(ids[i], mask[i])
        truncated_seq.append(new_ids)
        truncated_mask.append(new_mask)

    return torch.stack(truncated_seq), torch.stack(truncated_mask)


class TSGRPOTrainer(Trainer):
    """
    Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
    paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language
    Models](https://huggingface.co/papers/2402.03300).

    Example:

    ```python
    from datasets import load_dataset
    from trl import GRPOTrainer

    dataset = load_dataset("trl-lib/tldr", split="train")


    def reward_func(completions, **kwargs):
        # Dummy reward function that rewards completions with more unique letters.
        return [float(len(set(completion))) for completion in completions]


    trainer = GRPOTrainer(
        model="Qwen/Qwen2-0.5B-Instruct",
        reward_funcs=reward_func,
        train_dataset=dataset,
    )

    trainer.train()
    ```

    Args:
        model (`Union[str, PreTrainedModel]`):
            Model to be trained. Can be either:

            - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a
              path to a *directory* containing model weights saved using
              [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
              using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keyword arguments in
              `args.model_init_kwargs`.
            - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
        reward_funcs (`Union[RewardFunc, list[RewardFunc]]`):
            Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
            functions with the prompts and completions and sum the rewards. Can be either:

            - A single reward function, such as:
                - A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a
                path to a *directory* containing model weights saved using
                [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
                using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the
                keyword arguments in `args.model_init_kwargs`.
                - A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported.
                - A custom reward function: The function is provided with the prompts and the generated completions,
                  plus any additional columns in the dataset. It should return a list of rewards. Custom reward
                  functions can also return `None` when the reward is not applicable to those samples. This is useful
                  for multi-task training where different reward functions apply to different types of samples. When a
                  reward function returns `None` for a sample, that reward function is excluded from the reward
                  calculation for that sample. For more details, see [Using a custom reward
                  function](#using-a-custom-reward-function).

                  The trainer's state is also passed to the reward function. The trainer's state is an instance of
                  [`~transformers.TrainerState`] and can be accessed by accessing the `trainer_state` argument to the
                  reward function's signature.
            - A list of reward functions, where each item can independently be any of the above types. Mixing different
            types within the list (e.g., a string model ID and a custom reward function) is allowed.
        args ([`GRPOConfig`], *optional*, defaults to `None`):
            Configuration for this trainer. If `None`, a default configuration is used.
        train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
            Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is
            ignored. The format of the samples can be either:

            - [Standard](dataset_formats#standard): Each sample contains plain text.
            - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
              and content).
        eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
            Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
        processing_class ([`~transformers.PreTrainedTokenizerBase`] or [`~transformers.ProcessorMixin`], *optional*, defaults to `None`):
            Processing class used to process the data. The padding side must be set to "left". If `None`, the
            processing class is loaded from the model's name with [`~transformers.AutoProcessor.from_pretrained`]. A
            padding token, `tokenizer.pad_token`, must be set. If the processing class has not set a padding token,
            `tokenizer.eos_token` will be used as the default.
        reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`):
            Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:

            - A single processing class: Used when `reward_funcs` contains only one reward function.
            - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`.
            If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is
            `None`, the tokenizer for the model is automatically loaded using
            [`~transformers.AutoTokenizer.from_pretrained`]. For elements in `reward_funcs` that are custom reward
            functions (not [`~transformers.PreTrainedModel`]), the corresponding entries in `reward_processing_classes`
            are ignored.
        callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
            List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed
            in [here](https://huggingface.co/docs/transformers/main_classes/callback).

            If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
            method.
        optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
            A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
            model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
        peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
            PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
    """

    _tag_names = ["trl", "grpo"]

    def __init__(
        self,
        model: Union[str, PreTrainedModel],
        reward_funcs: Union[RewardFunc, list[RewardFunc]],
        args: Optional[GRPOConfig] = None,
        train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
        eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
        processing_class: Optional[Union[PreTrainedTokenizerBase, ProcessorMixin]] = None,
        reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
        callbacks: Optional[list[TrainerCallback]] = None,
        optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
        peft_config: Optional["PeftConfig"] = None,
    ):
        # Args
        if args is None:
            model_name = model if isinstance(model, str) else model.config._name_or_path
            model_name = model_name.split("/")[-1]
            args = GRPOConfig(f"{model_name}-GRPO")

        # Models
        # Trained model
        model_init_kwargs = args.model_init_kwargs or {}
        if isinstance(model, str):
            model_id = model
            torch_dtype = model_init_kwargs.get("torch_dtype")
            if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
                pass  # torch_dtype is already a torch.dtype or "auto" or None
            elif isinstance(torch_dtype, str):  # it's a str, but not "auto"
                torch_dtype = getattr(torch, torch_dtype)
                model_init_kwargs["torch_dtype"] = torch_dtype
            else:
                raise ValueError(
                    "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
                    f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
                )
            # Disable caching if gradient checkpointing is enabled (not supported)
            # config = AutoConfig.from_pretrained(model_id)
            # architecture = getattr(transformers, config.architectures[0])
            model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
        else:
            model_id = model.config._name_or_path
            if args.model_init_kwargs is not None:
                raise ValueError(
                    "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
                    "This argument can only be used when the `model` argument is a string."
                )

        # Some models (SmolVLM/Idefics3) don't support `logits_to_keep` argument and error out if we pass it
        # Inspect the forward method before we wrap the model with PEFT
        self.model_kwarg_keys = (
            inspect.signature(model.forward).parameters.keys()
            if not hasattr(model, "get_base_model")
            else inspect.signature(model.get_base_model().forward).parameters.keys()
        )

        if peft_config is not None:
            if not is_peft_available():
                raise ImportError("PEFT is required to use `peft_config`. Run `pip install peft`.")
            model = get_peft_model(model, peft_config)

        # Enable gradient checkpointing if requested
        if args.gradient_checkpointing:
            model = self._enable_gradient_checkpointing(model, args)

        # Processing class
        if processing_class is None:
            processing_class = AutoProcessor.from_pretrained(model.config._name_or_path, trust_remote_code=True)

        # Handle pad token for processors or tokenizers
        if isinstance(processing_class, ProcessorMixin):
            tokenizer = processing_class.tokenizer
        elif isinstance(processing_class, PreTrainedTokenizerBase):
            tokenizer = processing_class
        else:
            raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`")

        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        self.pad_token = tokenizer.pad_token
        self.pad_token_id = tokenizer.pad_token_id
        self.eos_token_id = tokenizer.eos_token_id
        self.image_token = getattr(processing_class, "image_token", None)
        self.image_token_id = getattr(processing_class, "image_token_id", None)
        self.vision_start_token_id = getattr(model.config, "vision_start_token_id", None)
        self.vision_end_token_id = getattr(model.config, "vision_end_token_id", None)

        # Reward functions
        if not isinstance(reward_funcs, list):
            reward_funcs = [reward_funcs]
        self.reward_func_names = []
        for i, reward_func in enumerate(reward_funcs):
            if isinstance(reward_func, str):
                reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
                    reward_func, num_labels=1, **model_init_kwargs
                )
            if isinstance(reward_funcs[i], nn.Module):  # Use Module over PretrainedModel for compat w/ compiled models
                self.reward_func_names.append(reward_funcs[i].config._name_or_path.split("/")[-1])
            else:
                self.reward_func_names.append(reward_funcs[i].__name__)
        self.reward_funcs = reward_funcs

        # Reward weights
        if args.reward_weights is not None:
            if len(args.reward_weights) != len(reward_funcs):
                raise ValueError(
                    f"Number of reward weights ({len(args.reward_weights)}) must match number of reward "
                    f"functions ({len(reward_funcs)})"
                )
            self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
        else:
            self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)

        # Reward processing class
        if reward_processing_classes is None:
            reward_processing_classes = [None] * len(reward_funcs)
        elif not isinstance(reward_processing_classes, list):
            reward_processing_classes = [reward_processing_classes]
        else:
            if len(reward_processing_classes) != len(reward_funcs):
                raise ValueError("The number of reward processing classes must match the number of reward functions.")

        for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
            if isinstance(reward_func, PreTrainedModel):
                if reward_processing_class is None:
                    reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path, trust_remote_code=True)
                if reward_processing_class.pad_token_id is None:
                    reward_processing_class.pad_token = reward_processing_class.eos_token
                # The reward model computes the reward for the latest non-padded token in the input sequence.
                # So it's important to set the pad token ID to the padding token ID of the processing class.
                reward_func.config.pad_token_id = reward_processing_class.pad_token_id
                reward_processing_classes[i] = reward_processing_class
        self.reward_processing_classes = reward_processing_classes
        
        self.max_prompt_length = args.max_prompt_length
        self.max_completion_length = args.max_completion_length  # = |o_i| in the GRPO paper
        self.num_generations = args.num_generations  # = G in the GRPO paper
        self.temperature = args.temperature
        self.top_p = args.top_p
        self.top_k = args.top_k
        self.min_p = args.min_p
        self.repetition_penalty = args.repetition_penalty
        self.use_transformers_paged = args.use_transformers_paged
        self.use_vllm = args.use_vllm
        self.vllm_mode = args.vllm_mode
        self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization  # only applies to colocation mode
        self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size  # only applies to colocation mode
        self.use_liger_loss = args.use_liger_loss

        # Multi-turn tool calling configuration
        self.enable_multi_turn_tools = getattr(args, "enable_multi_turn_tools", False)
        self.max_tool_calls = getattr(args, "max_tool_calls", 8)
        self.tool_stop_string = getattr(args, "tool_stop_string", "</tool_call>")
        self.tools = getattr(args, "tools", None)
        if self.tools is not None and isinstance(self.tools, str):
            with open(self.tools, 'r') as f:
                self.tools = json.load(f)
        # End multi-turn tool calling configuration

        self.loss_type = args.loss_type
        self.scale_rewards = args.scale_rewards
        self.importance_sampling_level = args.importance_sampling_level
        self.mask_truncated_completions = args.mask_truncated_completions
        self.top_entropy_quantile = args.top_entropy_quantile
        if self.use_liger_loss and self.top_entropy_quantile < 1.0:
            raise NotImplementedError(
                "Liger Kernels don't currently support masking token positions based on entropy."
            )
        if self.use_liger_loss and not self.importance_sampling_level == "token":
            raise NotImplementedError(
                "Liger Kernels currently only support token-level importance sampling. Please set"
                "`importance_sampling_level` to 'token'."
            )

        # Datasets
        self.shuffle_dataset = args.shuffle_dataset

        if (
            isinstance(train_dataset, IterableDataset)
            or isinstance(eval_dataset, IterableDataset)
            or (
                isinstance(eval_dataset, dict) and any(isinstance(ds, IterableDataset) for ds in eval_dataset.values())
            )
        ):
            # See https://github.com/huggingface/trl/issues/3213
            raise NotImplementedError(
                "Iterable datasets are not yet supported in GRPOTrainer. Please use a standard dataset instead."
            )

        # Multi-step
        self.num_iterations = args.num_iterations  # = 𝜇 in the GRPO paper
        self.epsilon_low = args.epsilon
        self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon
        # Tracks the number of iterations (forward + backward passes), including those within a grad accum cycle
        self._step = 0
        # Buffer the batch to reuse generated outputs across multiple updates. For more details, see
        # `_get_train_sampler` and `_prepare_inputs`.
        self._buffered_inputs = None

        # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
        # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
        # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
        # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
        # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
        # This acts as a flag to indicate that the warning has already been issued.
        model.warnings_issued["estimate_tokens"] = True

        super().__init__(
            model=model,
            args=args,
            data_collator=identity,  # No data collation is needed in GRPO
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            processing_class=processing_class,
            callbacks=callbacks,
            optimizers=optimizers,
        )

        # Reference model
        self.beta = args.beta
        if self.beta == 0.0:
            # If beta is 0.0, the reference model is not needed
            self.ref_model = None
        elif is_peft_model(model):
            # If PEFT is used, the reference model is not needed since the adapter can be disabled
            # to revert to the initial model.
            self.ref_model = None
        else:
            # For deepspeed, fsdp or non-distributed models, create a reference model from scratch
            config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
            # architecture = getattr(transformers, config.architectures[0])
            self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)

        # Disable dropout in the models
        if args.disable_dropout:
            disable_dropout_in_model(model)
            if self.ref_model is not None:
                disable_dropout_in_model(self.ref_model)

        # Liger loss
        if self.use_liger_loss:
            if not is_liger_kernel_available():
                raise ImportError(
                    "Liger is required to use `liger_loss` as the GRPO loss. Run `pip install liger-kernel`."
                )
            # redirect the model.module forward to the model forward to ensure pre-forward hooks are called
            self._forward_redirection = _ForwardRedirection()

            self.liger_grpo_loss = LigerFusedLinearGRPOLoss(
                beta=self.beta,
                epsilon_low=self.epsilon_low,
                epsilon_high=self.epsilon_high,
                temperature=self.temperature,
                use_ref_model=self.beta != 0.0,
                loss_type=self.loss_type,
                max_completion_length=self.max_completion_length,
            )

        # Initialize the metrics
        self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
        self._total_train_tokens = 0
        self.log_completions = args.log_completions
        self.wandb_log_unique_prompts = args.wandb_log_unique_prompts
        self.num_completions_to_print = args.num_completions_to_print
        self._printed_global_steps = set()
        # Keep logs sized to the generation batch to record only outputs from the latest model update.
        self._logs = {
            "image": deque(maxlen=2048),
            "prompt": deque(maxlen=2048),
            "solution": deque(maxlen=2048),
            "timeseries": deque(maxlen=2048),
            "completion": deque(maxlen=2048),
            "rewards": defaultdict(lambda: deque(maxlen=2048)),
            "advantages": deque(maxlen=2048),
            "rlvr": defaultdict(lambda: deque(maxlen=2048)),
        }

        # Ensure each process receives a unique seed to prevent duplicate completions when generating with
        # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but
        # it's safer to set it in all cases.
        set_seed(args.seed, device_specific=True)

        if self.use_vllm:
            if not is_vllm_available():
                raise ImportError(
                    "vLLM is not available and `use_vllm` is set to True. Please install vLLM with "
                    "`pip install vllm` to use it."
                )

            # Only support colocate mode for ThinkTime
            # Make sure vllm_tensor_parallel_size group size evenly divides the world size
            if not self.accelerator.num_processes % self.vllm_tensor_parallel_size == 0:
                raise ValueError(
                    f"vllm_tensor_parallel_size ({self.vllm_tensor_parallel_size}) must divide world size "
                    f"({self.accelerator.num_processes}) evenly."
                )

            if self.vllm_tensor_parallel_size > 1:
                # Create subgroups of ranks for TP
                self.tp_group, _ = torch.distributed.new_subgroups_by_enumeration(
                    [
                        list(range(i * self.vllm_tensor_parallel_size, (i + 1) * self.vllm_tensor_parallel_size))
                        for i in range(self.accelerator.num_processes // self.vllm_tensor_parallel_size)
                    ]
                )

            # Set environment variables for vLLM distributed training
            os.environ["RANK"] = str(self.accelerator.process_index)
            os.environ["LOCAL_RANK"] = str(self.accelerator.local_process_index)
            os.environ["WORLD_SIZE"] = str(self.accelerator.num_processes)
            os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "localhost")
            os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "12345")

            if self.max_prompt_length is not None and self.max_completion_length is not None:
                max_model_len = self.max_prompt_length + self.max_completion_length
            else:
                max_model_len = None

            import open_r1.trainer.ThinkTime_vllm

            # Initialize vLLM with ThinkTime support
            self.llm = LLM(
                model=model.name_or_path,
                enforce_eager=True, 
                tensor_parallel_size=args.vllm_tensor_parallel_size,
                gpu_memory_utilization=self.vllm_gpu_memory_utilization,
                max_num_seqs=self.args.per_device_train_batch_size
                * self.vllm_tensor_parallel_size
                * self.args.steps_per_generation,
                max_model_len=max_model_len,
                distributed_executor_backend="external_launcher",
                seed=self.accelerator.process_index // self.vllm_tensor_parallel_size,
                max_num_batched_tokens=4096,
                model_impl=self.args.vllm_model_impl,
                trust_remote_code=True,
                disable_custom_all_reduce=True,
                enable_prefix_caching=False
            )

            # vLLM specific sampling arguments
            self.guided_decoding_regex = args.vllm_guided_decoding_regex
            self._last_loaded_step = -1

            # Synchronize all processes after vLLM initialization
            self.accelerator.wait_for_everyone()
        else:
            generation_kwargs = {
                "max_new_tokens": self.max_completion_length,
                "do_sample": True,
                "pad_token_id": tokenizer.pad_token_id,
                "bos_token_id": tokenizer.bos_token_id,
                "eos_token_id": tokenizer.eos_token_id,
                "temperature": self.temperature,
                "top_p": self.top_p,
                "top_k": self.top_k,
                "min_p": self.min_p,
                "repetition_penalty": self.repetition_penalty,
                "cache_implementation": args.cache_implementation,
            }
            if args.use_transformers_paged:
                generation_kwargs["max_batch_tokens"] = 512
                generation_kwargs["num_blocks"] = 1024
                generation_kwargs["block_size"] = 128
            if args.generation_kwargs is not None:
                generation_kwargs.update(args.generation_kwargs)
            self.generation_config = GenerationConfig(**generation_kwargs)

        # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
        # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
        # self.model_accepts_loss_kwargs to False to enable scaling.
        self.model_accepts_loss_kwargs = False

        # Add tags to the model
        self.model.add_model_tags(self._tag_names)

        if self.ref_model is not None:
            if self.is_deepspeed_enabled:
                self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
            elif self.is_fsdp_enabled:
                self.ref_model = prepare_fsdp(self.ref_model, self.accelerator)
            else:
                self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)

        if args.sync_ref_model:
            self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))

        for i, reward_func in enumerate(self.reward_funcs):
            if isinstance(reward_func, PreTrainedModel):
                if self.is_deepspeed_enabled:
                    self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator)
                else:
                    # set device placement to True to make `prepare_model` move `reward_func` to device when using fsdp
                    self.reward_funcs[i] = self.accelerator.prepare_model(
                        reward_func, evaluation_mode=True, device_placement=True
                    )

    @staticmethod
    def _hash_str(s: str):
        import hashlib
        return hashlib.md5(str(s).encode()).hexdigest()[:3]

    def _set_signature_columns_if_needed(self):
        # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
        # By default, this method sets `self._signature_columns` to the model's expected inputs.
        # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
        # Instead, we set them to the columns expected by the `training_step` method, hence the override.
        if self._signature_columns is None:
            self._signature_columns = ["prompt", "image"]

    # This method overrides `Trainer.get_train_dataloader` to support our custom batching strategy.
    # Instead of returning a standard per-step batch (i.e., `per_device_batch_size), our dataloader loads an
    # *generation* batch (i.e., `per_device_batch_size × steps_per_generation`). This allows us to generate completions
    # once every steps_per_generation step—rather than once per accumulation step—which is significantly more
    # efficient. The only change from the original implementation is multiplying the batch size by
    # `steps_per_generation`. Thus, `_prepare_inputs` is called with this *generation* batch, and it handles the
    # splitting internally.
    # Maintenance note: This method is a copy-paste of the original `Trainer.get_train_dataloader` with only one line
    # modification. As a result, some parts of the method aren't relevant to GRPO, but we keep them to stay one line
    # apart from the super method, ensuring easier maintenance in the future.
    def get_train_dataloader(self):
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")

        train_dataset = self.train_dataset
        data_collator = self.data_collator
        if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
            train_dataset = self._remove_unused_columns(train_dataset, description="training")
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="training")

        dataloader_params = {
            "batch_size": self._train_batch_size * self.args.steps_per_generation,  # < this is the change
            "collate_fn": data_collator,
            "num_workers": self.args.dataloader_num_workers,
            "pin_memory": self.args.dataloader_pin_memory,
            "persistent_workers": self.args.dataloader_persistent_workers,
        }

        if not isinstance(train_dataset, torch.utils.data.IterableDataset):
            dataloader_params["sampler"] = self._get_train_sampler()
            dataloader_params["drop_last"] = self.args.dataloader_drop_last
            if version.parse(transformers.__version__) >= version.parse("4.52.0"):
                # from transformers 4.52.0, the `seed_worker` requires the `num_workers` and `rank` arguments
                dataloader_params["worker_init_fn"] = partial(
                    seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index
                )
            else:
                dataloader_params["worker_init_fn"] = seed_worker
            dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor

        return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))

    def _get_train_sampler(self, dataset: Optional[Dataset] = None) -> Sampler:
        # Returns a sampler that
        # 1. ensures each prompt is repeated across multiple processes. This guarantees that identical prompts are
        #    distributed to different GPUs, allowing rewards to be computed and normalized correctly within each prompt
        #    group. Using the same seed across processes ensures consistent prompt assignment, preventing discrepancies
        #    in group formation.
        # 2. repeats the batch multiple times to allow reusing generations across multiple updates. Refer to
        #    _prepare_inputs to see how the generations are stored and reused.

        # In the following figure, the values are the prompt indices. The first row shows the first sampled batch, the
        # second row shows the second sampled batch, and so on.
        #
        #                                      |   GPU 0  |   GPU 1  |
        #
        #                 global_step   step    <-───>  num_generations=2
        #                                       <-───────> per_device_train_batch_size=3
        #  grad_accum    ▲  ▲  0          0     0   0   1   1   2   2   <- Generate for the first `steps_per_generation` (prompts 0 to 11); store the completions; use the first slice to compute the loss
        #     =2         ▼  |  0          1     3   3   4   4   5   5   <- Take the stored generations and use the second slice to compute the loss
        #                   |
        #                   |  1          2     6   6   7   7   8   8   <- Take the stored generations and use the third slice to compute the loss
        #  steps_per_gen=4  ▼  1          3     9   9  10  10  11  11   <- Take the stored generations and use the fourth slice to compute the loss
        #
        #                      2          4    12  12  13  13  14  14   <- Generate for the second `steps_per_generation` (prompts 12 to 23); store the completions; use the first slice to compute the loss
        #                      2          5    15  15  16  16  17  17   <- Take the stored generations and use the second slice to compute the loss
        #                                          ...
        if dataset is None:
            dataset = self.train_dataset
        return RepeatSampler(
            data_source=dataset,
            mini_repeat_count=self.num_generations,
            batch_size=self.args.generation_batch_size // self.num_generations,
            repeat_count=self.num_iterations * self.args.steps_per_generation,
            shuffle=self.shuffle_dataset,
            seed=self.args.seed,
        )

    def _get_eval_sampler(self, eval_dataset) -> Sampler:
        # See _get_train_sampler for an explanation of the sampler.
        eval_num_generations = getattr(self.args, "eval_num_generations", self.num_generations)
        return RepeatSampler(
            data_source=eval_dataset,
            mini_repeat_count=eval_num_generations,
            seed=self.args.seed,
            shuffle=False
        )

    def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: GRPOConfig) -> PreTrainedModel:
        """Enables gradient checkpointing for the model."""
        # Ensure use_cache is disabled
        model.config.use_cache = False

        # Enable gradient checkpointing on the base model for PEFT
        if is_peft_model(model):
            model.base_model.gradient_checkpointing_enable()
        # Enable gradient checkpointing for non-PEFT models
        else:
            model.gradient_checkpointing_enable()

        gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
        use_reentrant = (
            "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
        )

        if use_reentrant:
            model.enable_input_require_grads()

        return model

    def _convert_to_left_padding(self, input_ids, attention_mask):
        """
        Convert both-side padded tensors to left-padded for ThinkTime model compatibility.
        
        Args:
            input_ids: [B, L] tensor with potential both-side padding
            attention_mask: [B, L] tensor marking valid tokens (1) vs padding (0)
            
        Returns:
            left_padded_input_ids: [B, L] tensor with left padding only
            left_padded_attention_mask: [B, L] tensor with left padding only
            right_shift_map: [B] tensor indicating how many positions each sequence was shifted
        """
        B, L = input_ids.shape
        device = input_ids.device
        
        # Find the valid token spans for each sequence
        left_padded_input_ids = torch.full_like(input_ids, self.pad_token_id)
        left_padded_attention_mask = torch.zeros_like(attention_mask)
        right_shift_map = torch.zeros(B, dtype=torch.long, device=device)
        
        for i in range(B):
            # Find first and last valid token positions
            valid_positions = torch.where(attention_mask[i] == 1)[0]
            if len(valid_positions) == 0:
                continue  # All padding, keep as is
                
            first_valid = valid_positions[0].item()
            last_valid = valid_positions[-1].item()
            valid_length = last_valid - first_valid + 1
            
            # Calculate how much we need to shift right to make it left-padded
            left_pad_positions = L - valid_length
            right_shift = left_pad_positions - first_valid
            right_shift_map[i] = right_shift
            
            # Copy valid tokens to left-padded positions
            valid_tokens = input_ids[i, first_valid:last_valid+1]
            left_padded_input_ids[i, left_pad_positions:] = valid_tokens
            left_padded_attention_mask[i, left_pad_positions:] = 1
            
        return left_padded_input_ids, left_padded_attention_mask, right_shift_map

    def _extract_valid_logits(self, logits, new_token_positions, right_shift_map):
        """
        Extract valid logits based on new_token_positions mapping, handling next token prediction offset.
        
        Args:
            logits: [B, L, V] logits from model (left-padded, with multimodal tokens)
            new_token_positions: [B, L] new_token_positions showing valid positions (not False for original tokens)
            
        Returns:
            aligned_logits: [B, C, V] logits aligned with original completion tokens
        """
        B, L, V = logits.shape
        
        # Initialize output logits
        aligned_logits = torch.zeros_like(logits)
        
        for i in range(B):
            # Find valid positions in labels (where original tokens are)
            valid_positions = new_token_positions[i, torch.where(new_token_positions[i] != -1)[0]]
            
            # logger.error(f"{valid_positions.shape=}, {new_token_positions.shape=}, {valid_positions[:10]=}")

            if len(valid_positions) == 0:
                continue

            # Handle next token prediction offset:
            # logits[t] predicts token at position t+1
            # So we need logits[valid_positions[:-1]] to predict tokens at valid_positions[1:]
            valid_positions = torch.cat([valid_positions[1:] - 1, valid_positions[-1:]], dim=0)  # Shift left by 1
            # Make sure the valid_positions is larger than 0 and less than L, if smaller than 0, then set to 0
            valid_positions = torch.clamp(valid_positions, min=0, max=L-1)

            # print(f"{right_shift_map=}, {i=}, {valid_positions.min()=}, {valid_positions.max()=}, {logits[i].shape=}")
            
            # Extract logits for positions that predict the valid tokens
            if len(valid_positions) > 1:
                # Extract the corresponding logits
                valid_logits = logits[i, valid_positions]
                # Move to the correct positions in the output according to right_shift_map
                shift = right_shift_map[i].item()
                # print(f"{i=}, {valid_logits.shape=}, {valid_positions.min()=}, {valid_positions.max()=}, {valid_positions.shape=}, {logits[i].shape=}, {shift=}, {L=}, {L - valid_logits.size(0) - shift=}, {L - shift=}, {aligned_logits.shape=}, {logits.shape=}")
                aligned_logits[i, L - valid_logits.size(0) - shift: L - shift] = valid_logits
                
        return aligned_logits

    @profiling_decorator
    def _get_last_hidden_state(
        self,
        unwrapped_model,
        input_ids,
        attention_mask,
        logits_to_keep,
        pixel_values=None,
        image_grid_thw=None,
        pixel_attention_mask=None,
        image_sizes=None,
        timeseries_data=None,
    ):
        if is_peft_model(unwrapped_model):
            unwrapped_model = unwrapped_model.base_model.model

        # Convert to left padding for ThinkTime compatibility
        left_input_ids, left_attention_mask, right_shift_map = self._convert_to_left_padding(input_ids, attention_mask)

        # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't)
        model_inputs = {"input_ids": left_input_ids, "attention_mask": left_attention_mask}
        dev = input_ids.device

        # For Qwen models:
        if image_grid_thw is not None and pixel_values is not None:
            model_inputs["image_grid_thw"] = image_grid_thw
        # For Gemma, SmolVLM2, LLaVa-Next etc.:
        if pixel_values is not None:
            model_inputs["pixel_values"] = pixel_values
        # For SmolVLM2
        if pixel_attention_mask is not None:
            model_inputs["pixel_attention_mask"] = pixel_attention_mask
        # For LLaVa-Next
        if image_sizes is not None:
            model_inputs["image_sizes"] = image_sizes

        # Add timeseries data for ThinkTime multimodal support
        if timeseries_data is not None:
            # timeseries_data expected as List[List[Tensor]]; flatten and move to dev with correct dtype
            flat_ts = []
            target_dtype = input_ids.dtype if hasattr(input_ids, 'dtype') else torch.float32
            # For embedding layers, typically use float32 or the model's default dtype
            if hasattr(unwrapped_model, 'dtype'):
                target_dtype = unwrapped_model.dtype
            elif hasattr(unwrapped_model, 'model') and hasattr(unwrapped_model.model, 'embed_tokens'):
                target_dtype = unwrapped_model.model.embed_tokens.weight.dtype
            
            for ts_list in timeseries_data:
                if ts_list is None:
                    continue
                for ts in ts_list:
                    if not isinstance(ts, torch.Tensor):
                        ts = torch.as_tensor(ts, device=dev, dtype=target_dtype)
                    else:
                        ts = ts.to(device=dev, dtype=target_dtype)
                    flat_ts.append(ts)
            if len(flat_ts) > 0:
                model_inputs["timeseries"] = torch.stack(flat_ts, dim=0)

        # Only add logits_to_keep if the model supports it
        if "logits_to_keep" in self.model_kwarg_keys:
            # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
            model_inputs["logits_to_keep"] = logits_to_keep + 1

        # Get model outputs including logits and position mappings
        outputs = unwrapped_model.model(**model_inputs)
        last_hidden_state = outputs.last_hidden_state
        
        aligned_hidden_states = self._extract_valid_logits(last_hidden_state, outputs.new_token_positions, right_shift_map)
        
        # Fallback: standard processing without multimodal alignment
        last_hidden_state = last_hidden_state[:, :-1, :]  # (B, L-1, H)
        last_hidden_state = last_hidden_state[:, -logits_to_keep:, :]  # (B, logits_to_keep, H)
        final_hidden_states = last_hidden_state
        
        return final_hidden_states

    def _print_front_and_tail_zeros_num(self, x: torch.Tensor, name: str = "tensor"):
        """Print the number of leading and trailing zeros in a tensor."""
        if x.dim() == 0:
            return
        
        non_zero_idx = []
        for i in range(x.size(0)):
            if torch.any(x[i] != 0):
                non_zero_idx.append(i)

        if len(non_zero_idx) == 0:
            logger.warning(f"{name} has no non-zero elements.")
            return
        first_non_zero = non_zero_idx[0]
        last_non_zero = non_zero_idx[-1]
        leading_zeros = first_non_zero
        trailing_zeros = x.size(0) - last_non_zero - 1
        logger.warning(
            f"[DEBUG] {name} leading zeros: {leading_zeros}, trailing zeros: {trailing_zeros}, "
            f"total size: {x.size(0)}"
        )

    @profiling_decorator
    def _get_per_token_logps_and_entropies(
        self,
        model,
        input_ids,
        attention_mask,
        logits_to_keep,
        batch_size=None,
        compute_entropy=False,
        pixel_values=None,
        image_grid_thw=None,
        pixel_attention_mask=None,
        image_sizes=None,
        timeseries_data=None,
        completion_mask=None,
    ) -> dict[str, Optional[torch.Tensor]]:
        """Compute log-probs and (optionally) entropies for each token."""
        batch_size = batch_size or input_ids.size(0)  # Chunk inputs into smaller batches to reduce memory peak
        all_logps = []
        all_entropies = []
        
        cur_rank = self.accelerator.process_index

        for start in range(0, input_ids.size(0), batch_size):
            input_ids_batch = input_ids[start : start + batch_size]
            attention_mask_batch = attention_mask[start : start + batch_size]
            completion_mask_batch = completion_mask[start : start + batch_size] if completion_mask is not None else None
            dev = input_ids_batch.device

            # Output the last logits_to_keep tokens of input_ids for debugging
            completion_ids = input_ids_batch[0, -logits_to_keep:][torch.where(completion_mask_batch[0, -logits_to_keep:])]
            completion_text = self.processing_class.tokenizer.decode(completion_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)
           #  logger.warning(f"!!!!!!!!!! [INPUT_IDS_{cur_rank}] {input_ids_batch.shape=}, {attention_mask_batch.shape=}, {completion_ids.shape=}, {completion_text=}")

            # Convert to left padding for ThinkTime compatibility
            left_input_ids_batch, left_attention_mask_batch, right_shift_map = self._convert_to_left_padding(input_ids_batch, attention_mask_batch)

            # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't)
            model_inputs = {"input_ids": left_input_ids_batch, "attention_mask": left_attention_mask_batch}

            # Add labels for position tracking (handle label shift)
            # labels = left_attention_mask_batch.clone()
            # ignore_idx = self.model.config.ignore_index if hasattr(self.model.config, 'ignore_index') else -100
            # labels[left_attention_mask_batch == 0] = ignore_idx
            # # model_inputs["labels"] = labels

            # logger.warning(f"!!!!!!!!!! [LEFT_PAD_{cur_rank}] {left_input_ids_batch.shape=}, {left_attention_mask_batch.shape=}, {right_shift_map.shape=}")
            # # Print debug info for leading/trailing zeros
            # self._print_front_and_tail_zeros_num(left_input_ids_batch[0], f"left_input_ids_batch_{cur_rank}")
            # self._print_front_and_tail_zeros_num(left_attention_mask_batch[0], f"left_attention_mask_batch_{cur_rank}")
            # self._print_front_and_tail_zeros_num(attention_mask_batch[0], f"attention_mask_batch_{cur_rank}")
            # self._print_front_and_tail_zeros_num(labels[0] - ignore_idx, f"labels_{cur_rank}")

            if image_grid_thw is not None and pixel_values is not None:
                model_inputs["image_grid_thw"] = image_grid_thw[start : start + batch_size]
                start_pixel_idx = image_grid_thw[:start].prod(-1).sum().item()
                end_pixel_idx = image_grid_thw[: start + batch_size].prod(-1).sum().item()
                model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx]
            elif pixel_values is not None:
                model_inputs["pixel_values"] = pixel_values[start : start + batch_size]
            if pixel_attention_mask is not None:
                model_inputs["pixel_attention_mask"] = pixel_attention_mask[start : start + batch_size]
            if image_sizes is not None:
                model_inputs["image_sizes"] = image_sizes[start : start + batch_size]

            # Add timeseries data for ThinkTime multimodal support
            if timeseries_data is not None:
                # timeseries_data slice is List[List[Tensor]]; flatten and move to dev with correct dtype
                slice_ts = timeseries_data[start : start + batch_size]
                flat_ts = []
                target_dtype = input_ids_batch.dtype if hasattr(input_ids_batch, 'dtype') else torch.float32
                # For embedding layers, typically use float32 or the model's default dtype
                if hasattr(model, 'dtype'):
                    target_dtype = model.dtype
                elif hasattr(model, 'model') and hasattr(model.model, 'embed_tokens'):
                    target_dtype = model.model.embed_tokens.weight.dtype
                
                for ts_list in slice_ts:
                    if ts_list is None:
                        continue
                    for ts in ts_list:
                        if not isinstance(ts, torch.Tensor):
                            ts = torch.as_tensor(ts, device=dev, dtype=target_dtype)
                        else:
                            ts = ts.to(device=dev, dtype=target_dtype)
                        flat_ts.append(ts)
                if len(flat_ts) > 0:
                    model_inputs["timeseries"] = torch.stack(flat_ts, dim=0)

            # Only add logits_to_keep if the model supports it
            if "logits_to_keep" in self.model_kwarg_keys:
                # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
                model_inputs["logits_to_keep"] = logits_to_keep + 1

            # Get model outputs including logits and position mappings
            outputs = model(**model_inputs)
            logits = outputs.logits
            
            # Extract valid logits aligned with original completion tokens
            completion_ids = input_ids_batch[:, -logits_to_keep:]

            # self._print_front_and_tail_zeros_num(logits[0], f"nonrestored_logits_{cur_rank}")
            # self._print_front_and_tail_zeros_num(outputs.new_token_positions[0], f"new_token_positions_{cur_rank}")
            # self._print_front_and_tail_zeros_num(outputs.labels[0] - ignore_idx, f"outputs.labels_{cur_rank}")

            logits = self._extract_valid_logits(logits, outputs.new_token_positions, right_shift_map)
            
            # logger.warning(f"!!!!!!!!!! [RESTORE_RIGHT_PAD_{cur_rank}] {logits.shape=}, {input_ids_batch.shape=}, {outputs.new_token_positions.shape=}")
            # self._print_front_and_tail_zeros_num(logits[0], f"restored_logits_{cur_rank}")
            
            # =standard processing without multimodal alignment
            logits = logits[:, :-1, :]  # (B, L-1, V)
            logits = logits[:, -logits_to_keep:, :]  # (B, logits_to_keep, V)
            completion_ids = input_ids_batch[:, -logits_to_keep:]
            final_logits = logits

            # DEBUG: Decode logits[0] content, applying completion_mask for color
            if start == 0 and completion_mask is not None:
                active_logits = final_logits[0][torch.where(attention_mask_batch[0, -logits_to_keep:] == 1)]
                active_mask = completion_mask[0][torch.where(attention_mask_batch[0, -logits_to_keep:] == 1)]
                # self._debug_decode_logits(active_logits, f"final_logits_batch0_rank{cur_rank}", completion_mask=active_mask)
            elif start == 0:
                active_logits = final_logits[0][torch.where(attention_mask_batch[0, -logits_to_keep:] == 1)]
                # self._debug_decode_logits(active_logits, f"final_logits_batch0_rank{cur_rank}")
                            
            # Divide logits by sampling temperature and compute log probabilities
            final_logits = final_logits / self.temperature
            
            logps = selective_log_softmax(final_logits, completion_ids)  # compute logprobs
            all_logps.append(logps)

            if compute_entropy:
                with torch.no_grad():
                    entropies = entropy_from_logits(logits)
                all_entropies.append(entropies)

        logps = torch.cat(all_logps, dim=0)
        entropies = torch.cat(all_entropies, dim=0) if compute_entropy else None
        
        return logps, entropies

    def enable_multimodal_debug(self, enabled=True):
        """Enable or disable multimodal token position debugging."""
        self._debug_multimodal_positions = enabled
        if enabled:
            print("[DEBUG] Multimodal token position debugging enabled")
        else:
            print("[DEBUG] Multimodal token position debugging disabled")

    def _fix_param_name_to_vllm(self, name, extra_prefixes: Optional[list[str]] = None):
        extra_prefixes = extra_prefixes or []
        prefixes = ["_checkpoint_wrapped_module."] + extra_prefixes
        for prefix in prefixes:
            name = name.replace(prefix, "")
        return name

    def _sync_fsdp1_params_to_vllm(self, module: nn.Module, prefix: str = "", visited=None):
        """Memory-efficient post-order traversal of FSDP modules to extract full parameters and sync with vLLM."""
        # For FSDP1, we need to recurse into children and also use summon_full_params
        if visited is None:
            visited = set()
        for child_name, child_module in module.named_children():
            child_prefix = f"{prefix}.{child_name}" if prefix else child_name
            self._sync_fsdp1_params_to_vllm(
                child_module, prefix=child_prefix, visited=visited
            )  # recurse into the child

        if isinstance(module, FSDP):
            with FSDP.summon_full_params(module, recurse=False, writeback=False):
                for param_name, param in module.named_parameters():
                    full_name = f"{prefix}.{param_name}" if prefix else param_name
                    full_name = self._fix_param_name_to_vllm(full_name, extra_prefixes=["_fsdp_wrapped_module."])

                    if full_name in visited:
                        continue  # skip FSDP subtrees already traversed
                    visited.add(full_name)

                    if self.vllm_mode == "server" and self.accelerator.is_main_process:
                        self.vllm_client.update_named_param(full_name, param.data)
                    elif self.vllm_mode == "colocate":
                        llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
                        llm_model.load_weights([(full_name, param.data)])

    def _sync_fsdp2_params_to_vllm(self, module: nn.Module):
        # For FSDP2, module.state_dict() already covers all parameters, so no need for recursion
        for name, param in module.state_dict().items():
            if param.is_cpu:
                param = param.to(torch.device("cuda"))
            param = param.full_tensor()

            if self.vllm_mode == "server" and self.accelerator.is_main_process:
                self.vllm_client.update_named_param(name, param)
            elif self.vllm_mode == "colocate":
                llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
                llm_model.load_weights([(name, param)])

    @profiling_decorator
    def _move_model_to_vllm(self):
        # For DeepSpeed ZeRO-3 and FSDP, we need to gather all parameters before operations
        deepspeed_plugin = self.accelerator.state.deepspeed_plugin
        zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
        if zero_stage_3:
            import deepspeed

            gather_if_zero3 = deepspeed.zero.GatheredParameters
        else:
            gather_if_zero3 = nullcontext

        if is_peft_model(self.model):
            # With PEFT and FSDP/DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as
            # merging adapters in a sharded manner is not supported.
            with gather_if_zero3(list(self.model.parameters())):
                self.model.merge_adapter()

                # Update vLLM weights while parameters are gathered
                if self.is_fsdp_enabled:  # note if using FSDP, gather_if_zero3 is nullcontext
                    # Update vLLM weights while parameters are gathered
                    # For PEFT with FSDP we need to use the memory efficient post-order traversal
                    fsdp_plugin = getattr(self.accelerator.state, "fsdp_plugin", None)
                    fsdp_version = getattr(fsdp_plugin, "fsdp_version", 1) if fsdp_plugin else 1
                    if fsdp_version == 1:
                        self._sync_fsdp1_params_to_vllm(
                            self.model
                        )  # use memory-efficient post-order traversal for FSDP
                    elif fsdp_version == 2:
                        self._sync_fsdp2_params_to_vllm(self.model)
                else:
                    # DeepSpeed ZeRO-3 with PEFT
                    for name, param in self.model.named_parameters():
                        # When using PEFT, we need to recover the original parameter name and discard some parameters
                        name = name.removeprefix("base_model.model.").replace(".base_layer", "")
                        if self.model.prefix in name:
                            continue
                        # When module to save, remove its prefix and discard the original module
                        if "original_module" in name:
                            continue
                        name = self._fix_param_name_to_vllm(name, extra_prefixes=["modules_to_save.default."])

                        # Only support colocate mode for ThinkTime
                        llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
                        llm_model.load_weights([(name, param.data)])
                # Unmerge adapters while parameters are still gathered
                self.model.unmerge_adapter()
                # Parameters will automatically be repartitioned when exiting the context
        else:
            # For non-PEFT models, simply gather (if needed) and update each parameter individually.
            if self.is_fsdp_enabled:
                fsdp_plugin = getattr(self.accelerator.state, "fsdp_plugin", None)
                fsdp_version = getattr(fsdp_plugin, "fsdp_version", 1) if fsdp_plugin else 1
                if fsdp_version == 1:
                    self._sync_fsdp1_params_to_vllm(self.model)  # use memory-efficient post-order traversal for FSDP
                elif fsdp_version == 2:
                    self._sync_fsdp2_params_to_vllm(self.model)
            else:
                for name, param in self.model.named_parameters():
                    name = self._fix_param_name_to_vllm(name)
                    with gather_if_zero3([param]):
                        # Only support colocate mode for ThinkTime
                        llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
                        llm_model.load_weights([(name, param.data)])

        # Reset cache on vLLM (only colocate mode supported)
        self.llm.reset_prefix_cache()

    @profiling_decorator
    def _prepare_inputs(
        self, generation_batch: dict[str, Union[torch.Tensor, Any]]
    ) -> dict[str, Union[torch.Tensor, Any]]:

        mode = "train" if self.model.training else "eval"
        # print(f"[TSGRPOTrainer] {mode=}, step={self._step}, global_step={self.state.global_step}, {self.args.steps_per_generation=}, {self.num_iterations=}, {self.args.gradient_accumulation_steps=}")

        if mode == "train":
            generate_every = self.args.steps_per_generation * self.num_iterations
            if self._step % generate_every == 0 or self._buffered_inputs is None:
                # self._buffered_inputs=None can occur when resuming from a checkpoint
                generation_batch = self._generate_and_score_completions(generation_batch)
                generation_batch = split_pixel_values_by_grid(generation_batch)
                generation_batch = shuffle_sequence_dict(generation_batch)
                generation_batches = split_tensor_dict(generation_batch, self.args.steps_per_generation)
                self._buffered_inputs = [unsplit_pixel_values_by_grid(batch) for batch in generation_batches]
            inputs = self._buffered_inputs[self._step % self.args.steps_per_generation]
            self._step += 1
        else:
            # In evaluation, there is neither batch grouping for generation, nor multiple iterations, hence
            # local generation batch == local eval batch
            inputs = self._generate_and_score_completions(generation_batch)
        return inputs

    @profiling_decorator
    def _generate_with_vllm(self, prompts_text, images, timeseries_data, has_images, debug_context=""):
        """UNIFIED vLLM generation function with proper gather/TP handling for multi-turn consistency"""
        device = self.accelerator.device
        mode = "train" if self.model.training else "eval"
        
        # Ensure inputs are lists
        prompts_text = prompts_text if isinstance(prompts_text, list) else []
        timeseries_data = timeseries_data if isinstance(timeseries_data, list) else []
        
    # Always treat timeseries as present to avoid TP deadlocks from conditional paths
        has_timeseries = True  # Force unified timeseries handling across all TP processes
                
        # # Show first prompt and its timeseries details for main process
        # if len(prompts_text) > 0 and self.accelerator.is_main_process:
    # Debug prompt example removed

        # Only support colocate mode for ThinkTime
        if self.guided_decoding_regex:
            guided_decoding = GuidedDecodingParams(regex=self.guided_decoding_regex)
        else:
            guided_decoding = None

        generation_kwargs = {
            "n": 1,  # vLLM on each GPU generates only 1 in colocate mode
            "repetition_penalty": self.repetition_penalty,
            "temperature": self.temperature if mode == "train" else 0.1,
            "top_p": self.top_p,
            "top_k": -1 if self.top_k is None else self.top_k,
            "min_p": 0.0 if self.min_p is None else self.min_p,
            "max_tokens": 2048,
            "guided_decoding": guided_decoding,
        }
        if self.args.generation_kwargs is not None:
            generation_kwargs.update(self.args.generation_kwargs)
        sampling_params = SamplingParams(**generation_kwargs)
        
    # Handle TP distribution based on mode
        if self.vllm_tensor_parallel_size > 1:
            if mode == "train":
                # Training: Gather all inputs to all TP groups (original behavior)
                orig_size = len(prompts_text)
                
                # Gather each rank's local sizes to compute proper offsets
                sizes_per_rank = [0] * self.vllm_tensor_parallel_size
                torch.distributed.all_gather_object(sizes_per_rank, orig_size, group=self.tp_group)

                gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)]
                torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group)
                all_prompts_text = [p for sublist in gathered_prompts if isinstance(sublist, list) for p in sublist]

                gathered_timeseries = [None for _ in range(self.vllm_tensor_parallel_size)]
                torch.distributed.all_gather_object(gathered_timeseries, timeseries_data or [], group=self.tp_group)
                all_timeseries = [ts for sublist in gathered_timeseries if isinstance(sublist, list) for ts in sublist]
            else:
                # Evaluation: Split inputs across TP groups (each process has identical data)
                orig_size = len(prompts_text)
                
                # Calculate which TP group this process belongs to
                num_tp_groups = self.accelerator.num_processes // self.vllm_tensor_parallel_size
                local_rank_in_group = torch.distributed.get_rank(group=self.tp_group)
                tp_group_id = self.accelerator.process_index // self.vllm_tensor_parallel_size
                
                # Split inputs across TP groups
                if orig_size > 0:
                    inputs_per_group = orig_size // num_tp_groups
                    remaining_inputs = orig_size % num_tp_groups
                    
                    start_idx = tp_group_id * inputs_per_group + min(tp_group_id, remaining_inputs)
                    if tp_group_id < remaining_inputs:
                        end_idx = start_idx + inputs_per_group + 1
                    else:
                        end_idx = start_idx + inputs_per_group
                    
                    all_prompts_text = prompts_text[start_idx:end_idx]
                    all_timeseries = timeseries_data[start_idx:end_idx] if timeseries_data else []
                    
                    logger.info(f"[EVAL TP GROUP SPLIT] Group {tp_group_id}/{num_tp_groups}, Rank {local_rank_in_group}: processing inputs {start_idx}-{end_idx} ({len(all_prompts_text)} samples)")
                else:
                    all_prompts_text = []
                    all_timeseries = []
        else:
            all_prompts_text = prompts_text
            all_images = images if has_images else None
            all_timeseries = timeseries_data  # Always assign timeseries (empty list if None)
            orig_size = len(prompts_text)

        # Prepare vLLM inputs with multimodal data (images and timeseries)
        # and enforce max input length (prompt_len <= max_prompt_length + max_completion_length)
        safe_max_input_len = None
        if (self.max_prompt_length is not None) and (self.max_completion_length is not None):
            margin=0.8
            total_max = int(self.max_prompt_length) + int(self.max_completion_length)
            safe_max_input_len = int(total_max * margin)

        tokenizer = getattr(self.processing_class, "tokenizer", None) or self.processing_class
        prompt_token_lens = []
        skip_flags = []
        if safe_max_input_len is not None:
            for p_idx, prompt in enumerate(all_prompts_text):
                try:
                    enc = tokenizer(prompt, add_special_tokens=False, return_attention_mask=False)
                    plen = len(enc["input_ids"])
                except Exception as e:
                    plen = len(prompt)
                
                # Calculate timeseries tokens (patch_size=8)
                timeseries_tokens = 0
                if all_timeseries and p_idx < len(all_timeseries) and all_timeseries[p_idx] is not None:
                    ts_data = all_timeseries[p_idx]
                    if isinstance(ts_data, list):
                        for ts_seq in ts_data:
                            if hasattr(ts_seq, '__len__'):
                                timeseries_tokens += len(ts_seq) // 8
                    elif hasattr(ts_data, '__len__'):
                        timeseries_tokens += len(ts_data) // 8
                
                total_tokens = plen + timeseries_tokens
                prompt_token_lens.append(total_tokens)
                skip_flags.append(total_tokens > safe_max_input_len)
            
            skipped_count = sum(skip_flags)
            if skipped_count > 0:
                logger.warning(
                    f"[{debug_context}Length Check {device}] {skipped_count}/{len(all_prompts_text)} inputs exceed max_input_len={safe_max_input_len}. "
                    "They will be skipped (no new generation) and aligned with stub outputs."
                )
        else:
            skip_flags = [False] * len(all_prompts_text)

        kept_indices = [i for i, s in enumerate(skip_flags) if not s]

        # Build inputs only for kept indices
        vllm_inputs = []
        kept_to_input_pos = {}
        for pos, i in enumerate(kept_indices):
            prompt = all_prompts_text[i]
            input_data = {"prompt": prompt}
            multi_modal_data = {}

            if has_images and all_images and all_images[i] is not None:
                multi_modal_data["image"] = all_images[i]
            if all_timeseries and all_timeseries[i] is not None:
                multi_modal_data["timeseries"] = all_timeseries[i]
                num_ts_tokens = prompt.count("<ts><ts/>")
                # logger.warning(f"[debug] {num_ts_tokens=}, {len(all_timeseries[i])=}, {[len(item) for item in all_timeseries[i]]=}")

            if multi_modal_data:
                input_data["multi_modal_data"] = multi_modal_data
                vllm_inputs.append(input_data)
            else:
                vllm_inputs.append(prompt)
            kept_to_input_pos[i] = pos

        if self.accelerator.is_main_process:
            logger.success(
                f"[{debug_context}vLLM Generate {device}] ==================================================================================\n Starting generation with {len(vllm_inputs)} inputs (kept {len(kept_indices)} / total {len(all_prompts_text)})"
            )

    # Always call llm.generate to maintain TP synchronization
        # Even if vllm_inputs is empty, all processes must participate in the generation call
        with profiling_context(self, "vLLM.generate"):
            all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False)
    # Generation completion debug removed

        # Reconstruct completion_ids aligned to original order
        completion_ids = [None] * len(all_prompts_text)
        # Fill generated for kept indices
        for k_pos, i in enumerate(kept_indices):
            # Each output entry corresponds to exactly one kept input, with n=1
            out = all_outputs[k_pos]
            token_ids = out.outputs[0].token_ids if hasattr(out, "outputs") else []
            completion_ids[i] = token_ids
        # Fill stub for skipped indices (return EOS only to represent no new generation)
        for i, s in enumerate(skip_flags):
            if s:
                completion_ids[i] = [self.eos_token_id]

    # Handle result collection based on mode
        if self.vllm_tensor_parallel_size > 1:
            if mode == "train":
                # Training: Apply TP split using per-rank offsets to avoid misalignment
                local_rank_in_group = torch.distributed.get_rank(group=self.tp_group)
                # sizes_per_rank is defined only in train TP path above
                start_offset = sum(sizes_per_rank[:local_rank_in_group])
                end_offset = start_offset + orig_size
                tp_slice = slice(start_offset, end_offset)
                completion_ids = completion_ids[tp_slice]
                all_prompts_text_local = all_prompts_text[tp_slice]
                all_timeseries_local = all_timeseries[tp_slice] if has_timeseries else None
                # logger.info(
                # TP split debug removed
                # )
            else:
                # Evaluation: Simple all_gather across all processes, then extract what we need
                local_rank_in_group = torch.distributed.get_rank(group=self.tp_group)
                tp_group_id = self.accelerator.process_index // self.vllm_tensor_parallel_size
                num_tp_groups = self.accelerator.num_processes // self.vllm_tensor_parallel_size
                
                # All_gather results from all processes (including empty results from non-leaders)
                all_results = [None for _ in range(self.accelerator.num_processes)]
                torch.distributed.all_gather_object(all_results, completion_ids if local_rank_in_group == 0 else [])
                
                # Extract results from rank-0 of each TP group and flatten
                full_completion_ids = []
                for i in range(num_tp_groups):
                    leader_rank = i * self.vllm_tensor_parallel_size
                    if leader_rank < len(all_results) and all_results[leader_rank]:
                        full_completion_ids.extend(all_results[leader_rank])
                
                completion_ids = full_completion_ids[:orig_size]
                all_prompts_text_local = prompts_text
                all_timeseries_local = timeseries_data
                
                logger.info(f"[EVAL GATHER] Group {tp_group_id}, Rank {local_rank_in_group}: reconstructed {len(completion_ids)} results for {orig_size} original inputs")
        else:
            all_prompts_text_local = all_prompts_text
            all_timeseries_local = all_timeseries

        # CRITICAL: Handle empty completion_ids case
        if len(completion_ids) == 0:
            empty_tensor = torch.empty((0, 1), dtype=torch.long, device=device)
            return empty_tensor, all_prompts_text_local or [], all_timeseries_local or []

        # SAFETY: Enforce output count equals local input count; pad with EOS-only stubs if needed
        if mode == "train":
            expected_count = orig_size if self.vllm_tensor_parallel_size > 1 else len(all_prompts_text_local)
        else:
            expected_count = orig_size  # In eval mode, expect original input count
            
        if len(completion_ids) != expected_count:
            logger.error(
                f"[{debug_context}ALIGN FIX {device}] outputs={len(completion_ids)} != expected={expected_count}. Padding with stubs."
            )
            while len(completion_ids) < expected_count:
                completion_ids.append([self.eos_token_id])
            if len(completion_ids) > expected_count:
                completion_ids = completion_ids[:expected_count]

        # Convert to tensors
        completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
        completion_ids = pad(completion_ids, padding_value=self.pad_token_id)

        # Show first completion details
        if self.accelerator.is_main_process and completion_ids.size(0) > 0:
            # Print input
            for i in range(len(prompts_text[:2])):
                first_input_text = prompts_text[i]
                first_completion_tokens = completion_ids[i].cpu().tolist()
                first_completion_tokens = [token_id for token_id in first_completion_tokens if token_id != self.pad_token_id]
                first_completion_text = self.processing_class.decode(first_completion_tokens, skip_special_tokens=False)

                # Input/completion debug removed
        
        return completion_ids, all_prompts_text_local, all_timeseries_local

    @profiling_decorator
    def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list):
        device = self.accelerator.device
        rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
        rlvr_details = []

        # Repeat all input columns (but "prompt", "completion", and "completion_ids") to match the num of generations
        keys = [key for key in inputs[0] if key not in ["prompt", "completion", "completion_ids"]]
        reward_kwargs = {key: [example[key] for example in inputs] for key in keys}

        # This allows for dynamic reward shaping based on training progress.
        reward_kwargs["trainer_state"] = self.state

        for i, (reward_func, reward_processing_class, reward_func_name) in enumerate(
            zip(self.reward_funcs, self.reward_processing_classes, self.reward_func_names)
        ):
            with profiling_context(self, reward_func_name):
                if isinstance(reward_func, nn.Module):  # Module (no PretrainedModel) for compat with compiled models
                    if isinstance(prompts[0], list) and all(isinstance(msg, dict) and "role" in msg for msg in prompts[0]):
                        messages = [{"messages": p + c} for p, c in zip(prompts, completions)]

                        if self.tools is not None:
                            texts = [apply_chat_template(x, reward_processing_class, tools=self.tools)["text"] for x in messages]
                        else:
                            texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
                    else:
                        texts = [p + c for p, c in zip(prompts, completions)]
                    reward_inputs = reward_processing_class(
                        text=texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
                    )
                    reward_inputs = super()._prepare_inputs(reward_inputs)
                    with torch.inference_mode():
                        rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0]  # Shape (B*G,)
                else:
                    output_reward_func_tuple = reward_func(
                        prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs
                    )

                    # Fetch rlvr details from output_reward
                    if 'rlvr' in reward_func_name and isinstance(output_reward_func_tuple, tuple):
                        output_reward_func = output_reward_func_tuple[0]
                        logger.warning(f"[rlvr_acc] {output_reward_func=}")
                        if len(output_reward_func_tuple) == 3:
                            cur_rlvr_details = output_reward_func_tuple[2]
                            if type(cur_rlvr_details) == list:
                                rlvr_details.extend(cur_rlvr_details)
                    else:
                        output_reward_func = output_reward_func_tuple

                    # Convert None values to NaN
                    output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func]

                    rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)

        # If all reward functions return None for a given row, issue a detailed warning
        if torch.isnan(rewards_per_func).all(dim=1).any():
            nan_row_idx = torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0]
            row_reward_kwargs = {key: value[nan_row_idx] for key, value in reward_kwargs.items()}
            row_reward_kwargs["prompt"] = prompts[nan_row_idx]
            row_reward_kwargs["completion"] = completions[nan_row_idx]
            warnings.warn(
                f"All reward functions returned None for the following kwargs: {row_reward_kwargs}. "
                "Please ensure that at least one reward function returns a valid reward."
            )

        # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
        # completions may be distributed across processes
        mode = "train" if self.model.training else "eval"
        if mode == "train":
            rewards_per_func = gather(rewards_per_func)

        # Gather the rlvr details
        if rlvr_details and mode == "train":
            rlvr_details = gather_object(rlvr_details)

        return rewards_per_func, rlvr_details

    def _generate_and_score_completions(
        self, inputs: list[dict[str, Union[torch.Tensor, Any]]]
    ) -> dict[str, Union[torch.Tensor, Any]]:
        device = self.accelerator.device
        mode = "train" if self.model.training else "eval"

        prompts = [x["prompt"] for x in inputs]

        # We don't yet support visual reward models/function, so we keep a copy of the original text-only prompts for
        # later use in the reward computation. If images are present, we insert {"type": "image"} as required by the
        # VLM chat template.
        original_prompts = copy.deepcopy(prompts)

        # Extract timeseries data for ThinkTime multimodal support
        timeseries_data = []
        unfolded_timeseries_data = []
        has_timeseries = False
        for example in inputs:
            if "timeseries" in example:
                timeseries_data.append(example["timeseries"])
                unfolded_timeseries_data.extend(example["timeseries"])
                has_timeseries = True
            else:
                timeseries_data.append(None)

        # Handle both images and timeseries in multimodal content
        kwargs = {}
        has_images = "image" in inputs[0]
        if has_images:
            images = [example.get("image") for example in inputs]
            kwargs = {"images": [[img] for img in images]}
        else:
            images = None
            for prompt in prompts:
                if isinstance(prompt, list):
                    for message in prompt:
                        if not isinstance(message, dict):
                            continue
                        content = message.get("content")
                        role = message.get("role")
                        if isinstance(content, str):
                            if role == "user":
                                message["content"] = [{"type": "image"}, {"type": "text", "text": content}]
                            elif role == "system":
                                message["content"] = [{"type": "text", "text": content}]

        # For ThinkTime: extract timeseries data from conversational prompts if present
        if has_timeseries:
            for i, prompt in enumerate(prompts):
                if isinstance(prompt, list):
                    
                    for msg_idx, message in enumerate(prompt):
                        if not isinstance(message, dict):
                            continue
                        content = message.get("content")
                        role = message.get("role")
                        
                        if role == "user" and timeseries_data[i] is not None:                            
                            if isinstance(content, str):
                                # Convert string content to multimodal format with timeseries
                                message["content"] = [{"type": "text", "text": content}]
                                if isinstance(timeseries_data[i], list):
                                    for ts_idx, ts in enumerate(timeseries_data[i]):
                                        message["content"].append({"timeseries": ts})                            
                            elif isinstance(content, list):
                                if isinstance(timeseries_data[i], list):
                                    for ts_idx, ts in enumerate(timeseries_data[i]):
                                        message["content"].append({"timeseries": ts})
                              
        
        if self.tools is not None:
            prompts_text = [maybe_apply_chat_template(example, self.processing_class, tools=self.tools)["prompt"] for example in inputs]
        else:
            prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]

        # Preserve full prompts (pre-truncation) for logging, so eval logs show complete questions
        prompts_text_full = list(prompts_text)
        logger.warning(f"[_generate_and_score_completions] {mode=} {len(prompts_text)=}, {len(inputs)=}, {len(prompts_text_full)=}")

        if self.max_prompt_length is not None:
            # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
            # Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text,
            # because we can't use `skip_special_tokens=True` (some special tokens are still needed for generation).

            local_prompt_inputs = self.processing_class.tokenizer(
                text=prompts_text,
                return_tensors="pt",
                padding=True,
                padding_side="left"
            )
            local_prompt_ids, local_prompt_mask = local_prompt_inputs["input_ids"], local_prompt_inputs["attention_mask"]

            protected = [self.image_token_id, self.vision_start_token_id, self.vision_end_token_id]
            protected = [token for token in protected if token is not None]
            local_prompt_ids, local_prompt_mask = truncate_with_protected_tokens(
                local_prompt_ids, local_prompt_mask, self.max_prompt_length, protected
            )

            prompts_text = self.processing_class.batch_decode(
                local_prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
            )
            prompts_text = [re.sub(rf"^({re.escape(self.pad_token)})+", "", text) for text in prompts_text]

            # print(f"[debug _prepare_inputs] {prompts_text[0]=}, {len(prompt_inputs['timeseries'])=}")

            # The chat template inserts a single image token into the prompt text. However, when this text is later
            # tokenized, the single image token string is expanded into multiple image token IDs, depending on the
            # image size. Since we're detokenizing here, we may see repeated image tokens in the decoded text. We
            # collapse them back into a single token string to match the original template.
            if self.image_token is not None:
                prompts_text = [
                    re.sub(rf"({re.escape(self.image_token)})+", self.image_token, text) for text in prompts_text
                ]

        prompt_inputs = self.processing_class(
            text=prompts_text,
            timeseries=unfolded_timeseries_data if has_timeseries else None,
            return_tensors="pt",
            padding=True,
            # padding_side="left",
            # add_special_tokens=False,
            **kwargs,
        )
        prompt_inputs = super()._prepare_inputs(prompt_inputs)
        prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]

        # Generate completions using either vLLM or regular generation
        if self.use_vllm:
            # First, update the vLLM weights if needed
            if self.state.global_step != self._last_loaded_step:
                self._move_model_to_vllm()
                self._last_loaded_step = self.state.global_step

            completion_ids, all_prompts_text_local, all_timeseries_local = self._generate_with_vllm(
                prompts_text, images, timeseries_data, has_images, debug_context="FIRST "
            )
            
            # Track whether we constructed a custom completion mask (e.g., multi-turn tool path)
            used_custom_completion_mask = False

            if self.enable_multi_turn_tools:
                # Input: Original prompts (unchanged)
                # Output: Complete multi-turn conversation as one completion
                # Loss: Only computed on assistant parts, tool responses masked
                
                if self.guided_decoding_regex:
                    guided_decoding = GuidedDecodingParams(regex=self.guided_decoding_regex)
                else:
                    guided_decoding = None

                generation_kwargs = {
                    "n": 1,  # vLLM on each GPU generates only 1 in colocate mode
                    "repetition_penalty": self.repetition_penalty,
                    "temperature": self.temperature,
                    "top_p": self.top_p,
                    "top_k": -1 if self.top_k is None else self.top_k,
                    "min_p": 0.0 if self.min_p is None else self.min_p,
                    "max_tokens": 2048,
                    "guided_decoding": guided_decoding,
                }
                if self.args.generation_kwargs is not None:
                    generation_kwargs.update(self.args.generation_kwargs)
                sampling_params = SamplingParams(**generation_kwargs)

                # Execute multi-turn tool calling and build complete conversation
                complete_multi_turn_text, new_timeseries_data, all_multi_turn_text, all_timeseries_data = self._execute_multi_turn_tools_complete(
                    prompts, all_prompts_text_local, sampling_params, timeseries_data, unfolded_timeseries_data, 
                    has_timeseries, completion_ids
                )
                
                # Completion is the complete multi-turn conversation
                completions_text = complete_multi_turn_text
                
                # Re-tokenize the complete multi-turn conversation as completion
                completion_enc = self.processing_class(
                    text=complete_multi_turn_text,
                    timeseries=[ts for item in new_timeseries_data for ts in item] if new_timeseries_data else None,
                    return_tensors="pt", 
                    padding=True,
                    add_special_tokens=False,
                    padding_side="right"
                )
                completion_ids = completion_enc["input_ids"].to(device)

                # logger.success(f"[debug] {completion_ids=}, {completion_ids.shape=}, {len(completions_text)=}, {len(all_multi_turn_text)=}, {complete_multi_turn_text[0][:100]=}")

                processed_new_timeseries_data = completion_enc.get("timeseries", None)
                
                # Build custom completion mask: 1 for assistant tokens, 0 for tool response tokens
                # Simplified: now only needs completion_ids, will decode internally to avoid inconsistency
                completion_mask = self._build_multi_turn_completion_mask(completion_ids)
                
                # DEBUG: Show masked vs unmasked portions (main process only)
                if self.accelerator.is_main_process:
                    # Decode for debug purposes only
                    processed_completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
                    self._debug_completion_mask(completion_ids, completion_mask, processed_completions_text, max_samples=1)

                # We should also process all the timeseries with the complete text
                all_enc = self.processing_class(
                    text=all_multi_turn_text,
                    timeseries=[ts for item in all_timeseries_data for ts in item] if all_timeseries_data else None,
                    return_tensors="pt", 
                    padding=True,
                    add_special_tokens=False,
                )
                all_processed_timeseries = all_enc.get("timeseries", None)
                # We also need a folded version for later use
                folded_all_processed_timeseries = []
                cur_idx = 0
                for item in all_timeseries_data:
                    folded_all_processed_timeseries.append(
                        all_processed_timeseries[cur_idx:cur_idx + len(item)]
                    )
                    cur_idx += len(item)

                # logger.success(f"{type(folded_all_processed_timeseries)=}, {len(folded_all_processed_timeseries)=}, {type(folded_all_processed_timeseries[0])=}, {len(folded_all_processed_timeseries[0])=}")
                
                used_custom_completion_mask = True
                # print(f"[DEBUG Multi-turn] Complete conversation tokens: {completion_ids.shape}, Masked tokens: {completion_mask.sum().item()}")
            # === MULTI-TURN TOOL CALLING LOOP END (CORRECT ARCHITECTURE) ===

            # Pad the completions, and concatenate them with the prompts
            if not self.enable_multi_turn_tools:
                # Regular single-turn case: convert completion_ids to tensors
                completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
                completion_ids = pad(completion_ids, padding_value=self.pad_token_id).to(device)
            # Multi-turn case: completion_ids already processed above
            
            # Build prompt_completion_ids for model input
            prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
            
            # Handle attention mask
            if self.enable_multi_turn_tools and used_custom_completion_mask:
                # Attention mask should include all tokens (assistant + tool responses)
                # Let the model see the full conversation context, only mask loss computation
                completion_attention = (completion_ids != self.pad_token_id).to(dtype=prompt_mask.dtype)
                attention_mask = torch.cat([prompt_mask, completion_attention], dim=1)
            else:
                # Regular case: will be set later
                attention_mask = None

        elif self.use_transformers_paged:
            raise NotImplementedError("Paged generation is not supported in this trainer. Please use vLLM or regular generation paths.")
        else:
            raise NotImplementedError("Only vLLM generation is supported in this trainer.")

        # Mask everything after the first EOS token (unless we've already built a custom mask in multi-turn flow)
        if not (self.enable_multi_turn_tools and 'used_custom_completion_mask' in locals() and used_custom_completion_mask):
            is_eos = completion_ids == self.eos_token_id
            eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
            eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
            sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
            completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
            # print(f"[DEBUG Mask] Applied EOS-based masking: batch={completion_mask.size(0)}, T={completion_mask.size(1)}")
        else:
            # Keep completion_mask from tokenizer attention; synthesize an EOS at the last non-pad token for metrics
            lengths = completion_mask.sum(dim=1).to(torch.long)
            is_eos = torch.zeros_like(completion_ids, dtype=torch.bool)
            valid = lengths > 0
            row_idx = torch.arange(is_eos.size(0), device=device)[valid]
            col_idx = (lengths[valid] - 1)
            is_eos[row_idx, col_idx] = True
            # print(f"[DEBUG Mask] Using custom assistant-only mask: batch={completion_mask.size(0)}, T={completion_mask.size(1)}")

        # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need
        # to re-tokenize completions if the reward is computed from tokens.
        completion_ids_list = [
            [id.item() for id, m in zip(row, mask_row) if m] for row, mask_row in zip(completion_ids, completion_mask)
        ]

        # Sum along sequence dimension (dim=1) to get completion length per sequence, used for logging
        completion_lengths = completion_mask.sum(1)

        # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
        if self.mask_truncated_completions and not (self.enable_multi_turn_tools and 'used_custom_completion_mask' in locals() and used_custom_completion_mask):
            truncated_completions = ~is_eos.any(dim=1)
            completion_mask = completion_mask * (~truncated_completions).unsqueeze(1).int()
            print(f"[DEBUG Mask] Zeroed truncated completions: count={truncated_completions.sum().item()}")

        # Concatenate prompt_mask with completion_mask for logit computation
        # Handle multi-turn case where attention_mask might already be constructed
        if self.enable_multi_turn_tools and attention_mask is not None:
            # Use the pre-constructed attention mask from multi-turn processing
            # print(f"[DEBUG Mask] Using pre-constructed attention_mask: shape={attention_mask.shape}, {attention_mask.dtype=}")
            pass
        else:
            # Regular single-turn case
            attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)  # (B, P+C)
            # print(f"[DEBUG Mask] Built attention_mask: shape={attention_mask.shape}")

        logits_to_keep = completion_ids.size(1)  # we only need to compute the logits for the completion tokens
        batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size

        # In eval mode, preserve all logic and logging but skip heavy per-token logprob computations
        old_per_token_logps = None
        ref_per_token_logps = None
        if mode == "train":
            with torch.no_grad():
                # If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of
                # a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the
                # samples may come from an earlier version of the model. In that case, we need to track old_per_token_logps
                # for importance sampling. If the steps are aligned, importance sampling isn't necessary and we set
                # old_per_token_logps to None.
                generate_every = self.args.steps_per_generation * self.num_iterations  # generation frequency
                if self.args.gradient_accumulation_steps % generate_every != 0:
                    old_per_token_logps, _ = self._get_per_token_logps_and_entropies(
                        self.model,
                        prompt_completion_ids,
                        attention_mask,
                        logits_to_keep,
                        batch_size,
                        pixel_values=prompt_inputs.get("pixel_values"),
                        image_grid_thw=prompt_inputs.get("image_grid_thw"),
                        pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
                        image_sizes=prompt_inputs.get("image_sizes"),
                        timeseries_data=folded_all_processed_timeseries if has_timeseries else None,
                        completion_mask=completion_mask
                    )
                else:
                    old_per_token_logps = None

                # Compute the per-token log probabilities for the reference model
                if self.beta != 0.0:
                    if self.ref_model is not None:
                        ref_per_token_logps, _ = self._get_per_token_logps_and_entropies(
                            self.ref_model,
                            prompt_completion_ids,
                            attention_mask,
                            logits_to_keep,
                            batch_size=batch_size,
                            pixel_values=prompt_inputs.get("pixel_values"),
                            image_grid_thw=prompt_inputs.get("image_grid_thw"),
                            pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
                            image_sizes=prompt_inputs.get("image_sizes"),
                            timeseries_data=folded_all_processed_timeseries if has_timeseries else None,
                            completion_mask=completion_mask
                        )
                    else:
                        with self.accelerator.unwrap_model(self.model).disable_adapter():
                            ref_per_token_logps, _ = self._get_per_token_logps_and_entropies(
                                self.model,
                                prompt_completion_ids,
                                attention_mask,
                                logits_to_keep,
                                batch_size=batch_size,
                                pixel_values=prompt_inputs.get("pixel_values"),
                                image_grid_thw=prompt_inputs.get("image_grid_thw"),
                                pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
                                image_sizes=prompt_inputs.get("image_sizes"),
                                timeseries_data=folded_all_processed_timeseries if has_timeseries else None,
                                completion_mask=completion_mask
                            )

        # Decode the generated completions
        completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
        if isinstance(prompts[0], list) and all(isinstance(msg, dict) and "role" in msg for msg in prompts[0]):
            completions = []
            for prompt, completion in zip(prompts, completions_text):
                bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
                if self.args.add_first_think_token and not completion.strip().startswith("<think>"):
                    bootstrap = bootstrap + "<think>\n"
                completions.append([{"role": "assistant", "content": bootstrap + completion}])
        else:
            completions = completions_text

        # Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is
        # important because rewards will be normalized per group, and completions are distributed. We will later slice
        # rewards_per_func to extract each process's subset.
        rewards_per_func, rlvr_details = self._calculate_rewards(inputs, original_prompts, completions, completion_ids_list)

        # Apply weights to each reward function's output and sum
        rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)

        # Compute grouped-wise rewards
        group_size = self.num_generations if mode == "train" else getattr(self.args, "eval_num_generations", self.num_generations)
        mean_grouped_rewards = rewards.view(-1, group_size).mean(dim=1)
        std_grouped_rewards = rewards.view(-1, group_size).std(dim=1)
        is_std_zero = torch.isclose(std_grouped_rewards, torch.zeros_like(std_grouped_rewards))

        # Normalize the rewards to compute the advantages
        mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(group_size, dim=0)
        std_grouped_rewards = std_grouped_rewards.repeat_interleave(group_size, dim=0)
        advantages = rewards - mean_grouped_rewards
        if self.scale_rewards:
            advantages = advantages / (std_grouped_rewards + 1e-4)

        # Slice to keep only the local part of the data
        process_slice = slice(
            self.accelerator.process_index * len(prompts),
            (self.accelerator.process_index + 1) * len(prompts),
        )
        all_process_advantages = advantages.clone()  # keep the aggregated advantages for logging
        advantages = advantages[process_slice]

        # Log the metrics
        if mode == "train":
            self.state.num_input_tokens_seen += self.accelerator.gather(attention_mask.sum()).sum().item()
        self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen]

        # Log completion lengths, mean, min, max
        agg_completion_lengths = self.accelerator.gather(completion_lengths)
        self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item())
        self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item())
        self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item())

        # Identify sequences that terminated with EOS and log their lengths
        agg_terminated_with_eos = self.accelerator.gather(is_eos.any(dim=1))
        term_completion_lengths = agg_completion_lengths[agg_terminated_with_eos]
        clipped_completions_ratio = 1 - len(term_completion_lengths) / len(agg_completion_lengths)
        self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio)
        if len(term_completion_lengths) == 0:  # edge case where no terminated sequences are found
            term_completion_lengths = torch.zeros(1, device=device)
        self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item())
        self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item())
        self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item())

        # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)
        for i, reward_func_name in enumerate(self.reward_func_names):
            mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()
            self._metrics[mode][f"rewards/{reward_func_name}"].append(mean_rewards)
            # std_rewards = nanstd(rewards_per_func[:, i]).item()
            # self._metrics[mode][f"rewards/{reward_func_name}_std"].append(std_rewards)
        
        # Record the rlvr details
        if rlvr_details:
            # Get avg values
            score_by_types = defaultdict(list)

            for rlvr_detail in rlvr_details:
                if rlvr_detail is None or not isinstance(rlvr_detail, dict):
                    self._logs["rlvr"]["ability_type"].append("unknown")
                    for m in ["cate", "num", "reason", "other"]:
                        self._logs["rlvr"][m].append(None)
                    continue

                self._logs["rlvr"]["ability_type"].append(rlvr_detail.get("ability_type", "unknown"))

                # Extract items
                for m in ["cate", "num", "reason", "other"]:
                    v = rlvr_detail["scores"].get(m)
                    if v is None:
                        self._logs["rlvr"][m].append(None)
                        continue
                    cur_key = f"{rlvr_detail['ability_type']}_{m}"
                    score_by_types[cur_key].extend(v)
                    score_by_types[f"all_{m}"].extend(v)
                    self._logs["rlvr"][m].append(float(np.mean(v)))

            for key, value in score_by_types.items():
                # Calculate the mean value for the non-NaN
                value = [v for v in value if v is not None and not np.isnan(v)]

                if len(value):
                    avg_rlvr = np.mean(value)
                    self._metrics[mode][f"rlvr/{key}"].append(float(avg_rlvr))

        self._metrics[mode]["reward"].append(mean_grouped_rewards.mean().item())
        self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item())
        self._metrics[mode]["frac_reward_zero_std"].append(is_std_zero.float().mean().item())

        # Log prompt and completion texts
        # Use full prompts for logging to avoid incomplete questions in eval logs
        log_prompts = prompts_text_full
        if mode == "train":
            # In training, completions/prompts are sharded across ranks -> gather for proper logging
            self._logs["prompt"].extend(gather_object(log_prompts))
            self._logs["completion"].extend(gather_object(completions_text))
            self._logs["solution"].extend(gather_object([x.get("solution", "") for x in inputs]))
            self._logs["timeseries"].extend(gather_object(timeseries_data))
        else:
            # In eval, each process reconstructs full results; avoid duplicates by logging only on main process
            if self.accelerator.is_main_process:
                self._logs["prompt"].extend(log_prompts)
                self._logs["completion"].extend(completions_text)
                self._logs["solution"].extend([x.get("solution", "") for x in inputs])
                self._logs["timeseries"].extend(timeseries_data)
        for i, name in enumerate(self.reward_func_names):
            self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist())
        self._logs["advantages"].extend(all_process_advantages.tolist())

        if has_images:
            self._logs["image"].extend(gather_object(images))

        logger.warning(f"[_generate_and_score_completions] {mode=}, {len(self._logs['prompt'])=}, {len(set(self._logs['prompt']))=}")

        # Logging and evaluation (debug output)
        inputs_to_log = gather_object(inputs)
        prompts_to_log = gather_object(prompts_text)
        completions_to_log = gather_object(completions_text)
        rewards_to_log = rewards.tolist()
        timeseries_to_log = gather_object(timeseries_data)

        # Create rewards_per_func_dict for detailed reward logging
        rewards_per_func_dict = {}
        for i, reward_func_name in enumerate(self.reward_func_names):
            rewards_per_func_dict[reward_func_name] = rewards_per_func[:, i].tolist()

        # Gather self._printed_global_steps from all processes
        self._printed_global_steps = gather_object(sorted(self._printed_global_steps))
        self._printed_global_steps = set(self._printed_global_steps)

        if self.log_completions and self.state.global_step % self.args.logging_steps == 0:
            if self.accelerator.is_main_process:
                # Print to stdout
                if True and self.state.global_step not in self._printed_global_steps:
                    for i in range(min(len(rewards[:2]), len(prompts_to_log))):
                        if True:
                            print(f"====================Step {self.state.global_step} (Item {i})====================")
                            print(f"【Time】Local Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}, Global Step: {self.state.global_step}, Local Step: {self._step}")
                            print(f"【Prompts】{prompts_to_log[i]}")
                            print(f"【Completions】{completions_to_log[i]}")
                            
                            # Print timeseries shape if available
                            if timeseries_to_log[i] is not None:
                                if isinstance(timeseries_to_log[i], list):
                                    timeseries_shapes = [ts.shape if hasattr(ts, 'shape') else f"len={len(ts)}" for ts in timeseries_to_log[i]]
                                    print(f"【Timeseries Shapes】{timeseries_shapes}")
                                else:
                                    ts_shape = timeseries_to_log[i].shape if hasattr(timeseries_to_log[i], 'shape') else f"type={type(timeseries_to_log[i])}"
                                    print(f"【Timeseries Shape】{ts_shape}")
                            else:
                                print(f"【Timeseries Shape】None")
                            
                            print(f"【Reward】 {rewards_to_log[i]}, {[(k, v[i]) for k, v in rewards_per_func_dict.items()]}")
                            if "rlvr" in self._logs:
                                print(f"【RLVR Details】 {[(k, v[i]) for k, v in self._logs['rlvr'].items() if len(v) > i]}")
                            print(f"=====================================================================")

                        self._printed_global_steps.add(self.state.global_step)

        output = {
            "prompt_ids": prompt_ids,
            "prompt_mask": prompt_mask,
            "completion_ids": completion_ids,
            "completion_mask": completion_mask,
            "advantages": advantages,
        }
        if old_per_token_logps is not None:
            output["old_per_token_logps"] = old_per_token_logps
        if ref_per_token_logps is not None:
            output["ref_per_token_logps"] = ref_per_token_logps
        if "pixel_values" in prompt_inputs:
            output["pixel_values"] = prompt_inputs["pixel_values"]
        if "image_grid_thw" in prompt_inputs:
            output["image_grid_thw"] = prompt_inputs["image_grid_thw"]
        if "pixel_attention_mask" in prompt_inputs:
            output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"]
        if "image_sizes" in prompt_inputs:
            output["image_sizes"] = prompt_inputs["image_sizes"]
        if has_timeseries:
            output["timeseries_data"] = timeseries_data
            output["all_timeseries_data"] = all_timeseries_data
            output["folded_all_processed_timeseries"] = folded_all_processed_timeseries
            # output["processed_new_timeseries_data"] = processed_new_timeseries_data
        return output

    def compute_liger_loss(self, unwrapped_model, inputs):
        # Compute the per-token log probabilities for the model
        prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
        completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
        input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
        attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
        logits_to_keep = completion_ids.size(1)  # we only need to compute the logits for the completion tokens

        # Get the last hidden state of the model
        last_hidden_state = self._get_last_hidden_state(
            unwrapped_model,
            input_ids,
            attention_mask,
            logits_to_keep,
            inputs.get("pixel_values"),
            inputs.get("image_grid_thw"),
            inputs.get("pixel_attention_mask"),
            inputs.get("image_sizes"),
        )

        # compute loss and metrics using liger grpo loss
        loss, metrics = self.liger_grpo_loss(
            _input=last_hidden_state,
            lin_weight=unwrapped_model.lm_head.weight,
            selected_token_ids=completion_ids,
            attention_mask=completion_mask,
            advantages=inputs["advantages"],
            bias=unwrapped_model.lm_head.bias,
            old_per_token_logps=inputs.get("old_per_token_logps"),
            ref_per_token_logps=inputs.get("ref_per_token_logps"),
        )
        # Extract metrics from the liger_grpo_loss output
        # KL divergence is the first metric when beta is non-zero
        mean_kl = metrics[0] if self.beta != 0.0 else None
        clip_ratio = metrics[-1]

        mode = "train" if self.model.training else "eval"
        if self.beta != 0.0:
            self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).mean().item())
        self._metrics[mode]["clip_ratio"].append(self.accelerator.gather(clip_ratio).mean().item())
        return loss

    @profiling_decorator
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        # In eval mode, skip loss computation entirely but keep generation, rewards, and logging elsewhere
        if not self.model.training:
            # Ensure we return a tensor on the correct device
            dev = None
            try:
                dev = inputs.get("completion_ids", None).device
            except Exception:
                dev = self.accelerator.device
            return torch.tensor(0.0, device=dev, requires_grad=False)
        if return_outputs:
            raise ValueError("The GRPOTrainer does not support returning outputs")
        if self.use_liger_loss:
            # Compute the loss using the liger grpo loss
            unwrapped_model = self.accelerator.unwrap_model(model)
            return self._forward_redirection(model, unwrapped_model, self.compute_liger_loss, unwrapped_model, inputs)
        else:
            return self._compute_loss(model, inputs)

    def _compute_loss(self, model, inputs):
        # Compute the per-token log probabilities for the model
        prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
        completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
        input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
        
        completion_attention = (completion_ids != self.pad_token_id).to(dtype=prompt_mask.dtype)
        attention_mask = torch.cat([prompt_mask, completion_attention], dim=1)
        logits_to_keep = completion_ids.size(1)  # we only need to compute the logits for the completion tokens

        # Compute the per_token_logps and the entropy at each position in the completion
        per_token_logps, entropies = self._get_per_token_logps_and_entropies(
            model,
            input_ids,
            attention_mask,
            logits_to_keep,
            compute_entropy=True,
            pixel_values=inputs.get("pixel_values"),
            image_grid_thw=inputs.get("image_grid_thw"),
            pixel_attention_mask=inputs.get("pixel_attention_mask"),
            image_sizes=inputs.get("image_sizes"),
            timeseries_data=inputs.get("folded_all_processed_timeseries"),
            completion_mask=completion_mask,
        )

        if self.top_entropy_quantile < 1.0:
            entropy_mask = get_high_entropy_mask(entropies, completion_mask, 1 - self.top_entropy_quantile)
        else:
            entropy_mask = None

        # Compute the KL divergence between the model and the reference model
        if self.beta != 0.0:
            ref_per_token_logps = inputs["ref_per_token_logps"]
            per_token_kl = (
                torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
            )

        # Compute the loss
        advantages = inputs["advantages"]
        # When using num_iterations == 1 and steps_per_generation <= gradient_accumulation_steps
        # old_per_token_logps == per_token_logps, so we can skip it's computation
        # (see _generate_and_score_completions) and use per_token_logps.detach() instead.
        old_per_token_logps = inputs.get("old_per_token_logps")
        old_per_token_logps = per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps

        log_ratio = per_token_logps - old_per_token_logps
        if self.importance_sampling_level == "token":
            log_importance_weights = log_ratio
        elif self.importance_sampling_level == "sequence":
            log_importance_weights = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)
            log_importance_weights = log_importance_weights.unsqueeze(-1)
        else:
            raise ValueError(
                f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' "
                "and 'sequence'."
            )
        # From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on
        # importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1)

        coef_1 = torch.exp(log_importance_weights)
        coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)

        # Two-sided clipping
        if self.args.delta is not None:
            coef_1 = torch.clamp(coef_1, max=self.args.delta)

        per_token_loss1 = coef_1 * advantages.unsqueeze(1)
        per_token_loss2 = coef_2 * advantages.unsqueeze(1)
        per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
        if entropy_mask is not None:
            per_token_loss = per_token_loss * entropy_mask
        if self.beta != 0.0:
            per_token_loss = per_token_loss + self.beta * per_token_kl

        if self.loss_type == "grpo":
            loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean()
        elif self.loss_type == "bnpo":
            loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0)
        elif self.loss_type == "dr_grpo":
            loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length)
        else:
            raise ValueError(f"Unknown loss type: {self.loss_type}")

        # if self.accelerator.is_main_process:
        try:
            threshold = getattr(self.args, "loss_debug_threshold", 10.0)
            B, T = completion_ids.shape
            for i in range(B):
                valid = completion_mask[i].bool()
                if valid.sum().item() == 0:
                    continue
                ids_i = completion_ids[i][valid].detach().cpu().tolist()
                losses_i = per_token_loss[i][valid].detach().float().cpu()
                abnormal = losses_i > threshold
                if not bool(abnormal.any()):
                    continue
                parts = []
                for tok_id, is_abn, lval in zip(ids_i, abnormal.tolist(), losses_i.tolist()):
                    tok_txt = self.processing_class.decode([int(tok_id)], skip_special_tokens=False)
                    if is_abn:
                        parts.append(f"\033[91m{tok_txt}\033[0m({lval:.2f})")
                    else:
                        parts.append(tok_txt)
                line = "".join(parts)
                print(f"[LOSS-DEBUG] step={self.state.global_step} sample={i} thr={threshold:.2f}")
                print(line)
        except Exception:
            pass

        # Log the metrics
        mode = "train" if self.model.training else "eval"

        completion_token_count = completion_mask.sum().clamp(min=1.0)

        def masked_batch_mean(x):
            if x.shape[1] == 1:  # when importance_sampling_level == "sequence"
                return x.mean()
            else:
                return (x * completion_mask).sum() / completion_token_count

        if self.beta != 0.0:
            mean_kl = masked_batch_mean(per_token_kl)
            self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).nanmean().item())

        mean_entropy = masked_batch_mean(entropies)
        self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item())

        # Compute the clipped probability ratios
        is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0)
        is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0)
        is_region_clipped = is_low_clipped | is_high_clipped

        low_clip = masked_batch_mean(is_low_clipped.float())
        high_clip = masked_batch_mean(is_high_clipped.float())
        clip_ratio = masked_batch_mean(is_region_clipped.float())

        gathered_low_clip = self.accelerator.gather(low_clip)
        self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item())
        self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item())
        gathered_high_clip = self.accelerator.gather(high_clip)
        self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item())
        self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item())
        gathered_clip_ratio = self.accelerator.gather(clip_ratio)
        self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item())
        return loss

    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None):
        inputs = self._prepare_inputs(inputs)
        with torch.no_grad():
            with self.compute_loss_context_manager():
                loss = self.compute_loss(model, inputs)
            loss = loss.mean().detach()
        return loss, None, None

    def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
        mode = "train" if self.model.training else "eval"
        metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()}  # average the metrics

        # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
        # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
        if mode == "eval":
            metrics = {f"eval_{key}": val for key, val in metrics.items()}

        logs = {**logs, **metrics}
        super().log(logs, start_time)
        self._metrics[mode].clear()

        if self.accelerator.is_main_process and self.log_completions:
            if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None:
                import pandas as pd

                table = {
                    "step": [str(self.state.global_step)] * len(self._logs["prompt"]),
                    "prompt": self._logs["prompt"],
                    "solution": self._logs["solution"],
                    "completion": self._logs["completion"],
                    **self._logs["rewards"],
                    "advantage": self._logs["advantages"],
                    **self._logs["rlvr"]
                }

                if self._logs["image"]:
                    table["image"] = []
                    for img in self._logs["image"]:
                        if img is not None:
                            # Convert images to wandb Image objects for proper visualization
                            table["image"].append(wandb.Image(img))
                        else:
                            table["image"].append(None)

                for k, v in table.items():
                    logger.warning(f"[log] {k}: {len(v)} items")
                    if len(v) != len(self._logs["prompt"]):
                        logger.error(f"[log] Mismatched lengths for key '{k}': expected {len(self._logs['prompt'])}, got {len(v)}: {v}")
                df = pd.DataFrame(table)
                if self.wandb_log_unique_prompts:
                    df = df.drop_duplicates(subset=["prompt"])
                wandb.log({"completions": wandb.Table(dataframe=df)})

            logger.warning(f"[log] {mode=}, {len(self._logs['prompt'])=}, {len(set(self._logs['prompt']))=}")

            # We also create a local file to inspect
            local_json_dir = os.path.join(self.args.output_dir, f"completions/{mode}")
            os.makedirs(local_json_dir, exist_ok=True)
            local_file_path = os.path.join(local_json_dir, f"{self.state.global_step}.jsonl")
            with open(local_file_path, "wt") as f:
                for idx in range(len(self._logs["prompt"])):
                    # Build the entry with the same structure as wandb logging
                    entry = {
                        "step": str(self.state.global_step),
                        "mode": mode,
                        "prompt": self._logs["prompt"][idx],
                        "completion": self._logs["completion"][idx],
                        "solution": self._logs["solution"][idx],
                    }
                    
                    # Add rewards data
                    for reward_name, reward_values in self._logs["rewards"].items():
                        if idx < len(reward_values):
                            entry[f"reward_{reward_name}"] = reward_values[idx]
                    
                    # Add advantage
                    if idx < len(self._logs["advantages"]):
                        entry["advantage"] = self._logs["advantages"][idx]
                    
                    # Add RLVR details if available
                    for rlvr_key, rlvr_values in self._logs["rlvr"].items():
                        if idx < len(rlvr_values):
                            entry[f"rlvr_{rlvr_key}"] = rlvr_values[idx]
                    
                    f.write(json.dumps(entry, ensure_ascii=False) + "\n")

        # Removed duplicate file-write block to prevent duplicated entries in completions/{mode}
        for k, v in self._logs.items():
            v.clear()

    # Ensure the model card is saved along with the checkpoint
    def _save_checkpoint(self, model, trial):
        if self.args.hub_model_id is None:
            model_name = Path(self.args.output_dir).name
        else:
            model_name = self.args.hub_model_id.split("/")[-1]
        self.create_model_card(model_name=model_name)
        super()._save_checkpoint(model, trial)

    def create_model_card(
        self,
        model_name: Optional[str] = None,
        dataset_name: Optional[str] = None,
        tags: Union[str, list[str], None] = None,
    ):
        """
        Creates a draft of a model card using the information available to the `Trainer`.

        Args:
            model_name (`str` or `None`, *optional*, defaults to `None`):
                Name of the model.
            dataset_name (`str` or `None`, *optional*, defaults to `None`):
                Name of the dataset used for training.
            tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
                Tags to be associated with the model card.
        """
        if not self.is_world_process_zero():
            return

        if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
            base_model = self.model.config._name_or_path
        else:
            base_model = None

        # normalize `tags` to a mutable set
        if tags is None:
            tags = set()
        elif isinstance(tags, str):
            tags = {tags}
        else:
            tags = set(tags)

        if hasattr(self.model.config, "unsloth_version"):
            tags.add("unsloth")

        tags.update(self._tag_names)

        citation = textwrap.dedent(
            """\
            @article{zhihong2024deepseekmath,
                title        = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
                author       = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
                year         = 2024,
                eprint       = {arXiv:2402.03300},
            }
            """
        )

        model_card = generate_model_card(
            base_model=base_model,
            model_name=model_name,
            hub_model_id=self.hub_model_id,
            dataset_name=dataset_name,
            tags=tags,
            wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
            comet_url=get_comet_experiment_url(),
            trainer_name="GRPO",
            trainer_citation=citation,
            paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models",
            paper_id="2402.03300",
        )

        model_card.save(os.path.join(self.args.output_dir, "README.md"))
        
    def _execute_multi_turn_tools_complete(
        self,
        prompts: List[Any],
        initial_prompts_text: List[str], 
        sampling_params,
        timeseries_data: List[Any],
        unfolded_timeseries_data: List[Any],
        has_timeseries: bool,
        first_completion_ids: List[List[int]],
    ) -> Tuple[List[str], List[List[Tuple[int, int]]], List[Any]]:
        """
    Build complete multi-turn conversation as single completion.
        
        Returns:
            complete_conversations: List[str] - Complete multi-turn text for each sample
            tool_response_masks: List[List[Tuple[int, int]]] - (start, end) positions of tool responses to mask
            updated_timeseries_data: List[Any] - Updated timeseries data with tool results
        """
        B = len(prompts)
        # logger.info(f"� [MULTI-TURN COMPLETE] Processing {B} samples")
        
        # Initialize conversations and track assistant generations
        conversations: List[List[Dict[str, Any]]] = []
        
        # Track accumulated timeseries per sample
        ts_accum: List[List[Any]] = [[] for _ in range(B)]
        
        # Extract initial timeseries and build conversations
        for i in range(B):
            # Build initial conversation from prompt
            if isinstance(prompts[i], list) and len(prompts[i]) > 0 and isinstance(prompts[i][0], dict) and "role" in prompts[i][0]:
                conv = copy.deepcopy(prompts[i])
            else:
                # Single text prompt, wrap as user message
                conv = [{"role": "user", "content": prompts[i]}]
            conversations.append(conv)
            
            # Extract timeseries from conversation and accumulate
            for msg in conv:
                content = msg.get("content", [])
                if isinstance(content, list):
                    for item in content:
                        if isinstance(item, dict) and "timeseries" in item:
                            ts_accum[i].append(item["timeseries"])
            
            # Add initial timeseries if not already in conversation
            if has_timeseries and timeseries_data and timeseries_data[i] is not None:
                if len(ts_accum[i]) == 0:  # No timeseries in conversation yet
                    if isinstance(timeseries_data[i], list):
                        ts_accum[i].extend(timeseries_data[i])
                    else:
                        ts_accum[i].append(timeseries_data[i])
        
        # Decode first completions
        first_completion_texts = {}
        for idx, ids in enumerate(first_completion_ids):
            if isinstance(ids, torch.Tensor):
                ids = ids.cpu().tolist()
            # Remove padding tokens
            clean_ids = [token_id for token_id in ids if token_id != self.pad_token_id]
            text = self.processing_class.decode(clean_ids, skip_special_tokens=True)
            first_completion_texts[idx] = text
        
        # Add first assistant response to conversations
        for i, text in first_completion_texts.items():
            if self.args.add_first_think_token and not text.strip().startswith("<think>\n"):
                text = "<think>\n" + text
            conversations[i].append({"role": "assistant", "content": text})
        
        # Multi-turn tool calling loop
        for turn in range(self.max_tool_calls):
            # Turn start debug removed
            
            # Parse tool calls from current assistant responses
            tool_results = []
            for i, text in first_completion_texts.items() if turn == 0 else current_assistant_texts.items():
                tool_spec = self.parse_tool(text)
                if tool_spec:
                    # Set conversation context for tool execution
                    self._set_current_conversation_context(conversations[i])
                    result = self.call_tool(tool_spec, text) or {"text": "", "timeseries": None}
                    
                    # Accumulate tool timeseries
                    if result.get("timeseries"):
                        ts_accum[i].extend(result["timeseries"])
                    
                    tool_results.append({"idx": i, "result": result, "tool_spec": tool_spec})
            
            # Synchronize across all TP processes before deciding to break
            # Each process reports how many tool calls it found
            local_tool_count = len(tool_results)
            if self.accelerator.num_processes > 1:
                # Gather tool counts from all processes in TP group
                gathered_tool_counts = [0] * self.accelerator.num_processes
                torch.distributed.all_gather_object(gathered_tool_counts, local_tool_count)
                total_tool_count = sum(x if isinstance(x, int) else 0 for x in gathered_tool_counts)
            else:
                total_tool_count = local_tool_count
            
            # Only break if ALL processes have no tool calls
            if total_tool_count == 0:
                break
            
            # Add tool responses and generate next assistant responses
            next_assistant_texts = []
            batch_inputs = []
            batch_map = []
            
            for tr in tool_results:
                idx = tr["idx"]
                result = tr["result"]
                
                # Add tool response as user message
                tool_text = f"<tool_response>\n{result.get('text', '')}\n</tool_response>"
                tool_content = [{"type": "text", "text": tool_text}]
                for ts in (result.get("timeseries") or []):
                    tool_content.append({"timeseries": ts})
                
                conversations[idx].append({"role": "user", "content": tool_content})
                
                # Build prompt for next generation
                conv_wrapper = {"prompt": conversations[idx]}
                if self.tools is not None:
                    prompt_text = apply_chat_template(conv_wrapper, self.processing_class, tools=self.tools)["prompt"]
                else:
                    prompt_text = apply_chat_template(conv_wrapper, self.processing_class)["prompt"]
                
                if prompt_text.endswith("<think>\n"):
                    prompt_text = prompt_text[:-len("<think>\n")]

                sample = {"prompt": prompt_text}
                if ts_accum[idx]:
                    sample["multi_modal_data"] = {"timeseries": ts_accum[idx]}
                
                batch_inputs.append(sample)
                batch_map.append(idx)
            
            local_batch_size = len(batch_inputs)
            if self.accelerator.num_processes > 1:
                # Check if any process has inputs to avoid hanging
                gathered_batch_sizes = [0] * self.accelerator.num_processes
                torch.distributed.all_gather_object(gathered_batch_sizes, local_batch_size)
                total_batch_size = sum(x if isinstance(x, int) else 0 for x in gathered_batch_sizes)
            else:
                total_batch_size = local_batch_size
            
            # Skip vLLM generation if no process has any inputs
            if total_batch_size == 0:
                logger.warning(f"[TURN {turn + 1}: {cur_process_id}] No batch inputs across all processes, skipping generation")
                break
            
            # Generate next assistant responses
            batch_prompts_text = [item["prompt"] for item in batch_inputs] if batch_inputs else []
            batch_timeseries = [item.get("multi_modal_data", {}).get("timeseries") for item in batch_inputs] if batch_inputs else []
            
            debug_suffix = "EMPTY " if local_batch_size == 0 else ""
            cur_process_id = self.accelerator.process_index if self.accelerator.num_processes > 1 else 0
            
            next_completion_ids, _, _ = self._generate_with_vllm(
                batch_prompts_text, 
                images=None,
                timeseries_data=batch_timeseries,
                has_images=False,
                debug_context=f"TURN {turn + 1} {debug_suffix}"
            )
            
            # Decode and add assistant responses
            current_assistant_texts = {}
            actual_batch_size = next_completion_ids.size(0)

            # In TP mode, vLLM applies TP slicing to outputs, but batch_map still refers to original indices
            # We need to verify that len(batch_map) == actual_batch_size to prevent index errors
            if len(batch_map) != actual_batch_size:
                logger.error(f"[TURN {turn + 1}] ALIGNMENT ERROR: batch_map length {len(batch_map)} != actual_batch_size {actual_batch_size}")
                logger.error(f"   batch_map={batch_map}, local_batch_size={local_batch_size}")
                # Truncate batch_map to match actual outputs to prevent crashes
                batch_map = batch_map[:actual_batch_size]
            
            # Only process outputs if we have them
            if actual_batch_size > 0:
                for k in range(actual_batch_size):
                    if k >= len(batch_map):
                        logger.error(f"[TURN {turn + 1}] Index out of range: k={k}, batch_map length={len(batch_map)}")
                        break
                    
                    idx = batch_map[k]
                    new_ids = next_completion_ids[k].cpu().tolist()
                    new_ids = [token_id for token_id in new_ids if token_id != self.pad_token_id]
                    new_text = self.processing_class.decode(new_ids, skip_special_tokens=True)
                    current_assistant_texts[idx] = new_text

                    # Print the debug items
                    conv_wrapper = {"prompt": conversations[idx]}
                    if self.tools is not None:
                        prompt_text = apply_chat_template(conv_wrapper, self.processing_class, tools=self.tools)["prompt"]
                    else:
                        prompt_text = apply_chat_template(conv_wrapper, self.processing_class)["prompt"]

                    # Add to conversation history
                    conversations[idx].append({"role": "assistant", "content": new_text})
        
        # Build complete conversation texts
        complete_texts: List[str] = []
        all_texts: List[str] = []


        for i in range(B):
            # Extract only assistant parts for the final completion text
            conv_wrapper = {"prompt": conversations[i]}
            input_conv = {"prompt": conversations[i][:2]}

            if self.tools is not None:
                prompt_text = apply_chat_template(conv_wrapper, self.processing_class, tools=self.tools)["prompt"]
                input_prompt_text = apply_chat_template(input_conv, self.processing_class, tools=self.tools)["prompt"]
            else:
                prompt_text = apply_chat_template(conv_wrapper, self.processing_class)["prompt"]
                input_prompt_text = apply_chat_template(input_conv, self.processing_class)["prompt"]

            if not prompt_text.startswith(input_prompt_text):
                raise ValueError(
                    f"Prompt text does not start with input prompt text: {prompt_text}... vs {input_prompt_text}... This should never happen!"
                )

            complete_texts.append(prompt_text[len(input_prompt_text):])
            all_texts.append(prompt_text)
        
        return complete_texts, [item[len(timeseries_data[i]) if timeseries_data[i] is not None else 0:] for i, item in enumerate(ts_accum)], all_texts, ts_accum

    def _build_multi_turn_completion_mask(
        self, 
        completion_ids: torch.Tensor
    ) -> torch.Tensor:
        """
        Build completion mask that excludes tool response tokens from loss computation.
        
        Uses tokenizer offsets to precisely identify tool response sections:
        - Identifies tokens belonging to tool response patterns  
        - Masks them out (set to 0) while keeping assistant tokens (set to 1)
        - Handles RIGHT PADDING correctly
        
        Args:
            completion_ids: Tokenized completion text [B, T] with RIGHT PADDING
            
        Returns:
            completion_mask: [B, T] with 1 for assistant tokens, 0 for tool responses/padding
        """
        import re
        
        B, T = completion_ids.shape
        device = completion_ids.device

        # Start with basic attention mask (1 for non-padding, 0 for padding)  
        completion_mask = (completion_ids != self.pad_token_id).float()

        # Tool response pattern to identify
        tool_response_pattern = re.compile(
            r'<\|im_start\|>user\s*<tool_response>.*?</tool_response>\s*<\|im_end\|>\s*<\|im_start\|>assistant\s*\n',
            re.DOTALL
        )

        for i in range(B):
            sample_tokens = completion_ids[i].tolist()
            
            # RIGHT PADDING: Find the last non-pad token index
            last_non_pad_idx = -1
            for j in range(T-1, -1, -1):
                if sample_tokens[j] != self.pad_token_id:
                    last_non_pad_idx = j
                    break
            
            if last_non_pad_idx == -1:
                # All tokens are padding
                continue
                
            # Extract non-padding tokens (from start to last valid token)
            non_pad_tokens = sample_tokens[:last_non_pad_idx + 1]
            if not non_pad_tokens:
                continue

            # Get text and offsets using tokenizer
            decoded_text = self.processing_class.tokenizer.decode(non_pad_tokens, skip_special_tokens=False)
            encoding = self.processing_class.tokenizer(
                text=[decoded_text],
                add_special_tokens=False,
                return_offsets_mapping=True,
                padding=False,
                return_tensors=None
            )
            
            # Extract offsets and verify consistency
            enc_ids = encoding["input_ids"][0] if isinstance(encoding["input_ids"], list) else encoding["input_ids"]
            offsets = encoding["offset_mapping"][0] if isinstance(encoding["offset_mapping"], list) else encoding["offset_mapping"]
            
            if hasattr(enc_ids, 'tolist'):
                enc_ids = enc_ids.tolist()
            if hasattr(offsets, 'tolist'):  
                offsets = offsets.tolist()
            
            # Skip if re-encoding doesn't match (fallback to basic mask)
            if len(enc_ids) != len(non_pad_tokens) or any(a != b for a, b in zip(enc_ids, non_pad_tokens)):
                logger.error(f"[MASKING] Sample {i} Token mismatch after re-encoding, skipping tool masking. Encoded IDs: {enc_ids}, Original IDs: {non_pad_tokens}")
                continue
            
            # Find tool response spans and mask overlapping tokens
            matched_cnt = 0
            for match in tool_response_pattern.finditer(decoded_text):
                span_start, span_end = match.span()
                
                # Find tokens that overlap with this span
                for token_idx, (char_start, char_end) in enumerate(offsets):
                    if char_start is not None and char_end is not None:
                        # Token overlaps with tool response span
                        if not (char_end <= span_start or char_start >= span_end):
                            mask_idx = token_idx
                            if mask_idx < T:
                                completion_mask[i, mask_idx] = 0.0
                matched_cnt += 1

        return completion_mask

    def _debug_decode_logits(self, logits_tensor, prefix="logits", completion_mask=None):
        """
        DEBUG: Decode logits tensor to show predicted tokens, with color for masked parts.
        """
        # Get predicted token IDs (argmax along vocab dimension)
        predicted_ids = logits_tensor.argmax(dim=-1)  # (seq_len,)
        predicted_ids_cpu = predicted_ids.cpu().numpy()

        if completion_mask is None:
            # No mask, just decode and print
            decoded_text = self.processing_class.decode(predicted_ids_cpu)
            logger.success(f"[DEBUG_DECODE] {prefix}: '{decoded_text}'")
        else:
            # Apply completion_mask for color-coded output
            colored_text = ""
            for token_id, is_masked in zip(predicted_ids_cpu, completion_mask.cpu().numpy()):
                token_text = self.processing_class.decode([token_id])
                if is_masked == 1:
                    # Not masked (kept), standard color
                    colored_text += token_text
                else:
                    # Masked (ignored), apply color (e.g., red)
                    colored_text += f"\033[91m{token_text}\033[0m"
            logger.success(f"---- [DEBUG_DECODE] {'='*80}")
            print(f"---- [DEBUG_DECODE] {prefix}: '{colored_text}'")

    def _debug_completion_mask(
        self, 
        completion_ids: torch.Tensor, 
        completion_mask: torch.Tensor, 
        completion_texts: List[str],
        max_samples: int = 10
    ):
        """
        DEBUG: Show masked vs unmasked portions of completion text for debugging.
        
        Args:
            completion_ids: Tokenized completion text [B, T]
            completion_mask: Completion mask [B, T] with 1 for kept tokens, 0 for masked
            completion_texts: Original completion text for reference
            max_samples: Maximum number of samples to debug (default 1)
        """
        B, T = completion_ids.shape
        num_samples = min(max_samples, B)
        
        print(f"\n{'='*80}")
        print(f"[COMPLETION MASK DEBUG] Showing {num_samples}/{B} samples")
        print(f"{'='*80}")
        
        for i in range(num_samples):
            print(f"\n--- Sample {i} ---")
            
            # Get tokens and mask for this sample
            sample_tokens = completion_ids[i].cpu().tolist()
            sample_mask = completion_mask[i].cpu().tolist()
            
            # Remove padding tokens
            valid_length = sum(1 for token in sample_tokens if token != self.pad_token_id)
            sample_tokens = sample_tokens[:valid_length]
            sample_mask = sample_mask[:valid_length]
            
            # Split into masked and unmasked tokens
            kept_tokens = []
            masked_tokens = []
            
            for j, (token, mask_val) in enumerate(zip(sample_tokens, sample_mask)):
                if mask_val > 0.5:  # Kept (unmasked)
                    kept_tokens.append(token)
                else:  # Masked
                    masked_tokens.append(token)
            
            # Decode the portions
            try:
                if kept_tokens:
                    kept_text = self.processing_class.decode(kept_tokens, skip_special_tokens=False)
                else:
                    kept_text = "[NO KEPT TOKENS]"
                    
                if masked_tokens:
                    masked_text = self.processing_class.decode(masked_tokens, skip_special_tokens=False)
                else:
                    masked_text = "[NO MASKED TOKENS]"
                    
                full_text = self.processing_class.decode(sample_tokens, skip_special_tokens=False)
                
            except Exception as e:
                logger.error(f"[DECODE ERROR] Sample {i}: {e}")
                continue
            
            # Statistics
            total_tokens = len(sample_tokens)
            kept_count = len(kept_tokens)
            masked_count = len(masked_tokens)
            mask_ratio = masked_count / total_tokens if total_tokens > 0 else 0.0
            
            logger.success(f"Stats: Total={total_tokens}, Kept={kept_count}, Masked={masked_count}, Ratio={mask_ratio:.1%}")
            
            # Show original text (truncated)
            original_text = completion_texts[i] if i < len(completion_texts) else "[NO ORIGINAL TEXT]"
            logger.success(f"Original ({len(original_text)} chars): {original_text[:200]}{'...' if len(original_text) > 200 else ''}")
            
            # Show full decoded text (truncated)
            logger.success(f"Full Decoded ({len(full_text)} chars): {full_text[:200]}{'...' if len(full_text) > 200 else ''}")
            
            # Show kept portions
            logger.success(f"KEPT (for loss): {kept_text[:300]}{'...' if len(kept_text) > 300 else ''}")
            
            # Show masked portions  
            logger.success(f"MASKED (no loss): {masked_text[:300]}{'...' if len(masked_text) > 300 else ''}")
        
    logger.success(f"{'='*80}\n")

    # Multi-turn tool calling user-implemented hooks based on vllm_tool_using.py logic
    def parse_tool(self, assistant_text: str) -> Optional[dict]:  # Based on vllm_tool_using.py
        """
        Parse tool calls from assistant text based on vllm_tool_using.py logic.
        
        Looks for <tool_call>{json}</tool_call> patterns and parses the latest complete one.
        """
        try:
            # Clean the output by removing <|im_end|> if present
            cleaned_output = assistant_text.replace('<|im_end|>', '').strip()
            
            # Check if the last token is </tool_call>
            if not cleaned_output.endswith('</tool_call>'):
                return None
            
            # Find the latest (last) complete tool call block
            tool_call_pattern = r'<tool_call>(.*?)</tool_call>'
            tool_call_matches = re.findall(tool_call_pattern, cleaned_output, re.MULTILINE | re.DOTALL)
            
            if not tool_call_matches:
                logger.error(f"[DEBUG Parse Tool] No tool call blocks found. This is very strange: {cleaned_output}")
                return None  # No valid tool call blocks found
            
            
            # Get the latest (last) tool call and parse JSON
            latest_tool_call_content = json.loads(tool_call_matches[-1].strip())
            
            # Return parsed tool spec
            result = {
                "name": latest_tool_call_content.get('name'),
                "arguments": latest_tool_call_content.get('arguments', {})
            }
            return result
            
        except Exception as e:
            print(f"[DEBUG Parse Tool] Parsing failed: {e}")
            return None  # Parsing failed

    def call_tool(self, tool_spec: dict, assistant_text: str = None) -> dict:
        """
        Execute tool based on vllm_tool_using.py logic.
        
        Currently supports get_timeseries_slice and compare_timeseries_slice tools.
        """
        function_name = tool_spec.get("name")
        args = tool_spec.get("arguments", {})
                
        try:
            args["assistant_text"] = assistant_text

            if function_name == "get_timeseries_slice":
                result = self._execute_get_timeseries_slice(**args)
                
                return result
                
            elif function_name == "compare_timeseries_slice":
                result = self._execute_compare_timeseries_slice(**args)
               
                return result
            else:
                error_msg = f"Error: Unknown tool function '{function_name}'"
                logger.error(f"[TOOL ERROR] Unknown function: {function_name}")
                return {
                    "text": error_msg,
                    "timeseries": None
                }
                
        except Exception as e:
            error_msg = f"Error executing tool: {str(e)}"
            logger.error(f"[TOOL EXCEPTION] {function_name}: {str(e)}")
            return {
                "text": error_msg,
                "timeseries": None
            }
        finally:
            pass
    
    def _execute_get_timeseries_slice(self, metric_name: str, start: int, end: int, assistant_text: str) -> dict:
        """
        Execute get_timeseries_slice tool based on vllm_tool_using.py logic.
        
        Searches through conversation history to find timeseries data for the specified metric.
        """
        
        try:
            target_timeseries = None
            
            # Access current conversation context if available
            conversation_messages = getattr(self, '_current_conversation_messages', [])
            
            # Print detailed context structure for debugging
            for ctx_idx, ctx_msg in enumerate(conversation_messages):
                role = ctx_msg.get("role", "unknown")
                content = ctx_msg.get("content", [])
                
                if isinstance(content, list):
                    for item_idx, item in enumerate(content):
                        if isinstance(item, dict):
                            if item.get("type") == "text":
                                text_preview = item.get("text", "")[:100] + "..." if len(item.get("text", "")) > 100 else item.get("text", "")
                            elif "timeseries" in item:
                                ts_len = len(item["timeseries"]) if isinstance(item["timeseries"], list) else "unknown"
                elif isinstance(content, str):
                    text_preview = content[:100] + "..." if len(content) > 100 else content
            
            # Search through conversation messages for the target metric
            first_user_message = None
            first_timeseries = None
            for msg_idx, message in enumerate(conversation_messages):
                if message.get("role") == "user":
                    content = message.get("content", [])
                    text_content = ""
                    ts_data = []
                    
                    # Extract text and timeseries from content - handle both dict and string formats
                    if isinstance(content, str):
                        # Simple string content - convert to proper format
                        text_content = content
                    elif isinstance(content, list):
                        # Multimodal content list
                        for item in content:
                            if isinstance(item, dict):
                                if item.get("type") == "text":
                                    text_content = item.get("text", "")
                                elif "timeseries" in item:
                                    ts_data.append(item["timeseries"])
                            elif isinstance(item, str):
                                text_content += item

                    if first_user_message is None:
                        first_user_message = text_content + ''

                    if first_timeseries is None and ts_data:
                        first_timeseries = ts_data[0]

                    # Search for the metric name in the text (case insensitive, partial match)
                    if text_content and metric_name.lower() in text_content.lower():
                        # Find which timeseries index corresponds to this metric
                        # Split by <ts><ts/> markers as in vllm_tool_using.py
                        lines = text_content.split('<ts><ts/>')
                        
                        for i, line in enumerate(lines):
                            if metric_name.lower() in line.lower():
                                # Simple heuristic - assume order in text matches order in timeseries data
                                ts_index = i
                                if ts_index < len(ts_data):
                                    target_timeseries = ts_data[ts_index]
                                    break
                                else:
                                    logger.warning(f"[TIMESERIES INDEX] Index {ts_index} out of range (have {len(ts_data)} timeseries)")
                        if target_timeseries is not None:
                            break
            
            # If no timeseries found in conversation, return error
            if target_timeseries is None:
                # Get hash_str for the conversation
                assistant_text_hash = self._hash_str(assistant_text) if assistant_text else "unknown"
                first_user_message_hash = self._hash_str(first_user_message) if first_user_message else "unknown"

                error_msg = f"Error: Metric '{metric_name}' not found in conversation context"
                logger.error(f"[TIMESERIES NOT FOUND] ({assistant_text_hash=}, {first_user_message_hash=}) {error_msg}, fall back to the first timeseries, context: {first_user_message}")

                if first_timeseries is not None:
                    target_timeseries = first_timeseries
                    logger.warning(f"[FALLBACK] Using first available timeseries for metric '{metric_name}'")
                    metric_name = f"the first available metric (the provided {metric_name} was not found, please check the provided metric name)"
                else:
                    return {
                        "text": error_msg,
                        "timeseries": None
                    }
            
            # Validate and correct bounds
            ts_length = len(target_timeseries)
            original_start, original_end = start, end
            start = max(0, min(start, ts_length - 1))
            end = max(start + 1, min(end, ts_length))
            
            if original_start != start or original_end != end:
                logger.warning(f"[BOUNDS ADJUSTED] From [{original_start}:{original_end}] to [{start}:{end}] for length {ts_length}")
            
            # Extract the slice
            slice_data = target_timeseries[start:end]
            
            result_text = f"The slice of {metric_name} from {start} to {end} is: <ts><ts/>."
            return {
                "text": result_text,
                "timeseries": [slice_data]  # Return as list for consistency with multimodal format
            }
            
        except Exception as e:
            error_msg = f"Error processing timeseries slice for {metric_name}: {str(e)}"
            logger.error(f"[TIMESERIES EXCEPTION] {error_msg}")
            return {
                "text": error_msg,
                "timeseries": None
            }

    def _execute_compare_timeseries_slice(
        self, 
        metric_name_1: str, 
        start_1: int, 
        end_1: int,
        metric_name_2: str, 
        start_2: int, 
        end_2: int,
        assistant_text: str
    ) -> dict:
        """
        Execute compare_timeseries_slice tool based on vllm_tool_using.py logic.
        
        Compares two timeseries slices from the conversation context.
        """
        
        try:
            # Access current conversation context
            conversation_messages = getattr(self, '_current_conversation_messages', [])
            
            # Helper function to find timeseries by metric name
            def find_timeseries_by_metric(metric_name: str):
                for msg_idx, message in enumerate(conversation_messages):
                    if message.get("role") == "user":
                        content = message.get("content", [])
                        text_content = ""
                        ts_data = []
                        
                        # Extract text and timeseries from content
                        if isinstance(content, str):
                            text_content = content
                        elif isinstance(content, list):
                            for item in content:
                                if isinstance(item, dict):
                                    if item.get("type") == "text":
                                        text_content = item.get("text", "")
                                    elif "timeseries" in item:
                                        ts_data.append(item["timeseries"])
                                elif isinstance(item, str):
                                    text_content += item

                        # Search for the metric name in the text (case insensitive, partial match)
                        if text_content and metric_name.lower() in text_content.lower():
                            # Find which timeseries index corresponds to this metric
                            lines = text_content.split('<ts><ts/>')
                            
                            for i, line in enumerate(lines):
                                if metric_name.lower() in line.lower():
                                    ts_index = i
                                    if ts_index < len(ts_data):
                                        return ts_data[ts_index]
                                    else:
                                        logger.warning(f"[COMPARE TIMESERIES] Index {ts_index} out of range for {metric_name}")
                return None

            # Get first timeseries as fallback
            first_timeseries = None
            for message in conversation_messages:
                if message.get("role") == "user":
                    content = message.get("content", [])
                    if isinstance(content, list):
                        for item in content:
                            if isinstance(item, dict) and "timeseries" in item:
                                first_timeseries = item["timeseries"]
                                break
                    if first_timeseries is not None:
                        break

            # Find both timeseries
            target_timeseries_1 = find_timeseries_by_metric(metric_name_1)
            target_timeseries_2 = find_timeseries_by_metric(metric_name_2)
            
            # Handle missing timeseries with fallback logic
            if target_timeseries_1 is None:
                logger.warning(f"[COMPARE FALLBACK] Metric '{metric_name_1}' not found, using first available timeseries")
                target_timeseries_1 = first_timeseries
                metric_name_1 = f"the first available metric (requested {metric_name_1} was not found)"
                
            if target_timeseries_2 is None:
                logger.warning(f"[COMPARE FALLBACK] Metric '{metric_name_2}' not found, using first available timeseries")
                target_timeseries_2 = first_timeseries
                metric_name_2 = f"the first available metric (requested {metric_name_2} was not found)"
            
            if target_timeseries_1 is None or target_timeseries_2 is None:
                error_msg = f"Error: Could not find timeseries data for comparison"
                logger.error(f"[COMPARE TIMESERIES] {error_msg}")
                return {
                    "text": error_msg,
                    "timeseries": None
                }
            
            # Validate and correct bounds for first timeseries
            ts_length_1 = len(target_timeseries_1)
            original_start_1, original_end_1 = start_1, end_1
            start_1 = max(0, min(start_1, ts_length_1 - 1))
            end_1 = max(start_1 + 1, min(end_1, ts_length_1))
            
            # Validate and correct bounds for second timeseries  
            ts_length_2 = len(target_timeseries_2)
            original_start_2, original_end_2 = start_2, end_2
            start_2 = max(0, min(start_2, ts_length_2 - 1))
            end_2 = max(start_2 + 1, min(end_2, ts_length_2))
            
            if (original_start_1 != start_1 or original_end_1 != end_1 or 
                original_start_2 != start_2 or original_end_2 != end_2):
                logger.warning(f"[COMPARE BOUNDS] Adjusted bounds for comparison")
            
            # Extract the slices
            slice_data_1 = target_timeseries_1[start_1:end_1]
            slice_data_2 = target_timeseries_2[start_2:end_2]
            
            result_text = (f"Comparison between {metric_name_1}[{start_1}:{end_1}] and "
                          f"{metric_name_2}[{start_2}:{end_2}]: <ts><ts/> vs <ts><ts/>")
            
            return {
                "text": result_text,
                "timeseries": [slice_data_1, slice_data_2]  # Return both slices for comparison
            }
            
        except Exception as e:
            error_msg = f"Error comparing timeseries slices: {str(e)}"
            logger.error(f"[COMPARE TIMESERIES EXCEPTION] {error_msg}")
            return {
                "text": error_msg,
                "timeseries": None
            }

    def _set_current_conversation_context(self, conversation_messages: List[Dict]):
        """Set the current conversation context for tool functions to reference."""
        self._current_conversation_messages = conversation_messages
        
        for ctx_idx, ctx_msg in enumerate(conversation_messages):
            role = ctx_msg.get("role", "unknown")
            content = ctx_msg.get("content", [])
            if isinstance(content, list):
                ts_count = sum(1 for item in content if isinstance(item, dict) and "timeseries" in item)
                text_items = [item for item in content if isinstance(item, dict) and item.get("type") == "text"]
                text_content = " ".join(item.get("text", "") for item in text_items)[:200]
            elif isinstance(content, str):
                text_preview = content[:200] + "..." if len(content) > 200 else content
