from __future__ import annotations

import copy
import logging
import os
from collections import defaultdict
from typing import Any, Dict, List, Optional, Union
import datasets
import numpy as np
import torch
from omegaconf import DictConfig, ListConfig
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer, ProcessorMixin
import verl.utils.torch_functional as verl_F
from verl.utils.model import compute_position_id_with_mask
from PIL import Image
import io
import json
import re

try:  
    from verl.models.transformers.qwen2_vl import get_rope_index  # type: ignore
except Exception:  
    from verl.models.transformers.qwen2_5_vl import get_rope_index  # type: ignore

logger = logging.getLogger(__name__)

def _to_pil(img):
    if isinstance(img, Image.Image):
        return img.convert("RGB")
    if isinstance(img, str):
        return Image.open(img).convert("RGB")  
    if isinstance(img, bytes):
        return Image.open(io.BytesIO(img)).convert("RGB")
    if isinstance(img, np.ndarray):
        if img.ndim == 2:  
            img = np.stack([img]*3, axis=-1)
        return Image.fromarray(img.astype("uint8")).convert("RGB")
    return img

def _resize(img_pil: Image.Image, keep_ratio: bool = True) -> Image.Image:
    SZ = 560
    if not keep_ratio:
        return img_pil.resize((SZ, SZ), Image.BICUBIC)

    # letterbox：按比例缩放，居中填充
    w, h = img_pil.size
    scale = min(SZ / w, SZ / h)
    nw, nh = int(round(w * scale)), int(round(h * scale))
    img_resize = img_pil.resize((nw, nh), Image.BICUBIC)

    canvas = Image.new("RGB", (SZ, SZ), (0, 0, 0))
    pad_l = (SZ - nw) // 2
    pad_t = (SZ - nh) // 2
    canvas.paste(img_resize, (pad_l, pad_t))
    return canvas

def collate_fn(data_list: List[Dict[str, Any]]) -> Dict[str, Any]:
    tensors: Dict[str, List[torch.Tensor]] = defaultdict(list)
    non_tensors: Dict[str, List[Any]] = defaultdict(list)

    for sample in data_list:
        for k, v in sample.items():
            if isinstance(v, torch.Tensor):
                tensors[k].append(v)
            else:
                non_tensors[k].append(v)

    for k, vals in tensors.items():
        tensors[k] = torch.stack(vals, dim=0)
    for k, vals in non_tensors.items():
        non_tensors[k] = np.fromiter(vals, dtype=object, count=len(vals))

    return {**tensors, **non_tensors}

