import copy
import logging
import os
import re
from collections import defaultdict
from typing import Optional
from PIL import Image
import io
import random

import datasets
import numpy as np
import torch
from omegaconf import DictConfig, ListConfig
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer, ProcessorMixin

import verl.utils.torch_functional as verl_F
from verl.utils.model import compute_position_id_with_mask
from .jigsaw_utils import slice_image, mask_random_regions, process_video_jigsaw, center_crop,draw_numbered_markers, process_bev_images

logger = logging.getLogger(__name__)


def collate_fn(data_list: list[dict]) -> dict:
    """
    Collate a batch of sample dicts into batched tensors and arrays.

    Args:
        data_list: List of dicts mapping feature names to torch.Tensor or other values.

    Returns:
        Dict where tensor entries are stacked into a torch.Tensor of shape
        (batch_size, \*dims) and non-tensor entries are converted to
        np.ndarray of dtype object with shape (batch_size,).
    """
    tensors = defaultdict(list)
    non_tensors = defaultdict(list)

    for data in data_list:
        for key, val in data.items():
            if isinstance(val, torch.Tensor):
                tensors[key].append(val)
            else:
                non_tensors[key].append(val)

    for key, val in tensors.items():
        tensors[key] = torch.stack(val, dim=0)

    for key, val in non_tensors.items():
        non_tensors[key] = np.array(val, dtype=object)

    return {**tensors, **non_tensors}

def random_crop_tile(img, tile_size):
    W, H = img.size
    tile_w, tile_h = tile_size

    if tile_w <= 0 or tile_h <= 0:
        raise ValueError("tile_size must be positive (tile_w, tile_h).")
    if tile_w > W or tile_h > H:
        raise ValueError(f"tile_size {tile_size} does not fit inside image size {(W, H)}.")

    x1 = random.randint(0, W - tile_w)
    y1 = random.randint(0, H - tile_h)
    x2 = x1 + tile_w
    y2 = y1 + tile_h

    return img.crop((x1, y1, x2, y2))

def random_rotate(lst):
    k = random.randrange(len(lst))   # random shift amount
    return lst[k:] + lst[:k]

