# data_utils_all_label.py: 将sft中的Q+A一起作为label进行标准的LM训练
import os
import random
from typing import Optional, Dict, Any

import torch
from torch.utils.data import Dataset
from datasets import load_dataset, Dataset as HFDataset

import pandas as pd



# ====================== 通用 collate ======================

def collate_sft(batch):
    """
    所有 SFT Dataset 都返回 dict: {input_ids, attention_mask, labels}
    这里统一堆叠。
    """
    input_ids = torch.stack([b["input_ids"] for b in batch], dim=0)
    attention_mask = torch.stack([b["attention_mask"] for b in batch], dim=0)
    labels = torch.stack([b["labels"] for b in batch], dim=0)
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels,
    }


# ====================== MMLU SFT ======================

def format_mmlu_example(dataset: HFDataset, include_answer: bool = True):
    """
    你给的 format_mmlu_example，保持完全一致。
    """
    choices = ["A", "B", "C", "D"]
    data = []
    for ques, cho, ans in zip(dataset["question"], dataset["choices"], dataset["answer"]):
        prompt = ques
        for i in range(len(cho)):
            prompt += "\n{}. {}".format(choices[i], cho[i])
        prompt += "\nAnswer:"
        if include_answer:
            # 这里直接把正确选项的字母和文本拼进去
            prompt += " {}. {}\n\n".format(choices[ans], choices[ans])
        data.append(prompt)
    return data

def _load_mmlu_raw(
    data_dir: str,
    seed: int,
    num_samples: Optional[int] = None,
    split: str = "train"
) -> HFDataset:
    """
    逻辑参考你给的 get_mmlu_trainenc：
    - subclass 列表完全一致
    - 支持 num_samples（总样本数），会按 subclass 均匀分配
    - data_dir 下有 mmlu 子目录
    """
    from tqdm import tqdm as tqdm_local

    subclass = [
        'abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge',
        'college_biology', 'college_chemistry', 'college_computer_science', 'college_mathematics',
        'college_medicine', 'college_physics', 'computer_security', 'conceptual_physics',
        'econometrics', 'electrical_engineering', 'elementary_mathematics', 'formal_logic',
        'global_facts', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science',
        'high_school_european_history', 'high_school_geography', 'high_school_government_and_politics',
        'high_school_macroeconomics', 'high_school_mathematics', 'high_school_microeconomics',
        'high_school_physics', 'high_school_psychology', 'high_school_statistics',
        'high_school_us_history', 'high_school_world_history', 'human_aging', 'human_sexuality',
        'international_law', 'jurisprudence', 'logical_fallacies', 'machine_learning', 'management',
        'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes', 'moral_scenarios',
        'nutrition', 'philosophy', 'prehistory', 'professional_accounting', 'professional_law',
        'professional_medicine', 'professional_psychology', 'public_relations', 'security_studies',
        'sociology', 'us_foreign_policy', 'virology', 'world_religions'
    ]

    random.seed(seed)
    keys = ["question", "subject", "choices", "answer"]
    traindata = {key: [] for key in keys}
    mmlu_path = os.path.join(data_dir, "mmlu")

    if split == "train":
        print("Loading MMLU Train Set for Training!")
        if num_samples is None:
            # 不限样本数：直接把所有 train / validation 拼一起
            for class_name in tqdm_local(subclass, desc="Loading all MMLU subclasses"):
                try:
                    data = load_dataset(mmlu_path, class_name, split="train")
                except Exception:
                    data = load_dataset(mmlu_path, class_name, split="validation")
                for key in keys:
                    traindata[key].extend(data[key])
        else:
            # 和你原来的 get_mmlu_trainenc 一样：按 subclass 平均分配 nsamples
            subnum = num_samples // len(subclass)
            extra = num_samples % len(subclass)
            num_list = [subnum + 1] * extra + [subnum] * (len(subclass) - extra)
            random.shuffle(num_list)

            for num, class_name in tqdm_local(
                zip(num_list, subclass),
                total=len(subclass),
                desc="Loading the subclass in MMLU",
            ):
                if num == 0:
                    continue
                try:
                    data = load_dataset(mmlu_path, class_name, split="train").shuffle(seed=seed)[:num]
                except Exception:
                    data = load_dataset(mmlu_path, class_name, split="validation").shuffle(seed=seed)[:num]
                for key in keys:
                    traindata[key].extend(data[key])
    
    elif split == "test":
        print("Loading MMLU Test Set for Evaluation!")
        # 不限样本数：直接把所有 train / validation 拼一起
        for class_name in tqdm_local(subclass, desc="Loading all MMLU subclasses"):
            data = load_dataset(mmlu_path, class_name, split="test")
            for key in keys:
                traindata[key].extend(data[key])

    dataset = HFDataset.from_dict(traindata)
    return dataset


