from typing import Union, Tuple, List, Dict

import math
import json

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np


from transformers import AutoTokenizer
from transformers.models.gpt2.tokenization_gpt2_fast import GPT2TokenizerFast
from transformers.models.t5.tokenization_t5_fast import T5TokenizerFast


class BaseDataset(Dataset):

    def __init__(
        self,
        tokenizer,
        device,
        path: str,
        data=None, 
        pad_side="left", 
        num = None
    ):

        # self.config = config
        if data is None:
            with open(path) as file:
                self.data = json.load(file)
        if num is not None:
            self.data = self.data[:num]
        self.pad_side = pad_side
        self.tok = tokenizer
        self.device = device

    def __len__(self):
        return len(self.data)

    def collate_fn(
        self,
        tuples: Tuple[Dict[str, Dict[str, torch.LongTensor]]]
    ) -> Dict[str, List[Dict[str, torch.LongTensor]]]:
        
        tem = {}
        for k in tuples[0]["tok_data"].keys():
            pad_value = -100 if k == "labels" else 0
            tens = [tup["tok_data"][k].squeeze(0) for tup in tuples]
            tem[k] = self.pad_data(tensors=tens, padding_value =pad_value, pad_side=self.pad_side) 
        
        te = {"tok_data": tem, "raw_data": [tup["raw_data"] for tup in tuples]}
        return te
        
    
    def pad_data(self, 
                 tensors: List,
                 padding_value, 
                 pad_side: str):
        max_len = np.max([ten.size()[-1] for ten in tensors])
        tem = []
        dtype = tensors[0].dtype 
        device = self.device
        for ten in tensors:
            ten = ten.squeeze(0).tolist()
            pad = [padding_value] * (max_len - len(ten))
            tem.append(pad+ten if pad_side=="left" else ten+pad)
        return torch.tensor(tem, device=device, dtype=dtype)
            
    
    def __getitem__(self, idx) -> Dict[str, Dict[str, torch.LongTensor]]:
        row = self.data[idx]
        
        prompt = row["question"]
        answer = row["answer"]
    
        tem =  {
            "tok_data": self.tok_tuples(prompt, answer),
            "raw_data": {"question": prompt, "answer": answer}
            }
        return tem
        
    def tok_tuples(
        self,
        prompt: str,
        answer: str
    ) -> Dict[str, torch.LongTensor]:

        # if isinstance(self.tok, GPT2TokenizerFast):
        # answer = " " + answer
            
        tok_prompt = self.tok(
            prompt,
            return_tensors = "pt",
        )
        tok_answer = self.tok(
            answer,
            return_tensors = "pt",
            add_special_tokens = False
        )
        if tok_answer["input_ids"][0][0] == self.tok.bos_token_id or tok_answer["input_ids"][0][0] == self.tok.unk_token_id:
            for key in tok_answer:
                tok_answer[key] = tok_answer[key][:, 1:]

        
        if isinstance(self.tok, T5TokenizerFast):

            tok_tuples = {
                "input_ids": tok_prompt["input_ids"],
                "attention_mask": tok_prompt["attention_mask"],
                "decoder_input_ids": torch.cat((
                    torch.LongTensor([[0]]),
                    tok_answer["input_ids"][:, :-1]
                ), -1),
                "decoder_attention_mask": tok_answer["attention_mask"],
                "labels": tok_answer["input_ids"]
            }
        else:
            # input_ids attention_mask labels
            tok_tuples = {
                key: torch.cat((value, tok_answer[key][:, :-1]), -1)
                for key, value in tok_prompt.items()
            }
            
            tok_tuples["labels"] = torch.cat((
                torch.full(tok_prompt["input_ids"].shape, -100)[:, 1:],
                tok_answer["input_ids"]
            ), -1)

        return tok_tuples
    