from datasets import load_from_disk
from torch.utils.data import DataLoader
import torch
import torch
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
from datasets import load_from_disk

def preprocess_wikitext(example, tokenizer, max_length=512):
    # 只做基本文本截断处理，不加问答模板
    text = example["text"]
    inputs = tokenizer(
        text,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt"
    )
    return {
        "input_ids": inputs["input_ids"].squeeze(0),
        "attention_mask": inputs["attention_mask"].squeeze(0),
    }

class LanguageModelingDataset(torch.utils.data.Dataset):
    def __init__(self, input_ids_list, attention_mask_list):
        self.input_ids_list = input_ids_list
        self.attention_mask_list = attention_mask_list

    def __len__(self):
        return len(self.input_ids_list)

    def __getitem__(self, idx):
        return {
            "input_ids": self.input_ids_list[idx],
            "labels": self.input_ids_list[idx],
            "attention_mask": self.attention_mask_list[idx],
        }

def get_wikitext_dataloader(batch_size=2, tokenizer=None, max_length=512, split="test"):
    print("Loading WikiText-2 dataset...")
    dataset = load_from_disk("data/wikitext-2-raw-v1")
    input_ids_list = []
    attention_mask_list = []
    for example in dataset:
        # 可以过滤掉太短的文本（如空行）
        if len(example["text"].strip()) < 20:
            continue
        processed = preprocess_wikitext(example, tokenizer, max_length=max_length)
        input_ids_list.append(processed["input_ids"])
        attention_mask_list.append(processed["attention_mask"])
    dataset = LanguageModelingDataset(input_ids_list, attention_mask_list)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)


def get_wikitext_dataloader_ddp(
    batch_size=2,
    tokenizer=None,
    max_length=512,
    split="test",
    world_size=1,
    rank=0,
    sampler_cls=DistributedSampler  # 可以传 None 使用普通采样器
):
    print("Loading WikiText-2 dataset...")
    dataset_raw = load_from_disk("data/wikitext-2-raw-v1")
    input_ids_list = []
    attention_mask_list = []
    for example in dataset_raw:
        if len(example["text"].strip()) < 20:
            continue
        processed = preprocess_wikitext(example, tokenizer, max_length=max_length)
        input_ids_list.append(processed["input_ids"])
        attention_mask_list.append(processed["attention_mask"])

    dataset = LanguageModelingDataset(input_ids_list, attention_mask_list)

    if sampler_cls is not None:
        sampler = sampler_cls(dataset, num_replicas=world_size, rank=rank, shuffle=True)
        shuffle = False
    else:
        sampler = RandomSampler(dataset)
        shuffle = True

    return [DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=sampler,
        shuffle=shuffle,
        drop_last=False,
        pin_memory=True,
    )]

# 测试代码
if __name__ == "__main__":
    from transformers import GPT2Tokenizer
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token
    dataloader = get_wikitext_dataloader(batch_size=2, tokenizer=tokenizer, max_length=512)
    for batch in dataloader:
        print(batch)
        break  # 只看一个batch