class MMLUSFTDataset(Dataset):
    """
    用于 SFT 的 MMLU 数据：
    - 把每个 (question, choices, answer) 转为一个 prompt + 正确答案文本
    - 然后做标准 LM 训练（labels = input_ids）
    """
    def __init__(
        self,
        data_dir: str,
        tokenizer,
        max_length: int,
        seed: int = 42,
        num_samples: Optional[int] = None,
        include_answer: bool = True,
        split: str = "train"
    ):
        self.tokenizer = tokenizer
        self.max_length = max_length

        raw_dataset = _load_mmlu_raw(
            data_dir=data_dir,
            seed=seed,
            num_samples=num_samples,
            split=split
        )
        texts = format_mmlu_example(raw_dataset, include_answer=include_answer)

        encodings = self.tokenizer(
            texts,
            truncation=True,
            max_length=self.max_length,
            padding="max_length",
            return_tensors="pt",
        )
        self.input_ids = encodings["input_ids"]
        self.attention_mask = encodings["attention_mask"]

    def __len__(self):
        return self.input_ids.size(0)

    def __getitem__(self, idx):
        input_ids = self.input_ids[idx]
        attention_mask = self.attention_mask[idx]
        labels = input_ids.clone()

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        }

# ====================== BoolQ SFT ======================
class BoolQDataset(Dataset):
    """
    用于 SFT 的 BoolQ 数据集：
    - 每条样本构造成一个 Yes/No QA prompt
    - label 直接是完整输出（prompt + 正确选项字母）
    """

    @staticmethod
    def _build_prompt(passage: str, question: str) -> str:
        """
        和你之前 BoolQAdapter._build_prompt 保持一致。
        """
        prompt = (
            "### Task:\nRead the following passage and only answer the Yes/No question based on it.\n\n"
            f"### Passage:\n{passage}\n\n"
            f"### Question:\n{question}\n\n"
            "### Options:\nA. Yes\nB. No\n\n"
            "### Answer:"
        )
        return prompt

    def __init__(
        self,
        data_path: str,
        tokenizer,
        max_length: int,
        split: str = "train",
        num_samples: Optional[int] = None,
        seed: int = 42,
    ):
        """
        data_path: BoolQ parquet 文件路径（比如你之前的 train-00000-of-00001.parquet）
        split:    load_dataset 的 split，默认 "train"
        """
        self.tokenizer = tokenizer
        self.max_length = max_length

        # 加载 parquet 数据
        ds = load_dataset(
            "parquet",
            data_files=data_path,
            split=split,
        )

        # 可选：下采样一定数量
        if num_samples is not None and num_samples < len(ds):
            rng = random.Random(seed)
            indices = list(range(len(ds)))
            rng.shuffle(indices)
            indices = indices[:num_samples]
            ds = ds.select(indices)

        # 构建文本（input + 正确答案）
        texts = []
        for item in ds:
            passage = item["passage"]
            question = item["question"]
            answer_bool = item["answer"]  # True / False

            answer_letter = "A" if bool(answer_bool) else "B"

            prompt = self._build_prompt(passage, question)
            # 和你 MMLU 风格统一：答案直接接在 Answer: 后面
            full_text = prompt + f" {answer_letter}"
            texts.append(full_text)

        encodings = self.tokenizer(
            texts,
            truncation=True,
            max_length=self.max_length,
            padding="max_length",
            return_tensors="pt",
        )
        self.input_ids = encodings["input_ids"]
        self.attention_mask = encodings["attention_mask"]

    def __len__(self):
        return self.input_ids.size(0)

    def __getitem__(self, idx):
        input_ids = self.input_ids[idx]
        attention_mask = self.attention_mask[idx]
        labels = input_ids.clone()

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        }


