# -*- coding: utf-8 -*-
"""p2020_high_priority_metrics_p1.py  – ADAS-critical KPI subset (≥ 8 pts)"""

from __future__ import annotations
import cv2, numpy as np
from scipy import fft
from typing import Dict, Sequence, Tuple, Iterable, List, Union

# ───────────────────────── helpers ─────────────────────────
EPS = 1e-9  # avoid divide-by-zero everywhere
Frame = np.ndarray
Video = Union[Iterable[Frame], np.ndarray]  # list/iter or (T,H,W,C) array


def _gray(img: np.ndarray) -> np.ndarray:
    """Convert BGR/GRAY → float32 Gray (linear space)."""
    if img.ndim == 3:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    return img.astype(np.float32)


def _safe_div(a, b):
    return a / (b + EPS)


def _ensure_frames(video: Video) -> List[Frame]:
    """Accept list/iter of frames or (T,H,W,C) ndarray → list[frame]."""
    if isinstance(video, np.ndarray):
        if video.ndim == 4:
            return [video[t] for t in range(video.shape[0])]
        raise ValueError("video ndarray must be (T,H,W,C)")
    return list(video)


# ─────────────── DYNAMIC-RANGE / HDR (§5) ───────────────

def frame_dynamic_range_proxy(img: Frame,
                              p_lo: float = 0.1,
                              p_hi: float = 99.9,
                              assume_gamma: float | None = None) -> float:
    """
    {proxy-DR} 单帧直方图分位差；assume_gamma→返回 EV(log2)。
    """
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY).astype(np.float32)
    if assume_gamma:
        # 近似线性化：sRGB→linear ~ pow(gamma)
        gray = (gray / 255.0) ** assume_gamma
    lo, hi = np.percentile(gray, [p_lo, p_hi])
    if assume_gamma:
        return float(np.log2((hi + EPS) / (lo + EPS)))
    return float(hi - lo)


def sequence_dynamic_range_proxy(video: Video,
                                 p_lo: float = 0.1,
                                 p_hi: float = 99.9,
                                 assume_gamma: float | None = None) -> float:
    """
    proxy-DR（整段直方图分位差）。assume_gamma→返回 EV(log2)。
    """
    frames = _ensure_frames(video)
    all_pix = np.concatenate(
        [cv2.cvtColor(f, cv2.COLOR_BGR2GRAY).ravel() for f in frames]
    ).astype(np.float32)
    if assume_gamma:
        all_pix = (all_pix / 255.0) ** assume_gamma
    lo, hi = np.percentile(all_pix, [p_lo, p_hi])
    if assume_gamma:
        return float(np.log2((hi + EPS) / (lo + EPS)))
    return float(hi - lo)


def temporal_exposure_jitter(video: Video) -> float:
    frames = _ensure_frames(video)
    if len(frames) < 3:
        return 0.0
    lum = [cv2.cvtColor(f, cv2.COLOR_BGR2GRAY).mean() for f in frames]
    return float(np.std(np.diff(lum)))


# ─────────────────── SHARPNESS / RESOLUTION (§6) ───────────────────

def mtf50(img: Frame, axis: int | str = "both") -> float:
    """Sobel-ESF surrogate; 双向平均更稳健。"""
    if axis in (0, 1):
        return _mtf_sobel(img, axis, 0.50)
    return float((_mtf_sobel(img, 0, 0.50) + _mtf_sobel(img, 1, 0.50)) * 0.5)


def mtf10(img: Frame, axis: int | str = "both") -> float:
    if axis in (0, 1):
        return _mtf_sobel(img, axis, 0.10)
    return float((_mtf_sobel(img, 0, 0.10) + _mtf_sobel(img, 1, 0.10)) * 0.5)


