# data_utils.py
import os
import random
from typing import Optional

import torch
from torch.utils.data import Dataset
from datasets import load_dataset, Dataset as HFDataset
import pandas as pd


# ====================== 通用 collate ======================

def collate_sft(batch):
    """
    所有 SFT Dataset 都返回 dict: {input_ids, attention_mask, labels}
    这里统一堆叠。
    """
    input_ids = torch.stack([b["input_ids"] for b in batch], dim=0)
    attention_mask = torch.stack([b["attention_mask"] for b in batch], dim=0)
    labels = torch.stack([b["labels"] for b in batch], dim=0)
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels,
    }


# ====================== 公共工具：prompt+answer → input/labels ======================

def _ensure_pad_token(tokenizer):
    """
    确保 tokenizer 有 pad_token，若没有则用 eos_token 顶上。
    """
    if tokenizer.pad_token_id is None:
        if tokenizer.eos_token_id is not None:
            tokenizer.pad_token = tokenizer.eos_token
            tokenizer.pad_token_id = tokenizer.eos_token_id
        else:
            # 兜底：随便取个 0，当 pad
            tokenizer.add_special_tokens({"pad_token": "[PAD]"})


def build_prompt_answer_encoding(tokenizer, prompt_text: str, answer_text: str, max_length: int):
    """
    核心：只在答案 token 上打 label，prompt 部分 label = -100。

    返回:
        input_ids: [max_length]
        attention_mask: [max_length]
        labels: [max_length]
    """
    _ensure_pad_token(tokenizer)

    # 分别 encode prompt 和 answer，不加特殊符号，避免对齐问题
    enc_prompt = tokenizer(prompt_text, add_special_tokens=False)
    enc_answer = tokenizer(answer_text, add_special_tokens=False)

    prompt_ids = enc_prompt["input_ids"]
    answer_ids = enc_answer["input_ids"]

    # 可选：在答案后面加一个 eos 作为结束标记（也参与监督）
    if tokenizer.eos_token_id is not None:
        answer_ids = answer_ids + [tokenizer.eos_token_id]

    # 总长度，如超长则从 prompt 左侧裁掉，保证答案完整
    total_len = len(prompt_ids) + len(answer_ids)
    if total_len > max_length:
        overflow = total_len - max_length
        if overflow >= len(prompt_ids):
            # 极端情况：prompt 太长，把 prompt 压缩到只保留一个 token
            # 这里简单处理：保留最后 (len(prompt_ids) - overflow) 个，至少留 1 个
            keep = max(len(prompt_ids) - overflow, 1)
            prompt_ids = prompt_ids[-keep:]
        else:
            prompt_ids = prompt_ids[overflow:]
        total_len = len(prompt_ids) + len(answer_ids)

    input_ids = prompt_ids + answer_ids
    attention_mask = [1] * len(input_ids)

    # labels: prompt 部分 -100，答案部分真实 token id
    labels = [-100] * len(prompt_ids) + answer_ids

    # pad 到 max_length
    pad_len = max_length - len(input_ids)
    if pad_len > 0:
        input_ids += [tokenizer.pad_token_id] * pad_len
        attention_mask += [0] * pad_len
        labels += [-100] * pad_len

    return (
        torch.tensor(input_ids, dtype=torch.long),
        torch.tensor(attention_mask, dtype=torch.long),
        torch.tensor(labels, dtype=torch.long),
    )


# ====================== MMLU SFT ======================

def format_mmlu_example(dataset: HFDataset, include_answer: bool = True):
    """
    旧版本保留，当前实现不再使用，只做兼容。
    """
    choices = ["A", "B", "C", "D"]
    data = []
    for ques, cho, ans in zip(dataset["question"], dataset["choices"], dataset["answer"]):
        prompt = ques
        for i in range(len(cho)):
            prompt += "\n{}. {}".format(choices[i], cho[i])
        prompt += "\nAnswer:"
        if include_answer:
            prompt += " {}. {}\n\n".format(choices[ans], choices[ans])
        data.append(prompt)
    return data