# ====================== PIQA SFT ======================
class PIQADataset(Dataset):
    """
    用于 SFT 的 PIQA 数据集：
    - 每条样本是一个物理常识/常识推理单选题（2 个备选项）
    - 这里统一构造成一个选择题 prompt，答案输出为 A/B
    """

    @staticmethod
    def _build_prompt(question: str, choices):
        """
        构造 PIQA 的 prompt。
        你也可以按自己口味改 task 描述，只要保持格式稳定即可。
        """
        # choices 是长度为 2 的 list: [choice0, choice1]
        assert len(choices) == 2, f"PIQA choices 应该有 2 个，实际得到 {len(choices)} 个"

        prompt = (
            "### Task:\n"
            "Choose the most physically plausible solution to achieve the goal.\n\n"
            f"### Goal:\n{question}\n\n"
            "### Options:\n"
            f"A. {choices[0]}\n"
            f"B. {choices[1]}\n\n"
            "### Answer:"
        )
        return prompt

    def __init__(
        self,
        data_path: str,
        tokenizer,
        max_length: int,
        split: str = "train",
        num_samples: Optional[int] = None,
        seed: int = 42,
    ):
        """
        data_path: PIQA parquet 文件路径（例如 train-00000-of-00001.parquet）
        split:    load_dataset 的 split，默认 "train"
        """
        self.tokenizer = tokenizer
        self.max_length = max_length

        # 加载 parquet 数据
        ds = load_dataset(
            "parquet",
            data_files=data_path,
            split=split,
        )

        # 可选：下采样一定数量
        if num_samples is not None and num_samples < len(ds):
            rng = random.Random(seed)
            indices = list(range(len(ds)))
            rng.shuffle(indices)
            indices = indices[:num_samples]
            ds = ds.select(indices)

        # 构建文本（input + 正确答案字母）
        texts = []
        for item in ds:
            question = item["question"]
            choices = item["choices"]   # ['xxx', 'yyy']

 
            if "answer_index" in item and item["answer_index"] is not None:
                ans_idx = int(item["answer_index"])
                assert ans_idx in (0, 1)
                answer_letter = "A" if ans_idx == 0 else "B"
            else:
                # 兜底：从 'answer' 字段解析 'A' / 'B'
                ans = str(item["answer"]).strip().upper()
                if ans == "A":
                    answer_letter = "A"
                elif ans == "B":
                    answer_letter = "B"
                else:
                    raise ValueError(f"Unexpected answer format: {ans}")

            prompt = self._build_prompt(question, choices)
            full_text = prompt + f" {answer_letter}"
            texts.append(full_text)

        encodings = self.tokenizer(
            texts,
            truncation=True,
            max_length=self.max_length,
            padding="max_length",
            return_tensors="pt",
        )
        self.input_ids = encodings["input_ids"]
        self.attention_mask = encodings["attention_mask"]

    def __len__(self):
        return self.input_ids.size(0)

    def __getitem__(self, idx):
        input_ids = self.input_ids[idx]
        attention_mask = self.attention_mask[idx]
        labels = input_ids.clone()

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        }

