from chip.utils.fourier import fft_2D
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
import torch
from chip.datasets.superres_dataset import SuperresolutionDS

class FourierTomogramDataset(Dataset):
    def __init__(self, files, data_path):
        """
        Args:
            data (list or array): Your data.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.data = SuperresolutionDS(files, data_path)
        self.train_transform = transforms.Compose(
            [
                transforms.RandomAffine((-180, 180), (0, 0), (1., 1.), interpolation=InterpolationMode.BILINEAR),
                # transforms.ToTensor(),
            ]
        )
        w = self.data[0][0].shape[-1]
        cp = torch.cartesian_prod(torch.arange(w), torch.arange(w))
        self.circle_mask = (cp[:, 0] - w / 2) ** 2 + (cp[:, 1] - w / 2) ** 2 > (w / 2) ** 2

    def test_transform(self):
        return transforms.Compose([
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x, y, file = self.data[idx]

        sample = self.train_transform(y.unsqueeze(0))[0]
        fourier_sample = fft_2D(sample)
        fourier_sample.view(-1)[self.circle_mask] *= 0

        return sample, fourier_sample, file