def _load_mmlu_raw(
    data_dir: str,
    seed: int,
    num_samples: Optional[int] = None,
    split: str = "train"
) -> HFDataset:
    """
    逻辑参考你给的 get_mmlu_trainenc：
    - subclass 列表完全一致
    - 支持 num_samples（总样本数），会按 subclass 均匀分配
    - data_dir 下有 mmlu 子目录
    """
    from tqdm import tqdm as tqdm_local

    subclass = [
        'abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge',
        'college_biology', 'college_chemistry', 'college_computer_science', 'college_mathematics',
        'college_medicine', 'college_physics', 'computer_security', 'conceptual_physics',
        'econometrics', 'electrical_engineering', 'elementary_mathematics', 'formal_logic',
        'global_facts', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science',
        'high_school_european_history', 'high_school_geography', 'high_school_government_and_politics',
        'high_school_macroeconomics', 'high_school_mathematics', 'high_school_microeconomics',
        'high_school_physics', 'high_school_psychology', 'high_school_statistics',
        'high_school_us_history', 'high_school_world_history', 'human_aging', 'human_sexuality',
        'international_law', 'jurisprudence', 'logical_fallacies', 'machine_learning', 'management',
        'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes', 'moral_scenarios',
        'nutrition', 'philosophy', 'prehistory', 'professional_accounting', 'professional_law',
        'professional_medicine', 'professional_psychology', 'public_relations', 'security_studies',
        'sociology', 'us_foreign_policy', 'virology', 'world_religions'
    ]

    random.seed(seed)
    keys = ["question", "subject", "choices", "answer"]
    traindata = {key: [] for key in keys}
    mmlu_path = os.path.join(data_dir, "mmlu")

    if split == "train":
        print("Loading MMLU Train Set for Training!")
        if num_samples is None:
            # 不限样本数：直接把所有 train / validation 拼一起
            for class_name in tqdm_local(subclass, desc="Loading all MMLU subclasses"):
                try:
                    data = load_dataset(mmlu_path, class_name, split="train")
                except Exception:
                    data = load_dataset(mmlu_path, class_name, split="validation")
                for key in keys:
                    traindata[key].extend(data[key])
        else:
            # 和你原来的 get_mmlu_trainenc 一样：按 subclass 平均分配 nsamples
            subnum = num_samples // len(subclass)
            extra = num_samples % len(subclass)
            num_list = [subnum + 1] * extra + [subnum] * (len(subclass) - extra)
            random.shuffle(num_list)

            for num, class_name in tqdm_local(
                zip(num_list, subclass),
                total=len(subclass),
                desc="Loading the subclass in MMLU",
            ):
                if num == 0:
                    continue
                try:
                    data = load_dataset(mmlu_path, class_name, split="train").shuffle(seed=seed)[:num]
                except Exception:
                    data = load_dataset(mmlu_path, class_name, split="validation").shuffle(seed=seed)[:num]
                for key in keys:
                    traindata[key].extend(data[key])

    elif split == "test":
        print("Loading MMLU Test Set for Evaluation!")
        for class_name in tqdm_local(subclass, desc="Loading all MMLU subclasses"):
            data = load_dataset(mmlu_path, class_name, split="test")
            for key in keys:
                traindata[key].extend(data[key])

    else:
        raise ValueError(f"Unsupported split for MMLU: {split}")

    dataset = HFDataset.from_dict(traindata)
    return dataset


class MMLUSFTDataset(Dataset):
    """
    用于 SFT 的 MMLU 数据（答案 token loss）：
    - input_ids = prompt(question+choices+"Answer:") + answer_letter(+" eos")
    - labels    = [-100] * len(prompt) + answer_tokens(+" eos")
    """
    def __init__(
        self,
        data_dir: str,
        tokenizer,
        max_length: int,
        seed: int = 42,
        num_samples: Optional[int] = None,
        include_answer: bool = True,  # 保留参数但逻辑总是监督答案
        split: str = "train"
    ):
        self.tokenizer = tokenizer
        self.max_length = max_length

        raw_dataset = _load_mmlu_raw(
            data_dir=data_dir,
            seed=seed,
            num_samples=num_samples,
            split=split
        )

        input_ids_list = []
        attention_mask_list = []
        labels_list = []

        letters = ["A", "B", "C", "D"]

        for q, choices, ans in zip(
            raw_dataset["question"], raw_dataset["choices"], raw_dataset["answer"]
        ):
            # 构造 prompt（不带答案）
            prompt = q
            for i, ch in enumerate(choices):
                prompt += f"\n{letters[i]}. {ch}"
            prompt += "\nAnswer:"

            # 答案文本：这里只让模型生成选项字母即可（也可以扩展成 " A. xxx"）
            answer_letter = letters[ans]
            answer_text = f" {answer_letter}"

            ids, mask, lbl = build_prompt_answer_encoding(
                self.tokenizer, prompt, answer_text, self.max_length
            )
            input_ids_list.append(ids)
            attention_mask_list.append(mask)
            labels_list.append(lbl)

        self.input_ids = torch.stack(input_ids_list, dim=0)
        self.attention_mask = torch.stack(attention_mask_list, dim=0)
        self.labels = torch.stack(labels_list, dim=0)

    def __len__(self):
        return self.input_ids.size(0)

    def __getitem__(self, idx):
        return {
            "input_ids": self.input_ids[idx],
            "attention_mask": self.attention_mask[idx],
            "labels": self.labels[idx],
        }