# ====================== ARC-Challenge SFT ======================
class ARCChallengeDataset(Dataset):
    """
    用于 SFT 的 ARC-Challenge 数据集：
    - 每条样本是一个多选题（通常 4 个选项 A/B/C/D）
    - 我们把 question + options 拼成 prompt，
      然后在 Answer: 后面接上正确选项字母（answerKey）
    """

    @staticmethod
    def _build_prompt(question: str, choice_texts, choice_labels):
        """
        构造 ARC-Challenge 的 prompt。
        choice_texts: ['opt1', 'opt2', ...]
        choice_labels: ['A', 'B', 'C', 'D'] —— 和 texts 一一对应
        """
        lines = []
        for lab, txt in zip(choice_labels, choice_texts):
            lines.append(f"{lab}. {txt}")

        options_str = "\n".join(lines)

        prompt = (
            "### Task:\n"
            "Choose the best answer to the following question.\n\n"
            f"### Question:\n{question}\n\n"
            f"### Options:\n{options_str}\n\n"
            "### Answer:"
        )
        return prompt

    def __init__(
        self,
        data_path: str,
        tokenizer,
        max_length: int,
        split: str = "train",
        num_samples: Optional[int] = None,
        seed: int = 42,
    ):
        """
        data_path: ARC-Challenge parquet 文件路径
                  （例如 train-00000-of-00001.parquet）
        split:    load_dataset 的 split，默认 "train"
        """
        self.tokenizer = tokenizer
        self.max_length = max_length

        # 加载 parquet 数据
        ds = load_dataset(
            "parquet",
            data_files=data_path,
            split=split,
        )

        # 可选：下采样
        if num_samples is not None and num_samples < len(ds):
            rng = random.Random(seed)
            indices = list(range(len(ds)))
            rng.shuffle(indices)
            indices = indices[:num_samples]
            ds = ds.select(indices)

        texts = []
        for item in ds:
            question = item["question"]
            choice_texts = item["choices"]["text"]
            choice_labels = item["choices"]["label"]   # ['A','B','C','D', ...]
            answer_key = str(item["answerKey"]).strip()  # 'A' / 'B' / 'C' / 'D'

            # 保险起见：确保 answerKey 在 labels 里
            if answer_key not in choice_labels:
                raise ValueError(
                    f"answerKey {answer_key} not in labels {choice_labels} for id={item.get('id','?')}"
                )

            prompt = self._build_prompt(question, choice_texts, choice_labels)
            full_text = prompt + f" {answer_key}"
            texts.append(full_text)

        encodings = self.tokenizer(
            texts,
            truncation=True,
            max_length=self.max_length,
            padding="max_length",
            return_tensors="pt",
        )
        self.input_ids = encodings["input_ids"]
        self.attention_mask = encodings["attention_mask"]

    def __len__(self):
        return self.input_ids.size(0)

    def __getitem__(self, idx):
        input_ids = self.input_ids[idx]
        attention_mask = self.attention_mask[idx]
        labels = input_ids.clone()

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        }

# ====================== ARC-Easy SFT ======================
class ARCEasyDataset(Dataset):
    """
    用于 SFT 的 ARC-Easy 数据集：
    - 样本格式与 ARC-Challenge 相同（question + choices + answerKey）
    - 只需要构造 prompt → "### Answer: X"
    """

    @staticmethod
    def _build_prompt(question: str, choice_texts, choice_labels):
        """
        构造 ARC-Easy 的 prompt。
        格式与 ARC-Challenge 相同，保持一致性。
        """
        options = []
        for lab, txt in zip(choice_labels, choice_texts):
            options.append(f"{lab}. {txt}")
        options_str = "\n".join(options)

        prompt = (
            "### Task:\n"
            "Choose the best answer to the following question.\n\n"
            f"### Question:\n{question}\n\n"
            f"### Options:\n{options_str}\n\n"
            "### Answer:"
        )
        return prompt

    def __init__(
        self,
        data_path: str,
        tokenizer,
        max_length: int,
        split: str = "train",
        num_samples: Optional[int] = None,
        seed: int = 42,
    ):
        self.tokenizer = tokenizer
        self.max_length = max_length

        # 加载 parquet
        ds = load_dataset(
            "parquet",
            data_files=data_path,
            split=split,
        )

        # 随机下采样
        if num_samples is not None and num_samples < len(ds):
            rng = random.Random(seed)
            indices = list(range(len(ds)))
            rng.shuffle(indices)
            ds = ds.select(indices[:num_samples])

        texts = []
        for item in ds:
            question = item["question"]
            choice_texts = item["choices"]["text"]
            choice_labels = item["choices"]["label"]
            answer_key = str(item["answerKey"]).strip()

            # 安全检查
            if answer_key not in choice_labels:
                raise ValueError(f"answerKey {answer_key} not in {choice_labels}, id={item['id']}")

            # 构建 prompt
            prompt = self._build_prompt(question, choice_texts, choice_labels)
            full_text = prompt + f" {answer_key}"
            texts.append(full_text)

        # Tokenize
        encodings = self.tokenizer(
            texts,
            truncation=True,
            max_length=self.max_length,
            padding="max_length",
            return_tensors="pt",
        )

        self.input_ids = encodings["input_ids"]
        self.attention_mask = encodings["attention_mask"]

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        input_ids = self.input_ids[idx]
        attention_mask = self.attention_mask[idx]

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": input_ids.clone(),
        }