class RLHFDataset(Dataset):
    """
    Load and preprocess RLHF data from Parquet files.

    - Caches files locally.
    - Reads into a HuggingFace Dataset and tokenizes prompts.
    - Optionally handles images/videos via a ProcessorMixin.
    - Filters prompts over a max length.
    - Supports resuming from checkpoints.

    Args:
        data_files (str or list): Path(s) to Parquet file(s).
        tokenizer (PreTrainedTokenizer): For the tokenization of text to token IDs.
        config (DictConfig): Options like cache_dir, prompt_key, max_prompt_length, truncation, etc.
        processor (ProcessorMixin, optional): Multimodal preprocessor for images/videos.
    """

    def __init__(
        self,
        data_files: str | list[str],
        tokenizer: PreTrainedTokenizer,
        config: DictConfig,
        processor: Optional[ProcessorMixin] = None,
    ):
        if not isinstance(data_files, list | ListConfig):
            data_files = [data_files]

        self.data_files = copy.deepcopy(data_files)
        self.original_data_files = copy.deepcopy(data_files)  # use for resume
        self.tokenizer = tokenizer
        self.processor = processor
        self.config = config

        self.cache_dir = os.path.expanduser(config.get("cache_dir", "~/.cache/verl/rlhf"))
        self.prompt_key = config.get("prompt_key", "prompt")
        self.image_key = config.get("image_key", "images")
        self.video_key = config.get("video_key", "videos")
        self.max_prompt_length = config.get("max_prompt_length", 1024)
        self.return_raw_chat = config.get("return_raw_chat", False)
        self.return_full_prompt = config.get("return_full_prompt", False)
        self.truncation = config.get("truncation", "error")
        self.filter_overlong_prompts = config.get("filter_overlong_prompts", True)
        self.multimodal_folder = config.get("multimodal_folder", None)
        self.max_pixels = config.get("max_pixels", None)
        self.max_frames = config.get("max_frames", None)
        self.dataset_repeat_times = config.get("dataset_repeat_times", None)
        if self.dataset_repeat_times is not None:
            if type(self.dataset_repeat_times) == str:
                self.dataset_repeat_times = eval(self.dataset_repeat_times)
        self.tcs_loader = None
        self.num_workers = config.get("filter_overlong_prompts_workers", max(1, os.cpu_count() // 4))
        self.num_workers = min(self.num_workers, os.cpu_count())
        self.use_shm = config.get("use_shm", False)
        self.chat_template_func = config.get("chat_template_func", None)
        self.need_tools_kwargs = config.get("need_tools_kwargs", False)
        self.filter_prompts = config.get("filter_prompts", True)
        self.serialize_dataset = False
        self.return_multi_modal_inputs = config.get("return_multi_modal_inputs", True)

        self._download()
        self._read_files_and_tokenize()

    def _download(self, use_origin_parquet=False):
        from verl.utils.fs import copy_to_local

        data_files = self.data_files if not use_origin_parquet else self.original_data_files
        for i, parquet_file in enumerate(data_files):
            self.data_files[i] = copy_to_local(src=parquet_file, cache_dir=self.cache_dir, use_shm=self.use_shm)

    def _read_files_and_tokenize(self):
        dataframes = []
        repeat_times = [1] * len(self.data_files)
        if self.dataset_repeat_times is not None:
            repeat_times = self.dataset_repeat_times

        dataset_index = 0
        all_columns = set()

        # --- First pass: collect all column names across datasets
        for parquet_file in self.data_files:
            df = datasets.load_dataset("parquet", data_files=parquet_file)["train"]
            all_columns.update(df.column_names)

        # --- Second pass: load, align missing columns, and repeat/sample
        for parquet_file, repeat_time in zip(self.data_files, repeat_times):
            df = datasets.load_dataset("parquet", data_files=parquet_file)["train"]
            df = df.add_column("dataset_index", [dataset_index] * len(df))

            # Fill in missing columns with None
            for col in all_columns:
                if col not in df.column_names:
                    df = df.add_column(col, [None] * len(df))

            if repeat_time >= 1:
                for _ in range(int(repeat_time)):
                    dataframes.append(df)
            elif 0 < repeat_time < 1:
                n_samples = int(len(df) * repeat_time)
                sampled = df.shuffle(seed=42).select(range(n_samples))
                dataframes.append(sampled)

            dataset_index += 1

        # Concatenate safely (all datasets now share same columns)
        self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes)

        print(f"dataset len: {len(self.dataframe)}")

        self.dataframe = self.maybe_filter_out_long_prompts(self.dataframe)

    def maybe_filter_out_long_prompts(self, dataframe: datasets.Dataset = None):
        # filter out too long prompts
        if self.filter_overlong_prompts:
            tokenizer = self.tokenizer
            processor = self.processor
            prompt_key = self.prompt_key
            image_key = self.image_key
            video_key = self.video_key

            if processor is not None:
                from verl.utils.dataset.vision_utils import process_image, process_video

                def doc2len(doc) -> int:
                    messages = self._build_messages(doc)
                    raw_prompt = self.processor.apply_chat_template(
                        messages, add_generation_prompt=True, tokenize=False
                    )
                    images = (
                        [process_image(image) for image in messages.pop(image_key)] if image_key in messages else None
                    )
                    videos = (
                        [process_video(video) for video in messages.pop(video_key)] if video_key in messages else None
                    )

                    return len(processor(text=[raw_prompt], images=images, videos=videos)["input_ids"][0])

            else:

                def doc2len(doc) -> int:
                    return len(tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True))

            dataframe = dataframe.filter(
                lambda doc: doc2len(doc) <= self.max_prompt_length,
                num_proc=self.num_workers,
                desc=f"Filtering prompts longer than {self.max_prompt_length} tokens",
            )

            print(f"filter dataset len: {len(dataframe)}")
        return dataframe

    def resume_dataset_state(self):
        self.serialize_dataset = not hasattr(self, "original_data_files")
        # resume dataframe if not it's serialized in data.pt
        if not self.serialize_dataset:
            self._download(use_origin_parquet=True)  # download and resume from original parquet files
            self._read_files_and_tokenize()
        else:
            print(r"old dataloader ckpt file is used, please train from scratch for better ckpt performance")

    def __len__(self):
        return len(self.dataframe)

    def _build_messages(self, example: dict):
        messages: list = example.pop(self.prompt_key)

        if self.image_key in example or self.video_key in example:
            for message in messages:
                content = message["content"]
                content_list = []
                segments = re.split("(<image>|<video>)", content)
                segments = [item for item in segments if item != ""]
                for segment in segments:
                    if segment == "<image>":
                        content_list.append({"type": "image"})
                    elif segment == "<video>":
                        content_list.append({"type": "video"})
                    else:
                        content_list.append({"type": "text", "text": segment})

                message["content"] = content_list

        return messages

    def __getitem__(self, item):
        """
        Note that we also return the raw_input_ids so that it can be combined with other chat template
        """
        row_dict: dict = self.dataframe[item]

        if self.image_key in row_dict:
            if 'data_source' in row_dict and 'jigsaw' in row_dict['data_source']:
                raw_images_info = row_dict.get(self.image_key)
                assert len(raw_images_info) == 1

                if 'co3d_jigsaw' in row_dict['data_source']:
                    frames = raw_images_info[0]['path']
                    combinations = raw_images_info[0]['combinations']
                    indices = random.choice(combinations)
                    frames = [frames[index] for index in indices]

                    frames = [Image.open(os.path.join(self.multimodal_folder, frame)).convert('RGB') for frame in frames]
                    if 'resize_size' in raw_images_info[0]:
                        resize_size_list = raw_images_info[0]['resize_size']
                        if resize_size_list[0][0] > 0:
                            frames = [frame.resize(resize_size) for frame, resize_size in zip(frames, resize_size_list)]
                    anchor_image = frames[0]
                    other_images = frames[1:]
                    frame_order = list(range(len(frames)-1))
                    random.shuffle(frame_order)
                    gt = [frame_order.index(_) +1 for _ in range(len(frame_order))]
                    row_dict['reward_model']['ground_truth'] = gt
                    other_images = [other_images[_] for _ in frame_order]
                    row_dict[self.image_key] = [anchor_image] + other_images
                elif 'egoexo4d_jigsaw' in row_dict['data_source']:
                    frames = raw_images_info[0]['path']
                    frame_ids = raw_images_info[0]['frame_ids']
                    frame_id = random.choice(frame_ids)
                    frames = [frame_path.format(frame_id=frame_id) for frame_path in frames]

                    frames = random_rotate(frames)
                    frames = [Image.open(os.path.join(self.multimodal_folder, frame)).convert('RGB') for frame in frames]
                    if 'resize_size' in raw_images_info[0]:
                        resize_size_list = raw_images_info[0]['resize_size']
                        frames = [frame.resize(resize_size) for frame, resize_size in zip(frames, resize_size_list)]

                    anchor_image = frames[0]
                    other_images = frames[1:]
                    frame_order = list(range(len(frames)-1))
                    random.shuffle(frame_order)
                    gt = [frame_order.index(_) +1 for _ in range(len(frame_order))]
                    row_dict['reward_model']['ground_truth'] = gt
                    other_images = [other_images[_] for _ in frame_order]
                    row_dict[self.image_key] = [anchor_image] + other_images
                elif 'scannet_jigsaw' in row_dict['data_source'] or 'scannetpp_jigsaw' in row_dict['data_source']:
                    frames = raw_images_info[0]['path']
                    frames = [Image.open(os.path.join(self.multimodal_folder, frame)).convert('RGB') for frame in frames]
                    if 'resize_size' in raw_images_info[0]:
                        resize_size_list = raw_images_info[0]['resize_size']
                        frames = [frame.resize(resize_size) for frame, resize_size in zip(frames, resize_size_list)]

                    anchor_image = frames[0]
                    other_images = frames[1:]
                    frame_order = list(range(len(frames)-1))
                    random.shuffle(frame_order)
                    gt = [frame_order.index(_) +1 for _ in range(len(frame_order))]
                    row_dict['reward_model']['ground_truth'] = gt
                    other_images = [other_images[_] for _ in frame_order]
                    row_dict[self.image_key] = [anchor_image] + other_images
                elif 'scannet_depth_jigsaw' in row_dict['data_source'] or 'arkitscenes_depth_jigsaw' in row_dict['data_source']:
                    raw_image_info = raw_images_info[0]
                    image = Image.open(os.path.join(self.multimodal_folder, raw_image_info['path'][0])).convert('RGB')

                    points = raw_image_info['points']
                    points_order = list(range(len(points)))
                    random.shuffle(points_order)
                    gt = [points_order.index(_) +1 for _ in range(len(points_order))]
                    points = [points[_] for _ in points_order]
                    row_dict['reward_model']['ground_truth'] = gt
                    image = draw_numbered_markers(image, points, radius=15, alpha=120)

                    if 'resize_size' in raw_image_info:
                        resize_size = raw_image_info['resize_size']
                        if type(resize_size[0]) is list:
                            resize_size = resize_size[0]
                        if resize_size[0] > 0 and resize_size[1] > 0:
                            image = image.resize(resize_size)
                    row_dict[self.image_key] = [image]
                elif 'scannetpp_bev_jigsaw' in row_dict['data_source']:
                    frames = raw_images_info[0]['path']
                    camera_info = raw_images_info[0]['camera_info']
                    K_bev = camera_info['K_bev']
                    w2c_bev = camera_info['w2c_bev']
                    K_cam = camera_info['K_cam']
                    camera_poses = camera_info['camera_poses']
                    frames = [Image.open(os.path.join(self.multimodal_folder, frame)).convert('RGB') for frame in frames]
                    frames[0] = process_bev_images(frames[0], K_bev, w2c_bev, K_cam, camera_poses)
                    if 'resize_size' in raw_images_info[0]:
                        resize_size_list = raw_images_info[0]['resize_size']
                        frames = [frame.resize(resize_size) for frame, resize_size in zip(frames, resize_size_list)]

                    bev_image = frames[0]
                    other_images = frames[1:]
                    frame_order = list(range(len(frames)-1))
                    random.shuffle(frame_order)
                    gt = [index + 1 for index in frame_order]
                    row_dict['reward_model']['ground_truth'] = gt
                    other_images = [other_images[_] for _ in frame_order]
                    row_dict[self.image_key] = [bev_image] + other_images

                else:
                    raw_image_info = raw_images_info[0]
                    if type(raw_image_info['path']) is list:
                        raw_image_info['path'] = raw_image_info['path'][0]
                    image = Image.open(os.path.join(self.multimodal_folder, raw_image_info['path'])).convert('RGB')
                    if 'resize_size' in raw_image_info:
                        resize_size = raw_image_info['resize_size']
                        if type(resize_size[0]) is list:
                            resize_size = resize_size[0]
                        if resize_size[0] > 0 and resize_size[1] > 0:
                            image = image.resize(resize_size)
                    if 'crop_size' in raw_image_info:
                        crop_size = raw_image_info['crop_size']
                        if type(crop_size[0]) is list:
                            crop_size = crop_size[0]
                        if crop_size[0] > 0:
                            image = center_crop(image, crop_size[0], crop_size[1])

                    jigsaw_config = raw_image_info['jigsaw_config']
                    num_tiles = jigsaw_config[0]*jigsaw_config[1]
                    tile_order = list(range(num_tiles))
                    if 'noise_image_candidates' in raw_image_info:
                        noise_tiles = []
                        noise_image_candidates = raw_image_info['noise_image_candidates']
                        noise_tiles_count = raw_image_info['noise_tiles_count']
                        tile_width, tile_height = raw_image_info['tile_width'], raw_image_info['tile_height']
                        while len(noise_tiles) < noise_tiles_count:
                            noise_image = random.choice(noise_image_candidates)
                            noise_image = Image.open(os.path.join(self.multimodal_folder, noise_image)).convert('RGB')
                            noise_tile = random_crop_tile(noise_image, [tile_width, tile_height])
                            noise_tiles.append(noise_tile)

                        random.shuffle(tile_order)
                        gt = [tile_order.index(_) +1 for _ in range(len(tile_order)-noise_tiles_count)]
                        image_tiles = slice_image(image, jigsaw_config[0], jigsaw_config[1]) + noise_tiles
                    else:
                        random.shuffle(tile_order)
                        gt = [tile_order.index(_) +1 for _ in range(len(tile_order))]
                        image_tiles = slice_image(image, jigsaw_config[0], jigsaw_config[1])
                    row_dict['reward_model']['ground_truth'] = gt
                    if 'partial_hint' in raw_image_info:
                        partial_hint = [str(idx) for idx in gt]
                        partial_hint_num = raw_image_info['partial_hint']
                        indices = random.sample(range(len(gt)), len(gt)-partial_hint_num)
                        for i in indices:
                            partial_hint[i] = "_"
                        partial_hint = ", ".join(partial_hint)
                        row_dict['prompt'][-1]['content'] = row_dict['prompt'][-1]['content'].format(partial_hint=partial_hint)
                    image_tiles = [image_tiles[_] for _ in tile_order]
                            
                    row_dict[self.image_key] = image_tiles

            if 'data_source' in row_dict and 'VL-Cogito' in row_dict['data_source']:
                raw_images_info = row_dict.get(self.image_key)
                assert len(raw_images_info) == 1
                raw_image_info = raw_images_info[0]
                if type(raw_image_info['path']) is list:
                    raw_image_info['path'] = raw_image_info['path'][0]
                image = Image.open(os.path.join(self.multimodal_folder, raw_image_info['path'])).convert('RGB')
                if 'resize_size' in raw_image_info:
                    resize_size = raw_image_info['resize_size']
                    if type(resize_size[0]) is list:
                        resize_size = resize_size[0]
                    if resize_size[0] > 0 and resize_size[1] > 0:
                        image = image.resize(resize_size)
                if 'crop_size' in raw_image_info:
                    crop_size = raw_image_info['crop_size']
                    if type(crop_size[0]) is list:
                        crop_size = crop_size[0]
                    if crop_size[0] > 0:
                        image = center_crop(image, crop_size[0], crop_size[1])                        
                row_dict[self.image_key] = [image]

            if 'data_source' in row_dict and 'fillblank' in row_dict['data_source']:
                raw_images_info = row_dict.get(self.image_key)
                assert len(raw_images_info) == 1
                for image_i, raw_image_info in enumerate(raw_images_info):
                    if isinstance(raw_image_info, dict) and 'blank_config' in raw_image_info:
                        blank_config = raw_image_info['blank_config']
                        chosen_indices = raw_image_info['chosen_indices']
                        target_tile_index = raw_image_info['target_tile_index']
                        image = Image.open(os.path.join(self.multimodal_folder, raw_image_info['path'])).convert('RGB')
                        masked_image, image_tiles = mask_random_regions(image, blank_config[0], blank_config[1], blank_config[2], chosen_indices=chosen_indices)
                        image_tile = image_tiles[target_tile_index]
                        
                row_dict[self.image_key] = [masked_image, image_tile]

        messages = self._build_messages(row_dict)
        model_inputs = {}

        if self.processor is not None:
            from verl.utils.dataset.vision_utils import process_image, process_video

            raw_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
            multi_modal_data = {}

            images = None
            if self.image_key in row_dict and row_dict.get(self.image_key, None) is not None:
                images = [process_image(image) for image in row_dict.pop(self.image_key)]

                # due to the image key is "image" instead of "images" in vllm, we need to use "image" here
                # link: https://github.com/vllm-project/vllm/blob/3c545c0c3b98ee642373a308197d750d0e449403/vllm/multimodal/parse.py#L205
                multi_modal_data["image"] = images

            videos = None
            fps_list = None
            if self.video_key in row_dict and row_dict.get(self.video_key, None) is not None:
                if 'data_source' in row_dict and 'jigsaw' in row_dict['data_source']:
                    if 's3:' not in row_dict[self.video_key][0]['path']:
                        row_dict[self.video_key][0]['path'] = os.path.join(self.multimodal_folder, row_dict[self.video_key][0]['path'])
                    try:
                        videos, fps_list = process_video_jigsaw(row_dict[self.video_key][0], max_pixels=self.max_pixels, fps_max_frames=self.max_frames)
                    except:
                        import traceback
                        traceback.print_exc()
                        random_idx = random.randint(0, self.__len__()-1)
                        return self.__getitem__(random_idx)

                    clip_order = row_dict[self.video_key][0]['clip_order']
                    clip_order = list(range(len(clip_order)))
                    random.shuffle(clip_order)
                    gt = [clip_order.index(_) +1 for _ in range(len(clip_order))]
                    row_dict['reward_model']['ground_truth'] = gt
                    videos = [videos[_] for _ in clip_order]
                    fps_list = [fps_list[_] for _ in clip_order]

                else:
                    videos = [process_video(video) for video in row_dict.pop(self.video_key)]

                # due to the video key is "video" instead of "videos" in vllm, we need to use "video" here
                # link: https://github.com/vllm-project/vllm/blob/3c545c0c3b98ee642373a308197d750d0e449403/vllm/multimodal/parse.py#L205
                multi_modal_data["video"] = [video.numpy() for video in videos]

            model_inputs = self.processor(text=[raw_prompt], images=images, videos=videos, return_tensors="pt")

            input_ids = model_inputs.pop("input_ids")
            attention_mask = model_inputs.pop("attention_mask")

            # if "second_per_grid_ts" in model_inputs:
            #     model_inputs.pop("second_per_grid_ts")

            if len(multi_modal_data):
                # There's a trap here, multi_modal_inputs has to be a dict, not BatchFeature
                row_dict["multi_modal_data"] = multi_modal_data

                # We will do batch.union() in the trainer,
                # so we cannot have "multi_modal_inputs" in row_dict if rollout generates new multi_modal_inputs
                if self.return_multi_modal_inputs:
                    row_dict["multi_modal_inputs"] = dict(model_inputs)

                    # second_per_grid_ts isn't used for training, just for mrope
                    row_dict["multi_modal_inputs"].pop("second_per_grid_ts", None)

            

        else:
            raw_prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
            model_inputs = self.tokenizer(raw_prompt, return_tensors="pt", add_special_tokens=False)
            input_ids = model_inputs.pop("input_ids")
            attention_mask = model_inputs.pop("attention_mask")

        input_ids, attention_mask = verl_F.postprocess_data(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=self.max_prompt_length,
            pad_token_id=self.tokenizer.pad_token_id,
            left_pad=True,
            truncation=self.truncation,
        )

        if self.processor is not None and "Qwen2VLImageProcessor" in self.processor.image_processor.__class__.__name__:
            from verl.models.transformers.qwen2_vl import get_rope_index

            position_ids = [
                get_rope_index(
                    self.processor,
                    input_ids=input_ids[0],
                    image_grid_thw=model_inputs.get("image_grid_thw"),
                    video_grid_thw=model_inputs.get("video_grid_thw"),
                    second_per_grid_ts=model_inputs.get("second_per_grid_ts"),
                    attention_mask=attention_mask[0],
                )
            ]  # (1, 3, seq_len)

        else:
            position_ids = compute_position_id_with_mask(attention_mask)

        row_dict["input_ids"] = input_ids[0]
        row_dict["attention_mask"] = attention_mask[0]
        row_dict["position_ids"] = position_ids[0]

        raw_prompt_ids = self.tokenizer.encode(raw_prompt, add_special_tokens=False)
        if len(raw_prompt_ids) > self.max_prompt_length:
            if self.truncation == "left":
                raw_prompt_ids = raw_prompt_ids[-self.max_prompt_length :]
            elif self.truncation == "right":
                raw_prompt_ids = raw_prompt_ids[: self.max_prompt_length]
            elif self.truncation == "middle":
                left_half = self.max_prompt_length // 2
                right_half = self.max_prompt_length - left_half
                raw_prompt_ids = raw_prompt_ids[:left_half] + raw_prompt_ids[-right_half:]
            elif self.truncation == "error":
                raise RuntimeError(f"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}.")

        row_dict["raw_prompt_ids"] = raw_prompt_ids
        # encode prompts without chat template
        if self.return_raw_chat:
            row_dict["raw_prompt"] = messages

        # get prompts with chat template
        if self.return_full_prompt:
            row_dict["full_prompts"] = raw_prompt  # array of strings

        # add index for each prompt
        index = row_dict.get("extra_info", {}).get("index", 0)
        tools_kwargs = row_dict.get("extra_info", {}).get("tools_kwargs", {})
        interaction_kwargs = row_dict.get("extra_info", {}).get("interaction_kwargs", {})
        need_tools_kwargs = row_dict.get("extra_info", {}).get("need_tools_kwargs", self.need_tools_kwargs)
        if need_tools_kwargs and not tools_kwargs:
            logger.warning("tools_kwargs is empty for index {}, data source: {}", index, row_dict["data_source"])
        row_dict["index"] = index
        row_dict["tools_kwargs"] = tools_kwargs
        row_dict["interaction_kwargs"] = interaction_kwargs
        return row_dict

    def __getstate__(self):
        if not self.serialize_dataset:
            state = self.__dict__.copy()

            if "dataframe" in state:
                del state["dataframe"]
            return state

        return self.__dict__.copy()
