import torch
from torch.utils.data import DataLoader, Dataset

from typing import Callable


class ASRDataset(Dataset):
    def __init__(self, df, tokenizer):
        self.df = df
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.df)


    def __getitem__(self, idx):

        # df = self.df[idx]
        df = self.df.iloc[idx]

        asr_text, sentence = df['pron'], str(df['latex']) + self.tokenizer.eos_token

        asr_text_tokens = self.tokenizer.encode(asr_text, return_tensors = 'pt', add_special_tokens = False)[0]
        sentence_tokens = self.tokenizer.encode(sentence, return_tensors = 'pt', add_special_tokens = False)[0]

        general_input = torch.cat([asr_text_tokens, sentence_tokens])
        mask = torch.tensor([False] * len(asr_text_tokens) + [True] * len(sentence_tokens), dtype = torch.bool)
        assert len(mask) == len(general_input)
        return general_input, mask


def get_dataset(df, tokenizer):

    return ASRDataset(df, tokenizer)


def get_collate_function(tokenizer):
    def collate_fnc(data):
        general_inputs, masks = zip(*data)
        general_inputs = list(general_inputs)
        masks = list(masks)
        max_len = max([t.shape[-1] for t in general_inputs])
        for i in range(len(general_inputs)):
            pad_len = max_len - general_inputs[i].shape[-1]
            masks[i] = torch.cat([masks[i], torch.tensor(pad_len*[False], dtype=bool)], dim=0)
            general_inputs[i] = torch.cat([general_inputs[i], torch.tensor(pad_len*[tokenizer.pad_token_id], dtype=int)], dim=0)
    
        general_inputs = torch.stack(general_inputs)
        masks = torch.stack(masks)
        return general_inputs, masks

    return collate_fnc

def get_dataloader(dataset: Dataset,
                   batch_size: int,
                   collate_fn: Callable,
                   num_workers: int,
                   train: bool = False) -> DataLoader:
    
    return DataLoader(dataset=dataset, 
                      batch_size=batch_size, 
                      collate_fn=collate_fn, 
                      num_workers=num_workers, 
                      shuffle=train, 
                      drop_last=train)
