from matplotlib import pyplot as plt
import torch
from torch import fft


def plot_vol_slices(vol, reduction="slice"):
    if torch.is_tensor(vol):
        vol = vol.cpu().squeeze()
    fig, ax = plt.subplots(1, 3)
    half = torch.tensor(vol.shape) // 2
    if reduction == "slice":
        ax[0].imshow(vol[half[0], :, :], cmap="gray")
        ax[1].imshow(vol[:, half[1], :], cmap="gray")
        ax[2].imshow(vol[:, :, half[2]], cmap="gray")
    elif reduction == "sum":
        ax[0].imshow(vol.sum(0), cmap="gray")
        ax[1].imshow(vol.sum(1), cmap="gray")
        ax[2].imshow(vol.sum(2), cmap="gray")
    else:
        raise ValueError(f"Reduction '{reduction}' is not supported! Choose one of ['slice', 'sum'].")
    return fig


def fft_3d(vol, norm="ortho"):
    fft_dim = (-1, -2, -3)
    return fft.fftshift(fft.fftn(vol, dim=fft_dim, norm=norm), dim=fft_dim)

def plot_vol_fft_slices(vol, reduction="sum"):
    return plot_vol_slices(fft_3d(vol).abs(), reduction=reduction)