# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Union, Optional
import re
from omegaconf import DictConfig

from transformers import PreTrainedTokenizer, ProcessorMixin

from verl.utils.model import compute_position_id_with_mask
import verl.utils.torch_functional as verl_F
from verl.utils.dataset.rl_dataset import RLHFDataset

from meta_researcher.tool.base import BaseToolEnv

class ToolRLDataset(RLHFDataset):
    """
    Dataset for tool use in RLHF
    """
    def __init__(
        self,
        data_files: Union[str, List[str]],
        tokenizer: PreTrainedTokenizer,
        config: DictConfig,
        processor: Optional[ProcessorMixin] = None,
        env: Optional[BaseToolEnv] = None,
    ):
        self.env = env
        self.use_default_tool_template = config.get("use_default_tool_template", True)
        if self.use_default_tool_template and self.env is not None:
            self.tools = [tool.tool_description for tool in self.env.tools]
        self.use_custom_system_prompt = config.get("use_custom_system_prompt", False)
        super().__init__(data_files, tokenizer, config, processor)

    def _build_messages(self, example: dict):
        messages = example.pop(self.prompt_key)
        
        # Apply custom system prompt if needed
        if self.use_custom_system_prompt and self.env is not None:
            if isinstance(messages, list):
                if messages[0]["role"] == "system":
                    messages[0]["content"] = messages[0]["content"] + self.env.tools_format_func()
                else:
                    system_msg = [{"role": "system", "content": self.env.tools_format_func()}]
                    messages = system_msg + messages
        
        # Handle image/video content if present
        if self.image_key in example or self.video_key in example:
            for message in messages:
                content = message["content"]
                content_list = []
                for segment in re.split("(<image>|<video>)", content):
                    if segment == "<image>":
                        content_list.append({"type": "image"})
                    elif segment == "<video>":
                        content_list.append({"type": "video"})
                    else:
                        content_list.append({"type": "text", "text": segment})

                message["content"] = content_list

        return messages

    def __getitem__(self, item):
        """
        Note that we also return the raw_input_ids so that it can be combined with other chat template
        """
        row_dict = self.dataframe[item]
        messages = self._build_messages(row_dict)
        model_inputs = {}

        # Apply the appropriate chat template based on settings
        if self.use_default_tool_template and hasattr(self, 'tools'):
            raw_prompt = self.tokenizer.apply_chat_template(messages, tools=self.tools, add_generation_prompt=True, tokenize=False)
        else:
            raw_prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)

        if self.processor is not None:
            from verl.utils.dataset.vision_utils import process_image, process_video

            multi_modal_data = {}

            images = None
            if self.image_key in row_dict:
                images = [process_image(image) for image in row_dict.pop(self.image_key)]
                multi_modal_data["image"] = images

            videos = None
            if self.video_key in row_dict:
                videos = [process_video(video) for video in row_dict.pop(self.video_key)]
                multi_modal_data["video"] = [video.numpy() for video in videos]

            model_inputs = self.processor(text=[raw_prompt], images=images, videos=videos, return_tensors="pt")

            input_ids = model_inputs.pop("input_ids")
            attention_mask = model_inputs.pop("attention_mask")

            if "second_per_grid_ts" in model_inputs:
                model_inputs.pop("second_per_grid_ts")

            # There's a trap here, multi_modal_inputs has to be a dict, not BatchFeature
            row_dict["multi_modal_data"] = multi_modal_data
            row_dict["multi_modal_inputs"] = dict(model_inputs)

            # second_per_grid_ts isn't used for training, just for mrope
            row_dict["multi_modal_inputs"].pop("second_per_grid_ts", None)

        else:
            model_inputs = self.tokenizer(raw_prompt, return_tensors="pt", add_special_tokens=False)
            input_ids = model_inputs.pop("input_ids")
            attention_mask = model_inputs.pop("attention_mask")

        input_ids, attention_mask = verl_F.postprocess_data(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=self.max_prompt_length,
            pad_token_id=self.tokenizer.pad_token_id,
            left_pad=True,
            truncation=self.truncation,
        )

        if self.processor is not None and self.processor.image_processor.__class__.__name__ == "Qwen2VLImageProcessor":
            from verl.models.transformers.qwen2_vl import get_rope_index

            position_ids = [
                get_rope_index(
                    self.processor,
                    input_ids=input_ids[0],
                    image_grid_thw=model_inputs.get("image_grid_thw"),
                    video_grid_thw=model_inputs.get("video_grid_thw"),
                    second_per_grid_ts=model_inputs.get("second_per_grid_ts"),
                    attention_mask=attention_mask[0],
                )
            ]

        else:
            position_ids = compute_position_id_with_mask(attention_mask)

        row_dict["input_ids"] = input_ids[0]
        row_dict["attention_mask"] = attention_mask[0]
        row_dict["position_ids"] = position_ids[0]

        raw_prompt_ids = self.tokenizer.encode(raw_prompt, add_special_tokens=False)
        if len(raw_prompt_ids) > self.max_prompt_length:
            if self.truncation == "left":
                raw_prompt_ids = raw_prompt_ids[-self.max_prompt_length :]
            elif self.truncation == "right":
                raw_prompt_ids = raw_prompt_ids[: self.max_prompt_length]
            elif self.truncation == "avaiable":
                raw_prompt_ids = self.truncate_with_think_tags(raw_prompt_ids)
            elif self.truncation == "error":
                raise RuntimeError(f"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}.")

        row_dict["raw_prompt_ids"] = raw_prompt_ids
        # encode prompts without chat template
        if self.return_raw_chat:
            row_dict["raw_prompt"] = messages

        # add index for each prompt
        index = row_dict.get("extra_info", {}).get("index", 0)
        row_dict["index"] = index

        return row_dict
    
    def truncate_with_think_tags(self, token_ids):
        """
        当文本超长时，优先从前往后删除 <tool_response></tool_response> 标签内的内容，
        其次删除 <think></think> 标签内的内容，最后再进行截断
        """
        if len(token_ids) <= self.max_prompt_length:
            return token_ids
        
        tool_start = 151665  # <tool_response> 标签的token ID
        tool_end = 151666    # </tool_response> 标签的token ID
        think_start = 151667 # <think> 标签的token ID
        think_end = 151668   # </think> 标签的token ID
        
        token_ids = self._remove_tags(token_ids, tool_start, tool_end)
        
        if len(token_ids) > self.max_prompt_length:
            token_ids = self._remove_tags(token_ids, think_start, think_end)
        
        if len(token_ids) > self.max_prompt_length:
            token_ids = token_ids[-self.max_prompt_length:]
        
        return token_ids

    def _remove_tags(self, token_ids, start_token, end_token):
        """辅助函数：从前往后删除指定标签内的内容"""
        start_indices = [i for i, token in enumerate(token_ids) if token == start_token]
        end_indices = [i for i, token in enumerate(token_ids) if token == end_token]
        
        tags = []
        for start in start_indices:
            for i, end in enumerate(end_indices):
                if end > start and (not tags or end > tags[-1][0]):
                    tags.append((start, end))
                    break
        
        for start, end in tags:
            if len(token_ids) <= self.max_prompt_length:
                break
            token_ids = token_ids[:start] + [self.tokenizer.encode(" ", add_special_tokens=False)[1]] + token_ids[end+1:]
        
        return token_ids