# ====================== BoolQ SFT ======================

class BoolQDataset(Dataset):
    """
    用于 SFT 的 BoolQ 数据集（答案 token loss）：
    - input_ids = prompt(passage+question+options+"Answer:") + " A"/" B"
    - labels    = [-100]*len(prompt) + 答案 token
    """

    @staticmethod
    def _build_prompt(passage: str, question: str) -> str:
        prompt = (
            "### Task:\nRead the following passage and only answer the Yes/No question based on it.\n\n"
            f"### Passage:\n{passage}\n\n"
            f"### Question:\n{question}\n\n"
            "### Options:\nA. Yes\nB. No\n\n"
            "### Answer:"
        )
        return prompt

    def __init__(
        self,
        data_path: str,
        tokenizer,
        max_length: int,
        split: str = "train",
        num_samples: Optional[int] = None,
        seed: int = 42,
    ):
        self.tokenizer = tokenizer
        self.max_length = max_length

        ds = load_dataset(
            "parquet",
            data_files=data_path,
            split=split,
        )

        if num_samples is not None and num_samples < len(ds):
            rng = random.Random(seed)
            indices = list(range(len(ds)))
            rng.shuffle(indices)
            indices = indices[:num_samples]
            ds = ds.select(indices)

        input_ids_list = []
        attention_mask_list = []
        labels_list = []

        for item in ds:
            passage = item["passage"]
            question = item["question"]
            answer_bool = item["answer"]
            answer_letter = "A" if bool(answer_bool) else "B"

            prompt = self._build_prompt(passage, question)
            answer_text = f" {answer_letter}"

            ids, mask, lbl = build_prompt_answer_encoding(
                self.tokenizer, prompt, answer_text, self.max_length
            )
            input_ids_list.append(ids)
            attention_mask_list.append(mask)
            labels_list.append(lbl)

        self.input_ids = torch.stack(input_ids_list, dim=0)
        self.attention_mask = torch.stack(attention_mask_list, dim=0)
        self.labels = torch.stack(labels_list, dim=0)

    def __len__(self):
        return self.input_ids.size(0)

    def __getitem__(self, idx):
        return {
            "input_ids": self.input_ids[idx],
            "attention_mask": self.attention_mask[idx],
            "labels": self.labels[idx],
        }


# ====================== PIQA SFT ======================

