import torch
from torch.utils.data import Dataset
from basic_shapes import basic_structured_shapes, test_structured_shapes

def test_data(stype='train'):
    shapes = test_structured_shapes()
    shapes = torch.cat([torch.tensor(shape.copy(), dtype=torch.int8).unsqueeze(0) for shape in shapes], dim=0)
    group = []
    idx = torch.arange(9)
    for i in range(3):
        if stype == 'train':
            x = torch.zeros(8, 8)
            x[:4, :4] = shapes[idx[3*i]]
            x[4:, 4:] = shapes[idx[3*i+1]]
            group.append(x.clone().to(torch.float32))
            x = torch.zeros(8, 8)
            x[:4, :4] = shapes[idx[3*i+1]]
            x[4:, 4:] = shapes[idx[3*i+2]]
            group.append(x.clone().to(torch.float32))
            x = torch.zeros(8, 8)
            x[:4, :4] = shapes[idx[3*i+2]]
            x[4:, 4:] = shapes[idx[3*i]]
            group.append(x.clone().to(torch.float32))
        elif stype == 'test':
            x = torch.zeros(8, 8)
            x[:4, 4:] = shapes[idx[3*i]]
            x[4:, :4] = shapes[idx[3*i+1]]
            group.append(x.clone().to(torch.float32))
            x = torch.zeros(8, 8)
            x[:4, 4:] = shapes[idx[3*i+1]]
            x[4:, :4] = shapes[idx[3*i+2]]
            group.append(x.clone().to(torch.float32))
            x = torch.zeros(8, 8)
            x[:4, 4:] = shapes[idx[3*i+2]]
            x[4:, :4] = shapes[idx[3*i]]
            group.append(x.clone().to(torch.float32))
        else:
            raise NotImplementedError
    group = torch.stack(group)
    return group.unsqueeze(0)

# 3-row group dataset
class ThreeRowShapeDataset(Dataset):
    def __init__(self, num_groups=100, stype="train"):
        """
        num_groups: number of samples
        stype: training data or test data
        """
        super().__init__()
        shapes = basic_structured_shapes()
        shapes = torch.cat([torch.tensor(shape.copy(), dtype=torch.int8).unsqueeze(0) for shape in shapes], dim=0)
        self.groups = []

        for _ in range(num_groups):
            perm = torch.randperm(len(shapes))
            idx = perm[:9]
            x = shapes[idx]
            group = []
            for i in range(3):
                if stype == 'train':
                    x = torch.zeros(8, 8)
                    x[:4, :4] = shapes[idx[3*i]]
                    x[4:, 4:] = shapes[idx[3*i+1]]
                    group.append(x.clone().to(torch.float32))
                    x = torch.zeros(8, 8)
                    x[:4, :4] = shapes[idx[3*i+1]]
                    x[4:, 4:] = shapes[idx[3*i+2]]
                    group.append(x.clone().to(torch.float32))
                    x = torch.zeros(8, 8)
                    x[:4, :4] = shapes[idx[3*i+2]]
                    x[4:, 4:] = shapes[idx[3*i]]
                    group.append(x.clone().to(torch.float32))
                elif stype == 'test':
                    x = torch.zeros(8, 8)
                    x[:4, 4:] = shapes[idx[3*i]]
                    x[4:, :4] = shapes[idx[3*i+1]]
                    group.append(x.clone().to(torch.float32))
                    x = torch.zeros(8, 8)
                    x[:4, 4:] = shapes[idx[3*i+1]]
                    x[4:, :4] = shapes[idx[3*i+2]]
                    group.append(x.clone().to(torch.float32))
                    x = torch.zeros(8, 8)
                    x[:4, 4:] = shapes[idx[3*i+2]]
                    x[4:, :4] = shapes[idx[3*i]]
                    group.append(x.clone().to(torch.float32))
                else:
                    raise NotImplementedError
            group = torch.stack(group)
            self.groups.append(group)

    def __len__(self):
        return len(self.groups)

    def __getitem__(self, idx):
        """
        Returns a list of 2 tuples: (x, y)
        """
        data = self.groups[idx]
        x = data[:-1]
        y = data
        return x, y
