import torch
from torch.utils.data import Dataset, DataLoader
import random

class MultiLayerLogicDataset(Dataset):
    '''
    h1 = x0 and x1, x2 and x3, x4 and x5, x6 and x7, not x0 and not x1, not x2 and not x3, not x4 and not x5, not x6 and not x7
    h2 = h10 or h11, h12 or h13, h14 or h15, h16 or h17
    y = h20 and h21, h22 and h23
    '''
    def __init__(self, dtype=torch.get_default_dtype()):
        super().__init__()
        self.x = [[x0, x1, x2, x3, x4, x5, x6, x7] for x0 in [0, 1] for x1 in [0, 1] for x2 in [0, 1] for x3 in [0, 1] for x4 in [0, 1] for x5 in [0, 1] for x6 in [0, 1] for x7 in [0, 1]]
        h1 = [[x0 and x1, x2 and x3, x4 and x5, x6 and x7, not x0 and not x1, not x2 and not x3, not x4 and not x5, not x6 and not x7] for x0, x1, x2, x3, x4, x5, x6, x7 in self.x]
        h2 = [[h10 or h11, h12 or h13, h14 or h15, h16 or h17] for h10, h11, h12, h13, h14, h15, h16, h17 in h1]
        self.y = [[h20 and h21, h22 and h23] for h20, h21, h22, h23 in h2]
        self.x, self.y = torch.tensor(self.x, dtype=dtype), torch.tensor(self.y, dtype=dtype)
        
    def __len__(self):
        return self.x.shape[0]
        
    def __getitem__(self, idx):
        return self.x[idx, :], self.y[idx, :]
    
class ADDDataset(Dataset):
    def __init__(self, dtype=torch.get_default_dtype()):
        super().__init__()
        self.x = torch.tensor([[i, j] for i in range(10) for j in range(10)])
        self.y = torch.tensor([[i + j] for i, j in self.x])
        self.x = torch.nn.functional.one_hot(self.x).flatten(1, 2).type(dtype)
        self.y = torch.nn.functional.one_hot(self.y).flatten(1, 2).type(dtype)
        
    def __len__(self):
        return self.x.shape[0]
    
    def __getitem__(self, idx):
        return self.x[idx, :], self.y[idx, :]