class PIQADataset(Dataset):
    """
    用于 SFT 的 PIQA 数据集（答案 token loss）：
    - input_ids = prompt(question+2 options+"Answer:") + " A"/" B"
    - labels    = [-100]*len(prompt) + 答案 token
    """

    @staticmethod
    def _build_prompt(question: str, choices):
        assert len(choices) == 2, f"PIQA choices 应该有 2 个，实际得到 {len(choices)} 个"
        prompt = (
            "### Task:\n"
            "Choose the most physically plausible solution to achieve the goal.\n\n"
            f"### Goal:\n{question}\n\n"
            "### Options:\n"
            f"A. {choices[0]}\n"
            f"B. {choices[1]}\n\n"
            "### Answer:"
        )
        return prompt

    def __init__(
        self,
        data_path: str,
        tokenizer,
        max_length: int,
        split: str = "train",
        num_samples: Optional[int] = None,
        seed: int = 42,
    ):
        self.tokenizer = tokenizer
        self.max_length = max_length

        ds = load_dataset(
            "parquet",
            data_files=data_path,
            split=split,
        )

        if num_samples is not None and num_samples < len(ds):
            rng = random.Random(seed)
            indices = list(range(len(ds)))
            rng.shuffle(indices)
            indices = indices[:num_samples]
            ds = ds.select(indices)

        input_ids_list = []
        attention_mask_list = []
        labels_list = []

        for item in ds:
            question = item["question"]
            choices = item["choices"]

            if "answer_index" in item and item["answer_index"] is not None:
                ans_idx = int(item["answer_index"])
                assert ans_idx in (0, 1)
                answer_letter = "A" if ans_idx == 0 else "B"
            else:
                ans = str(item["answer"]).strip().upper()
                if ans == "A":
                    answer_letter = "A"
                elif ans == "B":
                    answer_letter = "B"
                else:
                    raise ValueError(f"Unexpected answer format: {ans}")

            prompt = self._build_prompt(question, choices)
            answer_text = f" {answer_letter}"

            ids, mask, lbl = build_prompt_answer_encoding(
                self.tokenizer, prompt, answer_text, self.max_length
            )
            input_ids_list.append(ids)
            attention_mask_list.append(mask)
            labels_list.append(lbl)

        self.input_ids = torch.stack(input_ids_list, dim=0)
        self.attention_mask = torch.stack(attention_mask_list, dim=0)
        self.labels = torch.stack(labels_list, dim=0)

    def __len__(self):
        return self.input_ids.size(0)

    def __getitem__(self, idx):
        return {
            "input_ids": self.input_ids[idx],
            "attention_mask": self.attention_mask[idx],
            "labels": self.labels[idx],
        }


# ====================== ARC-Challenge SFT ======================

class ARCChallengeDataset(Dataset):
    """
    ARC-Challenge（答案 token loss）：
    - input_ids = prompt(question+choices+"Answer:") + " A/B/C/D"
    - labels    = [-100]*len(prompt) + 答案 token
    """

    @staticmethod
    def _build_prompt(question: str, choice_texts, choice_labels):
        lines = []
        for lab, txt in zip(choice_labels, choice_texts):
            lines.append(f"{lab}. {txt}")
        options_str = "\n".join(lines)

        prompt = (
            "### Task:\n"
            "Choose the best answer to the following question.\n\n"
            f"### Question:\n{question}\n\n"
            f"### Options:\n{options_str}\n\n"
            "### Answer:"
        )
        return prompt

    def __init__(
        self,
        data_path: str,
        tokenizer,
        max_length: int,
        split: str = "train",
        num_samples: Optional[int] = None,
        seed: int = 42,
    ):
        self.tokenizer = tokenizer
        self.max_length = max_length

        ds = load_dataset(
            "parquet",
            data_files=data_path,
            split=split,
        )

        if num_samples is not None and num_samples < len(ds):
            rng = random.Random(seed)
            indices = list(range(len(ds)))
            rng.shuffle(indices)
            ds = ds.select(indices[:num_samples])

        input_ids_list = []
        attention_mask_list = []
        labels_list = []

        for item in ds:
            question = item["question"]
            choice_texts = item["choices"]["text"]
            choice_labels = item["choices"]["label"]
            answer_key = str(item["answerKey"]).strip()

            if answer_key not in choice_labels:
                raise ValueError(
                    f"answerKey {answer_key} not in labels {choice_labels} for id={item.get('id','?')}"
                )

            prompt = self._build_prompt(question, choice_texts, choice_labels)
            answer_text = f" {answer_key}"

            ids, mask, lbl = build_prompt_answer_encoding(
                self.tokenizer, prompt, answer_text, self.max_length
            )
            input_ids_list.append(ids)
            attention_mask_list.append(mask)
            labels_list.append(lbl)

        self.input_ids = torch.stack(input_ids_list, dim=0)
        self.attention_mask = torch.stack(attention_mask_list, dim=0)
        self.labels = torch.stack(labels_list, dim=0)

    def __len__(self):
        return self.input_ids.size(0)

    def __getitem__(self, idx):
        return {
            "input_ids": self.input_ids[idx],
            "attention_mask": self.attention_mask[idx],
            "labels": self.labels[idx],
        }


# ====================== ARC-Easy SFT ======================

