import torch
import numpy as np
from scipy.interpolate import griddata
from kornia.metrics import SSIM3D

@torch.no_grad()
def voxelize_pointcloud_batch(pos, values, grid_size=(64, 64, 64),
                              method='nearest', bounds=None):
    """
    pos:    [B, N, 3]
    values: [B, N, C]
    grid_size: (Dx, Dy, Dz)
    method: 'linear' | 'nearest' | 'cubic'
    bounds: ((x_min,y_min,z_min),(x_max,y_max,z_max))
    """
    assert pos.ndim == 3 and values.ndim == 3
    B, N, _ = pos.shape
    C = values.shape[2]
    Dx, Dy, Dz = grid_size

    pos_np = pos.detach().cpu().numpy()
    val_np = values.detach().cpu().numpy()
    grids = []

    for b in range(B):
        p = pos_np[b]      # [N,3]
        v = val_np[b]      # [N,C]

        if bounds is None:
            (x_min, y_min, z_min) = p.min(axis=0)
            (x_max, y_max, z_max) = p.max(axis=0)
        else:
            (x_min, y_min, z_min), (x_max, y_max, z_max) = bounds

        gx = np.linspace(x_min, x_max, Dx)
        gy = np.linspace(y_min, y_max, Dy)
        gz = np.linspace(z_min, z_max, Dz)
        grid_x, grid_y, grid_z = np.meshgrid(gx, gy, gz, indexing='ij')

        grid = np.zeros((C, Dx, Dy, Dz), dtype=np.float32)
        for c in range(C):
            grid[c] = griddata(points=p, values=v[:, c],
                               xi=(grid_x, grid_y, grid_z),
                               method=method, fill_value=0.0)
        grids.append(torch.from_numpy(grid))

    return torch.stack(grids, dim=0)  # [B, C, Dx, Dy, Dz]

@torch.no_grad()
def ssim3d_pointcloud_monai(pos, pred, target, grid_size=(64,64,64),
                            method='nearest', max_val=None, bounds=None):
    """
    pos:    [B, N, 3]
    pred:   [B, N, C]
    target: [B, N, C]
    """
    vol_pred = voxelize_pointcloud_batch(pos, pred,   grid_size, method, bounds)
    vol_tgt  = voxelize_pointcloud_batch(pos, target, grid_size, method, bounds)

    if max_val is None:
        max_val = (vol_tgt.max() - vol_tgt.min())
        max_val = float(max_val.item()) if max_val > 0 else 1.0

    ssim3d = SSIM3D(window_size=11, max_val=max_val, padding='same')
    return ssim3d(vol_pred, vol_tgt).mean(dim=(1, 2, 3, 4)).mean().numpy()
