from typing import Callable, Optional, List, Dict

import torch
from torch.utils.data import Dataset
from datasets import load_dataset
from transformers import AutoTokenizer, PreTrainedTokenizerBase


class XsumDataset(Dataset):
    def __init__(self, split, tokenizer, mode="train", max_length=1024):
        assert mode in {"train", "eval"}
        self.mode = mode
        self.tokenizer = tokenizer
        self.max_length = max_length
        # self.sys_prompt = "You are a helpful assistant that summarises news articles into one concise sentence."
        # import IPython;IPython.embed()
        # import sys;sys.exit()
        self.ds = load_dataset("venetis/xsum_clean_text", split=split,)
        self.ds = self.ds.shuffle(seed=1234)
        self.ds = self.ds.map(self.preprocess)
        # self.prompt_builder = prompt_builder or self._default_prompt

        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
    
    def _build_prompt(self, article):
        bos = getattr(self.tokenizer, "bos_token", "")
        if self.mode == "train":
            return f"{bos}{article}\nSummary:"
        else:
            return f"{bos}{article}\n", "Summary:"
    
    

    def __len__(self):
        return len(self.ds)

    def preprocess(self, example):
        if self.mode == "train":
            example["text"] = self._build_prompt(example["document"]) + example["summary"]
        else:
            tmp = self._build_prompt(example["document"])
            example["text"] = tmp[0] + tmp[1] + example["summary"]
        return example


    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        example = self.ds[idx]

        prompt_text = self._build_prompt(example["document"])
        if self.mode == "eval":
            return {
                "context": prompt_text[0],
                "question": prompt_text[1],
                # "input_ids": torch.tensor(prompt_ids),
                # "attention_mask": torch.tensor([1] * len(prompt_ids)),
                "reference": example["summary"],
            }
        prompt_ids = self.tokenizer(prompt_text, add_special_tokens=False).input_ids
        ans_ids = self.tokenizer(example["summary"], add_special_tokens=False).input_ids
        if ans_ids[-1] != self.tokenizer.eos_token_id:
            ans_ids.append(self.tokenizer.eos_token_id)
        prompt_ans_ids_len = len(prompt_ids + ans_ids)
        pad_len = max(0, self.max_length - prompt_ans_ids_len)
        input_ids = torch.tensor(prompt_ids + ans_ids + [self.tokenizer.pad_token_id] * pad_len)[:self.max_length]
        attn_mask = torch.tensor([1] * prompt_ans_ids_len + [0] * pad_len)[:self.max_length]
        labels = torch.tensor([-100] * len(prompt_ids) + ans_ids + [-100] * pad_len)[:self.max_length]
        return {
            "input_ids": input_ids,
            "attention_mask": attn_mask,
            "labels": labels,
        }