class ARCEasyDataset(Dataset):
    """
    ARC-Easy（答案 token loss）、格式同 ARC-Challenge
    """

    @staticmethod
    def _build_prompt(question: str, choice_texts, choice_labels):
        options = []
        for lab, txt in zip(choice_labels, choice_texts):
            options.append(f"{lab}. {txt}")
        options_str = "\n".join(options)

        prompt = (
            "### Task:\n"
            "Choose the best answer to the following question.\n\n"
            f"### Question:\n{question}\n\n"
            f"### Options:\n{options_str}\n\n"
            "### Answer:"
        )
        return prompt

    def __init__(
        self,
        data_path: str,
        tokenizer,
        max_length: int,
        split: str = "train",
        num_samples: Optional[int] = None,
        seed: int = 42,
    ):
        self.tokenizer = tokenizer
        self.max_length = max_length

        ds = load_dataset(
            "parquet",
            data_files=data_path,
            split=split,
        )

        if num_samples is not None and num_samples < len(ds):
            rng = random.Random(seed)
            indices = list(range(len(ds)))
            rng.shuffle(indices)
            ds = ds.select(indices[:num_samples])

        input_ids_list = []
        attention_mask_list = []
        labels_list = []

        for item in ds:
            question = item["question"]
            choice_texts = item["choices"]["text"]
            choice_labels = item["choices"]["label"]
            answer_key = str(item["answerKey"]).strip()

            if answer_key not in choice_labels:
                raise ValueError(f"answerKey {answer_key} not in {choice_labels}, id={item.get('id','?')}")

            prompt = self._build_prompt(question, choice_texts, choice_labels)
            answer_text = f" {answer_key}"

            ids, mask, lbl = build_prompt_answer_encoding(
                self.tokenizer, prompt, answer_text, self.max_length
            )
            input_ids_list.append(ids)
            attention_mask_list.append(mask)
            labels_list.append(lbl)

        self.input_ids = torch.stack(input_ids_list, dim=0)
        self.attention_mask = torch.stack(attention_mask_list, dim=0)
        self.labels = torch.stack(labels_list, dim=0)

    def __len__(self):
        return self.input_ids.size(0)

    def __getitem__(self, idx):
        return {
            "input_ids": self.input_ids[idx],
            "attention_mask": self.attention_mask[idx],
            "labels": self.labels[idx],
        }


# ====================== HellaSwag SFT ======================

class HellaSwagDataset(Dataset):
    """
    HellaSwag（答案 token loss）：
    - input_ids = prompt(ctx+4 endings+"Answer:") + " A/B/C/D"
    """

    @staticmethod
    def _build_prompt(ctx: str, endings):
        labels = ["A", "B", "C", "D"]
        option_lines = []
        for lab, text in zip(labels, endings):
            option_lines.append(f"{lab}. {text}")
        options_str = "\n".join(option_lines)

        prompt = (
            "### Task:\n"
            "Choose the most plausible continuation of the following context.\n\n"
            f"### Context:\n{ctx}\n\n"
            f"### Options:\n{options_str}\n\n"
            "### Answer:"
        )
        return prompt

    def __init__(
        self,
        data_path: str,
        tokenizer,
        max_length: int,
        split: str = "train",
        num_samples: Optional[int] = None,
        seed: int = 42,
    ):
        self.tokenizer = tokenizer
        self.max_length = max_length

        ds = load_dataset(
            "parquet",
            data_files=data_path,
            split=split,
        )

        if num_samples is not None and num_samples < len(ds):
            rng = random.Random(seed)
            idx = list(range(len(ds)))
            rng.shuffle(idx)
            ds = ds.select(idx[:num_samples])

        input_ids_list = []
        attention_mask_list = []
        labels_list = []

        for item in ds:
            ctx = item["ctx"]
            endings = item["endings"]
            label_str = str(item["label"]).strip()

            try:
                label_idx = int(label_str)
            except ValueError:
                raise ValueError(f"Unexpected label format: {label_str}")

            if not (0 <= label_idx < len(endings)):
                raise ValueError(
                    f"Label index {label_idx} out of range for endings (len={len(endings)}), ind={item.get('ind','?')}"
                )

            answer_letter = ["A", "B", "C", "D"][label_idx]
            prompt = self._build_prompt(ctx, endings)
            answer_text = f" {answer_letter}"

            ids, mask, lbl = build_prompt_answer_encoding(
                self.tokenizer, prompt, answer_text, self.max_length
            )
            input_ids_list.append(ids)
            attention_mask_list.append(mask)
            labels_list.append(lbl)

        self.input_ids = torch.stack(input_ids_list, dim=0)
        self.attention_mask = torch.stack(attention_mask_list, dim=0)
        self.labels = torch.stack(labels_list, dim=0)

    def __len__(self):
        return self.input_ids.size(0)

    def __getitem__(self, idx):
        return {
            "input_ids": self.input_ids[idx],
            "attention_mask": self.attention_mask[idx],
            "labels": self.labels[idx],
        }