def _mtf_sobel(img: Frame, axis: int, thr: float) -> float:
    """Sobel → mean profile (axis) → rFFT → first <= thr."""
    g = _gray(img)
    sob = cv2.Sobel(g, cv2.CV_32F, 1 - axis, axis, ksize=3)
    prof = sob.mean(axis=axis)
    spec = np.abs(fft.rfft(prof))
    m = spec.max()
    if m <= 0:
        return 0.0
    spec = spec / (m + EPS)
    idx = np.flatnonzero(spec <= thr)
    return float(idx[0] if idx.size else spec.size)


def contrast_transfer_accuracy(
        img: Frame,
        tgt_contrasts: Sequence[float] = (0.1, 0.5, 0.9),
        patch_size: int = 16) -> float:
    """多档 CTA：patch Michelson 对比度相对目标的回归斜率均值。"""
    g = _gray(img)
    small = cv2.resize(g, (0, 0), fx=0.25, fy=0.25, interpolation=cv2.INTER_AREA)
    h, w = small.shape
    if h < patch_size or w < patch_size:
        return 0.0

    vals = []
    for y in range(0, h - patch_size + 1, patch_size):
        for x in range(0, w - patch_size + 1, patch_size):
            p = small[y:y + patch_size, x:x + patch_size]
            c = _safe_div(p.max() - p.min(), p.max() + p.min())
            vals.append(float(c))
    if not vals:
        return 0.0
    mean_c = float(np.mean(vals))
    slopes = [mean_c / (t + EPS) for t in tgt_contrasts]
    return float(np.mean(slopes))


def edge_rise_time(img: Frame, window: int = 12) -> float:
    """10–90% edge-rise width（px）；双向平均。"""
    g = _gray(img)
    h, w = g.shape
    sob_x = cv2.Sobel(g, cv2.CV_32F, 1, 0, ksize=3)
    sob_y = cv2.Sobel(g, cv2.CV_32F, 0, 1, ksize=3)
    y_v, x_v = np.unravel_index(np.argmax(np.abs(sob_x)), sob_x.shape)
    y_h, x_h = np.unravel_index(np.argmax(np.abs(sob_y)), sob_y.shape)

    prof_vert = g[y_v, max(0, x_v - window): min(w, x_v + window + 1)]
    prof_hori = g[max(0, y_h - window): min(h, y_h + window + 1), x_h].ravel()

    def _erd(p: np.ndarray) -> float:
        p = np.asarray(p, float)
        if p.size < 2:
            return 0.0
        p10, p90 = np.percentile(p, [10, 90])
        i10s = np.flatnonzero(p >= p10)
        i90s = np.flatnonzero(p >= p90)
        i10 = int(i10s[0]) if i10s.size else 0
        i90 = int(i90s[0]) if i90s.size else p.size - 1
        return float(max(i90 - i10, 0))

    return float((_erd(prof_vert) + _erd(prof_hori)) * 0.5)


# ───────────────────────── GEOMETRY (§3) ─────────────────────────

def total_distortion(img: Frame, outer_frac: float = 0.8) -> float:
    """
    Total radial distortion (proxy, **signed**).
      + barrel, − pincushion；单位为相对量（≈百分比/100）。
    """
    g = _gray(img).astype(np.uint8)
    h, w = g.shape[:2]
    edges = cv2.Canny(g, 50, 150)
    ys, xs = np.where(edges)
    if xs.size < 50:
        return float("nan")

    cx, cy = w * 0.5, h * 0.5
    r = np.hypot(xs - cx, ys - cy)             # r_obs
    Rmax = float(np.hypot(cx, cy))

    order = np.argsort(r)
    r_sorted = r[order]
    M = r_sorted.size
    s = np.linspace(0.0, Rmax, M, dtype=np.float64)

    a, b = np.polyfit(s, r_sorted, 1)
    r_ref = a + b * s

    rel = (r_sorted - r_ref) / (s + 1e-6)
    k0 = int(np.clip(outer_frac * M, 0, M - 1))
    rel_outer = rel[k0:] if k0 < M else rel
    td = float(np.nanmedian(rel_outer))  # 保留符号
    return td


# ───────────────────────── FLARE / STRAY-LIGHT (§2) ─────────────────────────

