import torch
from typing import List, Tuple
from torch.nn.utils.rnn import pad_sequence
from abc import ABC, abstractmethod


def tokens_to_tensor(token_ids, sp):
    return torch.cat((torch.tensor([sp.bos_id()]),
                      torch.tensor(token_ids),
                      torch.tensor([sp.eos_id()])))


class DecoderDataset(torch.utils.data.Dataset, ABC):
    def __init__(self, data: List[str], tokenizer):
        self.tokenizer = tokenizer
        self.data = data

    def __len__(self):
        return len(self.data)

    @abstractmethod
    def collate_fn(self, batch: List[torch.Tensor]) -> torch.Tensor:
        pass

    @abstractmethod
    def __getitem__(self, idx: int) -> torch.Tensor:
        pass


class EncoderDecoderDataset(torch.utils.data.Dataset, ABC):
    def __init__(self, data: List[str], input_tokenizer, output_tokenizer, split="="):
        self.tok_in = input_tokenizer
        self.tok_out = output_tokenizer
        self.data = data
        # where to split the input and output
        # should be added back to the input after splitting
        self.split = split

    def __len__(self):
        return len(self.data)

    @abstractmethod
    def collate_fn(self, batch: List[Tuple[torch.Tensor, torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
        pass

    @abstractmethod
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        pass


class DecoderDatasetImpl(DecoderDataset):
    def collate_fn(self, batch):
        res_batch = []
        for ex in batch:
            res_batch.append(ex)

        res_batch = pad_sequence(
            res_batch, padding_value=self.tokenizer.pad_id())
        return res_batch

    def __getitem__(self, idx):
        ex = self.data[idx]
        ids = self.tokenizer.encode_as_ids(ex)
        return tokens_to_tensor(ids, self.tokenizer)


class EncoderDecoderDatasetImpl(EncoderDecoderDataset):
    def collate_fn(self, batch):
        src_batch, tgt_batch = [], []
        for src, tgt in batch:
            src_batch.append(src)
            tgt_batch.append(tgt)

        src_batch = pad_sequence(src_batch, padding_value=self.tok_in.pad_id())
        tgt_batch = pad_sequence(
            tgt_batch, padding_value=self.tok_out.pad_id())
        return src_batch, tgt_batch

    def __getitem__(self, idx):
        lhs, rhs = self.data[idx].split(self.split)
        lhs += self.split

        lhs_ints = self.tok_in.encode_as_ids(lhs)
        rhs_ints = self.tok_out.encode_as_ids(rhs)

        return tokens_to_tensor(lhs_ints, self.tok_in), tokens_to_tensor(rhs_ints, self.tok_out)