# ====================== Winogrande SFT ======================

class WinograndeDataset(Dataset):
    """
    Winogrande（答案 token loss）：
    - input_ids = prompt(sentence+2 options+"Answer:") + " A/B"
    """

    @staticmethod
    def _build_prompt(sentence: str, option1: str, option2: str) -> str:
        prompt = (
            "### Task:\n"
            "Choose the correct option to fill in the blank (\"_\") in the sentence.\n\n"
            f"### Sentence:\n{sentence}\n\n"
            "### Options:\n"
            f"A. {option1}\n"
            f"B. {option2}\n\n"
            "### Answer:"
        )
        return prompt

    def __init__(
        self,
        data_path: str,
        tokenizer,
        max_length: int,
        split: str = "train",
        num_samples: Optional[int] = None,
        seed: int = 42,
    ):
        self.tokenizer = tokenizer
        self.max_length = max_length

        ds = load_dataset(
            "parquet",
            data_files=data_path,
            split=split,
        )

        if num_samples is not None and num_samples < len(ds):
            rng = random.Random(seed)
            indices = list(range(len(ds)))
            rng.shuffle(indices)
            ds = ds.select(indices[:num_samples])

        input_ids_list = []
        attention_mask_list = []
        labels_list = []

        for item in ds:
            sentence = item["sentence"]
            option1 = item["option1"]
            option2 = item["option2"]
            ans_str = str(item["answer"]).strip()

            if ans_str not in ("1", "2"):
                raise ValueError(f"Unexpected answer value: {ans_str}")

            ans_idx = int(ans_str) - 1
            answer_letter = "A" if ans_idx == 0 else "B"

            prompt = self._build_prompt(sentence, option1, option2)
            answer_text = f" {answer_letter}"

            ids, mask, lbl = build_prompt_answer_encoding(
                self.tokenizer, prompt, answer_text, self.max_length
            )
            input_ids_list.append(ids)
            attention_mask_list.append(mask)
            labels_list.append(lbl)

        self.input_ids = torch.stack(input_ids_list, dim=0)
        self.attention_mask = torch.stack(attention_mask_list, dim=0)
        self.labels = torch.stack(labels_list, dim=0)

    def __len__(self):
        return self.input_ids.size(0)

    def __getitem__(self, idx):
        return {
            "input_ids": self.input_ids[idx],
            "attention_mask": self.attention_mask[idx],
            "labels": self.labels[idx],
        }


# ====================== XSum SFT ======================