# ====================== HellaSwag SFT ======================
class HellaSwagDataset(Dataset):
    """
    用于 SFT 的 HellaSwag 数据集：
    - 使用 `ctx` 作为前文，`endings` 作为 4 个续写选项
    - `label` 是正确选项的索引（'0' ~ '3'），我们映射为 A/B/C/D
    """

    @staticmethod
    def _build_prompt(ctx: str, endings):
        """
        构造 HellaSwag 的 prompt：
        给定一个场景描述，选择最合理的续写。
        endings: 长度为 4 的候选续写列表
        """
        labels = ["A", "B", "C", "D"]
        option_lines = []
        for lab, text in zip(labels, endings):
            option_lines.append(f"{lab}. {text}")
        options_str = "\n".join(option_lines)

        prompt = (
            "### Task:\n"
            "Choose the most plausible continuation of the following context.\n\n"
            f"### Context:\n{ctx}\n\n"
            f"### Options:\n{options_str}\n\n"
            "### Answer:"
        )
        return prompt

    def __init__(
        self,
        data_path: str,
        tokenizer,
        max_length: int,
        split: str = "train",
        num_samples: Optional[int] = None,
        seed: int = 42,
    ):
        """
        data_path: HellaSwag 的 parquet 文件路径（train-00000-of-00001.parquet）
        split:     默认 "train"
        """
        self.tokenizer = tokenizer
        self.max_length = max_length

        # 加载 parquet
        ds = load_dataset(
            "parquet",
            data_files=data_path,
            split=split,
        )

        # 可选下采样
        if num_samples is not None and num_samples < len(ds):
            rng = random.Random(seed)
            idx = list(range(len(ds)))
            rng.shuffle(idx)
            ds = ds.select(idx[:num_samples])

        texts = []
        for item in ds:
            ctx = item["ctx"]            # 已经包含 ctx_a + ctx_b
            endings = item["endings"]    # 长度 4 的候选续写
            label_str = str(item["label"]).strip()  # '0' ~ '3'

            try:
                label_idx = int(label_str)
            except ValueError:
                raise ValueError(f"Unexpected label format: {label_str}")

            if not (0 <= label_idx < len(endings)):
                raise ValueError(
                    f"Label index {label_idx} out of range for endings (len={len(endings)}), ind={item.get('ind','?')}"
                )

            answer_letter = ["A", "B", "C", "D"][label_idx]

            prompt = self._build_prompt(ctx, endings)
            full_text = prompt + f" {answer_letter}"
            texts.append(full_text)

        # 统一编码
        encodings = self.tokenizer(
            texts,
            truncation=True,
            max_length=self.max_length,
            padding="max_length",
            return_tensors="pt",
        )
        self.input_ids = encodings["input_ids"]
        self.attention_mask = encodings["attention_mask"]

    def __len__(self):
        return self.input_ids.size(0)

    def __getitem__(self, idx):
        input_ids = self.input_ids[idx]
        attention_mask = self.attention_mask[idx]
        labels = input_ids.clone()

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        }

