"""
Assess denoised STEM image quality (periodicity + denoise quality).

Public API:
- assess_image(img: np.ndarray, **kwargs) -> dict
- assess_image_path(path: str, **kwargs) -> dict

Metrics:
- periodicity_score ∈ [0,1] (FFT Bragg-peakiness + autocorr off-center ratio)
- snr  (robust, low-pass std / high-pass MAD-sigma)
- cnr  (Otsu bipartition contrast-to-noise)

Decision:
PASS if periodicity_score ≥ min_periodicity AND (snr ≥ min_snr OR cnr ≥ min_cnr)
"""
from __future__ import annotations
import json, os
from pathlib import Path
from typing import Dict, Tuple
import numpy as np
import cv2
from scipy import fft as spfft
from skimage.feature import peak_local_max
from skimage.filters import threshold_otsu

# -------------------- Core utilities --------------------
def imread_gray(path: str) -> np.ndarray:
    img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        raise FileNotFoundError(f"Cannot read image: {path}")
    img = img.astype(np.float32)
    p1, p99 = np.percentile(img, (1, 99))
    if p99 > p1:
        img = np.clip((img - p1) / (p99 - p1), 0, 1)
    else:
        rng = img.max() - img.min()
        img = (img - img.min()) / (rng + 1e-8)
    return img

def _robust_noise_sigma(x: np.ndarray) -> float:
    med = np.median(x)
    mad = np.median(np.abs(x - med)) + 1e-12
    return 1.4826 * mad

def _fft_logmag(img: np.ndarray) -> np.ndarray:
    F = spfft.fftshift(spfft.fft2(img))
    return np.log1p(np.abs(F))

def _bandpass_mask(h: int, w: int, rmin=6, rmax=None) -> np.ndarray:
    if rmax is None:
        rmax = min(h, w)//2 - 2
    cy, cx = h//2, w//2
    yy, xx = np.ogrid[:h, :w]
    rr2 = (yy - cy)**2 + (xx - cx)**2
    return ((rr2 >= rmin**2) & (rr2 <= rmax**2)).astype(np.float32)