def flare_attenuation(
    img: Frame,
    inner_ratio: float = 0.10,
    outer_band: tuple[float, float] = (0.40, 0.45),
) -> float:
    """
    Flare Attenuation (proxy) – 中心/外环亮度比（越小越好）。
    """
    g = _gray(img)
    h, w = g.shape
    diag = float(np.hypot(h, w))
    yy, xx = np.ogrid[:h, :w]
    r = np.hypot(xx - w * 0.5, yy - h * 0.5)

    inner_mask = (r < inner_ratio * diag)
    outer_mask = (r > outer_band[0] * diag) & (r < outer_band[1] * diag)
    if not inner_mask.any():
        inner_mask = np.ones_like(g, bool)
    if not outer_mask.any():
        outer_mask = ~inner_mask

    inner = float(g[inner_mask].mean())
    outer = float(g[outer_mask].mean())
    return inner / (outer + EPS)


# ───────────────────────── TEXTURE (§4) ─────────────────────────

def gradient_entropy(img: Frame) -> float:
    """梯度幅值直方图熵；越大纹理越丰富。"""
    g = _gray(img)
    gx = cv2.Sobel(g, cv2.CV_32F, 1, 0, ksize=3)
    gy = cv2.Sobel(g, cv2.CV_32F, 0, 1, ksize=3)
    mag = np.hypot(gx, gy)
    if mag.max() <= 0:
        return 0.0
    hist, _ = np.histogram(mag, 256, (0, mag.max() + EPS), density=True)
    hist = hist.astype(np.float64) + EPS
    return float(-(hist * np.log2(hist)).sum())


def blur_extent(img: Frame) -> float:
    """Gaussian(21,σ≈3) 的低/高频能量比（越小越清晰）。"""
    g = _gray(img)
    low = cv2.GaussianBlur(g, (21, 21), 0)
    high = g - low
    num = float(np.abs(low).mean())
    den = float(np.abs(high).mean())
    return _safe_div(num, den)


def chroma_aberration(img: Frame, on_empty: str = "nan") -> float:
    """边缘像素 R-B 差 σ；越大色散越重。"""
    b, g, r = cv2.split(img.astype(np.float32))
    edges = cv2.Canny(_gray(img).astype(np.uint8), 50, 150)
    ys, xs = np.where(edges > 0)
    if ys.size == 0:
        return (float('nan') if on_empty == "nan" else 0.0)
    diff = r[ys, xs] - b[ys, xs]
    diff = diff[np.isfinite(diff)]
    if diff.size == 0:
        return (float('nan') if on_empty == "nan" else 0.0)
    return float(np.std(diff))


# ─────────────────── FLICKER / TEMPORAL (§7) ───────────────────

def flicker_modulation_power(video: Video, fps: float = 10.0) -> float:
    """能量在 PWM 峰附近的比例；越大越易见闪烁。"""
    frames = _ensure_frames(video)
    if len(frames) < 3 or fps <= 0:
        return 0.0
    luma = np.array([_gray(f).mean() for f in frames], dtype=np.float32)
    spec = np.abs(fft.rfft(luma)) ** 2
    freqs = fft.rfftfreq(len(luma), 1.0 / fps)
    if len(freqs) < 3:
        return 0.0
    peak_idx = np.argmax(spec[1:]) + 1
    peak = freqs[peak_idx]
    bw = max(0.5, 0.2 * peak)
    band = (freqs > peak - bw) & (freqs < peak + bw)
    return float(spec[band].sum() / (spec.sum() + EPS))