class XSumDataset(Dataset):
    """
    XSum（答案 token loss 版本）：
    - input_ids = prompt("Summarize ...\n\nDialogue:\n...\n\nSummary:\n") + summary
    - labels    = [-100]*len(prompt) + summary tokens
    """

    @staticmethod
    def _build_prompt_summarization(text: str) -> str:
        prompt = (
            "### Task:\n"
            "Summarize the following dialogue in one sentence.\n\n"
            f"### Dialogue:\n{text}\n\n"
            "### Summary:\n"
        )
        return prompt

    def __init__(
        self,
        data_path: str,
        tokenizer,
        max_length: int,
        split: str = "train",
        num_samples: Optional[int] = None,
        seed: int = 42,
    ):
        self.tokenizer = tokenizer
        self.max_length = max_length

        # 解析路径：目录 or 单文件
        if os.path.isdir(data_path):
            if split == 'train':
                csv_file = os.path.join(data_path, "train.csv")
            elif split == 'validation' or split == 'test':
                csv_file = os.path.join(data_path, "validation.csv")
        else:
            csv_file = data_path

        if not os.path.exists(csv_file):
            raise ValueError(f"XSum split 文件不存在: {csv_file}")

        df = pd.read_csv(csv_file)

        if "summary" not in df.columns:
            raise ValueError(f"XSum csv 缺少 'summary' 列: {csv_file}, columns={list(df.columns)}")

        if "dialogue" in df.columns:
            text_col = "dialogue"
        elif "document" in df.columns:
            text_col = "document"
        elif "article" in df.columns:
            text_col = "article"
        else:
            raise ValueError(
                f"XSum csv 找不到文本列（期望 'dialogue' / 'document' / 'article' 之一），"
                f"实际 columns={list(df.columns)}"
            )

        if num_samples is not None and num_samples < len(df):
            rng = random.Random(seed)
            indices = list(range(len(df)))
            rng.shuffle(indices)
            indices = indices[:num_samples]
            df = df.iloc[indices].reset_index(drop=True)

        input_ids_list = []
        attention_mask_list = []
        labels_list = []

        for _, row in df.iterrows():
            src_text = str(row.get(text_col, ""))
            summary = str(row.get("summary", ""))

            prompt = self._build_prompt_summarization(src_text)
            answer_text = summary

            ids, mask, lbl = build_prompt_answer_encoding(
                self.tokenizer, prompt, answer_text, self.max_length
            )
            input_ids_list.append(ids)
            attention_mask_list.append(mask)
            labels_list.append(lbl)

        self.input_ids = torch.stack(input_ids_list, dim=0)
        self.attention_mask = torch.stack(attention_mask_list, dim=0)
        self.labels = torch.stack(labels_list, dim=0)

    def __len__(self):
        return self.input_ids.size(0)

    def __getitem__(self, idx):
        return {
            "input_ids": self.input_ids[idx],
            "attention_mask": self.attention_mask[idx],
            "labels": self.labels[idx],
        }


# ====================== OpenBookQA (OBQA) SFT ======================

class OpenBookQADataset(Dataset):
    """
    OpenBookQA（答案 token loss）：
    - input_ids = prompt(question_stem+choices+"Answer:") + " A/B/C/D"
    """

    @staticmethod
    def _build_prompt(question_stem: str, choice_texts, choice_labels):
        lines = []
        for lab, txt in zip(choice_labels, choice_texts):
            lines.append(f"{lab}. {txt}")
        options_str = "\n".join(lines)

        prompt = (
            "### Task:\n"
            "Choose the best answer to the following science question.\n\n"
            f"### Question:\n{question_stem}\n\n"
            f"### Options:\n{options_str}\n\n"
            "### Answer:"
        )
        return prompt

    def __init__(
        self,
        data_path: str,
        tokenizer,
        max_length: int,
        split: str = "train",
        num_samples: Optional[int] = None,
        seed: int = 42,
    ):
        self.tokenizer = tokenizer
        self.max_length = max_length

        ds = load_dataset(
            "parquet",
            data_files=data_path,
            split=split,
        )

        if num_samples is not None and num_samples < len(ds):
            rng = random.Random(seed)
            idx = list(range(len(ds)))
            rng.shuffle(idx)
            ds = ds.select(idx[:num_samples])

        input_ids_list = []
        attention_mask_list = []
        labels_list = []

        for item in ds:
            question = item["question_stem"]
            choice_texts = item["choices"]["text"]
            choice_labels = item["choices"]["label"]
            answer_key = str(item["answerKey"]).strip()

            if answer_key not in choice_labels:
                raise ValueError(
                    f"answerKey {answer_key} not in labels {choice_labels} for id={item.get('id','?')}"
                )

            prompt = self._build_prompt(question, choice_texts, choice_labels)
            answer_text = f" {answer_key}"

            ids, mask, lbl = build_prompt_answer_encoding(
                self.tokenizer, prompt, answer_text, self.max_length
            )
            input_ids_list.append(ids)
            attention_mask_list.append(mask)
            labels_list.append(lbl)

        self.input_ids = torch.stack(input_ids_list, dim=0)
        self.attention_mask = torch.stack(attention_mask_list, dim=0)
        self.labels = torch.stack(labels_list, dim=0)

    def __len__(self):
        return self.input_ids.size(0)

    def __getitem__(self, idx):
        return {
            "input_ids": self.input_ids[idx],
            "attention_mask": self.attention_mask[idx],
            "labels": self.labels[idx],
        }


# ====================== 统一入口：get_sft_dataset / get_loader ======================

