from typing import Any, Dict, List, Optional
import json
import re
import os
from datasets import load_dataset  # pip install datasets -U


# ------------------------------
# 数据加载器 (JSON → prompt)
# ------------------------------
class BaseDataset:
    """用户可根据自身 JSON 结构重写 _build_prompt 方法。"""
    FILE_PATH = "data/demo.jsonl"  # 默认路径，可根据实际情况修改

    def __init__(self, json_path: Optional[str], load_num: int = -1) -> None:
        self.records = self.load(self.FILE_PATH if json_path is None else json_path)
        if load_num > 0:
            self.records = self.records[:load_num]

    def __len__(self) -> int:  # noqa: Dunder length
        return len(self.records)

    def __getitem__(self, idx: int) -> Dict[str, Any]:  # noqa: Dunder getitem
        rec = self.records[idx]
        return {
            "prompt": self._build_prompt(rec),
            "meta": rec,  # 原始字段全部保存，方便回溯
        }

    @classmethod
    def load(cls, file_path: str) -> List[Dict[str, Any]]:
        """统一的加载入口：本地优先，找不到再尝试 HF。"""
        if file_path.startswith("hf:"):
            return cls._load_hf(file_path)

        if os.path.exists(file_path):
            return cls._load_local(file_path)

        # 对于 “看似本地但其实不存在” 的路径，可以自动退回 HF
        try:
            return cls._load_hf(f"hf:{file_path}")
        except Exception as e:
            raise FileNotFoundError(
                f"File '{file_path}' not found locally and failed to load from HF: {e}"
            )


    @staticmethod
    def _load_local(file_path: str) -> List[Dict[str, Any]]:
        """读取本地 JSON/JSONL 文件"""
        records: List[Dict[str, Any]] = []
        if file_path.endswith(".jsonl"):
            with open(file_path, "r", encoding="utf-8") as f:
                for line in f:
                    records.append(json.loads(line))
        elif file_path.endswith(".json"):
            with open(file_path, "r", encoding="utf-8") as f:
                records = json.load(f)
        else:
            raise ValueError(
                "Unsupported file format. Please provide a .jsonl or .json file."
            )
        return records
    
    @staticmethod
    def _load_hf(uri: str) -> List[Dict[str, Any]]:
        """从 Hugging Face Hub 加载数据。

        uri 语法：
            hf:{repo_id}[ @revision ][ :split ]
        例子：
            hf:my-org/my-dataset            # 默认 revision=main, split=train
            hf:my-org/my-dataset@dev:valid  # 指定 revision 和 split
        """
        if not uri.startswith("hf:"):
            raise ValueError("HF uri must start with 'hf:'")
        # 删掉前缀
        spec = uri[3:]

        # 解析 repo_id、revision、split
        repo_part, *split_part = spec.split(":")
        repo_id, *rev_part = repo_part.split("@")
        revision = rev_part[0] if rev_part else "main"
        split = split_part[0] if split_part else "train"

        print(
            f"Loading dataset from HuggingFace Hub: {repo_id} "
            f"(revision={revision}, split={split})"
        )
        ds = load_dataset(repo_id, split=split, revision=revision, trust_remote_code=True)

        # `datasets.Dataset` → List[dict]
        return [dict(r) for r in ds]

    @staticmethod
    def _build_prompt(rec: Dict[str, Any]) -> str:
        """根据实际字段，填充 Prompt 模板。

        默认示例：假设 JSON 每行形如 {"question": ..., "context": ...}
        """
        if "prompt" in rec:
            return rec["prompt"]  # 直接使用现成 prompt
        question = rec.get("question", "")
        context = rec.get("context", "")
        return (
            f"请根据以下上下文回答问题：\n\n{context}\n\n问题：{question}\n\n答案："
        )
    
    def output_processing(self, output: str) -> Any:
        """处理模型输出，默认返回原始字符串。

        用户可根据实际需要重写此方法。
        """
        return output
    

