# guided_diffusion/psf_measurement.py
# FFT-based linear operator y = K * x with a known PSF.
# Loads .npy/.tif/.png PSF, converts to grayscale, unit-sum normalizes,
# COM-centers, precomputes OTF, and exposes A / AT.

import torch
import torch.fft as fft
import numpy as np
from typing import Tuple, Dict, Any, Optional

try:
    from PIL import Image
    _PIL_OK = True
except Exception:
    _PIL_OK = False

def _load_psf(
    psf_path: str,
    target_hw: Optional[Tuple[int, int]] = None,
    center: bool = True,
    device: str = "cpu",
    dtype = torch.float32
) -> torch.Tensor:
    if psf_path.endswith(".npy"):
        k = np.load(psf_path).astype(np.float32)
    else:
        if not _PIL_OK:
            raise RuntimeError("Pillow is required to load non-.npy PSFs.")
        img = Image.open(psf_path)
        arr = np.array(img, dtype=np.float32)
        if arr.ndim == 3 and arr.shape[2] >= 3:
            arr = arr[..., :3].mean(axis=2)  # grayscale
        k = arr

    k = np.clip(k, 0, None)
    s = float(k.sum())
    if s <= 0:
        raise ValueError("PSF sum is zero or negative.")
    k = k / s

    # Optional resize (area)
    if target_hw is not None and (k.shape[0] != target_hw[0] or k.shape[1] != target_hw[1]):
        if not _PIL_OK:
            raise RuntimeError("Pillow is required to resize PSFs.")
        Ht, Wt = target_hw
        k8 = (k / (k.max() + 1e-12) * 255.0).astype(np.uint8)
        im = Image.fromarray(k8)
        im = im.resize((Wt, Ht), resample=Image.BOX)
        k = np.array(im, dtype=np.float32)
        k = np.clip(k, 0, None)
        s = float(k.sum())
        if s <= 0:
            raise ValueError("Resized PSF sum is zero.")
        k = k / s

    if center:
        H, W = k.shape
        yy, xx = np.indices((H, W), dtype=np.float32)
        total = float(k.sum())
        cy = float((k * yy).sum() / total)
        cx = float((k * xx).sum() / total)
        cy_t = (H - 1) / 2.0
        cx_t = (W - 1) / 2.0
        dy = int(round(cy_t - cy))
        dx = int(round(cx_t - cx))
        k = np.roll(k, shift=(dy, dx), axis=(0, 1))

    return torch.from_numpy(k).to(device=device, dtype=dtype)

def _psf_to_otf(psf: torch.Tensor, H: int, W: int) -> torch.Tensor:
    h, w = psf.shape
    if h != H or w != W:
        out = torch.zeros((H, W), device=psf.device, dtype=psf.dtype)
        hh = min(h, H); ww = min(w, W)
        sy = (H - hh) // 2; sx = (W - ww) // 2
        py = (h - hh) // 2; px = (w - ww) // 2
        out[sy:sy+hh, sx:sx+ww] = psf[py:py+hh, px:px+ww]
        psf = out
    psf_shift = torch.roll(psf, shifts=(-psf.shape[0] // 2, -psf.shape[1] // 2), dims=(0, 1))
    return fft.fft2(psf_shift)

class PSFMeasurement:
    def __init__(
        self,
        psf_path: Optional[str] = None,
        psf_tensor: Optional[torch.Tensor] = None,
        img_hw: Optional[Tuple[int, int]] = None,
        device: str = "cpu",
        dtype = torch.float32
    ):
        assert (psf_path is not None) or (psf_tensor is not None), "Provide psf_path or psf_tensor."
        if psf_tensor is None:
            psf = _load_psf(psf_path, target_hw=img_hw, device=device, dtype=dtype)
        else:
            psf = psf_tensor.to(device=device, dtype=dtype)

        H, W = img_hw if img_hw is not None else psf.shape
        self.HW = (H, W)
        self.device = device
        self.dtype = dtype

        self.psf = psf
        self.otf = _psf_to_otf(psf, H, W)
        self.otf_conj = torch.conj(self.otf)

    def A(self, x: torch.Tensor) -> torch.Tensor:
        H, W = self.HW
        X = fft.fft2(x, dim=(-2, -1))
        Y = X * self.otf.view(1, 1, H, W)
        return fft.ifft2(Y, dim=(-2, -1)).real

    def AT(self, y: torch.Tensor) -> torch.Tensor:
        H, W = self.HW
        Y = fft.fft2(y, dim=(-2, -1))
        X = Y * self.otf_conj.view(1, 1, H, W)
        return fft.ifft2(X, dim=(-2, -1)).real

    @torch.no_grad()
    def to_y(self, x: torch.Tensor, sigma: float = 0.01) -> torch.Tensor:
        y = self.A(x)
        if sigma is not None and sigma > 0:
            y = y + sigma * torch.randn_like(y)
        return y.clamp(0.0, 1.0)

def make_psf_measurement(
    psf_path: str,
    img_hw: Tuple[int, int],
    device: str = "cuda",
    dtype = torch.float32
) -> Dict[str, Any]:
    op = PSFMeasurement(psf_path=psf_path, img_hw=img_hw, device=device, dtype=dtype)
    return {"A": op.A, "AT": op.AT, "op": op}
