"""
Saliency Visualizer: panels, fg/bg diagnostics, mosaics, suite
Works on outputs saved by src.saliency (npy maps, masks, tokens)
"""
__author__ = "XYZ"

import pdb
import argparse
import csv
import json
import os
from pathlib import Path
from typing import List, Tuple, Optional, Dict, Any

import cv2
import numpy as np

from .core._log_ import logger
log = logger(__file__)


# ---------------------------
# Patch helpers: per-cell patch canvases & combination
# ---------------------------
def _write_image_variants(out_path: Path, img_bgr: np.ndarray, save_rgb_variant: bool = False) -> None:
  """
  Write an image (expects BGR ordering for OpenCV). Optionally write an RGB-preserved
  variant with suffix '.rgb' before the extension (so e.g. foo.panel.jpg -> foo.panel.rgb.jpg).

  The RGB-preserved variant is simply channel-swapped copy of the BGR image. This
  gives you both the OpenCV-normal image (correct for typical viewers) and a second
  file that preserves an RGB ordering for downstream tools that expect RGB arrays.
  """
  try:
    _ensure_dir(out_path.parent)
    # Primary (BGR) write
    cv2.imwrite(str(out_path), img_bgr, [cv2.IMWRITE_JPEG_QUALITY, 95])
    log.info("Wrote image: %s", out_path)
  except Exception as e:
    log.warning("Failed to write image %s: %s", out_path, e)

  if save_rgb_variant:
    try:
      rgb = img_bgr[..., ::-1].copy()  # BGR -> RGB
      rgb_path = out_path.with_name(out_path.stem + ".rgb" + out_path.suffix)
      cv2.imwrite(str(rgb_path), rgb, [cv2.IMWRITE_JPEG_QUALITY, 95])
      log.info("Wrote RGB-variant image: %s", rgb_path)
    except Exception as e:
      log.warning("Failed to write RGB-variant %s: %s", out_path, e)


# ensure _combine_rgb_and_saliency returns BGR (was returning RGB)
def _combine_rgb_and_saliency(patches_np: np.ndarray, salpatches_np: np.ndarray):
  if patches_np is None or salpatches_np is None:
    return None
  p = np.asarray(patches_np).astype(np.float32)
  s = np.asarray(salpatches_np).astype(np.float32)
  # Ensure same N
  N = min(p.shape[0], s.shape[0])
  p = p[:N]
  s = s[:N]
  # Normalize sal per patch to 0..1
  saln = np.zeros_like(s, dtype=np.float32)
  for i in range(s.shape[0]):
    si = s[i]
    mn = float(np.nanmin(si))
    mx = float(np.nanmax(si))
    if mx > mn + 1e-8:
      saln[i] = (si - mn) / (mx - mn)
    else:
      saln[i] = np.zeros_like(si)
  saln = saln[..., None]  # NxHxWx1
  rgb = p.astype(np.float32) / 255.0
  comb = (rgb * saln) * 255.0
  comb_u8 = np.clip(comb, 0, 255).astype(np.uint8)
  # Convert to BGR before returning so all downstream expects BGR
  return comb_u8[..., ::-1].copy()


def _build_patch_canvas_rgb(patches_np: np.ndarray, grid_size: int, preserve_rgb: bool = True):
  """
  patches_np: (N, ph, pw, 3) uint8 (most likely RGB from extract_grid_patches)
  Arrange into canvas (grid_size * ph, grid_size * pw, 3).

  IMPORTANT:
  - When preserve_rgb=True this returns a canvas in **RGB** channel order (no channel swap).
  - When preserve_rgb=False this behaves like the previous helper and returns BGR (swapped)
    which is what OpenCV drawing functions expect.
  """
  if patches_np is None:
    return None
  patches_np = np.asarray(patches_np)
  if patches_np.size == 0:
    return None

  # Ensure at least 4D: (N, H, W, C)
  if patches_np.ndim == 3:
    # Could be single patch HxWx3 or NxHxW grayscale - try to coerce
    if patches_np.shape[-1] == 3:
      patches_np = patches_np[None, ...]
    else:
      # single grayscale patch HxW -> convert to RGB
      patches_np = patches_np[..., None]
      patches_np = np.repeat(patches_np, 3, axis=-1)[None, ...]

  # Ensure dtype uint8 and channels=3
  if patches_np.dtype != np.uint8:
    patches_np = np.clip(patches_np, 0, 255).astype(np.uint8)
  if patches_np.shape[-1] != 3:
    # expand channel dimension if necessary
    patches_np = np.repeat(patches_np[..., None], 3, axis=-1)

  N, ph, pw, c = patches_np.shape
  rows = grid_size
  cols = grid_size
  canvas = np.zeros((rows * ph, cols * pw, 3), dtype=np.uint8)

  # Auto-detect color ordering: if mean red >> mean blue then assume RGB.
  try:
    avg_r = float(patches_np[..., 0].mean())
    avg_b = float(patches_np[..., 2].mean())
    looks_rgb = avg_r > avg_b + 1.0  # small hysteresis
  except Exception:
    looks_rgb = True

  # Decide how to place tiles:
  if looks_rgb:
    patches_use = patches_np.copy()   # preserve RGB ordering
  else:
    patches_use = patches_np[..., ::-1].copy()  # convert BGR->RGB-like for consistency

  # If caller expects BGR (preserve_rgb==False) swap channels before placing
  if not preserve_rgb:
    patches_use = patches_use[..., ::-1].copy()  # RGB -> BGR

  idx = 0
  for r in range(rows):
    for cidx in range(cols):
      if idx < N:
        ph0 = r*ph; ph1 = (r+1)*ph
        pw0 = cidx*pw; pw1 = (cidx+1)*pw
        tile = patches_use[idx]
        # If patch shape mismatches target (rare), try resize with cv2
        if tile.shape[0] != ph or tile.shape[1] != pw:
          try:
            tile = cv2.resize(tile, (pw, ph), interpolation=cv2.INTER_AREA)
          except Exception:
            th, tw = tile.shape[:2]
            t = np.zeros((ph, pw, 3), dtype=np.uint8)
            t[:min(th, ph), :min(tw, pw)] = tile[:min(th, ph), :min(tw, pw)]
            tile = t
        canvas[ph0:ph1, pw0:pw1, :] = tile
      idx += 1
  return canvas


def _build_patch_canvas_gray(salpatches_np: np.ndarray, grid_size: int):
  """
  salpatches_np: (N, ph, pw) float or (N, ph, pw) in [0..1] or [0..255].
  Returns normalized float canvas in [0..1] single-channel (not uint8).
  """
  if salpatches_np is None:
    return None
  salpatches_np = np.asarray(salpatches_np)
  if salpatches_np.size == 0:
    return None
  if salpatches_np.ndim == 2:
    salpatches_np = salpatches_np[None, ...]
  N = salpatches_np.shape[0]
  ph = int(salpatches_np.shape[1])
  pw = int(salpatches_np.shape[2])
  rows = grid_size; cols = grid_size
  canvas = np.zeros((rows * ph, cols * pw), dtype=np.float32)
  idx = 0
  for r in range(rows):
    for cidx in range(cols):
      if idx < N:
        try:
          tile = salpatches_np[idx].astype(np.float32)
          # if tile shape mismatch, resize to (pw, ph)
          if tile.shape[0] != ph or tile.shape[1] != pw:
            tile = cv2.resize(tile, (pw, ph), interpolation=cv2.INTER_LINEAR)
          canvas[r*ph:(r+1)*ph, cidx*pw:(cidx+1)*pw] = tile
        except Exception:
          # skip on error
          pass
      idx += 1

  # Normalize global canvas to [0..1] for consistent display
  mn = float(np.nanmin(canvas))
  mx = float(np.nanmax(canvas))
  if mx > mn + 1e-8:
    canvas = (canvas - mn) / (mx - mn)
  else:
    canvas = np.clip(canvas, 0.0, 1.0)
  return canvas


def _apply_mask_to_patches(patches_np: np.ndarray, maskpatches_np: np.ndarray):
  if patches_np is None or maskpatches_np is None:
    return None
  p = np.asarray(patches_np).astype(np.float32)
  m = np.asarray(maskpatches_np).astype(np.float32)
  if m.max() > 1.1:
    m = m / 255.0
  if m.ndim == 2:
    m = m[None, ...]
  if p.ndim == 3:
    p = p[None, ...]
  # ensure same N
  N = min(p.shape[0], m.shape[0])
  p = p[:N]
  m = m[:N, ..., None]
  out = (p * m)  # still float 0..255
  out = np.clip(out, 0, 255).astype(np.uint8)
  return out


def _combine_rgb_and_saliency(patches_np: np.ndarray, salpatches_np: np.ndarray):
  if patches_np is None or salpatches_np is None:
    return None
  p = np.asarray(patches_np).astype(np.float32)
  s = np.asarray(salpatches_np).astype(np.float32)
  # Ensure same N
  N = min(p.shape[0], s.shape[0])
  p = p[:N]
  s = s[:N]
  # Normalize sal per patch to 0..1
  saln = np.zeros_like(s, dtype=np.float32)
  for i in range(s.shape[0]):
    si = s[i]
    mn = float(np.nanmin(si))
    mx = float(np.nanmax(si))
    if mx > mn + 1e-8:
      saln[i] = (si - mn) / (mx - mn)
    else:
      saln[i] = np.zeros_like(si)
  saln = saln[..., None]  # NxHxWx1
  rgb = p.astype(np.float32) / 255.0
  comb = (rgb * saln) * 255.0
  return np.clip(comb, 0, 255).astype(np.uint8)