# ====================== Winogrande SFT ======================
class WinograndeDataset(Dataset):
    """
    用于 SFT 的 Winogrande 数据集：
    - sentence 中有一个 "_" 代表空缺
    - option1/option2 是两个候选填空
    - answer 为 '1' 或 '2'，表示正确选项
    """

    @staticmethod
    def _build_prompt(sentence: str, option1: str, option2: str) -> str:
        """
        构造 Winogrande prompt：
        让模型在两个候选中选一个填入句子空缺。
        """
        prompt = (
            "### Task:\n"
            "Choose the correct option to fill in the blank (\"_\") in the sentence.\n\n"
            f"### Sentence:\n{sentence}\n\n"
            "### Options:\n"
            f"A. {option1}\n"
            f"B. {option2}\n\n"
            "### Answer:"
        )
        return prompt

    def __init__(
        self,
        data_path: str,
        tokenizer,
        max_length: int,
        split: str = "train",
        num_samples: Optional[int] = None,
        seed: int = 42,
    ):
        """
        data_path: Winogrande parquet 文件路径（train-00000-of-00001.parquet）
        split:     默认 'train'
        """
        self.tokenizer = tokenizer
        self.max_length = max_length

        # 加载 parquet 数据
        ds = load_dataset(
            "parquet",
            data_files=data_path,
            split=split,
        )

        # 可选下采样
        if num_samples is not None and num_samples < len(ds):
            rng = random.Random(seed)
            indices = list(range(len(ds)))
            rng.shuffle(indices)
            ds = ds.select(indices[:num_samples])

        texts = []
        for item in ds:
            sentence = item["sentence"]
            option1 = item["option1"]
            option2 = item["option2"]
            ans_str = str(item["answer"]).strip()  # '1' or '2'

            if ans_str not in ("1", "2"):
                raise ValueError(f"Unexpected answer value: {ans_str}")

            ans_idx = int(ans_str) - 1  # 0 or 1
            answer_letter = "A" if ans_idx == 0 else "B"

            prompt = self._build_prompt(sentence, option1, option2)
            full_text = prompt + f" {answer_letter}"
            texts.append(full_text)

        encodings = self.tokenizer(
            texts,
            truncation=True,
            max_length=self.max_length,
            padding="max_length",
            return_tensors="pt",
        )
        self.input_ids = encodings["input_ids"]
        self.attention_mask = encodings["attention_mask"]

    def __len__(self):
        return self.input_ids.size(0)

    def __getitem__(self, idx):
        input_ids = self.input_ids[idx]
        attention_mask = self.attention_mask[idx]
        labels = input_ids.clone()

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        }


# ====================== XSum SFT ======================
class XSumDataset(Dataset):
    """
    用于 SFT 的 XSum 数据集：
    - 输入：一段对话 / 文本（默认列名 'dialogue'，也兼容 'document'/'article'）
    - 输出：其摘要（列名 'summary'）
    - 训练目标：对 prompt + summary 做标准自回归 LM 训练（labels = input_ids）
    """

    @staticmethod
    def _build_prompt_summarization(text: str) -> str:
        """
        构建 summarization 任务的 prompt。
        保留你原来的风格：Summarize the following dialogue in one sentence.
        如果你想改成 article/paragraph，只需要改这里的文案。
        """
        prompt = (
            "### Task:\n"
            "Summarize the following dialogue in one sentence.\n\n"
            f"### Dialogue:\n{text}\n\n"
            "### Summary:\n"
        )
        return prompt

    def __init__(
        self,
        data_path: str,
        tokenizer,
        max_length: int,
        split: str = "train",
        num_samples: Optional[int] = None,
        seed: int = 42,
    ):
        """
        data_path:
            - 如果是目录：会在目录下找 {split}.csv，例如 train.csv / validation.csv / test.csv
            - 如果是具体 csv 文件路径：直接读该文件，忽略 split 参数
        csv 列约定：
            - 文本列：优先使用 'dialogue'，否则尝试 'document' / 'article'
            - 摘要列：必须有 'summary'
        """
        self.tokenizer = tokenizer
        self.max_length = max_length

        # 解析路径：目录 or 单文件
        if os.path.isdir(data_path):
            csv_file = os.path.join(data_path, f"{split}.csv")
        else:
            # 直接传入的是具体 csv 文件
            csv_file = data_path

        if not os.path.exists(csv_file):
            raise ValueError(f"XSum split 文件不存在: {csv_file}")

        df = pd.read_csv(csv_file)

        if "summary" not in df.columns:
            raise ValueError(f"XSum csv 缺少 'summary' 列: {csv_file}, columns={list(df.columns)}")

        # 文本列名自动探测：兼容你原来用的 'dialogue'，也顺带支持常见的 'document'/'article'
        if "dialogue" in df.columns:
            text_col = "dialogue"
        elif "document" in df.columns:
            text_col = "document"
        elif "article" in df.columns:
            text_col = "article"
        else:
            raise ValueError(
                f"XSum csv 找不到文本列（期望 'dialogue' / 'document' / 'article' 之一），"
                f"实际 columns={list(df.columns)}"
            )

        # 可选：下采样
        if num_samples is not None and num_samples < len(df):
            rng = random.Random(seed)
            indices = list(range(len(df)))
            rng.shuffle(indices)
            indices = indices[:num_samples]
            df = df.iloc[indices].reset_index(drop=True)

        texts = []
        for _, row in df.iterrows():
            src_text = str(row.get(text_col, ""))
            summary = str(row.get("summary", ""))

            prompt = self._build_prompt_summarization(src_text)
            # 和其它 SFT 数据集保持一致：把 target 直接接在 prompt 后面
            full_text = prompt + summary
            texts.append(full_text)

        encodings = self.tokenizer(
            texts,
            truncation=True,
            max_length=self.max_length,
            padding="max_length",
            return_tensors="pt",
        )
        self.input_ids = encodings["input_ids"]
        self.attention_mask = encodings["attention_mask"]

    def __len__(self):
        return self.input_ids.size(0)

    def __getitem__(self, idx):
        input_ids = self.input_ids[idx]
        attention_mask = self.attention_mask[idx]
        labels = input_ids.clone()  # 标准自回归 LM 训练

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        }