def get_sft_dataset(
    name: str,
    tokenizer,
    max_length: int,
    seed: int = 42,
    num_samples: Optional[int] = None,
    split: str = "train",
):
    name = name.lower()

    if "mmlu" in name:
        return MMLUSFTDataset(
            data_dir="/seu_nvme/ogai/datasets/",
            tokenizer=tokenizer,
            max_length=max_length,
            seed=seed,
            num_samples=num_samples,
            split=split
        )

    if "boolq" in name:
        if split == "train":
            data_path = "/TO/MAY/PATH/MyDatasets/BoolQ/data/train-00000-of-00001.parquet"
        else:  # 用 validation 当 test
            data_path = "/TO/MAY/PATH/MyDatasets/BoolQ/data/validation-00000-of-00001.parquet"
        return BoolQDataset(
            data_path=data_path,
            tokenizer=tokenizer,
            max_length=max_length,
            split="train",   # parquet 是单文件，无真正子 split，这里用 "train"
            num_samples=num_samples,
            seed=seed,
        )

    if "piqa" in name:
        if split == "train":
            data_path = "/TO/MAY/PATH/MyDatasets/PIQA/train-00000-of-00001.parquet"
        else:
            data_path = "/TO/MAY/PATH/MyDatasets/PIQA/validation-00000-of-00001.parquet"
        return PIQADataset(
            data_path=data_path,
            tokenizer=tokenizer,
            max_length=max_length,
            split="train",
            num_samples=num_samples,
            seed=seed,
        )

    if "arc_challenge" in name or ("arc" in name and "easy" not in name):
        if split == "train":
            data_path = "/TO/MAY/PATH/MyDatasets/ARC-challenge/train-00000-of-00001.parquet"
        else:
            data_path = "/TO/MAY/PATH/MyDatasets/ARC-challenge/test-00000-of-00001.parquet"
        return ARCChallengeDataset(
            data_path=data_path,
            tokenizer=tokenizer,
            max_length=max_length,
            split="train",
            num_samples=num_samples,
            seed=seed,
        )

    if "arc_easy" in name or "arceasy" in name:
        if split == "train":
            data_path = "/TO/MAY/PATH/MyDatasets/ARC-easy/train-00000-of-00001.parquet"
        else:
            data_path = "/TO/MAY/PATH/MyDatasets/ARC-easy/test-00000-of-00001.parquet"
        return ARCEasyDataset(
            data_path=data_path,
            tokenizer=tokenizer,
            max_length=max_length,
            split="train",
            num_samples=num_samples,
            seed=seed,
        )

    if "hellaswag" in name:
        if split == "train":
            data_path = "/TO/MAY/PATH/MyDatasets/HellaS/data/train-00000-of-00001.parquet"
        else:
            data_path = "/TO/MAY/PATH/MyDatasets/HellaS/data/validation-00000-of-00001.parquet"
        return HellaSwagDataset(
            data_path=data_path,
            tokenizer=tokenizer,
            max_length=max_length,
            split="train",
            num_samples=num_samples,
            seed=seed,
        )

    if "winogrande" in name or "wino" in name:
        if split == "train":
            data_path = "/TO/MAY/PATH/MyDatasets/WinoG/winogrande_xl/train-00000-of-00001.parquet"
        else:
            data_path = "/TO/MAY/PATH/MyDatasets/WinoG/winogrande_xl/validation-00000-of-00001.parquet"
        return WinograndeDataset(
            data_path=data_path,
            tokenizer=tokenizer,
            max_length=max_length,
            split="train",
            num_samples=num_samples,
            seed=seed,
        )

    if "xsum" in name:
        return XSumDataset(
            data_path="/TO/MAY/PATH/MyDatasets/XSum/",
            tokenizer=tokenizer,
            max_length=max_length,
            split=split,  # "train"/"validation"/"test"
            num_samples=num_samples,
            seed=seed,
        )

    if "obqa" in name or "openbookqa" in name:
        if split == "train":
            data_path = "/TO/MAY/PATH/MyDatasets/OBQA/main/train-00000-of-00001.parquet"
        else:
            data_path = "/TO/MAY/PATH/MyDatasets/OBQA/main/test-00000-of-00001.parquet"
        return OpenBookQADataset(
            data_path=data_path,
            tokenizer=tokenizer,
            max_length=max_length,
            split="train",
            num_samples=num_samples,
            seed=seed,
        )

    raise ValueError(f"Unknown sft_dataset name: {name}")


def get_loader(*args, **kwargs):
    return get_sft_dataset(*args, **kwargs)
