import math
import torch
import matplotlib.pyplot as plt


def _sqrt_grid(B: int) -> int:
    g = int(math.isqrt(B))
    if g * g != B:
        raise ValueError(f"Batch size B={B} is not a perfect square.")
    return g

def vis_sdf(
    sdf_batch,
    figsize=(8, 8),
    title="Raw SDF Outputs",
    cmap="gray",
    cbar_label="Signed distance",
    vmin=None,
    vmax=None,
    show=True,
):
    if not torch.is_tensor(sdf_batch):
        sdf_batch = torch.tensor(sdf_batch)

    if sdf_batch.ndim != 3:
        raise ValueError(f"Expected shape [B,H,W], got {tuple(sdf_batch.shape)}")

    B = sdf_batch.shape[0]
    grid = _sqrt_grid(B)

    sdf_cpu = sdf_batch.detach().cpu()

    # Handle intensity range
    if vmin is None or vmax is None:
        vmin = sdf_cpu.min().item()
        vmax = sdf_cpu.max().item()

    fig = plt.figure(figsize=figsize)
    gs = fig.add_gridspec(
        nrows=grid,
        ncols=grid + 1,
        width_ratios=[1] * grid + [0.08],
        wspace=0.05,
        hspace=0.05,
    )

    axes = [fig.add_subplot(gs[r, c]) for r in range(grid) for c in range(grid)]
    cax = fig.add_subplot(gs[:, -1])

    im = None
    for i in range(B):
        ax = axes[i]
        ax.axis("off")
        im = ax.imshow(
            sdf_cpu[i],
            cmap=cmap,
            vmin=vmin,
            vmax=vmax,
        )

    cbar = fig.colorbar(im, cax=cax)
    cbar.set_label(cbar_label)

    fig.suptitle(title)

    if show:
        plt.show()

    return fig



def vis_mask(
    sdf_batch,
    figsize=(8, 8),
    title="Masked Digits",
    cmap="gray",
    show=True,
):
    if not torch.is_tensor(sdf_batch):
        sdf_batch = torch.tensor(sdf_batch)

    if sdf_batch.ndim != 3:
        raise ValueError(f"Expected shape [B,H,W], got {tuple(sdf_batch.shape)}")

    B = sdf_batch.shape[0]
    grid = _sqrt_grid(B)

    mask = (sdf_batch < 0).float().detach().cpu()

    fig, axes = plt.subplots(grid, grid, figsize=figsize)
    axes = axes.flatten()

    for i in range(B):
        axes[i].imshow(mask[i], cmap=cmap)
        axes[i].axis("off")

    fig.suptitle(title)

    if show:
        plt.show()

    return fig