class RLHFDataset(Dataset):
    def __init__(
        self,
        data_files: Union[str, List[str]],
        tokenizer: PreTrainedTokenizer,
        config: DictConfig,
        processor: Optional[ProcessorMixin] = None,
    ) -> None:
        if not isinstance(data_files, (list, ListConfig)):
            data_files = [data_files]

        self.data_files = copy.deepcopy(data_files)
        self.original_data_files = copy.deepcopy(data_files)  # for resume

        self.tokenizer = tokenizer
        self.processor = processor
        self.config = config

        self.cache_dir = os.path.expanduser(config.get("cache_dir", "~/.cache/verl/rlhf"))
        self.prompt_key = config.get("prompt_key", "prompt")
        self.image_key = config.get("image_key", "images") 
        self.video_key = config.get("video_key", "videos")
        self.max_prompt_length = int(config.get("max_prompt_length", 1024))
        self.return_raw_chat = bool(config.get("return_raw_chat", False))
        self.truncation = config.get("truncation", "error")
        self.filter_overlong_prompts = bool(config.get("filter_overlong_prompts", True))
        self.num_workers = int(config.get("filter_overlong_prompts_workers", max(1, (os.cpu_count() or 4) // 4)))
        self.num_workers = min(self.num_workers, os.cpu_count() or 1)
        self.chat_template_func = config.get("chat_template_func", None)
        self.need_tools_kwargs = bool(config.get("need_tools_kwargs", False))
        self.filter_prompts = bool(config.get("filter_prompts", True))
        self.serialize_dataset = False
        self.apply_chat_template_kwargs = dict(config.get("apply_chat_template_kwargs", {}))
        self.return_full_prompt = bool(config.get("return_full_prompt", False))
        self.return_multi_modal_inputs = bool(config.get("return_multi_modal_inputs", True))
        self.use_shm = bool(config.get("use_shm", False))
        self.answer_type = config.get("answer_type", None)
        self.user_prompt_close = (
            "You are an expert clinician analyzing a medical image.\n"
            "Based ONLY on the image, answer the following question: \"{Question}\"\n\n"
            "Follow these strict rules for your answer:\n"
            "- If the question is a yes/no question, answer with ONLY \"yes\" or \"no\".\n"
            "- If the question provides lettered options (e.g., A, B, C), answer with ONLY the single correct letter.\n"
            "- If the question provides word/phrase options, answer with ONLY the single correct option, copied exactly.\n\n"
            "Output your reasoning and final answer in this exact format:\n"
            "<think>Briefly describe your step-by-step reasoning based on visual evidence from the image.</think>\n"
            "<answer>Your final single answer here.</answer>"
        )
        self.user_prompt_open = (
            "You are an expert clinician analyzing a medical image.\n"
            "Based ONLY on the image, answer the following question: \"{Question}\"\n\n"
            "Follow these strict rules for your answer:\n"
            "- Provide a single, concise clinical word, phrase, or value.\n"
            "- Do NOT write a full sentence or add explanatory text.\n"
            "- Match the expected format of the answer precisely (e.g., units, case, symbols).\n\n"
            "Output your reasoning and final answer in this exact format:\n"
            "<think>Briefly describe your step-by-step reasoning based on visual evidence from the image and how you arrived at your answer.</think>\n"
            "<answer>Your final concise answer here.</answer>"
        )
        self._download()
        self._read_files_and_tokenize()

    def _download(self, use_origin_parquet=False):
        from verl.utils.fs import copy_to_local
        data_files = self.data_files if not use_origin_parquet else self.original_data_files
        for i, parquet_file in enumerate(data_files):
            self.data_files[i] = copy_to_local(src=parquet_file, cache_dir=self.cache_dir)

    def _read_files_and_tokenize(self) -> None:
        frames: List[datasets.Dataset] = []
        for f in self.data_files:
            if "vqa-rad" in f:
                ds = datasets.load_dataset(f)["train"]
            else:
                ds = datasets.load_dataset("parquet", data_files=f)["train"]
            frames.append(ds)
        self.dataframe: datasets.Dataset = datasets.concatenate_datasets(frames)
        print(f"dataset len: {len(self.dataframe)}")
        if self.answer_type:
            target = str(self.answer_type).upper()

            import json as _json
            def _get_answer_type_row(d):
                t = d.get("answer_type", None)
                if t is None:
                    ei = d.get("extra_info", {})
                    if isinstance(ei, str):
                        try: ei = _json.loads(ei)
                        except Exception: ei = {}
                    if isinstance(ei, dict):
                        t = ei.get("answer_type", None)
                return (t or "").upper()
            self.dataframe = self.dataframe.filter(
                lambda d: _get_answer_type_row(d) == target,
                num_proc=1,
                desc=f"Filtering by answer_type={target}",
            )
            print(f"[after answer_type filter={target}] len: {len(self.dataframe)}")
        if self.filter_overlong_prompts and self.processor is not None:
            from verl.utils.dataset.vision_utils import process_image as _proc_img
            def _doc2len(doc: Dict[str, Any]) -> int:
                q = (doc.get("new_q") or doc.get("question") or doc.get("text") or "").strip()
                messages = [
                    {"role": "system", "content": "You are a helpful assistant."},
                    {"role": "user", "content": [
                        {"type": "image"},
                        {"type": "text", "text": self.user_prompt_close.format(Question=q.lower().strip("."))},
                    ]},
                ]
                raw_prompt = self.processor.apply_chat_template(
                    messages, add_generation_prompt=True, tokenize=False, **self.apply_chat_template_kwargs
                )
                imgs = doc.get(self.image_key, None)
                if imgs is None and "image" in doc:
                    imgs = doc["image"]
                imgs_list = imgs if isinstance(imgs, (list, tuple)) else [imgs]
                imgs_list = [im for im in imgs_list if im is not None]
                images = [_proc_img(_resize(_to_pil(im), keep_ratio=True)) for im in imgs_list] if imgs_list else None
                return len(self.processor(text=[raw_prompt], images=images)["input_ids"][0])
            self.dataframe = self.dataframe.filter(
                lambda d: _doc2len(d) <= self.max_prompt_length,
                num_proc=1,
                desc=f"Filtering prompts longer than {self.max_prompt_length} tokens",
            )
            print(f"filter dataset len: {len(self.dataframe)}")

    def resume_dataset_state(self) -> None:
        self.serialize_dataset = not hasattr(self, "original_data_files")
        if not self.serialize_dataset:
            self._download(use_origin_parquet=True)
            self._read_files_and_tokenize()
        else:
            print(r"old dataloader ckpt file is used, please train from scratch for better ckpt performance")

    def __len__(self) -> int:
        return len(self.dataframe)
    
    def __getitem__(self, index: int) -> Dict[str, Any]:
        row: Dict[str, Any] = self.dataframe[index]
        q = (row.get("new_q") or row.get("text") or "").strip() 
        user_text = self.user_prompt_open.format(Question=q.lower().strip("."))
        messages = [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": [
                {"type": "image"},
                {"type": "text", "text": user_text},
            ]},
        ]
        raw_prompt: str = self.processor.apply_chat_template(
            messages, add_generation_prompt=True, tokenize=False, **self.apply_chat_template_kwargs
        ) if self.processor is not None else self.tokenizer.apply_chat_template(
            messages, add_generation_prompt=True, tokenize=False, **self.apply_chat_template_kwargs
        )
        from verl.utils.dataset.vision_utils import process_image as _proc_img
        imgs = row.get(self.image_key, None)
        if imgs is None and "image" in row:
            imgs = [row["image"]]
        imgs_list = imgs if isinstance(imgs, (list, tuple)) else [imgs]
        imgs_list = [im for im in imgs_list if im is not None]
        images_pil = [_resize(_to_pil(im), keep_ratio=True) for im in imgs_list] if imgs_list else []  
        from verl.utils.dataset.vision_utils import process_image as _proc_img
        images_proc = [_proc_img(p) for p in images_pil] if images_pil else None  
        if self.processor is None:
            raise RuntimeError("Qwen2.5-VL multimodal training requires providing a processor (AutoProcessor).")
        model_inputs = self.processor(text=[raw_prompt], images=images_proc, return_tensors="pt")
        input_ids = model_inputs.pop("input_ids")      
        if "second_per_grid_ts" in model_inputs:
            model_inputs.pop("second_per_grid_ts")
        input_ids, attention_mask = verl_F.postprocess_data(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=self.max_prompt_length,
            pad_token_id=self.tokenizer.pad_token_id,
            left_pad=True,
            truncation=self.truncation,
        ) 
        if hasattr(self.processor, "image_processor") and self.processor.image_processor is not None \
           and ("Qwen2VLImageProcessor" in self.processor.image_processor.__class__.__name__
                or "Qwen2_5_VLImageProcessor" in self.processor.image_processor.__class__.__name__):
            position_ids = get_rope_index(
                self.processor,
                input_ids=input_ids[0],
                image_grid_thw=model_inputs.get("image_grid_thw"),
                video_grid_thw=model_inputs.get("video_grid_thw"),
                second_per_grid_ts=model_inputs.get("second_per_grid_ts"),
                attention_mask=attention_mask[0],
            )  
        else:
            position_ids = compute_position_id_with_mask(attention_mask)[0] 
        multi_modal_data: Dict[str, Any] = {}
        if images_proc:
            multi_modal_data["image"] = images_proc
        if position_ids.ndim == 1:
            position_ids = position_ids.unsqueeze(0).repeat(3, 1)
        out: Dict[str, Any] = {
            "input_ids": input_ids[0],            
            "attention_mask": attention_mask[0],  
            "position_ids": position_ids,         
            "multi_modal_data": multi_modal_data,
        }
        if self.return_multi_modal_inputs:
            mmi = dict(model_inputs)
            mmi.pop("second_per_grid_ts", None)
            out["multi_modal_inputs"] = mmi
        out["images_pil"] = images_pil
        raw_prompt_ids = self.tokenizer.encode(raw_prompt, add_special_tokens=False)
        if len(raw_prompt_ids) > self.max_prompt_length:
            if self.truncation == "left":
                raw_prompt_ids = raw_prompt_ids[-self.max_prompt_length:]
            elif self.truncation == "right":
                raw_prompt_ids = raw_prompt_ids[: self.max_prompt_length]
            elif self.truncation == "middle":
                left = self.max_prompt_length // 2
                right = self.max_prompt_length - left
                raw_prompt_ids = raw_prompt_ids[:left] + raw_prompt_ids[-right:]
            elif self.truncation == "error":
                raise RuntimeError(
                    f"Prompt length {len(raw_prompt_ids)} > max_prompt_length={self.max_prompt_length}."
                )
        out["raw_prompt_ids"] = raw_prompt_ids
        gt = row.get("new_a", None) 
        if gt is None:
            for k in ("gt_answer", "label", "labels", "answers"):
                v = row.get(k, None)
                if v is not None:
                    if isinstance(v, (list, tuple)):
                        gt = v[0] if len(v) > 0 else ""
                    else:
                        gt = v
                    break
        if gt is None:
            gt = "" 
        out["answer"] = str(gt)
        extra_info = row.get("extra_info", None)
        if isinstance(extra_info, str):
            try:
                extra_info = json.loads(extra_info)
            except Exception:
                extra_info = {}
        if not isinstance(extra_info, dict):
            extra_info = {}
        extra_info["answer_type"] = "OPEN"
        out["extra_info"] = extra_info
        if self.return_raw_chat:
            out["raw_prompt"] = messages
        if self.return_full_prompt:
            out["full_prompts"] = raw_prompt
        return out

    def __getstate__(self):
        if not self.serialize_dataset:
            state = self.__dict__.copy()
            state.pop("dataframe", None)
            return state
        return self.__dict__.copy()
