import torch
from chip.utils.fourier import fft_2D, ifft_2D
from chip.utils import create_circle_filter, create_gaussian_filter


def fourier_filtering(image, filter):
    """
    Apply a filter to an image.

    Parameters:
    image (torch.Tensor): A batch of square images with shape (height, width).
    filter (torch.Tensor): Filter to be applied in Fourier space (height, width)
    Returns:
    torch.Tensor: The low resolution image after applying the filter
    """
    # Create a Gaussian filter
    fft_image = fft_2D(image)

    if len(filter.shape) < len(fft_image.shape):
        filter = filter.unsqueeze(0)

    # filter's aspect ratio is squared, but maybe the image has different aspect ratio
    small_filter = filter.clone().to(filter.device)
    for dim, (s, t) in enumerate(zip(filter.shape, fft_image.shape)):
        small_filter = torch.narrow(small_filter, dim, s//2 - t//2, t)
    fft_image *= small_filter
    return ifft_2D(fft_image).real
