import numpy as np
import pyvista as pv
import os

# Load the array
arr = np.load('')  # shape (20, 5, 64, 64, 64)

out_dir = ''
os.makedirs(out_dir, exist_ok=True)
window_size = [2048, 2048]

def fft_upsample_3d(data, scale=2):
    # FFT
    f = np.fft.fftn(data)
    # Shift zero-freq to center
    fshift = np.fft.fftshift(f)
    # Zero-pad in frequency domain
    shape = np.array(data.shape)
    new_shape = shape * scale
    pad_width = [(n//2, n - n//2) for n in (new_shape - shape)]
    fpad = np.pad(fshift, pad_width, mode='constant')
    # Shift back and inverse FFT
    fpad = np.fft.ifftshift(fpad)
    upsampled = np.fft.ifftn(fpad)
    # Take real part
    return np.real(upsampled)

for t in range(14,16):
    upsampled = arr[t, -2]  # use your desired channel
    upsampled = fft_upsample_3d(upsampled, scale=2)  # shape (128, 128, 128)
    grid = pv.ImageData(
        dimensions=np.array(upsampled.shape) + 1,
        spacing=(1, 1, 1),
        origin=(0, 0, 0)
    )
    grid.cell_data["values"] = upsampled.flatten(order="F")

    p = pv.Plotter(off_screen=True, window_size=window_size)
    p.add_volume(
        grid,
        scalars="values",
        cmap="inferno",
        opacity="sigmoid_5",
        shade=True,
        show_scalar_bar=False
    )
    p.set_background("white")
    p.camera_position = 'iso'
    filename = os.path.join(out_dir, f"ground_truth_476_t{t:02d}_fft2x.png")
    p.screenshot(filename)
    p.close()

print("Done! High-resolution FFT-upsampled images saved in:", out_dir)