import numpy as np
from scipy import ndimage


def get_spacing(affine: np.ndarray) -> np.ndarray:
    x_scale = np.linalg.norm(affine[:,0])
    y_scale = np.linalg.norm(affine[:,1])
    z_scale = np.linalg.norm(affine[:,2])
    return np.array([x_scale, y_scale, z_scale])



def resample_scan(scan_data: np.ndarray, current_spacing: np.ndarray, target_spacing: np.ndarray, order: int=1) -> np.ndarray:
    """
    Resample the scan data to the target spacing.
    
    Args:
        scan_data (np.ndarray): The 3D scan data (C, W, H, D) or (W, H, D).
        current_spacing (np.ndarray): The current spacing of the scan.
        target_spacing (np.ndarray): The desired spacing for resampling.
        order (int): The interpolation order for resampling. Default is 1 (linear interpolation) 0 for nearest.
    
    Returns:
        np.ndarray: Resampled scan data.
    """
    zoom_factors = current_spacing / target_spacing
    if scan_data.ndim == 4:
        zoom_factors = [1.0] + list(zoom_factors)
    resampled_data = ndimage.zoom(scan_data, zoom_factors, order=order)
    return resampled_data

