import torch

class Data:
    def __init__(
            self, 
            train_dataset_frac : float = .7       ,
            valid_dataset_frac : float = .1       ,
            test_dataset_frac  : float = .2       ,
            src_symbols        : int   = 100      ,
            tgt_symbols        : int   = 10       ,
            device             : str   = "cuda:0"
        ):
        super().__init__()
        self.src_symbols = src_symbols
        self.tgt_symbols = tgt_symbols
        self.mask_symbol = self.src_symbols + self.tgt_symbols

        src = torch.tensor([[i,j] for i in range(src_symbols) for j in range(src_symbols)])

        perm_idxs = torch.randperm(src.size(0))
        self.train_src = src[perm_idxs[int(0                                ) : int(0                                ) + int(len(perm_idxs)*train_dataset_frac)]].to(device)
        self.valid_src = src[perm_idxs[int(len(perm_idxs)*train_dataset_frac) : int(len(perm_idxs)*train_dataset_frac) + int(len(perm_idxs)*valid_dataset_frac)]].to(device)
        self.test_src  = src[perm_idxs[int(len(perm_idxs)*valid_dataset_frac) : int(len(perm_idxs)*valid_dataset_frac) + int(len(perm_idxs)*test_dataset_frac )]].to(device)

        self.train_tgt = (self.train_src.sum(-1,keepdim=True) % tgt_symbols)
        self.valid_tgt = (self.valid_src.sum(-1,keepdim=True) % tgt_symbols)
        self.test_tgt  = (self. test_src.sum(-1,keepdim=True) % tgt_symbols)

