import torch
import numpy as np


def mask_task(x, mask_labels):
    imgs, labels = x
    for mask_label in mask_labels:
        imgs[labels == mask_label] *= 0
    return imgs, labels
    

def flatten(x: torch.Tensor):
    d, lbl = x
    d = torch.flatten(d, start_dim=1)
    return d, lbl


class NumpyImageFFT:

    def __init__(self, input_shape=(28, 28), crop_shape=(8, 8),):
        self.input_shape = input_shape
        self.crop_shape = crop_shape

        self.crop_start_row = self.input_shape[0] // 2 - self.crop_shape[0] // 2
        self.crop_end_row = self.crop_start_row + self.crop_shape[0]
        self.crop_start_col = self.input_shape[1] // 2 - self.crop_shape[1] // 2
        self.crop_end_col = self.crop_start_col + self.crop_shape[1]
        
    def __call__(self, x):
        images, lbl = x
        assert images.shape[-2:] == self.input_shape

        images = np.fft.fft2(images)
        images = np.fft.fftshift(images, axes=(-2, -1))
        images = images[:, :,
                        self.crop_start_row: self.crop_end_row,
                        self.crop_start_col: self.crop_end_col]
        images = np.abs(images)
        images = images.reshape(-1, self.crop_shape[0] * self.crop_shape[1])
  
        return images, lbl
    

class TensorImageFFT(NumpyImageFFT):

    def __call__(self, x):
        images = x
        assert images.shape[-2:] == self.input_shape

        images = torch.fft.fft2(images)
        images = torch.fft.fftshift(images, dim=(-2, -1))
        images = images[:, :,
                        self.crop_start_row: self.crop_end_row,
                        self.crop_start_col: self.crop_end_col]
        images = torch.abs(images)
        images = torch.flatten(images, start_dim=1)
  
        return images