from datasets import load_from_disk
import random
import torch


def get_safe_dataset(nsamples, seed, tokenizer):
    advbench_dataset = load_from_disk("./AdvBench")['train'].shuffle(seed=seed).select(range(nsamples))
    trainloader = []
    random.seed(seed)

    for i in range(len(advbench_dataset)):
        question = advbench_dataset[i]["prompt"]
        answer = advbench_dataset[i]["response"]

        instruction_str = (
            f"<|start_header_id|>user<|end_header_id|>\n\n"
            f"{question}"
            f"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
        )
        response_str = f"{answer}<|eot_id|>"

        # ===== Tokenize =====
        tokenized_instruction = tokenizer(instruction_str, add_special_tokens=False, return_tensors="pt")
        tokenized_response = tokenizer(response_str, add_special_tokens=False, return_tensors="pt")

        input_ids = torch.cat([tokenized_instruction["input_ids"], tokenized_response["input_ids"], torch.tensor([[tokenizer.eos_token_id]])], dim=1)

        labels = input_ids.clone()
        labels[:, :tokenized_instruction["input_ids"].shape[1]] = -100

        trainloader.append((input_ids, labels))

    print(f"格式化完成，共加载 {len(trainloader)} 条样本。")
    return trainloader