import torch

class Data:
    def __init__(
            self, 
            train_dataset_frac : float = .7       ,
            valid_dataset_frac : float = .1       ,
            test_dataset_frac  : float = .2       ,
            src_symbols        : int   = 100      ,
            tgt_symbols        : int   = 10       ,
            device             : str   = "cuda:0"
        ):
        super().__init__()
        self.src_symbols = src_symbols
        self.tgt_symbols = tgt_symbols
        self.mask_symbol = self.src_symbols + self.tgt_symbols

        src = torch.tensor([[i,j] for i in range(src_symbols) for j in range(src_symbols)])

        perm_idxs = torch.randperm(src.size(0))
        train_src = src[perm_idxs[int(0                                ) : int(0                                ) + int(len(perm_idxs)*train_dataset_frac)]]
        valid_src = src[perm_idxs[int(len(perm_idxs)*train_dataset_frac) : int(len(perm_idxs)*train_dataset_frac) + int(len(perm_idxs)*valid_dataset_frac)]]
        test_src  = src[perm_idxs[int(len(perm_idxs)*valid_dataset_frac) : int(len(perm_idxs)*valid_dataset_frac) + int(len(perm_idxs)*test_dataset_frac )]]

        train_tgt = (train_src.sum(-1,keepdim=True) % tgt_symbols) + src_symbols
        valid_tgt = (valid_src.sum(-1,keepdim=True) % tgt_symbols) + src_symbols
        test_tgt  = ( test_src.sum(-1,keepdim=True) % tgt_symbols) + src_symbols

        self.train = torch.cat([train_src,train_tgt], dim=-1).to(device)
        self.valid = torch.cat([valid_src,valid_tgt], dim=-1).to(device)
        self. test = torch.cat([ test_src, test_tgt], dim=-1).to(device)