def _bragg_peakiness_score(logmag: np.ndarray, k_peaks=12, neighborhood=9) -> float:
    h, w = logmag.shape
    band = logmag * _bandpass_mask(h, w, rmin=6)
    coords = peak_local_max(
        band, min_distance=max(4, neighborhood//2),
        threshold_rel=0.25, num_peaks=k_peaks
    )
    if coords.size == 0:
        return 0.0

    def local_contrast(y, x, r_in=2, r_out=6):
        y0, y1 = max(0, y-r_in), min(h, y+r_in+1)
        x0, x1 = max(0, x-r_in), min(w, x+r_in+1)
        peak_val = float(np.max(logmag[y0:y1, x0:x1]))
        Y, X = np.ogrid[y-r_out:y+r_out+1, x-r_out:x+r_out+1]
        Yc = np.clip(Y, 0, h-1); Xc = np.clip(X, 0, w-1)
        R = np.sqrt((Y - y)**2 + (X - x)**2)
        ring = logmag[Yc, Xc][(R >= 0.7*r_out) & (R <= 1.0*r_out)]
        bg = float(np.median(ring)) if ring.size else 0.0
        return max(0.0, peak_val - bg)

    contrasts = [local_contrast(y, x) for y, x in coords]
    norm = float(np.median(band[band > 0])) + 1e-6
    peakiness = float(np.sum(contrasts) / (len(contrasts) * norm))
    return float(np.tanh(0.3 * peakiness))  # [0,1)-ish

def _autocorr_offcenter_ratio(img: np.ndarray) -> float:
    F = spfft.fft2(img)
    ac = np.real(spfft.ifft2(np.abs(F)**2))
    ac = np.fft.fftshift(ac)
    ac = (ac - ac.min()) / (ac.max() - ac.min() + 1e-12)
    h, w = ac.shape
    cy, cx = h//2, w//2
    center = float(ac[cy, cx])
    rr = min(h, w)//2 - 4
    yy, xx = np.ogrid[:h, :w]
    R = np.sqrt((yy - cy)**2 + (xx - cx)**2)
    ring = ac[(R >= 6) & (R <= rr)]
    if ring.size == 0:
        return 0.0
    offmax = float(np.max(ring))
    return float(np.tanh(0.8 * (offmax / (center + 1e-12))))

def _periodicity_metrics(img: np.ndarray) -> Dict[str, float]:
    logmag = _fft_logmag(img)
    bragg = _bragg_peakiness_score(logmag)
    ac_ratio = _autocorr_offcenter_ratio(img)
    score = 0.6 * bragg + 0.4 * ac_ratio
    return {
        "periodicity_score": float(score),
        "bragg_peakiness": float(bragg),
        "autocorr_peak_ratio": float(ac_ratio),
    }

def _snr_cnr_metrics(img: np.ndarray) -> Tuple[float, float, float]:
    lp = cv2.GaussianBlur(img, (0, 0), sigmaX=2.0, sigmaY=2.0)
    hp = img - lp
    noise_sigma = _robust_noise_sigma(hp)
    signal_std = float(np.std(lp))
    snr = signal_std / (noise_sigma + 1e-12)
    # CNR via Otsu (fallback if degenerate)
    try:
        t = threshold_otsu(img)
        fg = img[img >= t]; bg = img[img < t]
        if fg.size < 20 or bg.size < 20:
            raise ValueError
        cnr = abs(float(fg.mean() - bg.mean())) / np.sqrt(float(fg.var() + bg.var()) + 1e-12)
    except Exception:
        cnr = 0.0
    return float(snr), float(cnr), float(noise_sigma)

# -------------------- Public API --------------------
def assess_image(
    img: np.ndarray,
    min_periodicity: float = 0.35,
    min_snr: float = 2.0,
    min_cnr: float = 1.5,
) -> Dict[str, float | str]:
    """Assess a denoised STEM image (float32 array in [0,1] or raw 0-255)."""
    img = img.astype(np.float32)
    if img.max() > 1.5:  # likely 0~255
        img = img / 255.0

    per = _periodicity_metrics(img)
    snr, cnr, noise_sigma = _snr_cnr_metrics(img)
    decision = (per["periodicity_score"] >= min_periodicity) and ((snr >= min_snr) or (cnr >= min_cnr))

    return {
        "periodicity_score": round(per["periodicity_score"], 4),
        "bragg_peakiness":   round(per["bragg_peakiness"], 4),
        "autocorr_peak_ratio": round(per["autocorr_peak_ratio"], 4),
        "snr": round(snr, 3),
        "cnr": round(cnr, 3),
        "noise_sigma": round(noise_sigma, 4),
        "thresholds": {
            "min_periodicity": float(min_periodicity),
            "min_snr": float(min_snr),
            "min_cnr": float(min_cnr),
        },
        "decision": "PASS" if decision else "FAIL",
    }

def assess_image_path(
    path: str,
    min_periodicity: float = 0.35,
    min_snr: float = 2.0,
    min_cnr: float = 1.5,
) -> Dict[str, float | str]:
    """Assess a denoised STEM image from file path."""
    img = imread_gray(path)
    res = assess_image(img, min_periodicity=min_periodicity, min_snr=min_snr, min_cnr=min_cnr)
    res = {"image": str(path), **res}
    return res

# -------------------- Optional CLI --------------------
if __name__ == "__main__":
    import argparse, matplotlib.pyplot as plt
    from scipy import fft as spfft

    ap = argparse.ArgumentParser(description="Assess denoised STEM image quality.")
    ap.add_argument("--image", default=None, help="path to denoised STEM image")
    ap.add_argument("--min_periodicity", type=float, default=0.35)
    ap.add_argument("--min_snr", type=float, default=2.0)
    ap.add_argument("--min_cnr", type=float, default=1.5)
    ap.add_argument("--save_viz", action="store_true", help="save FFT/autocorr visualizations next to image")
    args = ap.parse_args()

    out = assess_image_path(
        args.image,
        min_periodicity=args.min_periodicity,
        min_snr=args.min_snr,
        min_cnr=args.min_cnr,
    )
    print(json.dumps(out, indent=2, ensure_ascii=False))

    if args.save_viz:
        img = imread_gray(args.image)
        base = Path(args.image).with_suffix("").name
        out_dir = Path(args.image).parent
        # FFT viz
        logmag = _fft_logmag(img)
        plt.figure(); plt.imshow(logmag, cmap="gray"); plt.title("FFT log magnitude"); plt.axis("off")
        plt.tight_layout(); plt.savefig(out_dir / f"{base}_fft.png", dpi=200); plt.close()
        # Autocorr viz
        F = spfft.fft2(img); ac = np.real(spfft.ifft2(np.abs(F)**2)); ac = np.fft.fftshift(ac)
        ac = (ac - ac.min()) / (ac.max() - ac.min() + 1e-12)
        plt.figure(); plt.imshow(ac, cmap="gray"); plt.title("Autocorrelation"); plt.axis("off")
        plt.tight_layout(); plt.savefig(out_dir / f"{base}_autocorr.png", dpi=200); plt.close()
