import os
import copy
from typing import Optional, Union
from PIL import Image
import datasets
import pandas as pd
import numpy as np
from omegaconf import DictConfig
from transformers import PreTrainedTokenizer, ProcessorMixin

from verl.utils.dataset.rl_dataset import RLHFDataset
from recipe.fileagent.utils.vision_utils import load_image, create_empty_image


class CustomRLHFDataset(RLHFDataset):
    def __init__(
        self,
        data_files: str | list[str],
        tokenizer: PreTrainedTokenizer,
        config: DictConfig,
        processor: Optional[ProcessorMixin] = None,
    ):
        self.agent_name = config.get("agent_name", "fileagent_agent")

        self.allow_heterogeneous_schemas = config.get("allow_heterogeneous_schemas", False)
        self.read_pq_use_threads = config.get("read_pq_use_threads", True)

        self.enable_image_resize = config.get("enable_image_resize", False)
        self.min_image_resolution = config.get("min_image_resolution", 4 * 28 * 28)
        self.max_image_resolution = config.get("max_image_resolution", 8192 * 28 * 28)

        super().__init__(
            data_files=data_files,
            tokenizer=tokenizer,
            config=config,
            processor=processor,
        )

    def _read_files_and_tokenize(self):
        if self.allow_heterogeneous_schemas:
            print(
                f"`allow_heterogeneous_schemas` is enabled, it supports input files with different column names. "
                f"For any column that is missing in a given row, the value will be filled with `None`."
            )

            dataframes = []
            for parquet_file in self.data_files:
                dataframe = pd.read_parquet(parquet_file, use_threads=self.read_pq_use_threads)
                dataframes.append(dataframe)
            combined_df = pd.concat(dataframes, ignore_index=True)
            combined_df.replace({np.nan: None}, inplace=True)  # replace NaN with None
            combined_df[self.prompt_key] = combined_df[self.prompt_key].map(
                lambda x: x.tolist() if isinstance(x, np.ndarray) else x
            )
            self.dataframe: list[dict] = combined_df.to_dict(orient="records")

            print(f"dataset columns: {combined_df.columns.tolist()}")
            # Only print image info if images exist in the dataset
            if 'images' in self.dataframe[0] and self.dataframe[0]['images'] is not None:
                print("row 0 images keys: ", self.dataframe[0]['images'][0].keys())
            print(f"dataset len: {len(self.dataframe)}")

            self.dataframe = self.maybe_filter_out_long_prompts(self.dataframe)
        else:
            dataframes = []
            for parquet_file in self.data_files:
                # read parquet files and cache
                dataframe = datasets.load_dataset("parquet", data_files=parquet_file)["train"]
                dataframes.append(dataframe)
            self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes)

            print(f"dataset len: {len(self.dataframe)}")

        self.dataframe = self._maybe_update_system_prompt(self.dataframe)
        self.dataframe = self._maybe_resize_images(self.dataframe)
        self.dataframe = self.maybe_filter_out_long_prompts(self.dataframe)
    
    def maybe_filter_out_long_prompts(self, dataframe: Union[datasets.Dataset, list[dict]] = None):
        if self.allow_heterogeneous_schemas:
            # filter out too long prompts
            if self.filter_overlong_prompts:
                tokenizer = self.tokenizer
                processor = self.processor
                prompt_key = self.prompt_key
                image_key = self.image_key
                video_key = self.video_key

                if processor is not None:
                    from verl.utils.dataset.vision_utils import process_image, process_video

                    def doc2len(doc) -> int:
                        messages = self._build_messages(doc)
                        raw_prompt = self.processor.apply_chat_template(
                            messages, add_generation_prompt=True, tokenize=False, **self.apply_chat_template_kwargs
                        )
                        images = [process_image(image) for image in doc[image_key]] if doc.get(image_key) else None
                        videos = [process_video(video) for video in doc[video_key]] if doc.get(video_key) else None

                        return len(processor(text=[raw_prompt], images=images, videos=videos)["input_ids"][0])

                else:

                    def doc2len(doc) -> int:
                        return len(
                            tokenizer.apply_chat_template(
                                doc[prompt_key], add_generation_prompt=True, **self.apply_chat_template_kwargs
                            )
                        )

                dataframe = [row for row in dataframe if doc2len(row) <= self.max_prompt_length]

                print(f"filter dataset len: {len(dataframe)}")
        else:
            dataframe = super().maybe_filter_out_long_prompts(dataframe)

        return dataframe
    
    def _maybe_resize_images(self, dataframe: Union[datasets.Dataset, list[dict]] = None):
        def _add_pixel_info(example):
            imgs = example.get(self.image_key, None)
            if imgs:
                new_imgs = []
                for img_dict in imgs:
                    img_dict["max_pixels"] = self.max_image_resolution
                    img_dict["min_pixels"] = self.min_image_resolution
                    new_imgs.append(img_dict)
                example[self.image_key] = new_imgs
            return example

        if self.enable_image_resize:
            if self.allow_heterogeneous_schemas:
                dataframe = [_add_pixel_info(row) for row in dataframe]
            else:
                dataframe = dataframe.map(_add_pixel_info, num_proc=self.num_workers, desc="Adding pixel info")

        return dataframe

    def _maybe_update_system_prompt(self, dataframe: Union[datasets.Dataset, list[dict]] = None):
        """Update the system prompt if `replace_system_prompt` is enabled."""
        if not self.config.replace_system_prompt:
            return dataframe

        # Load new system prompt from file
        new_sp_path = self.config.get("new_sp_path", "")
        if not os.path.exists(new_sp_path):
            raise ValueError(f"new_sp_path {new_sp_path} does not exist")
        print(f"Warning: replace system prompt is enabled, will replace the system prompt with {new_sp_path}")
        with open(new_sp_path, 'r') as f:
            new_sp = f.read()
        print(f"new_sp: {new_sp}")

        # Update system prompt in dataset
        def _update_sp_func(example):
            prompt = example[self.prompt_key]
            if prompt[0]["role"] == "system":
                new_prompt = [{"role": "system", "content": new_sp}] + prompt[1:]
            else:
                new_prompt = [{"role": "system", "content": new_sp}] + prompt
            example[self.prompt_key] = new_prompt
            return example

        if self.allow_heterogeneous_schemas:
            dataframe = [_update_sp_func(row) for row in dataframe]
        else:
            dataframe = dataframe.map(
                function=_update_sp_func,
                num_proc=self.num_workers,
                desc="Updating system prompt in dataset",
            )

        return dataframe

    def __getitem__(self, item):
        # Keep the original data
        raw_data = self.dataframe[item].copy()

        # Call the base class method
        row_dict = super().__getitem__(item)

        # Update agent name
        row_dict["agent_name"] = self.agent_name

        # Prepare tool data
        ## for image zoom in
        processed_images = row_dict.get("multi_modal_data", {}).get("image", None)
        original_images = None
        model_input_sizes = None
        if processed_images:
            original_images = [load_image(image) for image in raw_data[self.image_key]]
            model_input_sizes = [(img.width, img.height) for img in processed_images]
        # Update tool kwargs
        # Extract prewrites from extra_info.tools_kwargs.global_tool.create_kwargs if exists
        prewrites = row_dict.get("extra_info", {}).get("tools_kwargs", {}).get("global_tool", {}).get("create_kwargs", {}).get("prewrites", [])
        
        row_dict["tools_kwargs"] = {
            "calc_reward": {
                "create_kwargs": {
                    "ground_truth": row_dict.get("reward_model", {}).get("ground_truth") or row_dict.get("ground_truth", ""),
                    "data_source": row_dict.get("data_source", ""),
                    "extra_info": row_dict.get("extra_info", {}),
                },
            },
            "global_tool":{
                "create_kwargs": {
                    "images": original_images,
                    "model_input_sizes": model_input_sizes,
                    "prewrites": prewrites,  # 添加 prewrites
                },
            },
            "ImageZoomIn": {
                "create_kwargs": {
                    "images": original_images,
                    "model_input_sizes": model_input_sizes,
                },
            },
        }

        if self.allow_heterogeneous_schemas:
            # Restore the original data
            self.dataframe[item] = raw_data

        return row_dict


if __name__ == "__main__":
    from verl.utils import hf_processor, hf_tokenizer
    from omegaconf import OmegaConf

    model_path = "/mnt/hdfs/fileagent_storage/shared/models/Qwen2.5-VL-7B-Instruct"
    tokenizer = hf_tokenizer(model_path)
    processor = hf_processor(model_path, use_fast=True)
    config = OmegaConf.load("verl/trainer/config/data/legacy_data.yaml")
    config.replace_system_prompt = True
    config.new_sp_path = "recipe/fileagent/prompts/sp_v1.md"
    config.allow_heterogeneous_schemas = True
    config.enable_image_resize = True
    config.filter_overlong_prompts = False
    data_files = [
        "/mnt/hdfs/fileagent_storage/users/<your_username>/data/simpleqa_norm/train.parquet",
        "/mnt/hdfs/fileagent_storage/users/<your_username>/data/DeepEyes-Datasets-47k-clean/data_0.1.2_visual_toolbox_v2.train.parquet",
    ]
    dataset = CustomRLHFDataset(
        data_files=data_files,
        tokenizer=tokenizer,
        config=config,
        processor=processor,
    )
    print(dataset.dataframe[0])
    print(dataset[0])