# ====================== OpenBookQA (OBQA) SFT ======================
class OpenBookQADataset(Dataset):
    """
    用于 SFT 的 OpenBookQA (OBQA) 数据集：
    - question_stem: 题干
    - choices: { "text": [...], "label": ["A","B","C","D"] }
    - answerKey: 正确选项字母，比如 "D"
    """

    @staticmethod
    def _build_prompt(question_stem: str, choice_texts, choice_labels):
        """
        构造 OBQA 的 prompt，风格与 ARC 保持一致。
        """
        lines = []
        for lab, txt in zip(choice_labels, choice_texts):
            lines.append(f"{lab}. {txt}")
        options_str = "\n".join(lines)

        prompt = (
            "### Task:\n"
            "Choose the best answer to the following science question.\n\n"
            f"### Question:\n{question_stem}\n\n"
            f"### Options:\n{options_str}\n\n"
            "### Answer:"
        )
        return prompt

    def __init__(
        self,
        data_path: str,
        tokenizer,
        max_length: int,
        split: str = "train",
        num_samples: Optional[int] = None,
        seed: int = 42,
    ):
        """
        data_path: OBQA parquet 文件路径，例如:
            /.../OpenBookQA/train-00000-of-00001.parquet
        split:    load_dataset 的 split，默认 "train"
        """
        self.tokenizer = tokenizer
        self.max_length = max_length

        ds = load_dataset(
            "parquet",
            data_files=data_path,
            split=split,
        )

        # 可选下采样
        if num_samples is not None and num_samples < len(ds):
            rng = random.Random(seed)
            idx = list(range(len(ds)))
            rng.shuffle(idx)
            ds = ds.select(idx[:num_samples])

        texts = []
        for item in ds:
            question = item["question_stem"]
            choice_texts = item["choices"]["text"]
            choice_labels = item["choices"]["label"]   # ['A','B','C','D']
            answer_key = str(item["answerKey"]).strip()  # 'A' / 'B' / 'C' / 'D'

            if answer_key not in choice_labels:
                raise ValueError(
                    f"answerKey {answer_key} not in labels {choice_labels} for id={item.get('id','?')}"
                )

            prompt = self._build_prompt(question, choice_texts, choice_labels)
            full_text = prompt + f" {answer_key}"
            texts.append(full_text)

        encodings = self.tokenizer(
            texts,
            truncation=True,
            max_length=self.max_length,
            padding="max_length",
            return_tensors="pt",
        )
        self.input_ids = encodings["input_ids"]
        self.attention_mask = encodings["attention_mask"]

    def __len__(self):
        return self.input_ids.size(0)

    def __getitem__(self, idx):
        input_ids = self.input_ids[idx]
        attention_mask = self.attention_mask[idx]
        labels = input_ids.clone()

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        }

# ====================== 统一入口：get_sft_dataset / get_loader ======================