def _to_uint8_rgb_from_gray(ary, preserve_rgb: bool = True):
  """
  Convert grayscale or float arrays to uint8 3-channel.
  If preserve_rgb=True it returns **RGB** ordering; otherwise returns BGR (legacy).
  """
  if ary is None:
    return None
  a = np.asarray(ary)
  if a.ndim == 3 and a.shape[2] == 1:
    a = a[..., 0]
  if a.ndim == 2:
    # normalize to 0..1 if float, else assume 0..255
    if np.issubdtype(a.dtype, np.floating):
      mn, mx = float(a.min()), float(a.max())
      if mx > mn + 1e-8:
        a_n = (a - mn) / (mx - mn)
      else:
        a_n = np.clip(a, 0.0, 1.0)
      a_u8 = (a_n * 255.0).round().astype(np.uint8)
    else:
      a_u8 = a.astype(np.uint8)
    rgb = np.stack([a_u8, a_u8, a_u8], axis=-1)
    if preserve_rgb:
      return rgb.copy()
    return rgb[..., ::-1].copy()  # BGR
  # if already HxWx3
  if a.ndim == 3 and a.shape[2] == 3:
    if np.issubdtype(a.dtype, np.floating):
      mn, mx = float(a.min()), float(a.max())
      if mx > mn + 1e-8:
        a_n = (a - mn) / (mx - mn)
      else:
        a_n = np.clip(a, 0.0, 1.0)
      a_u8 = (a_n * 255.0).round().astype(np.uint8)
    else:
      a_u8 = a.astype(np.uint8)
    if preserve_rgb:
      return a_u8.copy()  # assume input RGB
    return a_u8[..., ::-1].copy()  # convert RGB->BGR
  return None


def _array_stats(arr: np.ndarray) -> Dict[str, Any]:
  """Compute rounded statistics for saliency/token arrays."""
  arr = arr.astype(np.float32)
  stats = {
    "shape": list(arr.shape),
    "min": round(float(arr.min()), 4),
    "max": round(float(arr.max()), 4),
    "mean": round(float(arr.mean()), 4),
    "std": round(float(arr.std()), 4),
  }
  return stats

def _token_stats(tokens: np.ndarray, grid: str) -> Dict[str, Any]:
  """Aggregate per-grid token statistics and per-token summary."""
  if tokens.ndim != 2 or tokens.shape[1] < 4:
    return {"grid": grid, "error": "invalid token shape"}
  mean_vals = tokens[:, 0].astype(np.float32)
  max_vals = tokens[:, 1].astype(np.float32)
  com_y = tokens[:, 2].astype(np.float32)
  com_x = tokens[:, 3].astype(np.float32)

  def _s(x: np.ndarray) -> Dict[str, float]:
    return {
      "min": round(float(x.min()), 4),
      "max": round(float(x.max()), 4),
      "mean": round(float(x.mean()), 4),
      "std": round(float(x.std()), 4),
    }

  per_token = []
  for i in range(tokens.shape[0]):
    tv = tokens[i].astype(np.float32)
    per_token.append({
      "idx": i,
      "mean": round(float(tv[0]), 4),
      "max": round(float(tv[1]), 4),
      "com_y": round(float(tv[2]), 4),
      "com_x": round(float(tv[3]), 4),
    })

  stats = {
    "grid": grid,
    "shape": list(tokens.shape),
    "mean_stats": _s(mean_vals),
    "max_stats": _s(max_vals),
    "com_y_stats": _s(com_y),
    "com_x_stats": _s(com_x),
    "mean_avg": round(float(mean_vals.mean()), 4),
    "max_avg": round(float(max_vals.mean()), 4),
    "com_y_avg": round(float(com_y.mean()), 4),
    "com_x_avg": round(float(com_x.mean()), 4),
    "per_token": per_token,
  }
  return stats

def _map_centroid(map_arr: np.ndarray) -> Dict[str, float]:
  """Compute centroid (weighted) of a 2D saliency/map array normalized to [0,1]."""
  a = map_arr.astype(np.float64)
  if a.ndim == 3 and a.shape[2] in (1,):
    a = a[..., 0]
  if a.ndim != 2:
    return {"y": 0.0, "x": 0.0}
  a = np.maximum(a, 0.0)
  s = float(a.sum()) + 1e-12
  H, W = a.shape[:2]
  ys = np.arange(H, dtype=np.float64)
  xs = np.arange(W, dtype=np.float64)
  wy = float((a.sum(axis=1) @ ys) / s)
  wx = float((a.sum(axis=0) @ xs) / s)
  # normalize to [0,1]
  ny = round(float(wy / max(1, H - 1)), 4)
  nx = round(float(wx / max(1, W - 1)), 4)
  return {"y": ny, "x": nx}

def _map_stats(map_arr: np.ndarray) -> Dict[str, Any]:
  """Compute array-level statistics and centroid for a saliency/map numpy array."""
  arr = map_arr.astype(np.float32)
  stats = {
    "shape": list(arr.shape),
    "min": round(float(arr.min()), 4),
    "max": round(float(arr.max()), 4),
    "mean": round(float(arr.mean()), 4),
    "std": round(float(arr.std()), 4),
  }
  cent = _map_centroid(arr)
  stats["centroid"] = cent
  return stats

def _safe_load_npy(p: Path) -> Optional[np.ndarray]:
  """Load npy if exists, return None otherwise."""
  try:
    if p.exists():
      return np.load(str(p))
  except Exception as e:
    log.warning(f"Failed loading npy {p}: {e}")
  return None

# =========================
# Filesystem / misc utils
# =========================

def _ensure_dir(p: Path) -> None:
  p.mkdir(parents=True, exist_ok=True)

def _str2bool(v) -> bool:
  if isinstance(v, bool): return v
  return str(v).strip().lower() in {"1","true","t","yes","y","on"}

def _parse_stem_and_class(filename: str) -> Tuple[str, Optional[str]]:
  """
  Accepts: 'img_13144.c7.overlay.jpg' or 'img_13144.c7.norm.phi_relu.npy'
  Returns ('img_13144', 'c7') or ('img_13144', None)
  """
  stem = Path(filename).stem
  parts = stem.split(".")
  if len(parts) >= 2 and parts[1].startswith("c"):
    return parts[0], parts[1]
  return parts[0], None


# =========================
# Normalization & overlays
# =========================

def _normalize01(x: np.ndarray, mode: str = "phi_relu") -> np.ndarray:
  x = x.astype(np.float32)
  if mode == "none":
    return x
  if mode == "minmax":
    mn, mx = float(x.min()), float(x.max())
    return (x - mn) / (mx - mn + 1e-8)
  if mode == "zscore":
    mu, sd = float(x.mean()), float(x.std()) + 1e-8
    return (x - mu) / sd
  if mode == "phi_relu":
    xp = np.maximum(x, 0.0)
    mx = float(xp.max()) + 1e-8
    return xp / mx
  if mode == "pdf":
    xp = np.maximum(x, 0.0)
    s = float(xp.sum()) + 1e-8
    return xp / s
  raise ValueError(f"Unknown normalize mode: {mode}")

def _colorize_01(x01: np.ndarray) -> np.ndarray:
  x = (x01 * 255.0).clip(0, 255).astype(np.uint8)
  # OpenCV 4.x expects positional 2nd arg
  return cv2.applyColorMap(x, cv2.COLORMAP_JET)

def _overlay_heat_on_bgr(base_bgr: np.ndarray, heat01: np.ndarray, alpha: float = 0.55) -> np.ndarray:
  hm = _colorize_01(heat01)
  return cv2.addWeighted(base_bgr, 1.0, hm, alpha, 0.0)

def _load_base_bgr(group_dir: Path, companions: Dict[str, Path], fallback_hw=(224,224)) -> np.ndarray:
  if "overlay" in companions:
    img = cv2.imread(str(companions["overlay"]), 1)
    if img is not None:
      return img
  if "smap_raw" in companions:
    smap = np.load(str(companions["smap_raw"])).astype(np.float32)
    h, w = smap.shape[:2]
    return np.zeros((h, w, 3), dtype=np.uint8)
  h, w = fallback_hw[1], fallback_hw[0]
  return np.zeros((h, w, 3), dtype=np.uint8)