def fmp_alias(video: Video, fps: float = 10.0, min_peak: float = 0.2) -> float:
    """低帧率 FMP 代理：忽略 <min_peak Hz 的慢漂移。"""
    frames = _ensure_frames(video)
    if len(frames) < 3 or fps <= 0:
        return 0.0
    lum = np.array([_gray(f).mean() for f in frames], dtype=np.float32)
    spec = np.abs(fft.rfft(lum)) ** 2
    freqs = fft.rfftfreq(len(lum), 1.0 / fps)
    if len(freqs) < 3:
        return 0.0
    peak_idx = np.argmax(spec[1:]) + 1
    peak_f = freqs[peak_idx]
    if peak_f < min_peak:
        return 0.0
    bw = max(0.5, 0.2 * peak_f)
    band = (freqs > peak_f - bw) & (freqs < peak_f + bw)
    return float(spec[band].sum() / (spec.sum() + EPS))


def modulation_mitigation_probability(video: Video) -> float:
    """全片亮度稳定度：|μ_t − μ| / μ < 5% 的帧占比。"""
    frames = _ensure_frames(video)
    if not frames:
        return 0.0
    means = np.array([_gray(f).mean() for f in frames], dtype=np.float32)
    dev = np.abs(means - means.mean()) / (means.mean() + EPS)
    return float((dev < 0.05).mean())


def mmp_alias(video: Video, fps: float = 10.0, band_hz: float = 0.5, thr: float = 0.05,
              win_sec: float = 3.0, hop_ratio: float = 0.5) -> float:
    """proxy-MMP：滑窗统计“alias 能量占比 < thr”的窗口比例。"""
    frames = _ensure_frames(video)
    T = len(frames)
    if T < 4 or fps <= 0:
        return 0.0

    lum = np.array([_gray(f).mean() for f in frames], dtype=np.float32)
    spec = np.abs(fft.rfft(lum)) ** 2
    freqs = fft.rfftfreq(T, 1.0 / fps)
    peak_idx = np.argmax(spec[1:]) + 1
    peak_f = freqs[peak_idx]
    if peak_f < 0.2:
        return 1.0

    win = max(8, int(round(win_sec * fps)))
    hop = max(1, int(round(win * hop_ratio)))
    if win > T:
        win = T
        hop = max(1, T // 2)

    band_lo, band_hi = peak_f - band_hz, peak_f + band_hz
    hits, total = 0, 0
    for s in range(0, T - win + 1, hop):
        x = lum[s:s + win]
        X = np.abs(fft.rfft(x)) ** 2
        F = fft.rfftfreq(len(x), 1.0 / fps)
        band = (F > band_lo) & (F < band_hi)
        A = X[band].sum() / (X.sum() + EPS)
        hits += float(A < thr)
        total += 1
    return float(hits / total) if total > 0 else 0.0


# ─────────────────── Convenience wrappers ───────────────────

def single_frame_metrics(img: Frame) -> Dict[str, float]:
    """Compute per-frame high-priority KPIs that need only one image."""
    fns = [
        frame_dynamic_range_proxy,
        mtf50, mtf10,
        contrast_transfer_accuracy,
        edge_rise_time,
        total_distortion,
        flare_attenuation,
        gradient_entropy,
        blur_extent,
        chroma_aberration,
    ]
    out: Dict[str, float] = {}
    for fn in fns:
        try:
            out[fn.__name__] = fn(img)
        except Exception:
            out[fn.__name__] = float('nan')
    return out


def video_metrics(frames: Video, fps: float = 30.0) -> Dict[str, float]:
    """Compute video-level KPIs (≥2 frames)."""
    frames_list = _ensure_frames(frames)
    if len(frames_list) < 2:
        return {}
    fns = [
        sequence_dynamic_range_proxy,
        lambda v: fmp_alias(v, fps=fps),
        lambda v: mmp_alias(v, fps=fps),
        temporal_exposure_jitter,
        lambda v: flicker_modulation_power(v, fps=fps),
        modulation_mitigation_probability,
    ]
    names = [
        "sequence_dynamic_range_proxy",
        "fmp_alias",
        "mmp_alias",
        "temporal_exposure_jitter",
        "flicker_modulation_power",
        "modulation_mitigation_probability",
    ]
    out: Dict[str, float] = {}
    for name, fn in zip(names, fns):
        try:
            out[name] = fn(frames_list)
        except Exception:
            out[name] = float('nan')
    return out
