# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023-2024 SGLang Team
# Copyright 2025 ModelBest Inc. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import logging
import os
import re
from collections import defaultdict
from typing import List, Optional, Union
from datasets.features import Features, Sequence, Image, Value
from PIL import Image as PImage
import datasets
import numpy as np
import torch
from omegaconf import DictConfig, ListConfig
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer, ProcessorMixin

import verl.utils.torch_functional as verl_F
from verl.utils.model import compute_position_id_with_mask
import io 
force_nothink = False

logger = logging.getLogger(__name__)
class ShuffledAlternatingDataset(Dataset):
    def __init__(self, dataset_a, dataset_b, seed=None):
        """
        初始化带随机打乱功能的交替数据集
        
        参数:
        - dataset_a: 第一个源数据集
        - dataset_b: 第二个源数据集
        - seed: 随机种子，用于结果复现
        """
        self.dataset_a = dataset_a
        self.dataset_b = dataset_b
        # 确保两个数据集长度相同
        assert len(dataset_a) == len(dataset_b), "两个数据集长度必须相同"
        
        # 设置随机种子
        if seed is not None:
            np.random.seed(seed)
        
        # 生成打乱后的索引
        self.shuffled_indices = np.random.permutation(len(dataset_a)).tolist()
        self.length = len(dataset_a) * 2  # 合并后的总长度
    
    def __getitem__(self, idx):
        """根据索引获取样本，偶数索引取自打乱后的dataset_a，奇数索引取自打乱后的dataset_b"""
        original_idx = self.shuffled_indices[idx // 2]
        print(f"!!!! {idx}")
        if idx % 2 == 0:
            # 偶数索引，从dataset_a获取打乱后的样本
            return self.dataset_a[original_idx]
        else:
            # 奇数索引，从dataset_b获取打乱后的样本
            return self.dataset_b[original_idx]
    
    def __len__(self):
        """返回合并后数据集的总长度"""
        return self.length
    

def collate_fn(data_list: list[dict]) -> dict:
    """Collate a batch of data."""
    tensors = defaultdict(list)
    non_tensors = defaultdict(list)

    for data in data_list:
        for key, val in data.items():
            if isinstance(val, torch.Tensor):
                tensors[key].append(val)
            else:
                non_tensors[key].append(val)

    for key, val in tensors.items():
        tensors[key] = torch.stack(val, dim=0)

    for key, val in non_tensors.items():
        non_tensors[key] = np.array(val, dtype=object)

    return {**tensors, **non_tensors}

templates = {
    "zero":"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think> <answer> answer (with \\boxed) here </answer>. User: {prompt}\nAssistant: Let me solve this step by step.\n<think>\n",
    "zero_detail_":"A conversation between User and Assistant. The user asks a question, and the assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process is enclosed within <think> </think> tags and <answer> </answer> tags, i.e., <think> reasoning process here </think> <answer> answer (with \\boxed) here </answer>. Normally, this thinking process first analyzes what the problem is asking and what knowledge it attempts to examine, and then elaborate on the detailed reasoning steps to solve the problem. The assistant could regularly self-question and self-verify existing reasoning to ensure the correctness.  User: {prompt}\nAssistant: Let me solve this step by step with detailed reasoning and put the final answer within \\boxed.\n<think>\n",
    "zero_detail": "You will be given a question. You need to first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process is enclosed within <think> </think> tags and <answer> </answer> tags, i.e., <think> reasoning process here </think> <answer> answer (with \\boxed) here </answer>. Normally, this thinking process first analyzes what the problem is asking and what knowledge it attempts to examine, and then elaborate on the detailed reasoning steps to solve the problem. You should regularly revisit and double-check the problem and its prior reasoning.\nQuestion: {prompt}",
    # Write your thought process according to the following guideline:
    # (Getting the Lay of the Land): First, I'll get a thorough overview of the document itself. I need to scan the title, abstract, introduction, and section headings. This gives me a mental map of the document's structure and its core topic. What's the main argument or purpose here? This initial scan prevents me from getting lost in the details later on.

    # (Defining My Mission): Next, I'll break down the user's question into its fundamental parts. I need to identify the key concepts, entities, and the specific type of information being sought. Is it a "what," "why," "how," or a "compare and contrast" question? Hmm, I should also anticipate synonyms and related terms for my keywords, as the document might not use the exact same phrasing. This step is crucial because it defines the scope and direction of my search.

    # (The Deep Dive): With a clear mission, I'll start methodically gathering evidence from all parts of the document. This isn't a linear read-through; it's a targeted investigation. I will keep in mind the keywords and the key question I aim to address, and then:

    # - Inspecting the Text: I'll start by searching the text for my keywords. But wait, I can't just pull out the sentence with the keyword. That's a rookie mistake. I must read the surrounding sentences and the entire paragraph to understand the full context. The answer might be nuanced, qualified, or pieced together from different sections. I need to be alert for transition words like "however," "in contrast," or "therefore," as they often signal important relationships or conclusions.

    # - Analyzing Tables: When I encounter a table, I need to pause and analyze it carefully. 
    # First, what does the title or caption tell me? 
    # Then, I'll examine the row and column headers to understand what data is being presented. Let me think... I absolutely have to check for units (e.g., millions, %, kg) and any footnotes, as these often contain critical information that could completely change the interpretation of the numbers. 
    # Only then will I locate the specific cells relevant to the question.

    # - Interpreting Figures: For figures like graphs, charts, or diagrams, I must resist the urge to jump to a conclusion based on a quick glance. 
    # I need to methodically check the title, the legend (what do the colors or symbols mean?), and, most importantly, the axes. 
    # What is being measured on the x-axis and y-axis? 
    # What are the scales and units? 
    # Understanding this framework is essential before I can accurately interpret the trends, patterns, or relationships shown by the data points, lines, or bars.

    # Synthesizing and Cross-Referencing (Connecting the Dots): The answer is rarely in just one place. The real skill is in synthesis. I need to actively seek connections between the text, tables, and figures. For example, the text might explain an anomaly seen in a graph, a table might provide the raw data that a pie chart visualizes, or a figure might illustrate a concept described in the text. I'll be like a detective, piecing together these different forms of evidence to construct a single, coherent answer.

    # Reflection and Verification (The Final Sanity Check): Before I formulate the final response, I must perform a critical self-review. Let me re-read the user's question one last time. Does my synthesized answer directly and completely address it? Have I overlooked any conflicting data or counterarguments presented elsewhere in the document? Let me quickly double-check the numbers, units, and axis labels I relied on. This final verification step is my quality control; it ensures my reasoning is sound and the answer is accurate and well-supported by the provided document.
    
    "doc_detail": """You will be given document pages and a question. You need to first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process is enclosed within <think> </think> tags and <answer> </answer> tags, i.e., <think> reasoning process here </think> <answer> answer (with \\boxed) here </answer>. 
    Normally, this thinking process first get a thorough overview of the document pages (get the lay of the land). 
    Then the thinking process analyzes what the problem is asking and collect evidences from the document.
    After that the thinking process might synthesize and cross-reference the information to reason about the final answer.
    You should regularly revisit and double-check the problem and its prior reasoning.\nQuestion: {prompt}
    """,
    "zero_template":"A conversation between User and Assistant. The user asks a question, and the assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process is enclosed within <think> </think> tags, i.e., <think> reasoning process here </think>. Normally, this thinking process first analyzes what the problem is asking and what knowledge it attempts to examine, and then elaborate on the detailed reasoning steps to solve the problem. The assistant could regularly self-question and self-verify existing reasoning to ensure the correctness.  User: {prompt}",
    "explore": "A conversation between User and Assistant. The user asks a question, and the Assistant explores the solution of it. The assistant first thinks creatively about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think> <answer> answer here </answer>. User: {prompt}.\nI think you are a brave and curious explorer. I already have the standard solution in mind. Think bravely and creatively and explore any possible solutions and alternatives you can think of.\nAssistant: Let me think bravely and creatively.\n<think>\n",
    "user":"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think> <answer> answer (with \\boxed) here </answer>. Question:\n{prompt}",
    "nothink_assistant": "<think>\nI have finished thinking.\n</think>\n<answer>/n{}</answer>",
    "assistant": "{}",
    "default": "Question:\n{prompt}\nLet's think step by step and put your final answer within \\boxed",
    "template_exploit":"",
    "template_explore":"\nI believe you are a brave and curious explorer. I already have the standard solution in mind for the given problem. So you need to think bravely and creatively and explore any possible solutions and alternatives you can think of."
}

    
def bytes_to_pillow_image(image_bytes):
    """
    将图像字节数据转换为 Pillow 的 Image 实例
    
    参数:
        image_bytes: 图像字节数据
    
    返回:
        PIL.Image.Image: Pillow 图像实例
    """
    # 创建字节流对象
    image_stream = io.BytesIO(image_bytes)
    
    # 打开字节流并返回 Image 实例
    return PImage.open(image_stream)

def resize_image_if_small(img, threshold=28):
    """
    Checks if an image is smaller than a threshold on either side and resizes it
    while maintaining the aspect ratio if it is.

    Args:
        image_path (str): The path to the image file.
        threshold (int): The minimum size for either dimension.

    Returns:
        PIL.Image.Image: The resized image object, or the original if no resize was needed.
    """
    try:
        if not isinstance(img, PImage.Image):
            if isinstance(img, str):
                img = PImage.open(img)
            elif isinstance(img, dict):
                if img.get("path", None) is not None:
                    img = PImage.open(img['path'])
                elif img.get("image", None) is not None:
                    # Handle the case where image path is stored in 'image' key
                    image_path = img['image']
                    # Remove 'file://' prefix if present (backward compatibility)
                    if image_path.startswith('file://'):
                        image_path = image_path[7:]
                    img = PImage.open(image_path)
                elif img.get("bytes", None) is not None:
                    img = bytes_to_pillow_image(img['bytes'])
                else:
                    raise Exception(f"cannot handle input image meta: {img}")
            else:
                raise Exception(f"cannot handle input image meta: {img}")
        width, height = img.size

        if width < threshold or height < threshold:
            print(f"Image dimensions ({width}x{height}) are below the threshold of {threshold}x{threshold}.")

            # Calculate the new dimensions while preserving the aspect ratio
            if width < height:
                new_width = threshold
                new_height = int(height * (threshold / width))
            else:
                new_height = threshold
                new_width = int(width * (threshold / height))

            print(f"Resizing image to {new_width}x{new_height}...")
            # The thumbnail method resizes the image in-place
            resized_image = img.resize((new_width, new_height))
            return resized_image
        else:
            # print("Image dimensions are within the acceptable range.")
            return img 

    except FileNotFoundError:
        print("Error: The specified image file was not found.")
        return None
    
class RLHFDataset(Dataset):
    """
    We assume the dataset contains a column that contains prompts and other information
    """

    def __init__(
        self,
        data_files: Union[str, List[str]],
        tokenizer: PreTrainedTokenizer,
        config: DictConfig,
        processor: Optional[ProcessorMixin] = None,
        explore=False,
        # max_pixels: Optional[int] = None,
        # min_pixels: Optional[int] = None,
        do_sft=False,
        sample_rate: Optional[float] = 1.0,
        **kwargs
    ):
        if not isinstance(data_files, (List, ListConfig)):
            data_files = [data_files]

        self.data_files = copy.deepcopy(data_files)
        self.original_data_files = copy.deepcopy(data_files)  # use for resume
        self.tokenizer = tokenizer
        self.processor = processor
        self.config = config
        self.sft_template = config.get("sft_template", "default")
        self.zero = config.get("zero", True)
        self.doc = config.get("doc", False)
        self.llama_instruct_zero = config.get("llama_ins_zero", False)
        self.do_sft = do_sft # config.get("do_sft", False)
        self.do_eval = kwargs.get("do_eval", False)
        self.explore = explore
        self.min_pixels = config.get("min_pixels", 128*128)
        self.max_pixels = config.get("max_pixels", 1024*1024)
        
        if self.processor is not None:
            print(f"VL image pixels: min={self.min_pixels}, max={self.max_pixels}")
        
        self.sample_rate = sample_rate
        # self.zero_temp = templates['explore'] if explore else templates['zero'] 

        self.cache_dir = os.path.expanduser(config.get("cache_dir", "~/.cache/verl/rlhf"))
        self.prompt_key = config.get("prompt_key", "prompt")
        self.solution_key = config.get("solution_key", "solution")
        self.image_key = config.get("image_key", "images")
        self.video_key = config.get("video_key", "videos")
        self.max_prompt_length = config.get("max_prompt_length", 1024)
        self.max_response_length = config.get("max_response_length", 4096)
        self.return_raw_chat = config.get("return_raw_chat", True)
        self.return_full_prompt = config.get("return_full_prompt", False)
        self.truncation = config.get("truncation", "error")
        self.filter_overlong_prompts = config.get("filter_overlong_prompts", True)
        self.num_workers = config.get("filter_overlong_prompts_workers", max(1, os.cpu_count() // 4))
        self.num_workers = min(self.num_workers, os.cpu_count())
        self.chat_template_func = config.get("chat_template_func", None)
        self.need_tools_kwargs = config.get("need_tools_kwargs", False)
        self.filter_prompts = config.get("filter_prompts", True)
        self.serialize_dataset = False
        self._download()
        self._read_files_and_tokenize()
        print(f'dataset {self.data_files} do_sft={do_sft}, do_eval={self.do_eval}')

    def _download(self, use_origin_parquet=False):
        from verl.utils.fs import copy_to_local

        data_files = self.data_files if not use_origin_parquet else self.original_data_files
        for i, parquet_file in enumerate(data_files):
            self.data_files[i] = copy_to_local(src=parquet_file, cache_dir=self.cache_dir)

    def _read_files_and_tokenize(self):
        dataframes = []
        for parquet_file in self.data_files:
            # read parquet files and cache
            print('--->', parquet_file)
            dataframe = datasets.load_dataset("parquet", data_files=parquet_file)["train"]
            if self.processor is not None:
                print('before', len(dataframe))
                dataframe = dataframe.filter(lambda example: len(example['images']) <= 16)
                print('after image filtering', len(dataframe))

            print(f"sample rate: {self.sample_rate}, dataframe len: {len(dataframe)}")
            if self.sample_rate < 1.0:
                dataframe = dataframe.shuffle(seed=42).select(range(int(len(dataframe) * self.sample_rate)))
            dataframes.append(dataframe)

        self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes)

        print(f"dataset len: {len(self.dataframe)}")
        # filter out too long prompts
        if self.do_sft and not self.do_eval and not self.config.get("no_overlong_sft", False):
            lengths = []
            for conv in self.dataframe[self.solution_key]:
                tt = len(self.tokenizer.encode(conv))
                lengths.append(tt)
                if tt>self.max_response_length:
                    print(f"ntoken={tt}, exceed {self.max_response_length} ? {tt>self.max_response_length}")
            # lengths = [(len(self.tokenizer.encode(conv))) for conv in self.dataframe[self.solution_key]]
            selector = [xx <= self.max_response_length for xx in lengths]
            filtered_dataset = self.dataframe.filter(
                lambda example, idx: selector[idx],  # Check the boolean at the current index
                with_indices=True  # Required to access the index
            )
            self.dataframe = filtered_dataset
            print(f"{sum(selector)} selected")
            print('--- example output ----')
            print(self.dataframe[self.solution_key][0])
            print('------')
            print(f"after filtering: {len(self.dataframe)}")

        if False: # self.filter_overlong_prompts: # TODO: check for sft
            tokenizer = self.tokenizer
            prompt_key = self.prompt_key
            # apply zero template
            if self.processor is not None:
                
                self.dataframe = self.dataframe.filter(
                    lambda doc: self.get_vl_token_length(doc) <= self.max_prompt_length-100,
                    num_proc=self.num_workers,
                    desc=f"Filtering prompts longer than {self.max_prompt_length} tokens",
                )
            else:
                if self.zero:
                    templ = templates['explore'] if self.explore else templates['zero']
                    self.dataframe = self.dataframe.filter(
                        lambda doc: len(tokenizer.encode(templ.format(prompt=doc[prompt_key][-1]['content']))) <= self.max_prompt_length,
                        num_proc=self.num_workers,
                        desc=f"Filtering prompts longer than {self.max_prompt_length} tokens",
                    )
                # elif self.do_sft:
                #     templ = templates['user']
                #     prev = len(self.dataframe)
                #     self.dataframe = self.dataframe.filter(
                #         lambda doc: (len(tokenizer.encode(templ.format(prompt=doc[prompt_key][-1]['content']))) <= self.max_prompt_length) and (len(tokenizer.encode(templates['assistant'].format(prompt=doc[self.solution_key]))) <= self.max_response_length),
                #         num_proc=self.num_workers,
                #         desc=f"Filtering prompts longer than {self.max_prompt_length}+{self.max_response_length} tokens",
                #     )
                #     print(f"===> [data] length filtering: {prev} QAs => {len(self.dataframe)} QAs")
                else:
                    print("applying chat template")
                    temp = self.dataframe[prompt_key].apply(lambda doc: len(tokenizer.encode(templ.format(prompt=doc[prompt_key][-1]['content']))) <= self.max_prompt_length)
                    tokenized_lengths = temp.apply(lambda x: len(tokenizer.apply_chat_template(x, add_generation_prompt=True)))
                    selector = tokenized_lengths <= self.max_prompt_length
                    self.dataframe = self.dataframe[selector]

            print(f"filter dataset len: {len(self.dataframe)}")

        self.shuffled_indices = np.random.permutation(len(self.dataframe)).tolist()
        self.length = len(self.dataframe)
        
        if self.explore:
            cnt = len(self.dataframe)
            def add_exploration_info(example, idx):
                # 复制 extra_info 字典并更新内容
                new_extra_info = example['extra_info'].copy()
                new_extra_info.update({
                    'index': 'explore'+new_extra_info['index'],
                })
                example['extra_info'] = new_extra_info
                conv = example[self.prompt_key]
                example[self.prompt_key] = [dict(role='user', content=conv[-1]['content']+templates['template_explore'])]

                return example

            # 使用 map 方法应用转换并保留索引
            self.explore_dataframe = self.dataframe.map(add_exploration_info, with_indices=True)
            self.length = 2*len(self.explore_dataframe)
            print(f"===> adding exploration prompts {cnt}->{len(self.dataframe)}")
        

    def resume_dataset_state(self):
        self.serialize_dataset = not hasattr(self, "original_data_files")
        # resume dataframe if not it's serialized in data.pt
        if not self.serialize_dataset:
            self._download(use_origin_parquet=True)  # download and resume from original parquet files
            self._read_files_and_tokenize()
        else:
            print(r"old dataloader ckpt file is used, please train from scratch for better ckpt performance")

    def __len__(self):
        return self.length

    def _build_messages(self, example: dict, template_key: str):
        messages: list = example.pop(self.prompt_key)

        if self.image_key in example or self.video_key in example:
            for message in messages:
                content = message["content"]
                content_list = []
                # for mmmu, there might be several pictures in question
                pattern = r"(<image [1-9]>|<image>|<video>)" # <image_[1-9]> is not needed for now
                for segment in re.split(pattern, content):
                    if segment == "<image>" or segment.startswith("<image "):
                        content_list.append({"type": "image"})
                    elif segment == "<video>":
                        content_list.append({"type": "video"})
                    else:
                        if segment != '':
                            segment = templates[template_key].format(prompt=segment) if (template_key is not None and template_key in templates) else segment
                            content_list.append({"type": "text", "text": segment})

                message["content"] = content_list

        return messages

    def get_vl_token_length(self, row_dict):
        from verl.utils.dataset.vision_utils import process_image, process_video
        
        # Fix image placeholder count issue before processing
        if self.image_key in row_dict:
            user_content = row_dict[self.prompt_key][-1]["content"]
            existing_image_count = len(re.findall(r'<image(?:\s+[1-9])?>', user_content))
            actual_image_count = len(row_dict.get(self.image_key, []))
            
            if actual_image_count > 0 and existing_image_count < actual_image_count:
                missing_placeholders = actual_image_count - existing_image_count
                image_placeholders = "<image>\n" * missing_placeholders
                user_content = image_placeholders + user_content
                row_dict[self.prompt_key][-1]["content"] = user_content
            elif actual_image_count > 0 and existing_image_count == 0:
                user_content = "<image>\n" + user_content
                row_dict[self.prompt_key][-1]["content"] = user_content
        
        template_key = None
        template_key = 'doc_detail'
        assist_start = "<think>\n"
        image_pattern = r"<image(?: [1-9])?>"
        messages = self._build_messages(row_dict, template_key='doc_detail') # for text, use no template
        raw_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) + assist_start
        # print("image tokens", raw_prompt.count("image_pad"))
        # if self.zero:
        #     # we don' t need messages in zero template
        #     template_key = 'doc_detail'
        #     if self.explore:
        #         template_key = 'explore'
        #     image_pattern = r"<image(?: [1-9])?>"
        #     prompt = templates[template_key].format(prompt=row_dict[self.prompt_key][-1]["content"])
        #     raw_prompt = re.sub(image_pattern, "<|vision_start|><|image_pad|><|vision_end|>", prompt)
        #     messages = self._build_messages(row_dict, template_key=template_key) # we actually don't use this messages
        # else:
        #     template_key = "zero_template"
        #     assist_start = "Let me solve this step by step with detailed reasoning.\n<think>\n"
        #     messages = self._build_messages(row_dict, template_key=template_key)
        #     raw_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) + assist_start

        multi_modal_data = {}

        images = None
        if self.image_key in row_dict:
            loaded = [resize_image_if_small(image) for image in row_dict.pop(self.image_key)]
            images = [process_image(image, self.min_pixels, self.max_pixels) for image in loaded]
            for iii in images:
                if iii.height<28 or iii.width<28:
                    breakpoint()
            multi_modal_data["image"] = images

        videos = None
        if self.video_key in row_dict:
            videos = [process_video(video) for video in row_dict.pop(self.video_key)]
            multi_modal_data["video"] = [video.numpy() for video in videos]

        model_inputs = self.processor(text=[raw_prompt], images=images, videos=videos, return_tensors="pt")

        input_ids = model_inputs.pop("input_ids")
        return len(input_ids)

    def __getitem__(self, item):
        """
        Note that we also return the raw_input_ids so that it can be combined with other chat template
        """
        if self.explore:
            real_idx = self.shuffled_indices[item//2]
            target_df = self.dataframe if item%2==0 else self.explore_dataframe
        else:
            real_idx = self.shuffled_indices[item]
            target_df = self.dataframe
        
        row_dict: dict = target_df[real_idx]
        # messages = self._build_messages(row_dict) # this will work only in VL cases, so i will put inside the branch
        model_inputs = {}
        # ------- VL part -------
        if self.processor is not None:
            from verl.utils.dataset.vision_utils import process_image, process_video

            template_key = None
            user = row_dict[self.prompt_key][-1]['content']
            
            # Count existing image placeholders
            existing_image_count = len(re.findall(r'<image(?:\s+[1-9])?>', user))
            
            # Get the actual number of images
            actual_image_count = len(row_dict.get(self.image_key, []))
            
            # If there are images but not enough placeholders, add them
            if actual_image_count > 0 and existing_image_count < actual_image_count:
                # Calculate how many more placeholders we need
                missing_placeholders = actual_image_count - existing_image_count
                
                # Add the missing image placeholders at the beginning
                image_placeholders = "<image>\n" * missing_placeholders
                user = image_placeholders + user
                row_dict[self.prompt_key][-1]['content'] = user
            elif actual_image_count > 0 and existing_image_count == 0:
                # Original logic: add one placeholder if none exist
                user = "<image>\n" + user 
                row_dict[self.prompt_key][-1]['content'] = user

            if self.sft_template=='doc_vl' or self.doc:
                # no think in the input 
                assist_start = "<think>\n"
                image_pattern = r"<image(?: [1-9])?>"
                messages = self._build_messages(row_dict, template_key='doc_detail') # for text, use no template
                raw_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) + assist_start

                
            elif self.zero:
                # we don' t need messages in zero template
                template_key = 'zero_detail'
                if self.explore:
                    template_key = 'explore'
                image_pattern = r"<image(?: [1-9])?>"
                prompt = templates[template_key].format(prompt=row_dict[self.prompt_key][-1]["content"])
                raw_prompt = re.sub(image_pattern, "<|vision_start|><|image_pad|><|vision_end|>", prompt)
                messages = self._build_messages(row_dict, template_key=template_key) # we actually don't use this messages
                # TODO: is the above incorrect? it does not get the raw_prompt using messages
            else:
                # template_key = "zero_template"
                template_key = "default"
                print('using default template')
                assist_start = ""
                # assist_start = "Let me solve this step by step with detailed reasoning.\n<think>\n"
                messages = self._build_messages(row_dict, template_key=template_key)
                raw_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) + assist_start
            multi_modal_data = {}

            images = None
            if self.image_key in row_dict:
                # print(row_dict[self.image_key])
                loaded = [resize_image_if_small(image) for image in row_dict.pop(self.image_key)]
                # print(loaded)
                images = [process_image(image, self.min_pixels, self.max_pixels) for image in loaded]
                multi_modal_data["image"] = images
                print(f"image size:", [x.size for x in images])

            videos = None
            if self.video_key in row_dict:
                videos = [process_video(video) for video in row_dict.pop(self.video_key)]
                multi_modal_data["video"] = [video.numpy() for video in videos]

            model_inputs = self.processor(text=[raw_prompt], images=images, videos=videos, return_tensors="pt")

            input_ids = model_inputs.pop("input_ids")
            print(f"tokens {input_ids.shape}")
            attention_mask = model_inputs.pop("attention_mask")

            if "second_per_grid_ts" in model_inputs:
                model_inputs.pop("second_per_grid_ts")

            # There's a trap here, multi_modal_inputs has to be a dict, not BatchFeature
            row_dict["multi_modal_data"] = multi_modal_data
            row_dict["multi_modal_inputs"] = dict(model_inputs)

            # second_per_grid_ts isn't used for training, just for mrope
            row_dict["multi_modal_inputs"].pop("second_per_grid_ts", None)

        else:
            messages = row_dict.pop(self.prompt_key)
            if self.llama_instruct_zero:
                # apply zero template
                ver = 'zero_detail'
                if 'explore' in row_dict['extra_info']:
                    ver = 'explore'
                templ = templates[ver]
                raw_prompt = templ.format(prompt=messages[-1]["content"])
                raw_prompt = raw_prompt.replace(r"Please reason step by step and put your final answer within \boxed{}","").replace(r"Let's think step by step and output the final answer within \boxed{}.", "")
                messages[-1]['content'] = raw_prompt
                raw_prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) + "Response:\n<think>\n"
            
            elif self.zero:
                # apply zero template
                ver = 'zero_detail'
                if 'explore' in row_dict['extra_info']:
                    ver = 'explore'
                templ = templates[ver]
                raw_prompt = templ.format(prompt=messages[-1]["content"])
                raw_prompt = raw_prompt.replace(r"Please reason step by step and put your final answer within \boxed{}","").replace(r"Let's think step by step and output the final answer within \boxed{}.", "")
            
            elif self.sft_template=="longwriter":
                raw_prompt = f'[INST]{messages[-1]["content"]}[/INST]'
            elif self.sft_template=="qwen":
                raw_prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
            elif self.sft_template=='code':
                # messages[-1]["content"] = f'Hint: You will be given a coding problem. You need to first think about the reasoning process in the mind and then provide the final coding solution. The reasoning process is enclosed within <think> </think> tags, i.e., <think> reasoning process here </think>. Normally, this thinking process first analyzes what the problem is asking and what knowledge it attempts to examine, and then elaborate on the detailed reasoning steps to solve the problem. The assistant could regularly self-question and self-verify existing reasoning to ensure the correctness.  User Query:\n{messages[-1]["content"]}\nRemember to enclose your final coding solution within ```python and ``` tags.'
                messages[-1]["content"] = f'You need to first think about the reasoning process in the mind and then provide the final coding solution. The reasoning process is enclosed within <think> </think> tags, i.e., <think> reasoning process here </think>. Normally, this thinking process first analyzes what the problem is asking and what knowledge it attempts to examine, and then elaborate on the detailed reasoning steps to solve the problem. The assistant could regularly self-question and self-verify existing reasoning to ensure the correctness.  User Query:\n{messages[-1]["content"]}'
                raw_prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) + "<think>\n"
            elif self.sft_template=='sft2rl':
                if self.do_eval:
                    ver = 'zero_detail'
                    templ = templates[ver]
                    raw_prompt = templ.format(prompt=messages[-1]["content"])
                else:
                    raw_prompt = messages 
            elif self.sft_template!='default':
                user_template = templates[self.sft_template]
                temp = messages[-1]['content']
                new = user_template.format(prompt=temp)
                messages[-1]['content'] = new 
                raw_prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False, enable_thinking=True)
                if (self.do_sft or self.do_eval) and not force_nothink: 
                    raw_prompt += "<think>\n"
            else: # default
                qwen3_nothink = getattr(self.config, "qwen3_nothink", False)
                if qwen3_nothink:
                    raw_prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False, enable_thinking=False) # will add <think></think> itself for qwen3-instruct
                    # raw_prompt = raw_prompt + "<think>\n\n</think>\n"
                else:
                    # if self.do_sft or self.do_eval: 
                    #     messages[-1]['content'] = messages[-1]['content'].split("<think>")[-1].strip()
                    raw_prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False, enable_thinking=True)
                    if (self.do_sft or self.do_eval) and not force_nothink: 
                        raw_prompt += "<think>\n"
            model_inputs = self.tokenizer(raw_prompt, return_tensors="pt", add_special_tokens=False)
            input_ids = model_inputs.pop("input_ids")
            attention_mask = model_inputs.pop("attention_mask")

        if item==0:
            print(f"raw messages", messages)
            print(f"raw prompt:\n")
            print(raw_prompt)
            

        input_ids, attention_mask = verl_F.postprocess_data(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=self.max_prompt_length,
            pad_token_id=self.tokenizer.pad_token_id,
            left_pad=True,
            truncation=self.truncation,
        )

        if self.processor is not None and self.processor.image_processor.__class__.__name__ == "Qwen2VLImageProcessor":
            from verl.models.transformers.qwen2_vl import get_rope_index

            position_ids = [
                get_rope_index(
                    self.processor,
                    input_ids=input_ids[0],
                    image_grid_thw=model_inputs.get("image_grid_thw"),
                    video_grid_thw=model_inputs.get("video_grid_thw"),
                    second_per_grid_ts=model_inputs.get("second_per_grid_ts"),
                    attention_mask=attention_mask[0],
                )
            ]  # (1, 3, seq_len)

        else:
            position_ids = compute_position_id_with_mask(attention_mask)

        row_dict["input_ids"] = input_ids[0]
        row_dict["attention_mask"] = attention_mask[0]
        row_dict["position_ids"] = position_ids[0]
        
        if self.do_sft:
            templ = templates['assistant'] 
            resp = row_dict[self.solution_key]
            
            raw_prompt = templ.format(resp).split("<think>")[-1].strip()
            if item==0:
                print(f"[response]\n")
                print(raw_prompt)
            resp = self.tokenizer(raw_prompt, return_tensors="pt", add_special_tokens=False)['input_ids']
            responses = verl_F.postprocess_prompt_only(
                input_ids=resp,
                max_length=self.max_response_length,
                pad_token_id=self.tokenizer.pad_token_id,
                left_pad=False,
                truncation='right',
            )
            row_dict['responses'] = responses.reshape(-1)
            

        raw_prompt_ids = self.tokenizer.encode(raw_prompt, add_special_tokens=False)
        if len(raw_prompt_ids) > self.max_prompt_length:
            if self.truncation == "left":
                raw_prompt_ids = raw_prompt_ids[-self.max_prompt_length :]
            elif self.truncation == "right":
                raw_prompt_ids = raw_prompt_ids[: self.max_prompt_length]
            elif self.truncation == "middle":
                left_half = self.max_prompt_length // 2
                right_half = self.max_prompt_length - left_half
                raw_prompt_ids = raw_prompt_ids[:left_half] + raw_prompt_ids[-right_half:]
            elif self.truncation == "error":
                raise RuntimeError(f"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}.")

        row_dict["raw_prompt_ids"] = raw_prompt_ids
        # encode prompts without chat template
        if self.return_raw_chat:
            row_dict["raw_prompt"] = messages
        
        # get prompts with chat template
        if self.return_full_prompt:
            row_dict["full_prompts"] = raw_prompt # array of strings

        # add index for each prompt
        index = row_dict.get("extra_info", {}).get("index", 0)
        tools_kwargs = row_dict.get("extra_info", {}).get("tools_kwargs", {})
        need_tools_kwargs = row_dict.get("extra_info", {}).get("need_tools_kwargs", self.need_tools_kwargs)
        if need_tools_kwargs and not tools_kwargs:
            logger.warning("tools_kwargs is empty for index {}, data source: {}", index, row_dict["data_source"])
        row_dict["index"] = index
        row_dict["tools_kwargs"] = tools_kwargs
        return row_dict

    def __getstate__(self):
        if not self.serialize_dataset:
            state = self.__dict__.copy()

            if "dataframe" in state:
                del state["dataframe"]
            return state

        return self.__dict__.copy()