def _visualize_mask_on(base: np.ndarray, mask01: np.ndarray, show_border=True) -> np.ndarray:
  base = base.copy()
  green = np.zeros_like(base)
  g = (mask01 * 255.0).astype(np.uint8)
  green[...,1] = g
  out = cv2.addWeighted(base, 1.0, green, 0.35, 0.0)
  if show_border:
    cnts, _ = cv2.findContours((mask01 > 0.5).astype(np.uint8),
                               cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    cv2.drawContours(out, cnts, -1, (0,255,255), 1, cv2.LINE_AA)
  return out

def _overlay_map_on(base_bgr: np.ndarray, map_arr: np.ndarray, norm: str = "phi_relu", alpha: float = 0.55) -> np.ndarray:
  mm = _normalize01(map_arr, norm)
  mh, mw = mm.shape[:2]
  ih, iw = base_bgr.shape[:2]
  if (mh, mw) != (ih, iw):
    mm = cv2.resize(mm, (iw, ih), interpolation=cv2.INTER_LINEAR)
  return _overlay_heat_on_bgr(base_bgr, mm, alpha)


# =========================
# Companions lookup
# =========================

def _token_file(group_dir: Path, patt: str, norm: str, part: str, grid_label: str) -> Optional[Path]:
  p = group_dir / "tokens" / part / norm / grid_label / f"{patt}.tokens.npy"
  return p if p.exists() else None


def _find_companions(group_dir: Path, stem: str, cprefix: Optional[str], norm: str) -> Dict[str, Path]:
  """
  Extended companion lookup. Adds keys for: tokens_<part>_<grid> (existing),
  and patches files if present:
    tokens_<part>_<grid>      -> <stem>.<cX>.tokens.npy
    patches_<part>_<grid>     -> <stem>.<cX>.patches.npy
    salpatches_<part>_<grid>  -> <stem>.<cX>.salpatches.npy
    maskpatches_<part>_<grid> -> <stem>.<cX>.maskpatches.npy
    meta_<part>_<grid>        -> <stem>.<cX>.meta.json
  """
  out: Dict[str, Path] = {}
  patt = f"{stem}.{cprefix}" if cprefix else stem

  # same as before: overlay, smap_raw, whole_norm, mask_npy, part_norms
  ov = group_dir / f"{patt}.overlay.jpg"
  if ov.exists(): out["overlay"] = ov

  smap_raw = group_dir / f"{patt}.smap.npy"
  if smap_raw.exists(): out["smap_raw"] = smap_raw

  whole_norm = group_dir / f"{patt}.norm.{norm}.npy"
  if whole_norm.exists(): out["whole_norm"] = whole_norm

  mnp = group_dir / "masks" / f"{patt}.mask.npy"
  if mnp.exists(): out["mask_npy"] = mnp

  for part in ("whole", "fg", "bg"):
    pth = group_dir / part / "norm" / norm / f"{patt}.npy"
    if pth.exists():
      out[f"{part}_norm_part"] = pth

  # tokens & patches: look under tokens/<part>/<norm>/<grid>/
  for part in ("fg", "bg", "whole"):
    for grid in ("4x4", "8x8"):
      tok_p = group_dir / "tokens" / part / norm / grid / f"{patt}.tokens.npy"
      if tok_p.exists():
        out[f"tokens_{part}_{grid}"] = tok_p
      # patch files stored in same dir with same base name
      patches_p = group_dir / "tokens" / part / norm / grid / f"{patt}.patches.npy"
      salpatches_p = group_dir / "tokens" / part / norm / grid / f"{patt}.salpatches.npy"
      maskpatches_p = group_dir / "tokens" / part / norm / grid / f"{patt}.maskpatches.npy"
      meta_p = group_dir / "tokens" / part / norm / grid / f"{patt}.meta.json"
      if patches_p.exists(): out[f"patches_{part}_{grid}"] = patches_p
      if salpatches_p.exists(): out[f"salpatches_{part}_{grid}"] = salpatches_p
      if maskpatches_p.exists(): out[f"maskpatches_{part}_{grid}"] = maskpatches_p
      if meta_p.exists(): out[f"meta_{part}_{grid}"] = meta_p

      # fallback: older layout: tokens/<part>/<norm>/ files without grid folder
      tok_p2 = group_dir / "tokens" / part / norm / f"{patt}.tokens.npy"
      if tok_p2.exists() and f"tokens_{part}_{grid}" not in out:
        out[f"tokens_{part}_{grid}"] = tok_p2
      patches_p2 = group_dir / "tokens" / part / norm / f"{patt}.patches.npy"
      if patches_p2.exists() and f"patches_{part}_{grid}" not in out:
        out[f"patches_{part}_{grid}"] = patches_p2
      salpatches_p2 = group_dir / "tokens" / part / norm / f"{patt}.salpatches.npy"
      if salpatches_p2.exists() and f"salpatches_{part}_{grid}" not in out:
        out[f"salpatches_{part}_{grid}"] = salpatches_p2
      maskpatches_p2 = group_dir / "tokens" / part / norm / f"{patt}.maskpatches.npy"
      if maskpatches_p2.exists() and f"maskpatches_{part}_{grid}" not in out:
        out[f"maskpatches_{part}_{grid}"] = maskpatches_p2
      meta_p2 = group_dir / "tokens" / part / norm / f"{patt}.meta.json"
      if meta_p2.exists() and f"meta_{part}_{grid}" not in out:
        out[f"meta_{part}_{grid}"] = meta_p2

  return out


# =========================
# Token visualization
# =========================

def _draw_tokens_grid(
    base: np.ndarray,
    tokens: np.ndarray,         # (G,4) = [mean, max, com_y, com_x]
    grid_label: str,            # "HxW"
    mode: str = "mean",
    base_alpha: float = 0.15,
    line_thick: int = 2,
    show_com: bool = True,
    show_grid: bool = True
) -> np.ndarray:
  H, W = base.shape[:2]
  gh, gw = [int(x) for x in grid_label.lower().split("x")]
  ph, pw = max(1, H // gh), max(1, W // gw)

  fill = np.zeros_like(base)
  vals = tokens[:, 0] if mode == "mean" else tokens[:, 1]
  vmin, vmax = float(vals.min()), float(vals.max()) + 1e-8

  idx = 0
  for i in range(gh):
    for j in range(gw):
      y0, y1 = i*ph, min((i+1)*ph, H)
      x0, x1 = j*pw, min((j+1)*pw, W)
      v = float((vals[idx] - vmin) / (vmax - vmin))
      color = (int(255*(1.0 - v)), int(255*v), 0)  # green→red
      cv2.rectangle(fill, (x0, y0), (x1-1, y1-1), color, thickness=-1)

      if show_com:
        cyf, cxf = float(tokens[idx, 2]), float(tokens[idx, 3])
        if cyf >= 1.0 or cxf >= 1.0:  # legacy: pixel coords
          cyf = cyf / max(1.0, (y1 - y0))
          cxf = cxf / max(1.0, (x1 - x0))
        cy = int(round(y0 + cyf * max(1, (y1 - y0) - 1)))
        cx = int(round(x0 + cxf * max(1, (x1 - x0) - 1)))
        cv2.drawMarker(fill, (cx, cy), color=(255, 255, 255),
                       markerType=cv2.MARKER_DIAMOND, markerSize=8, thickness=2, line_type=cv2.LINE_AA)
      idx += 1

  out = cv2.addWeighted(base, base_alpha, fill, 1.0 - base_alpha, 0.0)

  if show_grid:
    for i in range(1, gh):
      y = i * ph
      cv2.line(out, (0, y), (W-1, y), (0, 0, 0), line_thick, cv2.LINE_AA)
    for j in range(1, gw):
      x = j * pw
      cv2.line(out, (x, 0), (x, H-1), (0, 0, 0), line_thick, cv2.LINE_AA)

  return out


# =========================
# Header & labels
# =========================

def _make_header_bar(labels: List[str], total_width: int, tile_count: int, bar_h: int = 32) -> np.ndarray:
  header = 255 * np.ones((bar_h, total_width, 3), dtype=np.uint8)
  if tile_count <= 0:
    return header
  col_w = max(1, total_width // tile_count)
  for i, lab in enumerate(labels[:tile_count]):
    x = i * col_w + 6
    y = int(bar_h * 0.72)
    cv2.putText(header, lab, (x, y), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,0), 2, cv2.LINE_AA)
  return header

def _attach_left_label(img: np.ndarray, text: str, pad: int = 64) -> np.ndarray:
  H = img.shape[0]
  left = 255 * np.ones((H, pad, 3), dtype=np.uint8)
  cv2.putText(left, text, (8, int(min(28, H*0.6))), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,0,0), 2, cv2.LINE_AA)
  return cv2.hconcat([left, img])


# =========================
# Row builder (no text on tiles)
# =========================

# ---------------------------
# Row builder (fixed normalization before hconcat)
# ---------------------------

def _build_panel_row(group_dir: Path, filename: str, norm: str) -> Tuple[Optional[np.ndarray], List[str], Dict[str, Any]]:
  """
  Builds a single horizontal row containing:
    [img, mask, whole_norm, fg_norm, bg_norm, token-grids..., patch-canvases...]
  Guarantees:
   - each tile is 3-channel uint8 BGR
   - all tiles are resized/padded to common (W,H) before hconcat
  """
  stem, cprefix = _parse_stem_and_class(filename)
  comps = _find_companions(group_dir, stem, cprefix, norm=norm)
  if not comps:
    log.warning(f"No companions found for {filename} in {group_dir}")
    return None, [], {}

  base = _load_base_bgr(group_dir, comps)
  if base is None:
    base = np.zeros((224, 224, 3), dtype=np.uint8)
  tiles: List[np.ndarray] = []
  labels: List[str] = []

  # 1) image (BGR already from _load_base_bgr)
  tiles.append(base.copy()); labels.append("img")

  # 2) mask visualization
  if "mask_npy" in comps:
    try:
      m = _safe_load_npy(comps["mask_npy"])
      if m is not None:
        m01 = (m > 0.5).astype(np.float32)
        tiles.append(_visualize_mask_on(base, m01, show_border=True)); labels.append("mask")
    except Exception as e:
      log.warning("Failed to load/visualize mask for %s: %s", filename, e)

  # 3) whole_norm
  if "whole_norm" in comps:
    try:
      arr = _safe_load_npy(comps["whole_norm"])
      if arr is not None:
        tiles.append(_overlay_map_on(base, arr, norm="none", alpha=0.55)); labels.append(f"whole@{norm}")
    except Exception as e:
      log.warning("Failed to load overlay whole_norm for %s: %s", filename, e)

  # 4) per-part norm overlays (fg/bg)
  for part in ("fg", "bg"):
    key = f"{part}_norm_part"
    if key in comps:
      try:
        arr = _safe_load_npy(comps[key])
        if arr is not None:
          tiles.append(_overlay_map_on(base, arr, norm="none", alpha=0.55)); labels.append(f"{part}@{norm}")
      except Exception as e:
        log.warning("Failed to load overlay %s for %s: %s", key, filename, e)

  # 5) tokens' grid visualizations
  for part in ("fg", "bg", "whole"):
    for grid in ("4x4", "8x8"):
      tkey = f"tokens_{part}_{grid}"
      if tkey in comps:
        try:
          toks = _safe_load_npy(comps[tkey])
          if toks is not None:
            tiles.append(_draw_tokens_grid(base, toks.astype(np.float32), grid, mode="mean"))
            labels.append(f"{part}-tokens@{grid}")
        except Exception as e:
          log.warning("Failed to draw tokens grid for %s (%s): %s", filename, tkey, e)


  if not tiles:
    return None, [], {}

  # Ensure every tile is 3-channel uint8 and compute unified sizes
  heights = [t.shape[0] for t in tiles if t is not None and t.ndim >= 2]
  widths = [t.shape[1] for t in tiles if t is not None and t.ndim >= 2]
  H = max(heights) if heights else base.shape[0]
  W = max(widths) if widths else base.shape[1]

  pads = []
  for t in tiles:
    if t is None:
      tt = np.ones((H, W, 3), dtype=np.uint8) * 255
    else:
      # if 2D gray -> convert
      if t.ndim == 2:
        t = cv2.cvtColor((np.clip(t, 0, 255).astype(np.uint8)), cv2.COLOR_GRAY2BGR)
      # if smaller/larger - resize to (W, H)
      try:
        tt = cv2.resize(t, (W, H), interpolation=cv2.INTER_AREA)
      except Exception:
        # fallback: pad/crop
        tt = np.ones((H, W, 3), dtype=np.uint8) * 255
        h0 = min(H, t.shape[0]); w0 = min(W, t.shape[1])
        tt[0:h0, 0:w0] = t[0:h0, 0:w0] if t.ndim == 3 else cv2.cvtColor(t[0:h0, 0:w0].astype(np.uint8), cv2.COLOR_GRAY2BGR)
      # ensure uint8 and 3 channels
      if tt.dtype != np.uint8:
        tt = np.clip(tt, 0, 255).astype(np.uint8)
      if tt.ndim == 2:
        tt = cv2.cvtColor(tt, cv2.COLOR_GRAY2BGR)
      if tt.shape[2] != 3:
        tt = tt[..., :3]
    pads.append(cv2.copyMakeBorder(tt, 2, 2, 2, 2, cv2.BORDER_CONSTANT, value=(255,255,255)))

  # Horizontal concat now safe
  row = cv2.hconcat(pads)
  meta = {"stem": stem, "class_prefix": cprefix, "labels": labels}
  return row, labels, meta



# =========================
# Public renderers
# =========================

def render_panel(group_dir: Path, filename: str, out_dir: Path, norm: str = "phi_relu",
                 save_rgb: bool = True) -> Optional[Path]:
  stem, cprefix = _parse_stem_and_class(filename)
  comps = _find_companions(group_dir, stem, cprefix, norm=norm)

  row, labels, meta = _build_panel_row(group_dir, filename, norm)
  if row is None:
    return None

  header = _make_header_bar(labels, row.shape[1], len(labels), bar_h=32)
  panel = cv2.vconcat([header, row])

  _ensure_dir(out_dir)
  stem, cprefix = meta.get("stem"), (meta.get("class_prefix") or "c?")
  out_img = out_dir / f"{stem}.{cprefix.lower()}.panel.jpg"

  # write both BGR (primary) and optional RGB-preserved variant
  _write_image_variants(out_img, panel, save_rgb_variant=save_rgb)

  # 🔑 pass comps into sidecars (update to include rgb variant if written)
  _write_panel_sidecars(out_img, labels, norm, meta, comps, save_rgb=save_rgb)
  log.info(f"Wrote panel: {out_img}")
  return out_img


def _write_panel_sidecars(out_img: Path, labels: List[str], norm: str,
                          meta: Dict[str, Any], comps: Dict[str, Path],
                          save_rgb: bool = True) -> None:
  base = out_img.with_suffix("")
  # ensure file strings
  files_map = {k: str(v) for k, v in (comps.items() if comps else {})}
  info = {
    "panel": out_img.name,
    "image": out_img.name,
    "normalization": norm,
    "columns": labels,
    "explain": {
      "mask": "Foreground mask (green) with yellow contour.",
      "maps": "CAM heatmaps (blue→red) overlaid on the image; 'whole', 'fg', and 'bg' correspond to entire map, foreground-only, and background-only.",
      "tokens": "Image partitioned into an H×W grid; each patch shows mean saliency; a white dot marks the patch Center of Mass (COM).",
      "tokens_order": "Patches are indexed row-major: top-left is 0, then left→right, top→bottom.",
      "token_vec": "Each token is 4 numbers: [mean, max, com_y, com_x] in patch coordinates."
    },
    "files": files_map,
    "meta": meta,
    "stats": {},
  }

  ## Add RGB variant path if requested
  if save_rgb:
    try:
      rgb_name = out_img.with_name(out_img.stem + ".rgb" + out_img.suffix).name
      info["panel_rgb"] = rgb_name
      info["files"]["panel_rgb"] = rgb_name
    except Exception:
      pass


  stats = {}
  # gather stats for available companions
  if "mask_npy" in comps:
    p = Path(comps["mask_npy"])
    arr = _safe_load_npy(p)
    if arr is not None:
      m = (arr > 0.5).astype(np.float32)
      ms = _map_stats(m)
      ms["file"] = p.name
      stats["mask"] = ms

  map_keys = [k for k in comps.keys() if k.endswith("_norm_part") or k == "whole_norm" or k == "smap_raw"]
  for k in map_keys:
    p = Path(comps[k])
    arr = _safe_load_npy(p)
    if arr is None:
      continue
    ms = _map_stats(arr)
    ms["file"] = p.name
    stats[k] = ms

  token_keys = [k for k in comps.keys() if k.startswith("tokens_")]
  for tk in token_keys:
    p = Path(comps[tk])
    arr = _safe_load_npy(p)
    if arr is None:
      continue
    parts = tk.split("_")
    grid_label = parts[-1] if len(parts) >= 3 else ""
    tstats = _token_stats(arr.astype(np.float32), grid_label)
    tstats["file"] = p.name
    stats[tk] = tstats

  info["stats"] = stats

  try:
    with open(base.with_suffix(".panel.json"), "w") as f:
      json.dump(info, f, indent=2)
  except TypeError:
    # last-resort: sanitize again and dump
    with open(base.with_suffix(".panel.json"), "w") as f:
      json.dump(info, f, indent=2)

  # also write a small markdown explanation as before
  md = [
    f"# Panel: {out_img.name}",
    "",
    f"- normalization: `{norm}`",
    f"- columns: {', '.join(labels)}",
    "",
    "## What am I seeing?",
    "- **Mask**: foreground (green) + yellow contour.",
    "- **Maps**: CAM heatmaps (blue→red) blended on the image for `whole`, `fg`, and `bg`.",
    "- **Tokens**: image split into an H×W grid; each patch is colored by **mean** saliency; a white dot marks the patch **Center of Mass** (COM).",
    "",
  ]
  with open(base.with_suffix(".panel.md"), "w") as f:
    f.write("\n".join(md))



def validate_fg_bg(group_dir: Path, filename: str, out_dir: Path, norm: str = "phi_relu",
                   save_rgb: bool = True) -> Optional[Path]:
  stem, cprefix = _parse_stem_and_class(filename)
  comps = _find_companions(group_dir, stem, cprefix, norm=norm)
  if not comps:
    log.warning(f"No companions found for {filename} in {group_dir}")
    return None

  base = _load_base_bgr(group_dir, comps)
  tiles: List[np.ndarray] = []

  if "whole_norm" in comps:
    arr = np.load(str(comps["whole_norm"])).astype(np.float32)
    tiles.append(_overlay_map_on(base, arr, norm="none", alpha=0.55))

  if "mask_npy" in comps:
    m = np.load(str(comps["mask_npy"])).astype(np.float32)
    m = (m > 0.5).astype(np.float32)
    tiles.append(_visualize_mask_on(base, m))

  for part in ("fg", "bg"):
    k = f"{part}_norm_part"
    if k in comps:
      arr = np.load(str(comps[k])).astype(np.float32)
      tiles.append(_overlay_map_on(base, arr, norm="none", alpha=0.55))

  if not tiles:
    return None

  while len(tiles) < 4:
    tiles.append(np.full_like(base, 255, dtype=np.uint8))
  H = max(t.shape[0] for t in tiles)
  W = max(t.shape[1] for t in tiles)
  tiles = [cv2.resize(t, (W, H)) for t in tiles]
  grid2x2 = cv2.vconcat([cv2.hconcat(tiles[:2]), cv2.hconcat(tiles[2:4])])

  _ensure_dir(out_dir)
  out_path = out_dir / f"{stem}.{(cprefix or 'c?').lower()}.fgbg.jpg"

  _write_image_variants(out_path, grid2x2, save_rgb_variant=save_rgb)

  log.info(f"Wrote fg/bg diagnostic: {out_path}")
  return out_path


def render_group_panels(group_dir: Path, out_dir: Path, norm: str = "phi_relu", limit: int = 0) -> None:
  overlays = sorted(group_dir.glob("*.overlay.jpg"))
  if limit > 0:
    overlays = overlays[:limit]
  _ensure_dir(out_dir)
  made_any = False
  for ov in overlays:
    p = render_panel(group_dir, ov.name, out_dir, norm=norm)
    if p is not None:
      made_any = True
  if made_any:
    log.info(f"Wrote panels to: {out_dir}")
  else:
    log.warning(f"No panels were created under: {group_dir}")

def _build_patch_only_row(group_dir: Path, filename: str, norm: str) -> Tuple[Optional[np.ndarray], List[str], Dict[str, Any]]:
  """
  Build a row that contains *only* the patch-related canvases in a consistent order:
    for each (part in [fg,bg,whole]) x (grid in [4x4,8x8]):
      - RGB patches
      - saliency patches (grayscale colorized to 3ch)
      - masked RGB
      - combined RGB*sal
  Return (row, labels, meta)
  """
  stem, cprefix = _parse_stem_and_class(filename)
  comps = _find_companions(group_dir, stem, cprefix, norm=norm)
  if not comps:
    return None, [], {}

  tiles = []
  labels = []

  for part in ("fg", "bg", "whole"):
    for grid in ("4x4", "8x8"):
      pkey = f"patches_{part}_{grid}"
      spkey = f"salpatches_{part}_{grid}"
      mkey = f"maskpatches_{part}_{grid}"
      patches = _safe_load_npy(comps.get(pkey)) if pkey in comps else None
      salpatches = _safe_load_npy(comps.get(spkey)) if spkey in comps else None
      maskpatches = _safe_load_npy(comps.get(mkey)) if mkey in comps else None

      if patches is None and salpatches is None:
        continue

      # RGB
      if patches is not None:
        try:
          canv = _build_patch_canvas_rgb(patches, int(grid.split("x")[0]), preserve_rgb=True)
          if canv is not None:
            tiles.append(canv); labels.append(f"{part}-patches@{grid}")
        except Exception as e:
          log.warning("patch-only: failed RGB canvas %s: %s", pkey, e)

      # sal
      if salpatches is not None:
        try:
          canv = _to_uint8_rgb_from_gray(_build_patch_canvas_gray(salpatches, int(grid.split("x")[0])), preserve_rgb=True)
          if canv is not None:
            tiles.append(canv); labels.append(f"{part}-sal@{grid}")
        except Exception as e:
          log.warning("patch-only: failed sal canvas %s: %s", spkey, e)

      # masked
      if patches is not None and maskpatches is not None:
        try:
          masked = _apply_mask_to_patches(patches, maskpatches)
          if masked is not None:
            canv = _build_patch_canvas_rgb(masked, int(grid.split("x")[0]), preserve_rgb=True)
            tiles.append(canv); labels.append(f"{part}-masked@{grid}")
        except Exception as e:
          log.warning("patch-only: failed masked canvas %s: %s", mkey, e)

      # combined
      if patches is not None and salpatches is not None:
        try:
          comb = _combine_rgb_and_saliency(patches, salpatches)
          if comb is not None:
            canv = _build_patch_canvas_rgb(comb, int(grid.split("x")[0]), preserve_rgb=True)
            tiles.append(canv); labels.append(f"{part}-comb@{grid}")
        except Exception as e:
          log.warning("patch-only: failed combined canvas %s: %s", pkey, e)

  if not tiles:
    return None, [], {}

  # Normalize into same H/W and hconcat
  heights = [t.shape[0] for t in tiles]
  widths = [t.shape[1] for t in tiles]
  H = max(heights)
  W = max(widths)
  pads = []
  for t in tiles:
    if t is None:
      tt = np.ones((H, W, 3), dtype=np.uint8) * 255
    else:
      try:
        tt = cv2.resize(t, (W, H), interpolation=cv2.INTER_AREA)
      except Exception:
        tt = np.ones((H, W, 3), dtype=np.uint8) * 255
        h0 = min(H, t.shape[0]); w0 = min(W, t.shape[1])
        tt[0:h0, 0:w0] = t[0:h0, 0:w0]
      if tt.dtype != np.uint8:
        tt = np.clip(tt, 0, 255).astype(np.uint8)
    pads.append(cv2.copyMakeBorder(tt, 2, 2, 2, 2, cv2.BORDER_CONSTANT, value=(255,255,255)))

  row = cv2.hconcat(pads)
  meta = {"stem": stem, "class_prefix": cprefix, "labels": labels}
  return row, labels, meta


# =========================
# Mosaics
# =========================

def build_class_mosaic(class_dir: Path, out_dir: Path, norm: str = "phi_relu",
                       per_class_k: int = 8) -> Optional[Path]:
  overlays = sorted(class_dir.glob("*.overlay.jpg"))
  if not overlays:
    log.warning(f"No overlays in class dir: {class_dir}")
    return None
  overlays = overlays[:per_class_k]

  rows: List[np.ndarray] = []
  labels: List[str] = []

  for ov in overlays:
    row, row_labels, _ = _build_panel_row(class_dir, ov.name, norm=norm)
    if row is not None:
      rows.append(row)
      labels = row_labels

  if not rows:
    log.warning(f"No rows rendered for {class_dir}")
    return None

  W = max(r.shape[1] for r in rows)
  rows_resized = [cv2.copyMakeBorder(cv2.resize(r, (W, r.shape[0])), 4, 4, 4, 4,
                             cv2.BORDER_CONSTANT, value=(255, 255, 255)) for r in rows]
  body = cv2.vconcat(rows_resized)
  header = _make_header_bar(labels, body.shape[1], len(labels), bar_h=36)
  mosaic = cv2.vconcat([header, body])
  mosaic = _attach_left_label(mosaic, class_dir.name)

  _ensure_dir(out_dir)
  out_path = out_dir / f"{class_dir.name}.mosaic.jpg"
  cv2.imwrite(str(out_path), mosaic, [cv2.IMWRITE_JPEG_QUALITY, 95])
  log.info(f"Wrote class mosaic: {out_path}")

  # additionally build patches-only mosaic for this class
  patch_rows = []
  patch_labels = []
  for ov in overlays:
    prow, plabels, _ = _build_patch_only_row(class_dir, ov.name, norm=norm)
    if prow is not None:
      patch_rows.append(prow)
      patch_labels = plabels
  if patch_rows:
    Wp = max(r.shape[1] for r in patch_rows)
    patch_rows = [cv2.copyMakeBorder(cv2.resize(r, (Wp, r.shape[0])), 4, 4, 4, 4,
                                     cv2.BORDER_CONSTANT, value=(255,255,255)) for r in patch_rows]
    pbody = cv2.vconcat(patch_rows)
    pheader = _make_header_bar(patch_labels, pbody.shape[1], len(patch_labels), bar_h=36)
    pmosaic = cv2.vconcat([pheader, pbody])
    pmosaic = _attach_left_label(pmosaic, class_dir.name)

    # out_path2 = out_dir / f"{class_dir.name}.patches.mosaic.jpg"
    # cv2.imwrite(str(out_path2), pmosaic, [cv2.IMWRITE_JPEG_QUALITY, 95])
    # log.info(f"Wrote class patches-mosaic: {out_path2}")


    # pmosaic is RGB-ordered (by _build_patch_only_row). Convert to BGR for OpenCV write,
    # and write an RGB-preserved variant too.
    pmosaic_bgr = pmosaic[..., ::-1].copy()
    out_path2 = out_dir / f"{class_dir.name}.patches.mosaic.jpg"
    _write_image_variants(out_path2, pmosaic_bgr, save_rgb_variant=True)
    log.info(f"Wrote class patches-mosaic: {out_path2} (+ RGB variant)")

  return out_path


def build_overall_mosaic(group_dir: Path, out_dir: Path, norm: str = "phi_relu", limit: int = 0) -> Optional[Path]:
  overlays = sorted(group_dir.glob("*.overlay.jpg"))
  if limit > 0:
    overlays = overlays[:limit]
  if not overlays:
    log.warning(f"No overlays in {group_dir}")

  rows: List[np.ndarray] = []
  labels: List[str] = []
  lefts: List[str] = []

  for ov in overlays:
    row, row_labels, meta = _build_panel_row(group_dir, ov.name, norm=norm)
    if row is not None:
      rows.append(row)
      labels = row_labels
      lefts.append(meta.get("class_prefix", "c?"))

  if not rows:
    return None

  W = max(r.shape[1] for r in rows)
  rows = [cv2.copyMakeBorder(cv2.resize(r, (W, r.shape[0])), 4, 4, 4, 4,
                             cv2.BORDER_CONSTANT, value=(255, 255, 255)) for r in rows]
  rows = [_attach_left_label(r, lefts[i]) for i, r in enumerate(rows)]
  body = cv2.vconcat(rows)

  header_labels = labels if labels else ["img", "mask", f"whole@{norm}"]
  header = _make_header_bar(header_labels, body.shape[1], len(header_labels), bar_h=36)
  mosaic = cv2.vconcat([header, body])

  _ensure_dir(out_dir)
  out_path = out_dir / "overall.mosaic.jpg"
  cv2.imwrite(str(out_path), mosaic, [cv2.IMWRITE_JPEG_QUALITY, 95])
  log.info(f"Wrote overall mosaic: {out_path}")

  # patches-only mosaic
  patch_rows = []
  patch_lefts = []
  patch_labels = []
  for ov in overlays:
    prow, plabels, meta = _build_patch_only_row(group_dir, ov.name, norm=norm)
    if prow is not None:
      patch_rows.append(prow)
      patch_lefts.append(meta.get("class_prefix", "c?"))
      patch_labels = plabels

  if patch_rows:
    Wp = max(r.shape[1] for r in patch_rows)
    patch_rows = [cv2.copyMakeBorder(cv2.resize(r, (Wp, r.shape[0])), 4, 4, 4, 4,
                                     cv2.BORDER_CONSTANT, value=(255,255,255)) for r in patch_rows]
    patch_rows = [_attach_left_label(r, patch_lefts[i]) for i, r in enumerate(patch_rows)]
    pbody = cv2.vconcat(patch_rows)
    pheader = _make_header_bar(patch_labels, pbody.shape[1], len(patch_labels), bar_h=36)
    pmosaic = cv2.vconcat([pheader, pbody])
    out_path2 = out_dir / "overall.patches.mosaic.jpg"
    cv2.imwrite(str(out_path2), pmosaic, [cv2.IMWRITE_JPEG_QUALITY, 95])
    log.info(f"Wrote overall patches mosaic: {out_path2}")

  return out_path


# =========================
# Per-class panels (per-image)
# =========================

def render_per_class_panels(class_root: Path, out_root: Path, norm: str = "phi_relu") -> None:
  """Create per-image panels under saliency-analysis/per_class/<cK>/."""
  class_dirs = [d for d in sorted(class_root.iterdir()) if d.is_dir()]
  for cdir in class_dirs:
    out_dir = out_root / "per_class" / cdir.name
    _ensure_dir(out_dir)
    for ov in sorted(cdir.glob("*.overlay.jpg")):
      render_panel(cdir, ov.name, out_dir, norm=norm)


# =========================
# Indices export for downstream
# =========================

def _augment_tokens_row(row: Dict[str, str], arch: str, method: str) -> Dict[str, str]:
  group = row.get("group","")
  parts = group.split("/")
  correctness = parts[0] if parts else ""
  group_type = parts[1] if len(parts) > 1 else ""
  class_id = parts[2] if len(parts) > 2 else ""

  out = dict(row)
  out.update({
    "arch": arch,
    "method": method,
    "correctness": correctness,
    "group_type": group_type,          # overall | per_class
    "class_id": class_id,              # cK or ""
  })
  return out

def _clean_name_to_stem(name: str) -> str:
  """Normalize a filename (overlay/image/panel) into a canonical stem to search for npy files."""
  s = str(name).strip()
  # drop directory path
  s = Path(s).name
  # common suffixes produced by your pipeline
  for suf in (".overlay.jpg", ".overlay.png", ".overlay", ".panel.jpg",
              ".fgbg.jpg", ".mosaic.jpg", ".jpg", ".jpeg", ".png"):
    if s.lower().endswith(suf):
      s = s[: -len(suf)]
  # also drop trailing class prefix if it's like 'img_123.c7' -> keep 'img_123'
  parts = s.split(".")
  if len(parts) >= 2 and parts[1].startswith("c") and parts[1][1:].isdigit():
    s = parts[0]
  return s

def _infer_npy_path_from_row(row: Dict[str, str], sal_arch: Path, norm: str = "phi_relu", verbose: bool = False) -> Optional[Path]:
  """Heuristic lookup for .npy companion given a CSV row.

  Uses CSV columns first (stem, class_prefix, part, norm, grid) to build exact candidate
  filenames/paths that match this repo's output layout. Falls back to limited rglob only if needed.
  """
  tried: List[str] = []

  # prefer explicit stem/class_prefix fields if present
  stem = (row.get("stem") or row.get("file") or row.get("filename") or "").strip()
  class_prefix = (row.get("class_prefix") or row.get("class") or "").strip()
  part = (row.get("part") or "").strip()
  grid = (row.get("grid") or "").strip()
  row_norm = (row.get("norm") or norm or "").strip()

  # normalize: if full filename passed, strip extensions to get stem
  def _strip_common(s: str) -> str:
    s = Path(s).name
    for suf in (".overlay.jpg", ".overlay.png", ".panel.jpg", ".fgbg.jpg",
                ".mosaic.jpg", ".jpg", ".jpeg", ".png"):
      if s.lower().endswith(suf):
        s = s[: -len(suf)]
    return s

  stem = _strip_common(stem)

  stems = []
  if stem:
    stems.append(stem)
    if class_prefix:
      stems.append(f"{stem}.{class_prefix}")
  # also collect other possible candidate names from other CSV fields
  for k in ("file", "filename", "image", "overlay", "tokens", "token_file", "map", "smap"):
    v = (row.get(k) or "").strip()
    if not v:
      continue
    vv = _strip_common(v)
    if vv not in stems:
      stems.append(vv)

  # 1) Try direct fields that might already be .npy paths
  keys_to_try = ["npy", "tokens_file", "map_file", "file", "filename", "path"]
  for k in keys_to_try:
    v = (row.get(k) or "").strip()
    if not v:
      continue
    tried.append(f"direct:{k}='{v}'")
    p = Path(v)
    if p.suffix.lower() == ".npy":
      if p.is_absolute() and p.exists():
        if verbose: log.debug(f"infer: direct absolute npy for {k}: {p}")
        return p
      cand = sal_arch / p
      if cand.exists():
        if verbose: log.debug(f"infer: direct relative npy for {k}: {cand}")
        return cand

  # 2) For each stem candidate try exact companion filenames (most likely locations)
  common_suffixes = [
    ".smap.npy",
    f".norm.{row_norm}.npy" if row_norm else f".norm.{norm}.npy",
    ".npy",
    ".tokens.npy",
  ]
  for s in stems:
    for suf in common_suffixes:
      cand = sal_arch / f"{s}{suf}"
      tried.append(str(cand))
      if cand.exists():
        if verbose: log.debug(f"infer: matched companion {cand} for stem '{s}'")
        return cand

    # masks live under masks/ sometimes
    m_cand = sal_arch / "masks" / f"{s}.mask.npy"
    tried.append(str(m_cand))
    if m_cand.exists():
      if verbose: log.debug(f"infer: matched mask {m_cand} for stem '{s}'")
      return m_cand

    # whole norm maps often live at sal_arch/<stem>.norm.<norm>.npy (already tried above)
    # token files often live under tokens/<part>/<norm>/<grid>/
    if part and row_norm and grid:
      tok_cand = sal_arch / "tokens" / part / row_norm / grid / f"{s}.tokens.npy"
      tried.append(str(tok_cand))
      if tok_cand.exists():
        if verbose: log.debug(f"infer: matched tokens path {tok_cand} for stem '{s}'")
        return tok_cand
      # older layout: tokens/<part>/<norm>/<s>.tokens.npy (without grid)
      tok_cand2 = sal_arch / "tokens" / part / row_norm / f"{s}.tokens.npy"
      tried.append(str(tok_cand2))
      if tok_cand2.exists():
        if verbose: log.debug(f"infer: matched tokens path (no-grid) {tok_cand2} for stem '{s}'")
        return tok_cand2

  # 3) Look in tokens dir using pattern "<stem>*.tokens.npy" (limited rglob)
  for s in stems:
    pattern = f"*{s}*.tokens.npy"
    tried.append(f"rglob:tokens:{pattern}")
    for match in sal_arch.rglob(pattern):
      if "saliency-analysis" in str(match):
        continue
      if match.exists():
        if verbose: log.debug(f"infer: rglob tokens match {match} for stem '{s}'")
        return match

  # 4) Generic fallback rglob for <stem>*.npy (stop at first)
  for s in stems:
    pattern = f"*{s}*.npy"
    tried.append(f"rglob:generic:{pattern}")
    for match in sal_arch.rglob(pattern):
      if "saliency-analysis" in str(match):
        continue
      if match.exists():
        if verbose: log.debug(f"infer: rglob generic match {match} for stem '{s}'")
        return match

  if verbose:
    log.debug(f"infer: no npy found for row; tried: {tried}")
  return None


def _flatten_map_stats(p: Path, cache: Dict[str, Dict[str, Any]]) -> Dict[str, Any]:
  """Return flattened map stats for CSV (rounded), using cache keyed by path str.

  Output keys (all strings): shape, min, max, mean, std,
  centroid_y, centroid_x, count (empty for maps), npy_file, npy_relpath
  """
  key = str(p.resolve())
  if key in cache:
    return cache[key]
  arr = _safe_load_npy(p)
  out: Dict[str, Any] = {}
  if arr is None:
    cache[key] = out
    return out
  ms = _map_stats(arr)
  cent = ms.get("centroid", {})

  shape_list = ms.get("shape", [])
  out = {
    "shape": json.dumps(list(shape_list)),
    "min": ms.get("min", ""),
    "max": ms.get("max", ""),
    "mean": ms.get("mean", ""),
    "std": ms.get("std", ""),
    "centroid_y": cent.get("y", ""),
    "centroid_x": cent.get("x", ""),
    "count": "",                     # maps don't have 'count' in this context
    "npy_file": Path(p).name,
    "npy_relpath": "",
  }
  # relative path will be filled by caller if desired; keep field here
  # Ensure numeric rounding to 4 decimals and stringify
  for k, v in list(out.items()):
    if isinstance(v, float):
      out[k] = f"{round(float(v), 4):.4f}"
    elif isinstance(v, (int, np.integer)):
      out[k] = str(v)
  cache[key] = out
  return out


def _flatten_token_stats(p: Path, cache: Dict[str, Dict[str, Any]]) -> Dict[str, Any]:
  """Return flattened token stats for CSV (rounded), using cache keyed by path str.

  Output keys (all strings): shape, min, max, mean, std,
  centroid_y, centroid_x, count (number of tokens), npy_file, npy_relpath
  """
  key = str(p.resolve())
  if key in cache:
    return cache[key]
  arr = _safe_load_npy(p)
  out: Dict[str, Any] = {}
  if arr is None:
    cache[key] = out
    return out
  ts = _token_stats(arr.astype(np.float32), grid="")  # ts has mean/max/com stats

  shape_list = ts.get("shape", [])
  # Use aggregated stats from token arrays: mean_stats/min/max etc -> choose mean_stats as
  mean_stats = ts.get("mean_stats", {})
  max_stats = ts.get("max_stats", {})
  com_y_stats = ts.get("com_y_stats", {})
  com_x_stats = ts.get("com_x_stats", {})

  out = {
    "shape": json.dumps(list(shape_list)),
    "min": mean_stats.get("min", ""),
    "max": mean_stats.get("max", ""),
    "mean": mean_stats.get("mean", ""),
    "std": mean_stats.get("std", ""),
    "centroid_y": ts.get("com_y_avg", ""),
    "centroid_x": ts.get("com_x_avg", ""),
    "count": str(int(shape_list[0]) if shape_list else ""),
    "npy_file": Path(p).name,
    "npy_relpath": "",
  }
  # Ensure numeric rounding to 4 decimals and stringify
  for k, v in list(out.items()):
    if isinstance(v, float):
      out[k] = f"{round(float(v), 4):.4f}"
    elif isinstance(v, (int, np.integer)):
      out[k] = str(v)
  cache[key] = out
  return out


def export_indices_to_analysis(sal_root: Path, arch: str, method: str) -> None:
  """Copy + augment saliency CSV indices into saliency-analysis/<model_id>/<method>/index/.

  The code supports both old layout:
     <base_logs>/saliency/<arch>/<method>/...
  and new layout:
     <base_logs>/<model_id>/saliency/<method>/...
  It places saliency-analysis next to the model base folder (i.e. parallel to `saliency`).
  """
  sal_root = Path(sal_root)

  # Determine where the 'arch' lives. Two possibilities:
  #  A) old: sal_root/<arch>/<method>/...
  #  B) new: sal_root/<method>/...  (no arch layer; sal_root==<base_logs>/<model_id>/saliency)
  root_has_arch_method = (sal_root / arch / method).exists()
  if root_has_arch_method:
    sal_arch = sal_root / arch             # old layout: sal_root contains arch folders
    model_id_for_analysis = arch
  else:
    # new layout: sal_root is already model_base/saliency and methods live directly under it
    sal_arch = sal_root                      # use sal_root as search base for .npy companions
    # choose model id for analysis directory - parent folder name is model_id
    model_id_for_analysis = sal_root.parent.name if sal_root.parent.exists() else arch

  # analysis index dir should be placed under model_base/saliency-analysis/<model_id>/<method>/index
  anal_index = sal_root.parent / "saliency-analysis" / method / "index"
  _ensure_dir(anal_index)

  src_tokens = sal_arch / "saliency_tokens_index.csv"
  src_maps = sal_arch / "saliency_maps_index.csv"

  npy_cache: Dict[str, Dict[str, Any]] = {}

  def _write_csv(dst: Path, rows: List[Dict[str, str]]) -> None:
    if not rows:
      log.info(f"No rows to write to {dst}")
      return
    fieldnames = list(rows[0].keys())
    with open(dst, "w", newline="") as f:
      w = csv.DictWriter(f, fieldnames=fieldnames)
      w.writeheader()
      w.writerows(rows)
    log.info(f"Wrote augmented index CSV: {dst}")

  def _pick_first(row: Dict[str, str], keys: List[str]) -> str:
    for k in keys:
      v = (row.get(k) or "").strip()
      if v != "":
        return v
    return ""

  def _format_conf(v: str) -> str:
    if v is None or v == "":
      return ""
    try:
      fv = float(v)
      return f"{round(float(fv), 4):.4f}"
    except Exception:
      return v

  # --- Tokens ---
  if src_tokens.exists():
    with open(src_tokens, "r") as f:
      rdr = csv.DictReader(f)
      rows = [r for r in rdr if (r.get("method") or "").lower() == method.lower()]
    if rows:
      aug_rows: List[Dict[str, str]] = []
      for r in rows:
        rr = _augment_tokens_row(r, arch, method)

        # extract prediction/gt/confidence from common columns
        pred = _pick_first(r, ["pred", "prediction", "predicted", "pred_class", "pred_label",
                               "predicted_class", "top1", "label_pred"])
        gt = _pick_first(r, ["gt", "label", "target", "true_label", "ground_truth", "class"])
        conf_raw = _pick_first(r, ["conf", "confidence", "score", "prob", "probability", "softmax_top1"])
        conf = _format_conf(conf_raw)

        # ensure these keys exist on rr
        rr["pred"] = pred
        rr["gt"] = gt
        rr["confidence"] = conf
        rr["npy_file"] = ""
        rr["npy_relpath"] = ""
        rr["npy_abspath"] = ""

        npy_path = _infer_npy_path_from_row(r, sal_arch)
        if npy_path:
          # get flattened token stats (unified keys)
          tstats = _flatten_token_stats(npy_path, npy_cache)
          # fill npy metadata
          try:
            rel = str(npy_path.relative_to(sal_arch))
          except Exception:
            rel = str(npy_path)
          abspath = str(npy_path.resolve())
          tstats["npy_relpath"] = rel
          tstats["npy_file"] = npy_path.name
          # for backward compatibility ensure npy_abspath exists
          tstats["npy_abspath"] = abspath

          # merge flattened stats into row (stringified)
          for k, v in tstats.items():
            rr[k] = v
        else:
          # fill unified empty keys for consistent schema
          rr.update({
            "shape": "",
            "min": "",
            "max": "",
            "mean": "",
            "std": "",
            "centroid_y": "",
            "centroid_x": "",
            "count": "",
            "npy_file": "",
            "npy_relpath": "",
            "npy_abspath": "",
          })
        aug_rows.append(rr)

      out_tokens = anal_index / "tokens.index.csv"
      _write_csv(out_tokens, aug_rows)

      # split by part/grid for convenience (regenerate from augmented rows)
      by_pg: Dict[Tuple[str, str], List[Dict[str, str]]] = {}
      for r in aug_rows:
        by_pg.setdefault((r.get("part", ""), r.get("grid", "")), []).append(r)
      for (part, grid), rlist in by_pg.items():
        name = f"tokens.part-{part}.grid-{grid}.csv" if grid else f"tokens.part-{part}.csv"
        dst = anal_index / name
        _write_csv(dst, rlist)
    else:
      log.info(f"No token rows matched method='{method}' in {src_tokens}")
  else:
    log.info(f"Tokens source not found: {src_tokens}")

  # --- Maps ---
  if src_maps.exists():
    with open(src_maps, "r") as f:
      rdr = csv.DictReader(f)
      rows = [r for r in rdr if (r.get("method") or "").lower() == method.lower()]
    if rows:
      aug_maps: List[Dict[str, str]] = []
      for r in rows:
        group = r.get("group", "")
        parts = group.split("/")
        correctness = parts[0] if parts else ""
        group_type = parts[1] if len(parts) > 1 else ""
        class_id = parts[2] if len(parts) > 2 else ""
        rr = dict(r)
        rr.update({
          "arch": arch,
          "method": method,
          "correctness": correctness,
          "group_type": group_type,
          "class_id": class_id,
        })

        # extract prediction/gt/confidence from common columns
        pred = _pick_first(r, ["pred", "prediction", "predicted", "pred_class", "pred_label",
                               "predicted_class", "top1", "label_pred"])
        gt = _pick_first(r, ["gt", "label", "target", "true_label", "ground_truth", "class"])
        conf_raw = _pick_first(r, ["conf", "confidence", "score", "prob", "probability", "softmax_top1"])
        conf = _format_conf(conf_raw)

        rr["pred"] = pred
        rr["gt"] = gt
        rr["confidence"] = conf
        rr["npy_file"] = ""
        rr["npy_relpath"] = ""
        rr["npy_abspath"] = ""

        npy_path = _infer_npy_path_from_row(r, sal_arch)
        if npy_path:
          mstats = _flatten_map_stats(npy_path, npy_cache)
          try:
            rel = str(npy_path.relative_to(sal_arch))
          except Exception:
            rel = str(npy_path)
          abspath = str(npy_path.resolve())
          mstats["npy_relpath"] = rel
          mstats["npy_file"] = npy_path.name
          mstats["npy_abspath"] = abspath
          for k, v in mstats.items():
            rr[k] = v
        else:
          rr.update({
            "shape": "",
            "min": "",
            "max": "",
            "mean": "",
            "std": "",
            "centroid_y": "",
            "centroid_x": "",
            "count": "",
            "npy_file": "",
            "npy_relpath": "",
            "npy_abspath": "",
          })
        aug_maps.append(rr)
      out_maps = anal_index / "maps.index.csv"
      _write_csv(out_maps, aug_maps)
    else:
      log.info(f"No map rows matched method='{method}' in {src_maps}")
  else:
    log.info(f"Maps source not found: {src_maps}")


# =========================
# Suite (single combo)
# =========================

def run_viz_suite_from_context(context: Dict[str, Any], arch: Optional[str] = None,
                               method: Optional[str] = None, correctness: str = "both",
                               norm: str = "phi_relu", k: int = 8, limit_overall: int = 0) -> Dict[str, Any]:
  """
  SAGE-friendly wrapper: execute visualizer suite using values present in `context`.

  Behavior:
  - If `method` is None, empty, or "all" -> run the viz suite for ALL methods
    found under the saliency directory for the model (i.e. sal_root/<arch>/<method> or sal_root/<method>).
  - If `method` is a comma-separated string, run the viz suite for each listed method.
  - If `arch` is None, uses context['model_id'] when available (or falls back to arch param).
  - Uses context['out_root'] or context['saliency_dir'] or context['to_path'] as sal_root (the saliency dir).
  Returns the unmodified context (so SAGE execute_function can continue).
  """
  sal_root = None
  if isinstance(context, dict):
    sal_root = context.get("out_root") or context.get("saliency_dir") or context.get("to_path")
  if sal_root is None:
    raise ValueError("Context does not contain 'out_root' / 'saliency_dir' / 'to_path' to locate saliency files")

  sal_root = Path(sal_root)
  # Determine model_id / arch
  model_id = arch or context.get("model_id") or (sal_root.parent.name if sal_root.parent.exists() else "")

  # Normalize method param: None/empty -> treat as "all"
  method_param = (method or "").strip()
  if method_param == "" or method_param.lower() == "all":
    # discover methods under sal_root; consider both old/new layouts
    methods = []
    # old layout: sal_root/<arch>/<method>
    if (sal_root / model_id).exists():
      candidate_dir = sal_root / model_id
    else:
      # new layout: sal_root is model_base/saliency -> methods are direct children under sal_root
      candidate_dir = sal_root

    if candidate_dir.exists():
      # collect only directories that look like methods (exclude files)
      methods = sorted([d.name for d in candidate_dir.iterdir() if d.is_dir()])
      log.info(f"Discovered methods for model_id='{model_id}' under '{candidate_dir}': {methods}")
    else:
      log.warning(f"Couldn't discover methods for model_id='{model_id}': candidate_dir '{candidate_dir}' does not exist. Falling back to single call with method='all'")
      methods = ["all"]
  else:
    # user provided one or more method names; allow comma-separated list
    methods = [m.strip() for m in method_param.split(",") if m.strip()]
    if not methods:
      methods = ["all"]

  # Run the viz suite for each method discovered / requested
  for meth in methods:
    corr = correctness
    try:
      # use run_viz_suite which already supports 'all' and wildcards
      run_viz_suite(sal_root, model_id, meth, corr, norm=norm, k=k, limit_overall=limit_overall)
    except Exception as e:
      log.error(f"Visualization suite failed for method='{meth}': {e}")
      # do not re-raise; return original context so SAGE continues
  return context



def run_viz_suite_one(sal_root: Path, arch: str, method: str, correctness: str,
                      norm: str = "phi_relu", k: int = 8, limit_overall: int = 0) -> None:
  sal_root = Path(sal_root)

  # Resolve layout: old-style sal_root/<arch>/<method>/... OR new-style sal_root/<method>/...
  if (sal_root / arch / method).exists():
    arch_dir = sal_root / arch
    method_dir = arch_dir / method
  elif (sal_root / method).exists():
    # new layout: sal_root already points to model_base/<model_id>/saliency
    arch_dir = sal_root        # treat sal_root as the arch_dir equivalent
    method_dir = sal_root / method
  else:
    log.warning(f"[skip] method dir not found for arch='{arch}' under sal_root='{sal_root}'. Tried '{sal_root/arch/method}' and '{sal_root/method}'")
    return

  corr_dir = method_dir / correctness
  overall = corr_dir / "overall"
  class_root = corr_dir / "per_class"

  if not overall.exists():
    log.warning(f"[skip] overall group not found: {overall}")
    return

  # Analysis root parallel to 'saliency' (placed under model_base/saliency-analysis/<model_id>/...)
  # Determine model_id_for_analysis in same way export_indices uses:
  if (sal_root / arch / method).exists():
    model_id_for_analysis = arch
  else:
    model_id_for_analysis = sal_root.parent.name

  analysis_root = sal_root.parent / "saliency-analysis" / method / correctness
  out_overall   = analysis_root / "overall"
  out_mosaics   = analysis_root / "mosaics"

  _ensure_dir(analysis_root)
  _ensure_dir(out_overall)
  _ensure_dir(out_mosaics)

  # Per-image panels (overall + per_class) + mosaics
  render_group_panels(overall, out_overall, norm=norm, limit=limit_overall)
  build_overall_mosaic(overall, out_overall, norm=norm, limit=limit_overall)

  if class_root.exists():
    _ensure_dir(out_mosaics)
    render_per_class_panels(class_root, analysis_root, norm=norm)  # per-image panels
    for cdir in sorted([d for d in class_root.iterdir() if d.is_dir()]):
      build_class_mosaic(cdir, out_mosaics, norm=norm, per_class_k=k)
  else:
    log.warning(f"No per_class directory under: {corr_dir}")

  # Export indices for downstream
  export_indices_to_analysis(sal_root, arch, method)


# =========================
# Suite (arch/method wildcards)
# =========================

def _list_archs(sal_root: Path, arch: str) -> List[str]:
  if arch.lower() == "all":
    return sorted([d.name for d in Path(sal_root).iterdir() if d.is_dir()])
  return [arch]

def _list_methods(arch_dir: Path, method: str) -> List[str]:
  if method.lower() == "all":
    return sorted([d.name for d in arch_dir.iterdir() if d.is_dir()])
  return [method]

def _list_corr(correctness: str) -> List[str]:
  if correctness.lower() in ("both","all"):
    return ["correct","incorrect"]
  return [correctness]

def run_viz_suite(sal_root: Path, arch: str, method: str, correctness: str,
                  norm: str = "phi_relu", k: int = 8, limit_overall: int = 0) -> None:
  """
  Robustly handle both layouts:
    OLD:  <base_logs>/saliency/<arch>/<method>/...
    NEW:  <base_logs>/<model_id>/saliency/<method>/...
  Behavior:
    - If sal_root/<arch> exists, use that (old layout).
    - Otherwise treat sal_root itself as the arch_dir (new layout, methods are direct children).
  """
  sal_root = Path(sal_root)

  for a in _list_archs(sal_root, arch):
    # prefer old layout check first
    arch_dir_candidate = sal_root / a
    if arch_dir_candidate.exists() and arch_dir_candidate.is_dir():
      arch_dir = arch_dir_candidate
      layout = "old"
    else:
      # fallback: sal_root already points to model_base/<model_id>/saliency
      # so methods are direct children under sal_root
      arch_dir = sal_root
      layout = "new-fallback"

    if not arch_dir.exists() or not arch_dir.is_dir():
      log.warning(f"[skip] arch dir not found for arch='{a}'. Tried '{arch_dir_candidate}' and fallback '{sal_root}'.")
      continue

    # list methods under whichever arch_dir we determined
    for m in _list_methods(arch_dir, method):
      for c in _list_corr(correctness):
        log.info(f"[suite] layout={layout} arch={a} method={m} corr={c} (sal_root={sal_root})")
        run_viz_suite_one(sal_root, a, m, c, norm=norm, k=k, limit_overall=limit_overall)



def main(args: argparse.Namespace) -> None:
  if args.cmd == "panel":
    if not args.filename:
      ov = next(iter(sorted(args.group_dir.glob("*.overlay.jpg"))), None)
      if ov is None:
        log.error("No overlays in --group_dir; provide --filename")
        return
      args.filename = ov.name
    render_panel(args.group_dir, args.filename, args.out, norm=args.norm, save_rgb=args.save_rgb)

  elif args.cmd == "validate":
    validate_fg_bg(args.group_dir, args.filename, args.out, norm=args.norm, save_rgb=args.save_rgb)

  elif args.cmd == "panels":
    render_group_panels(args.group_dir, args.out, norm=args.norm, limit=args.limit, save_rgb=args.save_rgb)

  elif args.cmd == "class_mosaic":
    class_dirs = [d for d in sorted(args.class_root.iterdir()) if d.is_dir()]
    _ensure_dir(args.out)
    for cdir in class_dirs:
      build_class_mosaic(cdir, args.out, norm=args.norm, per_class_k=args.k)

  elif args.cmd == "overall_mosaic":
    build_overall_mosaic(args.group_dir, args.out, norm=args.norm, limit=args.limit)

  elif args.cmd == "suite":
    run_viz_suite(args.sal_root, args.arch, args.method, args.correctness,
                  norm=args.norm, k=args.k, limit_overall=args.limit_overall)

# =========================
# CLI
# =========================

def parse_args() -> argparse.Namespace:
  p = argparse.ArgumentParser(description="Saliency visualization: panels, fg/bg, mosaics, suite")
  sub = p.add_subparsers(dest="cmd", required=True)

  p1 = sub.add_parser("panel", help="Render a panel for one image inside a group")
  p1.add_argument("--group_dir", type=Path, required=True)
  p1.add_argument("--filename", type=str, required=False)
  p1.add_argument("--norm", type=str, default="phi_relu")
  p1.add_argument("--out", type=Path, required=True)
  p1.add_argument("--save-rgb", action="store_true", help="Also write an RGB-preserved variant (*.rgb.jpg)")

  p2 = sub.add_parser("validate", help="2x2 diagnostic for fg/bg sanity")
  p2.add_argument("--group_dir", type=Path, required=True)
  p2.add_argument("--filename", type=str, required=True)
  p2.add_argument("--norm", type=str, default="phi_relu")
  p2.add_argument("--out", type=Path, required=True)
  p2.add_argument("--save-rgb", action="store_true", help="Also write an RGB-preserved variant (*.rgb.jpg)")

  p3 = sub.add_parser("panels", help="Batch: panel for each overlay in a group")
  p3.add_argument("--group_dir", type=Path, required=True)
  p3.add_argument("--norm", type=str, default="phi_relu")
  p3.add_argument("--limit", type=int, default=0)
  p3.add_argument("--out", type=Path, required=True)
  p3.add_argument("--save-rgb", action="store_true", help="Also write an RGB-preserved variant (*.rgb.jpg)")

  p4 = sub.add_parser("class_mosaic", help="Build mosaics for per_class/*")
  p4.add_argument("--class_root", type=Path, required=True)
  p4.add_argument("--norm", type=str, default="phi_relu")
  p4.add_argument("--k", type=int, default=8)
  p4.add_argument("--out", type=Path, required=True)
  p4.add_argument("--save-rgb", action="store_true", help="Also write an RGB-preserved variant (*.rgb.jpg)")

  p5 = sub.add_parser("overall_mosaic", help="Build overall mosaic for a group (adds row class labels)")
  p5.add_argument("--group_dir", type=Path, required=True)
  p5.add_argument("--norm", type=str, default="phi_relu")
  p5.add_argument("--limit", type=int, default=0)
  p5.add_argument("--out", type=Path, required=True)
  p5.add_argument("--save-rgb", action="store_true", help="Also write an RGB-preserved variant (*.rgb.jpg)")

  p6 = sub.add_parser("suite", help="Panels + overall mosaic + per-class mosaics + indices export")
  p6.add_argument("--sal_root", type=Path, required=True, help=".../saliency")
  p6.add_argument("--arch", type=str, required=True, help="'all' or a specific arch folder")
  p6.add_argument("--method", type=str, required=True, help="'all' or a specific method folder")
  p6.add_argument("--correctness", type=str, required=True, help="'correct'|'incorrect'|'both'")
  p6.add_argument("--norm", type=str, default="phi_relu")
  p6.add_argument("--k", type=int, default=8)
  p6.add_argument("--limit_overall", type=int, default=0)
  p6.add_argument("--save-rgb", action="store_true", help="Also write an RGB-preserved variant (*.rgb.jpg)")

  return p.parse_args()

def print_args(args: argparse.Namespace) -> None:
  log.info("Arguments:")
  for k, v in vars(args).items():
    log.info(f"{k}: {v}")


if __name__ == "__main__":
  _args = parse_args()
  print_args(_args)
  main(_args)
