from torch.utils.data import Dataset
from datasets import load_dataset
import torch
from transformers import AutoTokenizer



class PiqaDataset(Dataset):
    def __init__(self, split, tokenizer, mode="train", max_length=512):
        self.tokenizer = tokenizer
        self.mode = mode
        self.max_length = max_length
        self.ds = load_dataset("nthngdy/piqa", split=split, trust_remote_code=True)
        self.ds = self.ds.map(self.preprocess)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

    def _build_prompt(self, goal, sol1, sol2):
        bos = getattr(self.tokenizer, "bos_token", "")
        if self.mode == "train":
            return (
                f"{bos}Question: {goal.strip()}\n"
                f"A. {sol1.strip()}\n"
                f"B. {sol2.strip()}\n"
                f"Which is better? (A or B)\n"
                f"Answer: "
            )
        else:
            return (
                f"{bos}Question: {goal.strip()}\n"
                f"A. {sol1.strip()}\n"
                f"B. {sol2.strip()}\n"
                f"Which is better? (A or B)\n"
            ), f"Answer: "
    
    def __len__(self):
        return len(self.ds)

    def preprocess(self, example):
        answer = "A. " + example["sol1"] if example["label"] == 0 else "B. " + example["sol2"]
        if self.mode == "train":
            example["text"] = self._build_prompt(example["goal"], example["sol1"], example["sol2"]) + answer
        else:
            tmp = self._build_prompt(example["goal"], example["sol1"], example["sol2"])
            example["text"] = tmp[0] + tmp[1] + answer
        return example

    def __getitem__(self, idx):
        example = self.ds[idx]
        prompt = self._build_prompt(
            example["goal"],
            example["sol1"],
            example["sol2"]
        )

        answer = "A. " + example["sol1"] if example["label"] == 0 else "B. " + example["sol2"]
        prompt_ids = self.tokenizer(prompt, add_special_tokens=False).input_ids
        if self.mode == "eval":
            # answer = "A. "
            return {
                # "input_ids": torch.tensor(prompt_ids),
                "context": prompt[0], #torch.tensor(prompt_ids),
                "question": prompt[1],
                "attention_mask": torch.tensor([1] * len(prompt_ids)),
                "reference": answer,
            }
        ans_ids = self.tokenizer(answer, add_special_tokens=False).input_ids
        if ans_ids[-1] != self.tokenizer.eos_token_id:
            ans_ids.append(self.tokenizer.eos_token_id)
        prompt_ans_ids_len = len(prompt_ids + ans_ids)
        pad_len = max(0, self.max_length - prompt_ans_ids_len)
        input_ids = torch.tensor(prompt_ids + ans_ids + [self.tokenizer.pad_token_id] * pad_len)[:self.max_length]
        attn_mask = torch.tensor([1] * prompt_ans_ids_len + [0] * pad_len)[:self.max_length]
        labels = torch.tensor(prompt_ids + ans_ids + [-100] * pad_len)[:self.max_length]
        return {
            "input_ids": input_ids,
            "attention_mask": attn_mask,
            "labels": labels,
        }

