import torch


def fft_2D(x, s=None, ortho:bool=False, dim=[-1, -2]):
    if ortho:
        fft_image = torch.fft.fftn(x, s=s, norm='ortho', dim=dim)
    else:
        fft_image = torch.fft.fftn(x, s=s, dim=dim)
    return torch.fft.fftshift(fft_image)


def ifft_2D(x, ortho:bool=False, dim=[-1, -2]):
    f_space = torch.fft.ifftshift(x)
    # Perform inverse 2D Fourier Transform
    if ortho:
        return torch.fft.ifftn(f_space, norm='ortho', dim=dim)
    else:
        return torch.fft.ifftn(f_space, dim=dim)


b_fft_2D = torch.vmap(lambda x: fft_2D(x, ortho=True))
b_ifft_2D = torch.vmap(lambda x: ifft_2D(x, ortho=True))