class DOCCIMCQAGenerationDataset(BaseDataset):
    FILE_PATH = "PATH TO CAPTION DATASET"  # 默认路径，可根据实际情况修改

    MCQA_PROMPT = '''
You are an assessment‑item writer.
Your job is to create a single multiple‑choice question (MCQ) that can only be answered by carefully understanding the image described below.

# Image description
{description}

# Your task
1. Read the description attentively.
2. Write one question that targets either
    - a spatial relationship in the image (e.g., relative positions, directions, sizes, distances), or
    - a fine detail visible in the scene (e.g., colors, numbers, small objects, text).
3. Create four answer options (keys A, B, C, D) that all sound plausible but only one is correct.
    - Make the distractors non‑trivial: a casual glance at the description should not reveal the answer.
    - Match the wording and specificity of the correct answer.
4. Mark the correct option with the field answer, whose value is the single capital letter A / B / C / D.
5. Return your result only as JSON inside a Markdown code block fenced with ```json—no additional text.

# Output schema
```json
{{
  "question": "string",
  "options": {{
    "A": "string",
    "B": "string",
    "C": "string",
    "D": "string"
  }},
  "answer": "A|B|C|D"
}}
```

# Example output
```json
{{
  "question": "What is located directly to the left of the red bicycle in the image?",
  "options": {{
    "A": "A yellow fire hydrant",
    "B": "A blue mailbox",
    "C": "A green trash bin",
    "D": "A small potted plant"
  }},
  "answer": "C"
}}
```

# Reminder
Output nothing except the single json‑fenced code block containing your object, with double‑quoted keys and values, and exactly four options labeled A–D.
'''

    @staticmethod
    def _build_prompt(rec: Dict[str, Any]) -> str:
        """根据实际字段，填充 Prompt 模板。

        默认示例：假设 JSON 每行形如 {"question": ..., "context": ...}
        """
        description = rec.get("description", "")

        MCQA_PROMPT = DOCCIMCQAGenerationDataset.MCQA_PROMPT
        return MCQA_PROMPT.format(description=description)
    
    def output_processing(self, output: str) -> dict:
        match = re.search(r"```json\s*(\{.*?\})\s*```", output, re.DOTALL)
        if not match:
            raise ValueError("No JSON code block found in the supplied text.")
        
        json_str = match.group(1)
        try:
            return json.loads(json_str)
        except json.JSONDecodeError as e:
            raise ValueError(f"Invalid JSON content: {e}") from e



class PixmoCapMCQAGenerationDataset(BaseDataset):
    FILE_PATH = "hf:anthracite-org/pixmo-cap-images:train"  # 默认路径，可根据实际情况修改

    MCQA_PROMPT = '''
You are an assessment‑item writer.
Your job is to create a single multiple‑choice question (MCQ) that can only be answered by carefully understanding the image described below.

# Image description
{description}

# Your task
1. Read the description attentively.
2. Write one question that targets either
    - a spatial relationship in the image (e.g., relative positions, directions, sizes, distances), or
    - a fine detail visible in the scene (e.g., colors, numbers, small objects, text).
3. Create four answer options (keys A, B, C, D) that all sound plausible but only one is correct.
    - Make the distractors non‑trivial: a casual glance at the description should not reveal the answer.
    - Match the wording and specificity of the correct answer.
4. Mark the correct option with the field answer, whose value is the single capital letter A / B / C / D.
5. Return your result only as JSON inside a Markdown code block fenced with ```json—no additional text.

# Output schema
```json
{{
  "question": "string",
  "options": {{
    "A": "string",
    "B": "string",
    "C": "string",
    "D": "string"
  }},
  "answer": "A|B|C|D"
}}
```

# Example output
```json
{{
  "question": "What is located directly to the left of the red bicycle in the image?",
  "options": {{
    "A": "A yellow fire hydrant",
    "B": "A blue mailbox",
    "C": "A green trash bin",
    "D": "A small potted plant"
  }},
  "answer": "C"
}}
```

# Reminder
Output nothing except the single json‑fenced code block containing your object, with double‑quoted keys and values, and exactly four options labeled A–D.
'''

    @staticmethod
    def _build_prompt(rec: Dict[str, Any]) -> str:
        """根据实际字段，填充 Prompt 模板。

        默认示例：假设 JSON 每行形如 {"question": ..., "context": ...}
        """
        description = rec.get("caption", "")

        MCQA_PROMPT = PixmoCapMCQAGenerationDataset.MCQA_PROMPT
        return MCQA_PROMPT.format(description=description)
    
    def output_processing(self, output: str) -> dict:
        match = re.search(r"```json\s*(\{.*?\})\s*```", output, re.DOTALL)
        if not match:
            raise ValueError("No JSON code block found in the supplied text.")
        
        json_str = match.group(1)
        try:
            return json.loads(json_str)
        except json.JSONDecodeError as e:
            raise ValueError(f"Invalid JSON content: {e}") from e