def get_sft_dataset(
    name: str,
    tokenizer,
    max_length: int,
    seed: int = 42,
    num_samples: Optional[int] = None,
    split: str = "train",
):
    name = name.lower()

    if "mmlu" in name:
        return MMLUSFTDataset(
            data_dir="/seu_nvme/ogai/datasets/",
            tokenizer=tokenizer,
            max_length=max_length,
            seed=seed,
            num_samples=num_samples,
            split=split
        )

    if "boolq" in name:
        if split == 'train':
            data_dir="/TO/MAY/PATH/MyDatasets/BoolQ/data/train-00000-of-00001.parquet",
        elif split == 'test':
            data_dir="/TO/MAY/PATH/MyDatasets/BoolQ/data/validation-00000-of-00001.parquet",
        return BoolQDataset(
            data_path=data_dir,
            tokenizer=tokenizer,
            max_length=max_length,
            split="train",
            num_samples=num_samples,
            seed=seed,
        )

    if "piqa" in name:
        print("Loading PIQA!")
        if split == 'train':
            data_dir="/TO/MAY/PATH/MyDatasets/PIQA/train-00000-of-00001.parquet",
        elif split == 'test':
            data_dir="/TO/MAY/PATH/MyDatasets/PIQA/validation-00000-of-00001.parquet",
        return PIQADataset(
            data_path=data_dir,
            tokenizer=tokenizer,
            max_length=max_length,
            split="train",
            num_samples=num_samples,
            seed=seed,
        )

    if "arcchallenge" in name or "arc_challenge" in name:
        print("Loading arc-challenge!")
        if split == 'train':
            data_dir="/TO/MAY/PATH/MyDatasets/ARC-challenge/train-00000-of-00001.parquet",
        elif split == 'test':
            data_dir="/TO/MAY/PATH/MyDatasets/ARC-challenge/test-00000-of-00001.parquet",
        return ARCChallengeDataset(
            data_path=data_dir,
            tokenizer=tokenizer,
            max_length=max_length,
            split="train",
            num_samples=num_samples,
            seed=seed,
        )

    if "arc_easy" in name or "arceasy" in name:
        print("Loading arc-easy!")
        if split == 'train':
            data_dir="/TO/MAY/PATH/MyDatasets/ARC-easy/train-00000-of-00001.parquet",
        elif split == 'test':
            data_dir="/TO/MAY/PATH/MyDatasets/ARC-easy/test-00000-of-00001.parquet",
        return ARCEasyDataset(
            data_path=data_dir,
            tokenizer=tokenizer,
            max_length=max_length,
            split="train",
            num_samples=num_samples,
            seed=seed,
        )
    
    if "hellaswag" in name:
        print("Load HellaS!")
        if split == 'train':
            data_dir="/TO/MAY/PATH/MyDatasets/HellaS/data/train-00000-of-00001.parquet",
        elif split == 'test':
            data_dir="/TO/MAY/PATH/MyDatasets/HellaS/data/validation-00000-of-00001.parquet",
        return HellaSwagDataset(
            data_path=data_dir,
            tokenizer=tokenizer,
            max_length=max_length,
            split="train",
            num_samples=num_samples,
            seed=seed,
        )

    if "winogrande" in name or "wino" in name:
        if split == 'train':
            data_dir="/TO/MAY/PATH/MyDatasets/WinoG/winogrande_xl/train-00000-of-00001.parquet",
        elif split == 'test':
            data_dir="/TO/MAY/PATH/MyDatasets/WinoG/winogrande_xl/validation-00000-of-00001.parquet",
        return WinograndeDataset(
            data_path=data_dir,
            tokenizer=tokenizer,
            max_length=max_length,
            split="train",
            num_samples=num_samples,
            seed=seed,
        )

    if "xsum" in name:
        return XSumDataset(
            data_path="/TO/MAY/PATH/MyDatasets/XSum/",
            tokenizer=tokenizer,
            max_length=max_length,
            split=split,          # 或 "validation"/"test"
            num_samples=num_samples,
            seed=seed,
        )

    if "obqa" in name or "openbookqa" in name:
        if split == 'train':
            data_dir="/TO/MAY/PATH/MyDatasets/OBQA/main/train-00000-of-00001.parquet",
        elif split == 'test':
            data_dir="/TO/MAY/PATH/MyDatasets/OBQA/main/test-00000-of-00001.parquet",
        return OpenBookQADataset(
            data_path=data_dir,
            tokenizer=tokenizer,
            max_length=max_length,
            split="train",
            num_samples=num_samples,
            seed=seed,
        )

    raise ValueError(f"Unknown sft_dataset name: {name}")

def get_loader(*args, **kwargs):
    return get_sft_dataset(*args, **kwargs)
