"""
Saliency module: CAM generation → optional fg/bg → normalization → patch tokens
Standalone CLI + SAGE-orchestrator wrappers.
"""
__author__ = "XYZ"

import pdb
import argparse
import inspect
import json
import glob
import math
import os
import re
import csv


from datetime import datetime
from importlib import import_module
from pathlib import Path
from typing import List, Dict, Any, Tuple, Optional

import cv2
import numpy as np
import torch

from tqdm import tqdm

from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image

from .core._log_ import logger
log = logger(__file__)

from .utils.torchapi import loadmodel, unloadmodel, get_target_layers
from .utils import yoloapi  ## API stays array-based; file I/O lives here


## ---------------------------
## Filesystem / misc utils
## ---------------------------

def _str2bool(v):
  if isinstance(v, bool): return v
  s = str(v).strip().lower()
  return s in {"1","true","t","yes","y","on"}

def _parse_size(s: str) -> Tuple[int, int]:
  try:
    val = eval(s)
    if isinstance(val, tuple) and len(val) == 2:
      return int(val[0]), int(val[1])
  except Exception:
    pass
  raise ValueError("input_size must look like '(224,224)'")

def _parse_grids(glist: List[str]) -> List[List[int]]:
  out = []
  for g in glist:
    if isinstance(g, str) and "x" in g:
      a, b = g.split("x")
      out.append([int(a), int(b)])
    elif isinstance(g, (list, tuple)) and len(g) == 2:
      out.append([int(g[0]), int(g[1])])
  return out

def _ensure_dir(p: Path) -> None:
  p.mkdir(parents=True, exist_ok=True)

def _read_index_file(tsv_path: Path) -> List[Dict[str, Any]]:
  items = []
  if not tsv_path.exists():
    return items
  for ln in tsv_path.read_text().splitlines():
    parts = ln.strip().split("\t")
    if len(parts) < 5:
      continue
    idx, relpath, gt, pr, conf = parts[:5]
    items.append({"idx": int(idx), "path": relpath, "gt": int(gt), "pr": int(pr), "conf": float(conf)})
  return items

def _abs_path(dataset_root: str, rel_or_abs: str) -> str:
  if os.path.isabs(rel_or_abs):
    return rel_or_abs
  return os.path.normpath(os.path.join(dataset_root or "", rel_or_abs))

def _infer_class_prefix(img_path: str, meta: Dict[str, Any], class_prefix_regex: Optional[str]) -> str:
  c_re = re.compile(class_prefix_regex, re.IGNORECASE) if class_prefix_regex else None
  cls = meta.get("class_id")
  if c_re and isinstance(cls, str) and c_re.match(cls):
    return cls
  parent = Path(img_path).parent.name
  if c_re and c_re.match(parent):
    return parent
  stem = Path(img_path).stem
  tok = stem.split("_")[0]
  if c_re and c_re.match(tok):
    return tok
  return "cX"

def _autodetect_weights(net: str, roots: List[str], pattern: str) -> Optional[str]:
  """
  Search for weights across multiple roots using a glob pattern.
  pattern may include '{net}', e.g. '{net}-final.pth' or '**/{net}-final.pth'
  """
  pat = pattern.replace("{net}", net)
  seen = set()
  for r in roots:
    if not r: 
      continue
    root = os.path.abspath(os.path.expanduser(r))
    if root in seen: 
      continue
    seen.add(root)
    hits = glob.glob(os.path.join(root, "**", pat), recursive=True)
    if hits:
      hits.sort(key=lambda p: (-os.path.getmtime(p), -len(p)))  ## newest first, then shorter path
      log.info(f"Autodetected weights under '{root}': {hits[0]}")
      return hits[0]
  return None


## ---------------------------
## Target layer resolution
## ---------------------------

def _resolve_target_layers(model: torch.nn.Module, arch: str, sal_cfg: Dict[str, Any]) -> List[torch.nn.Module]:
  target_cfg = (sal_cfg.get("target") or {})
  selection = target_cfg.get("selection", "auto")
  per_arch = target_cfg.get("per_arch", {}) or {}
  sel = per_arch.get(arch, selection)

  if sel in ("auto", "last_conv"):
    return get_target_layers(model, arch)

  def _resolve_one(path_str: str) -> torch.nn.Module:
    cur = model
    for attr in path_str.split("."):
      if attr.endswith("]"):
        seg, idx = attr.split("[")
        idx = int(idx[:-1])
        cur = getattr(cur, seg)[idx]
      else:
        cur = getattr(cur, attr)
    return cur

  if isinstance(sel, str):
    return [_resolve_one(sel)]
  if isinstance(sel, list):
    return [_resolve_one(s) for s in sel]
  raise ValueError(f"Unsupported target selection: {sel}")


## ---------------------------
## Image I/O helpers
## ---------------------------

def _reshape_to_4d(t: torch.Tensor) -> torch.Tensor:
  """
  Ensure CAM backends always get (N,C,H,W).
  If a target layer outputs (N,C,L) or (N,L,C), try to square L; else use H=1.
  """
  if not isinstance(t, torch.Tensor):
    return t
  if t.ndim == 3:
    # Try (N,C,L)
    if t.shape[1] <= t.shape[2]:
      n, c, l = t.shape
      s = int(math.sqrt(l))
      if s * s == l:
        return t.view(n, c, s, s)
      return t.unsqueeze(2)  # (N,C,1,L)
    # Else (N,L,C) -> (N,C,L)
    n, l, c = t.shape
    s = int(math.sqrt(l))
    if s * s == l:
      return t.permute(0, 2, 1).contiguous().view(n, c, s, s)
    return t.permute(0, 2, 1).contiguous().unsqueeze(2)  # (N,C,1,L)
  return t  # Already 4D (or 5D for 3D CAM)


def _load_image_as_rgb01(path: str, input_size: Tuple[int, int]) -> Tuple[np.ndarray, torch.Tensor]:
  img = cv2.imread(path, 1)
  if img is None:
    raise FileNotFoundError(f"Image not found: {path}")
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  if input_size:
    img = cv2.resize(img, input_size[::-1], interpolation=cv2.INTER_LINEAR)
  rgb01 = (img.astype(np.float32) / 255.0).clip(0, 1)
  tensor = preprocess_image(rgb01)  ## NCHW
  return rgb01, tensor

def _overlay_on_image(rgb01: np.ndarray, gray_cam01: np.ndarray) -> np.ndarray:
  vis = show_cam_on_image(rgb01, gray_cam01, use_rgb=True)
  return cv2.cvtColor(vis, cv2.COLOR_RGB2BGR)


## ---------------------------
## Normalization
## ---------------------------

def _normalize_map(x: np.ndarray, mode: str) -> np.ndarray:
  x = x.astype(np.float32)
  if mode == "none":
    return x
  if mode == "minmax":
    mn, mx = np.min(x), np.max(x)
    return (x - mn) / (mx - mn + 1e-8)
  if mode == "zscore":
    mu, sd = np.mean(x), np.std(x) + 1e-8
    return (x - mu) / sd
  if mode == "phi_relu":
    x = np.maximum(x, 0.0)
    mx = np.max(x) + 1e-8
    return x / mx
  if mode == "pdf":
    x = np.maximum(x, 0.0)
    s = np.sum(x) + 1e-8
    return x / s
  raise ValueError(f"Unknown normalization mode: {mode}")


## ---------------------------
## Patch tokenization
## ---------------------------

def _grid_tokens(smap: np.ndarray, grid: Tuple[int, int]) -> Tuple[np.ndarray, Dict[str, Any]]:
  gh, gw = int(grid[0]), int(grid[1])
  H, W = smap.shape[-2], smap.shape[-1]
  ph, pw = max(1, H // gh), max(1, W // gw)
  Hc, Wc = ph * gh, pw * gw
  S = smap[:Hc, :Wc]

  tokens, meta = [], {"grid": [gh, gw], "patch": [ph, pw], "orig_hw": [H, W]}
  for i in range(gh):
    for j in range(gw):
      patch = S[i*ph:(i+1)*ph, j*pw:(j+1)*pw]
      s = patch.astype(np.float32)
      s_sum = float(np.sum(s)) + 1e-8

      yy, xx = np.mgrid[0:patch.shape[0], 0:patch.shape[1]].astype(np.float32)
      com_y_pix = float(np.sum(yy * s) / s_sum)           # in [0, ph)
      com_x_pix = float(np.sum(xx * s) / s_sum)           # in [0, pw)
      com_y = com_y_pix / max(1.0, patch.shape[0])        # normalize to [0,1)
      com_x = com_x_pix / max(1.0, patch.shape[1])        # normalize to [0,1)

      feat = [float(np.mean(s)), float(np.max(s)), com_y, com_x]  # COM is fractional
      tokens.append(feat)

  return np.array(tokens, dtype=np.float32), meta


def extract_grid_patches(
  img_abspath: str,
  smap: np.ndarray,
  mask: Optional[np.ndarray] = None,
  grid: Tuple[int, int] = (4, 4),
  patch_out_hw: Tuple[int, int] = (32, 32),
  save_dir: Optional[Path] = None,
  prefix: str = "img_x"
) -> Dict[str, str]:
  """Extract per-cell RGB/saliency/mask patches, per-cell stats, and save to disk when save_dir provided."""
  img_p = Path(img_abspath)
  if not img_p.exists():
    raise FileNotFoundError(f"Image not found: {img_abspath}")

  # Ensure smap is HxW
  if smap.ndim == 3 and smap.shape[0] in (1,) :
    smap = smap[0]
  H, W = int(smap.shape[0]), int(smap.shape[1])

  # Read original image and resize to match smap dims so crops align
  img_bgr = cv2.imread(str(img_p), cv2.IMREAD_COLOR)
  if img_bgr is None:
    raise FileNotFoundError(f"Unable to read image: {img_abspath}")
  img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
  if (img_rgb.shape[0], img_rgb.shape[1]) != (H, W):
    img_rgb = cv2.resize(img_rgb, (W, H), interpolation=cv2.INTER_AREA)

  if mask is not None:
    if mask.shape != (H, W):
      mask_rs = cv2.resize(mask.astype(np.float32), (W, H), interpolation=cv2.INTER_NEAREST)
      mask_arr = (mask_rs > 0.5).astype(np.uint8)
    else:
      mask_arr = (mask > 0.5).astype(np.uint8)
  else:
    mask_arr = None

  gh, gw = int(grid[0]), int(grid[1])
  ph_out, pw_out = int(patch_out_hw[0]), int(patch_out_hw[1])

  ys = np.linspace(0, H, gh + 1, dtype=int)
  xs = np.linspace(0, W, gw + 1, dtype=int)

  N = gh * gw
  stats = np.zeros((N, 4), dtype=np.float32)       # mean, max, com_y, com_x
  patches = np.zeros((N, ph_out, pw_out, 3), dtype=np.uint8)
  salpatches = np.zeros((N, ph_out, pw_out), dtype=np.float32)
  maskpatches = np.zeros((N, ph_out, pw_out), dtype=np.uint8) if mask_arr is not None else None
  bboxes = []

  idx = 0
  for ry in range(gh):
    for rx in range(gw):
      y0, y1 = int(ys[ry]), int(ys[ry + 1])
      x0, x1 = int(xs[rx]), int(xs[rx + 1])
      bboxes.append({"y0": y0, "y1": y1, "x0": x0, "x1": x1})

      crop_img = img_rgb[y0:y1, x0:x1, :].copy()
      crop_sal = smap[y0:y1, x0:x1].astype(np.float32)
      crop_mask = None
      if mask_arr is not None:
        crop_mask = mask_arr[y0:y1, x0:x1].astype(np.uint8)

      # stats
      mean_sal = float(np.mean(crop_sal))
      max_sal = float(np.max(crop_sal))
      s_sum = float(np.sum(crop_sal)) + 1e-12
      yy, xx = np.indices(crop_sal.shape).astype(np.float32)
      com_y_pix = float(np.sum(yy * crop_sal) / s_sum)
      com_x_pix = float(np.sum(xx * crop_sal) / s_sum)
      com_y = com_y_pix / max(1.0, crop_sal.shape[0])
      com_x = com_x_pix / max(1.0, crop_sal.shape[1])

      stats[idx, :] = [mean_sal, max_sal, com_y, com_x]

      # resize patches to canonical output
      p_rgb = cv2.resize(crop_img, (pw_out, ph_out), interpolation=cv2.INTER_AREA)
      p_sal = cv2.resize(crop_sal, (pw_out, ph_out), interpolation=cv2.INTER_LINEAR)
      patches[idx] = p_rgb
      salpatches[idx] = p_sal.astype(np.float32)
      if maskpatches is not None:
        p_mask = cv2.resize(crop_mask.astype(np.uint8) * 255, (pw_out, ph_out),
                            interpolation=cv2.INTER_NEAREST)
        maskpatches[idx] = (p_mask > 127).astype(np.uint8)

      idx += 1

  meta = {
    "grid": [gh, gw],
    "patch_out_hw": [ph_out, pw_out],
    "orig_hw": [H, W],
    "bboxes": bboxes,
    "src_image": str(img_p),
    "version": "tokens_v1"
  }

  if save_dir is None:
    return {"stats": stats, "patches": patches, "salpatches": salpatches, "maskpatches": maskpatches, "meta": meta}

  outp = Path(save_dir)
  outp.mkdir(parents=True, exist_ok=True)
  base = prefix

  np.save(outp / f"{base}.tokens.npy", stats)
  np.save(outp / f"{base}.patches.npy", patches)
  np.save(outp / f"{base}.salpatches.npy", salpatches)
  if maskpatches is not None:
    np.save(outp / f"{base}.maskpatches.npy", maskpatches)

  with open(outp / f"{base}.meta.json", "w") as fh:
    json.dump(meta, fh, indent=2)

  log.debug("Wrote patch files for %s: %s", prefix, outp)
  return {
    "tokens": str(outp / f"{base}.tokens.npy"),
    "patches": str(outp / f"{base}.patches.npy"),
    "salpatches": str(outp / f"{base}.salpatches.npy"),
    "maskpatches": str(outp / f"{base}.maskpatches.npy") if maskpatches is not None else "",
    "meta": str(outp / f"{base}.meta.json")
  }


## ---------------------------
## CAM runner
## ---------------------------

def _get_cam_class(method_name: str):
  mod = import_module("pytorch_grad_cam")
  if not hasattr(mod, method_name):
    raise NotImplementedError(f"CAM method `{method_name}` not found in pytorch_grad_cam")
  return getattr(mod, method_name)


def _instantiate_cam(CamCls, model, target_layers, device: str):
  """
  Create a CAM instance that is compatible with multiple pytorch-grad-cam APIs.
  Some versions expect `use_cuda=bool`, others `device=torch.device`, others neither.
  """
  sig = inspect.signature(CamCls.__init__)
  kwargs = {}
  if "use_cuda" in sig.parameters:
    kwargs["use_cuda"] = (device == "cuda")
  if "device" in sig.parameters:
    kwargs["device"] = torch.device(device)
  return CamCls(model=model, target_layers=target_layers, **kwargs)


def _run_cam_for_paths(
  model: torch.nn.Module,
  target_layers: List[torch.nn.Module],
  device: str,
  img_paths: List[str],
  input_size: Tuple[int, int],
  methods: List[str],
  cam_batch_size: int
) -> Dict[str, List[np.ndarray]]:
  results = {m: [] for m in methods}
  for m in methods:
    CamCls = _get_cam_class(m)
    # NOTE: your local CAM classes take reshape_transform, not use_cuda
    with CamCls(model=model, target_layers=target_layers, reshape_transform=_reshape_to_4d) as cam:
      if hasattr(cam, "batch_size"):
        cam.batch_size = int(cam_batch_size)
      for p in tqdm(img_paths, desc=f"CAM:{m}", leave=False, unit="img"):
        rgb01, tensor = _load_image_as_rgb01(p, input_size)
        tensor = tensor.to(device)
        grayscale = cam(input_tensor=tensor)[0]  # (H,W) after BaseCAM handling
        results[m].append(grayscale.astype(np.float32))
  return results


## ---------------------------
## YOLO masks (runtime / precomputed)
## ---------------------------

def _largest_mask(masks: Optional[np.ndarray]) -> Optional[np.ndarray]:
  if masks is None or len(masks) == 0:
    return None
  areas = [m.sum() for m in masks]
  return masks[int(np.argmax(areas))]

def _get_runtime_masks(
  img_paths: List[str],
  device: str = "cuda",
  weights: str = "yolov8s-seg.pt",
  imgsz: int = 640,
  conf: float = 0.25,
  iou: float = 0.45,
  prefer_class: Optional[int] = 0,
  save: bool = False
) -> List[np.ndarray]:
  model, device, _info = yoloapi.load_model(weights, device=device)
  masks = []
  for p in tqdm(img_paths, desc="YOLO-masks", leave=False):
    img = cv2.imread(p, 1)
    if img is None:
      raise FileNotFoundError(f"Image not found: {p}")
    preds = yoloapi.predict(
      model, img, device=device, imgsz=imgsz,
      conf=conf, iou=iou,
      classes_filter=[prefer_class] if prefer_class is not None else None
    )
    fg = None
    if preds and preds[0].get("masks") is not None and len(preds[0]["masks"]) > 0:
      fg = _largest_mask(preds[0]["masks"])
    if fg is None:
      h, w = img.shape[:2]
      fg = np.ones((h, w), dtype=np.float32)
    fg = fg.astype(np.float32)
    masks.append(fg)
    if save:
      cv2.imwrite(str(Path(p).with_suffix(".mask.png")), (fg * 255).astype(np.uint8))
  return masks

def _get_precomputed_masks(img_paths: List[str]) -> List[np.ndarray]:
  out = []
  for p in img_paths:
    mpath = Path(p).with_suffix(".mask.png")
    if mpath.exists():
      m = cv2.imread(str(mpath), 0)
      m = (m > 127).astype(np.float32)
    else:
      img = cv2.imread(p, 0)
      h, w = (img.shape[:2] if img is not None else (224, 224))
      m = np.ones((h, w), dtype=np.float32)  ## fallback: whole image
    out.append(m)
  return out

def _apply_masks(group_maps: Dict[str, List[np.ndarray]], masks: List[np.ndarray]) -> Dict[str, Dict[str, List[np.ndarray]]]:
  ## returns {"whole":{...},"fg":{...},"bg":{...}} with per-method lists
  out = {"whole": {k: [x.copy() for x in v] for k, v in group_maps.items()}, "fg": {}, "bg": {}}
  for mname, mlist in group_maps.items():
    fg_list, bg_list = [], []
    for smap, msk in zip(mlist, masks):
      msk = cv2.resize(msk.astype(np.float32), (smap.shape[1], smap.shape[0]), interpolation=cv2.INTER_NEAREST)
      fg_list.append(smap * msk)
      bg_list.append(smap * (1.0 - msk))
    out["fg"][mname] = fg_list
    out["bg"][mname] = bg_list
  return out


## ---------------------------
## Normalization + tokens for all groups
## ---------------------------

def _normalize_all(groups: Dict[str, Any], modes: List[str]) -> Dict[str, Any]:
  normed = {}
  for part, methods in groups.items():
    normed[part] = {}
    for mname, arrs in methods.items():
      normed[part][mname] = {mode: [_normalize_map(a, mode) for a in arrs] for mode in modes}
  return normed


def _tokenize_all(normed: Dict[str, Any], grids: List[List[int]]) -> Dict[str, Any]:
  tokens: Dict[str, Any] = {}
  for part, methods in normed.items():
    tokens[part] = {}
    for mname, by_norm in methods.items():
      tokens[part][mname] = {}
      for norm_name, arrs in by_norm.items():
        grid_out = {}
        for g in grids:
          label = f"{g[0]}x{g[1]}"
          per_img = []
          for a in arrs:
            t, meta = _grid_tokens(a, (g[0], g[1]))
            per_img.append({"vec": t, "meta": meta})
          grid_out[label] = per_img
        tokens[part][mname][norm_name] = grid_out
  return tokens


## ---------------------------
## Group resolution from SAGE context
## ---------------------------

def _collect_groups_from_context(context: Dict[str, Any], use_clusters: bool) -> Dict[str, Dict[str, Any]]:
  """
  Discover groups (paths + meta) from an inference/cluster run.

  It supports both layouts:
    A) <BASE_DIR>/correct.txt, incorrect.txt, clusters/...
    B) <BASE_DIR>/<NET>/correct.txt, incorrect.txt, clusters/...
  """
  base_dir = Path(context.get("groups_base_dir", context["to_path"]))
  log.info(f"Using nested model directory for saliency groups: {base_dir}")

  args = context.get("args")
  net = getattr(args, "net", None)

  # Prefer <BASE_DIR>/<NET> when it exists and looks like an inference run
  cand = base_dir / net if net else None
  if cand and cand.exists():
    has_signals = (cand / "clusters").exists() or (cand / "correct.txt").exists() or (cand / "incorrect.txt").exists()
    if has_signals:
      log.info(f"Using nested model directory for saliency groups: {cand}")
      base_dir = cand

  dataset = context.get("dataset")
  dataset_root = getattr(dataset, "root", "")

  def _paths_from_csv(csv_path: Path) -> List[str]:
    import pandas as pd
    if not csv_path.exists():
      return []
    df = pd.read_csv(csv_path)
    if "path" not in df.columns:
      return []
    return [_abs_path(dataset_root, p) for p in df["path"].tolist()]

  groups: Dict[str, Dict[str, Any]] = {}

  # 1) cluster-analysis CSVs
  if use_clusters:
    for correctness in ["correct", "incorrect"]:
      csv_all = base_dir / "clusters" / "analysis" / correctness / f"clusters.{correctness}.csv"
      plist = _paths_from_csv(csv_all)
      if plist:
        groups[f"{correctness}/overall"] = {"paths": plist, "meta": {"correctness": correctness}}
      per_class_dir = base_dir / "clusters" / "analysis" / correctness / "per_class"
      if per_class_dir.exists():
        for cdir in per_class_dir.iterdir():
          if cdir.is_dir():
            cls_id = cdir.name
            csv_c = cdir / f"clusters.{correctness}-{cls_id}.csv"
            pclist = _paths_from_csv(csv_c)
            if pclist:
              groups[f"{correctness}/per_class/{cls_id}"] = {
                "paths": pclist, "meta": {"correctness": correctness, "class_id": cls_id}
              }
  if groups:
    return groups

  # 2) correct/incorrect TSVs
  for correctness in ["correct", "incorrect"]:
    base_file = base_dir / f"{correctness}.txt"
    entries = _read_index_file(base_file)
    if entries:
      groups[f"{correctness}/overall"] = {
        "paths": [_abs_path(dataset_root, e["path"]) for e in entries],
        "meta": {"correctness": correctness}
      }
    per_class_dir = base_dir / f"{correctness}_per_class"
    if per_class_dir.exists():
      for cls_file in per_class_dir.glob("c*.txt"):
        cls_id = cls_file.stem
        e2 = _read_index_file(cls_file)
        if e2:
          groups[f"{correctness}/per_class/{cls_id}"] = {
            "paths": [_abs_path(dataset_root, e["path"]) for e in e2],
            "meta": {"correctness": correctness, "class_id": cls_id}
          }

  if groups:
    return groups

  # 3) final fallback: use dataset listing if available
  if dataset is not None and getattr(dataset, "imgs", None):
    groups["all"] = {"paths": [_abs_path(dataset_root, p) for p in dataset.imgs], "meta": {"correctness": "mixed"}}
    return groups

  log.warning(f"No groups found under {base_dir} (clusters/correct-incorrect/dataset empty?)")
  return {}



## ---------------------------
## Persistence
## ---------------------------

def _save_group_outputs(
  out_root: Path,
  group_key: str,
  img_paths: List[str],
  raw_maps: Dict[str, List[np.ndarray]],
  normed: Dict[str, Any],
  tokens: Dict[str, Any],
  masks_used: Optional[List[np.ndarray]],
  save_raw: bool,
  save_norms: List[str],
  save_overlays: bool,
  save_tokens: bool,
  input_size: Tuple[int, int],
  group_meta: Dict[str, Any],
  class_prefix_regex: Optional[str],
  indices_maps: List[Dict[str, Any]],
  indices_tokens: List[Dict[str, Any]],
  stats_accum: Dict[str, Any],
  save_patches: bool = False,
  patch_out_hw: Tuple[int, int] = (32, 32),
) -> None:
  """
  Persist saliency artifacts for a single group (group_key).
  Writes raw maps, overlays, normalized maps, masks, tokens, and optionally patch files.
  Updates indices_maps and indices_tokens in-place.
  """

  def _group_dir(method: str) -> Path:
    parts = group_key.split("/")
    return out_root / method / "/".join(parts)

  # 1) Save raw CAM maps + overlays + normalized whole maps
  for mname, maps in raw_maps.items():
    gdir = _group_dir(mname)
    _ensure_dir(gdir)

    for pth, smap in zip(img_paths, maps):
      stem = Path(pth).stem
      cls_prefix = _infer_class_prefix(pth, group_meta, class_prefix_regex)

      if save_raw:
        try:
          np.save(gdir / f"{stem}.{cls_prefix}.smap.npy", smap.astype(np.float32))
        except Exception:
          log.warning("Failed to save raw smap for %s under %s", pth, gdir)

      if save_overlays:
        try:
          rgb01, _ = _load_image_as_rgb01(pth, input_size)
          ov = _overlay_on_image(rgb01, _normalize_map(smap, "phi_relu"))
          cv2.imwrite(str(gdir / f"{stem}.{cls_prefix}.overlay.jpg"), ov)
        except Exception:
          log.warning("Failed to write overlay for %s under %s", pth, gdir)

      for nmode in save_norms:
        try:
          nmap = _normalize_map(smap, nmode)
          np.save(gdir / f"{stem}.{cls_prefix}.norm.{nmode}.npy", nmap.astype(np.float32))
        except Exception:
          log.warning("Failed to save normalized map (%s) for %s", nmode, pth)

      # index row for maps (whole/raw path + shapes)
      try:
        indices_maps.append({
          "group": group_key, "part": "whole", "method": mname,
          "stem": stem, "class_prefix": cls_prefix,
          "smap_shape": list(smap.shape), "overlay": bool(save_overlays), "norms": save_norms
        })
      except Exception:
        log.warning("Failed to append indices_maps for %s", pth)

  # 2) Save masks (if provided)
  if masks_used is not None:
    for mname, maps in raw_maps.items():
      mdir = _group_dir(mname) / "masks"
      _ensure_dir(mdir)
      for pth, smap, msk in zip(img_paths, maps, masks_used):
        try:
          msk_rs = cv2.resize(msk.astype(np.float32), (smap.shape[1], smap.shape[0]),
                              interpolation=cv2.INTER_NEAREST)
          stem = Path(pth).stem
          cls_prefix = _infer_class_prefix(pth, group_meta, class_prefix_regex)
          np.save(mdir / f"{stem}.{cls_prefix}.mask.npy", msk_rs.astype(np.float32))
        except Exception:
          log.warning("Failed to save mask for %s", pth)

  # 3) Save per-part normalized maps (whole/fg/bg)
  for part, methods in normed.items():  # part ∈ {"whole", "fg", "bg"}
    for mname, by_norm in methods.items():
      for nmode, arrs in by_norm.items():
        pdir = _group_dir(mname) / f"{part}/norm/{nmode}"
        _ensure_dir(pdir)
        for pth, a in zip(img_paths, arrs):
          try:
            stem = Path(pth).stem
            cls_prefix = _infer_class_prefix(pth, group_meta, class_prefix_regex)
            np.save(pdir / f"{stem}.{cls_prefix}.npy", a.astype(np.float32))
          except Exception:
            log.warning("Failed to save normalized part map %s for %s", nmode, pth)

  # 4) Save tokens + optionally per-cell patches
  if save_tokens and tokens:
    for part, methods in tokens.items():           # whole/fg/bg
      for mname, by_norm in methods.items():
        for norm_name, grids in by_norm.items():
          for grid_label, per_img in grids.items():
            tdir = _group_dir(mname) / f"tokens/{part}/{norm_name}/{grid_label}"
            _ensure_dir(tdir)

            # per-image loop uses index so we can match up with normed maps and masks_used
            for idx_img, (pth, item) in enumerate(zip(img_paths, per_img)):
              arr = item.get("vec")
              meta = item.get("meta", {})
              stem = Path(pth).stem
              cls_prefix = _infer_class_prefix(pth, group_meta, class_prefix_regex)

              # Save token vector
              try:
                np.save(tdir / f"{stem}.{cls_prefix}.tokens.npy", arr.astype(np.float32))
              except Exception:
                log.warning("Failed to save tokens for %s", pth)
                # still continue to try patch extraction / indexing

              patches_path = ""

              if save_patches:
                try:
                  # fetch the normalized map list that produced these tokens
                  smap_list = normed.get(part, {}).get(mname, {}).get(norm_name, [])
                  smap = smap_list[idx_img] if idx_img < len(smap_list) else None
                  if smap is None:
                    log.warning("No normalized smap available for %s (part=%s, method=%s, norm=%s). Skipping patches.",
                                pth, part, mname, norm_name)
                  else:
                    # corresponding mask if available
                    msk = masks_used[idx_img] if (masks_used is not None and idx_img < len(masks_used)) else None

                    # call helper that will resize image->smap dims and save multiple patch files
                    patch_res = extract_grid_patches(
                      img_abspath=pth,
                      smap=smap,
                      mask=msk,
                      grid=tuple(int(x) for x in grid_label.split("x")),
                      patch_out_hw=patch_out_hw,
                      save_dir=tdir,
                      prefix=f"{stem}.{cls_prefix}"
                    )
                    patches_path = patch_res.get("patches", "") if isinstance(patch_res, dict) else ""
                    log.info("Saved patches for %s -> %s", stem, patches_path or "<none>")
                except Exception as e:
                  import traceback
                  log.warning("Failed to extract/save patches for %s: %s\n%s", pth, e, traceback.format_exc())

              # append token index row (with patches_path possibly empty)
              try:
                indices_tokens.append({
                  "group": group_key,
                  "part": part,
                  "method": mname,
                  "norm": norm_name,
                  "grid": grid_label,
                  "stem": stem,
                  "class_prefix": cls_prefix,
                  "tokens_shape": list(arr.shape) if isinstance(arr, (np.ndarray, list, tuple)) else [],
                  "orig_hw": meta.get("orig_hw"),
                  "patch": meta.get("patch"),
                  "grid_hw": meta.get("grid"),
                  "patches_path": patches_path
                })
              except Exception:
                log.warning("Failed to append indices_tokens entry for %s", pth)

  # 5) update stats
  stats = stats_accum["by_group"].setdefault(group_key, {"count": 0})
  stats["count"] += len(img_paths)
  stats_accum["total_images"] += len(img_paths)


## ---------------------------
## Public API (SAGE wrappers)
## ---------------------------

def generate_saliency(context: Dict[str, Any], use_clusters: bool = True, **kwargs) -> Dict[str, Any]:
  args = context["args"]
  model = context["model"]
  device = context["device"]
  arch = args.net
  input_size = tuple(args.input_size)

  sal_cfg = (context.get("cfg", {}) or {}).get("saliency_cfg", {}) or {}
  methods = sal_cfg.get("methods", ["GradCAM"])  ## case-sensitive class names from pytorch-grad-cam
  target_layers = _resolve_target_layers(model, arch, sal_cfg)

  normalize_modes = sal_cfg.get("normalize", ["phi_relu"])
  mask_cfg = sal_cfg.get("masks", {}) or {}
  mask_mode = mask_cfg.get("mode", "none")  ## none | precomputed | runtime
  save_runtime_masks = bool(mask_cfg.get("save_runtime", False))
  yolo_weights = mask_cfg.get("weights", "yolov8s-seg.pt")
  yolo_imgsz = int(mask_cfg.get("imgsz", 640))
  yolo_conf = float(mask_cfg.get("conf", 0.25))
  yolo_iou = float(mask_cfg.get("iou", 0.45))
  yolo_prefer_class = mask_cfg.get("prefer_class", 0)

  patch_cfg = sal_cfg.get("patchify", {}) or {}
  patch_enabled = bool(patch_cfg.get("enabled", True))
  grids = patch_cfg.get("grids", [[4, 4]])

  out_cfg = sal_cfg.get("outputs", {}) or {}
  save_raw = bool(out_cfg.get("save_raw_maps", True))
  save_overlays = bool(out_cfg.get("save_overlays", True))
  save_tokens = bool(out_cfg.get("save_tokens", True))

  cam_bs = int((sal_cfg.get("batching", {}) or {}).get("cam_batch_size", 16))
  fail_fast = bool(sal_cfg.get("fail_fast", False))
  class_prefix_regex = sal_cfg.get("class_prefix_regex")

  model_out = Path(context["to_path"])
  _ensure_dir(model_out / "saliency")

  groups = _collect_groups_from_context(context, use_clusters=use_clusters)
  for gkey, ginfo in groups.items():
    img_paths = ginfo["paths"]
    if not img_paths:
      continue

    try:
      ## 1) raw CAMs
      raw_maps = _run_cam_for_paths(
        model=model,
        target_layers=target_layers,
        device=device,
        img_paths=img_paths,
        input_size=input_size,
        methods=methods,
        cam_batch_size=cam_bs
      )

      ## 2) masks (optional)
      masks_used = None
      if mask_mode == "runtime":
        masks_used = _get_runtime_masks(
          img_paths,
          device=device,
          weights=yolo_weights,
          imgsz=yolo_imgsz,
          conf=yolo_conf,
          iou=yolo_iou,
          prefer_class=yolo_prefer_class,
          save=save_runtime_masks
        )
        grouped = _apply_masks(raw_maps, masks_used)
      elif mask_mode == "precomputed":
        masks_used = _get_precomputed_masks(img_paths)
        grouped = _apply_masks(raw_maps, masks_used)
      else:
        grouped = {"whole": raw_maps}

      ## 3) normalization
      normed = _normalize_all(grouped, normalize_modes)

      ## optionally filter parts by user request without if/else in the shell
      # robustly read parts from args; default to ["whole","fg","bg"] if missing or falsy
      parts_wanted = set(getattr(context["args"], "parts", ["whole", "fg", "bg"]) or ["whole", "fg", "bg"])
      normed = {k: v for k, v in normed.items() if k in parts_wanted and k in grouped}

      ## 4) tokens
      tokens = _tokenize_all(normed, grids) if patch_enabled else {}

      ## 5) persist (with class-id prefixing across all artifacts)
      _save_group_outputs(
        out_root=Path(context["out_root"]),
        group_key=gkey,
        img_paths=img_paths,
        raw_maps=raw_maps,
        normed=normed,
        tokens=tokens,
        masks_used=masks_used,
        save_raw=save_raw,
        save_norms=normalize_modes,
        save_overlays=save_overlays,
        save_tokens=save_tokens,
        input_size=input_size,
        group_meta=ginfo.get("meta", {}),
        class_prefix_regex=class_prefix_regex,
        indices_maps=context["_indices_maps"],
        indices_tokens=context["_indices_tokens"],
        stats_accum=context["_stats"],
        save_patches=bool(sal_cfg.get("patchify", {}).get("save_patches", False)),
        patch_out_hw=tuple(sal_cfg.get("patchify", {}).get("patch_out_hw", (32, 32))),
      )

    except Exception as e:
      log.error(f"Saliency failed for group `{gkey}`: {e}")
      if fail_fast:
        raise

  context["saliency_dir"] = str(Path(context["out_root"]))
  return context

def save_saliency(context: Dict[str, Any], **kwargs) -> Dict[str, Any]:
  sal_dir = Path(context.get("saliency_dir", context.get("out_root", Path(context["to_path"]) / "saliency")))
  manifest = {
    "dir": str(sal_dir),
    "methods": (context.get("cfg", {}) or {}).get("saliency_cfg", {}).get("methods", []),
  }
  _ensure_dir(sal_dir)
  (sal_dir / "manifest.json").write_text(json.dumps(manifest, indent=2))
  return context


def _write_indices_and_manifest(context: Dict[str, Any]) -> None:
  out_root = Path(context["out_root"])
  sal_dir = out_root  ## e.g., logs/.../saliency/resnet18
  _ensure_dir(sal_dir)

  ## indices
  maps_csv   = sal_dir / "saliency_maps_index.csv"
  tokens_csv = sal_dir / "saliency_tokens_index.csv"

  with open(maps_csv, "w", newline="") as f:
    w = csv.DictWriter(f, fieldnames=[
      "group","part","method","stem","class_prefix","smap_shape","overlay","norms"
    ])
    w.writeheader()
    for row in context["_indices_maps"]:
      w.writerow(row)

  with open(tokens_csv, "w", newline="") as f:
    w = csv.DictWriter(f, fieldnames=[
      "group","part","method","norm","grid","stem","class_prefix",
      "tokens_shape","orig_hw","patch","grid_hw","patches_path"
    ])
    w.writeheader()
    for row in context["_indices_tokens"]:
      w.writerow(row)

  ## manifest
  args = context["args"]
  cfg  = (context.get("cfg", {}) or {}).get("saliency_cfg", {})
  manifest = {
    "arch": args.net,
    "run_dir": context["to_path"],
    "saliency_dir": str(sal_dir),
    "created_at": datetime.utcnow().isoformat() + "Z",
    "dataset": args.dataset or "unknown",
    "input_size": args.input_size,
    "num_class": args.num_class,
    "methods": cfg.get("methods"),
    "target": cfg.get("target"),
    "normalize": cfg.get("normalize"),
    "masks": cfg.get("masks"),
    "patchify": cfg.get("patchify"),
    "outputs": cfg.get("outputs"),
    "batching": cfg.get("batching"),
    "fail_fast": cfg.get("fail_fast"),
    "class_prefix_regex": cfg.get("class_prefix_regex"),
    "parts_enabled": getattr(args, "parts", ["whole", "fg", "bg"]),
    "stats": context["_stats"],
    "indices": {
      "maps_csv": str(maps_csv),
      "tokens_csv": str(tokens_csv)
    },
    "file_naming": {
      "smap":       "<stem>.<cX>.smap.npy",
      "overlay":    "<stem>.<cX>.overlay.jpg",
      "norm":       "<stem>.<cX>.norm.<normalize>.npy",
      "mask":       "masks/<stem>.<cX>.mask.npy",
      "tokens":     "tokens/<part>/<normalize>/<grid>/<stem>.<cX>.tokens.npy",
      "parts_note": "part ∈ {whole, fg, bg}; group ∈ {correct/overall, correct/per_class/cK, incorrect/...}"
    }
  }
  (sal_dir / "manifest.json").write_text(json.dumps(manifest, indent=2))


def main(args: argparse.Namespace) -> None:
  ## coerce input_size to tuple
  args.input_size = _parse_size(args.input_size)

  ## derive weight roots: [--to] + --weights_roots + env SAGE_WEIGHTS_ROOTS + common fallbacks
  roots = [args.to_path]
  roots += list(args.weights_roots or [])
  env_roots = os.getenv("SAGE_WEIGHTS_ROOTS", "")
  if env_roots:
    roots += [r for r in env_roots.split(":") if r]
  ## common fallbacks (kept last)
  _dh = os.getenv("__DATAHUB_ROOT__")
  _ch = os.getenv("__CODEHUB_ROOT__")
  if _dh: roots.append(os.path.join(_dh, "sage-experimental-logs"))
  if _ch: roots += [os.path.join(_ch, "external", "rflownet", "logs")]

  ## resolve weights if requested or missing
  if args.weights_path == "auto" or not os.path.isfile(args.weights_path):
    cand = _autodetect_weights(args.net, roots=roots, pattern=args.weights_glob)
    if not cand:
      msg = (
        f"Could not find weights matching pattern '{args.weights_glob.replace('{net}', args.net)}' "
        f"under roots: {roots}. Pass --weights_path explicitly."
      )
      raise FileNotFoundError(msg)
    args.weights_path = cand

  device = "cuda" if args.gpu and torch.cuda.is_available() else "cpu"

  ## Load model (torchapi expects these fields on args)
  model = loadmodel(args).to(device).eval()

  ## Build context for orchestrator-style flow
  sal_cfg = {
    "methods": args.methods,
    "target": {"selection": args.target},
    "normalize": args.normalize,
    "masks": {
      "mode": args.mask_mode,
      "save_runtime": args.save_runtime_masks,
      "weights": args.yolo_weights,
      "imgsz": args.yolo_imgsz,
      "conf": args.yolo_conf,
      "iou": args.yolo_iou,
      "prefer_class": args.yolo_prefer_class,
    },
    "patchify": {
      "enabled": (not args.no_tokens),
      "grids": _parse_grids(args.grids),
      "save_patches": bool(args.save_patches),
      "patch_out_hw": _parse_size(args.patch_out_hw)
    },
    "outputs": {
      "save_raw_maps": args.save_raw_maps,
      "save_overlays": args.save_overlays,
      "save_tokens": args.save_tokens
    },
    "batching": {"cam_batch_size": args.cam_batch_size},
    "fail_fast": args.fail_fast,
    "class_prefix_regex": args.class_prefix_regex,
  }

  groups_base_dir = Path(args.to_path) / args.net
  # Put saliency under the model/arch folder, e.g. <to_path>/<net>/saliency
  out_root = groups_base_dir / "saliency"

  context = {
    "args": args,
    "device": device,
    "to_path": str(args.to_path),          # run root
    "groups_base_dir": str(groups_base_dir),
    "out_root": str(out_root),
    "model": model,
    "cfg": {"saliency_cfg": sal_cfg},
    "_indices_maps": [],
    "_indices_tokens": [],
    "_stats": {"by_group": {}, "total_images": 0}
  }


  ## If --from provided, synthesize a simple group so pipeline can run
  if args.from_path:
    files: List[str] = []
    for part in args.from_path.split(","):
      part = part.strip()
      if os.path.isdir(part):
        for root, _, fnames in os.walk(part):
          for f in fnames:
            if f.lower().endswith((".jpg", ".jpeg", ".png", ".bmp")):
              files.append(os.path.join(root, f))
      elif os.path.isfile(part):
        files.append(part)
    if files:
      Path(args.to_path).mkdir(parents=True, exist_ok=True)
      with open(os.path.join(args.to_path, "correct.txt"), "w") as f:
        for i, p in enumerate(files):
          f.write(f"{i}\t{p}\t0\t0\t1.0\n")

  generate_saliency(context, use_clusters=True)
  unloadmodel(model)
  _write_indices_and_manifest(context)


def parse_args() -> argparse.Namespace:
  p = argparse.ArgumentParser(
    description="Saliency generation (CAM → masks → normalize → tokens)"
  )

  # ---- core model/data/run args (used by loadmodel + shell script) ----
  p.add_argument("--net", type=str, default="resnet18")
  p.add_argument("--weights_path", type=str, default="auto")
  p.add_argument("--num_class", type=int, default=10)
  p.add_argument("--input_size", type=str, default="(224,224)",
                 help="Like '(224,224)'; parsed later to a tuple")

  p.add_argument("--dataset", type=str, default=None)
  p.add_argument("--datasetcfg", type=str, default="data/ddd-datasets.yml")
  p.add_argument("--split", type=str, default="test")

  # accept both flags for compatibility
  p.add_argument("--from", dest="from_path", type=str, default=None)
  p.add_argument("--from_path", dest="from_path", type=str)

  # accept --to and --to_path; both map to args.to_path
  p.add_argument("--to", "--to_path", dest="to_path", type=str, default=".",
                 help="Run root containing per-arch outputs (clusters, correct.txt, etc.)")

  p.add_argument("--loss", type=str, default="CrossEntropyLoss")
  p.add_argument("--score_level", type=int, default=1)
  p.add_argument("--in_channels", type=int, default=3)
  p.add_argument("--pretrain", type=_str2bool, default=False)
  p.add_argument("--gpu", type=_str2bool, default=True)

  # ---- autodetect knobs ----
  p.add_argument("--weights_glob", type=str, default="**/{net}-final.pth")
  p.add_argument("--weights_roots", nargs="+", default=[])

  # ---- saliency cfg ----
  p.add_argument("--methods", nargs="+", default=["GradCAM"])
  p.add_argument("--target", type=str, default="auto")
  p.add_argument("--normalize", nargs="+", default=["phi_relu"])
  p.add_argument("--class_prefix_regex", type=str, default="^c[0-9]+$")

  # ---- masks (default ON) ----
  p.add_argument("--mask_mode", type=str, default="runtime",
                 choices=["none", "precomputed", "runtime"])
  p.add_argument("--save_runtime_masks", type=_str2bool, default=True)
  p.add_argument("--yolo_weights", type=str, default="/codehub/external/rflownet/yolov8s-seg.pt")
  p.add_argument("--yolo_imgsz", type=int, default=640)
  p.add_argument("--yolo_conf", type=float, default=0.25)
  p.add_argument("--yolo_iou", type=float, default=0.45)
  p.add_argument("--yolo_prefer_class", type=int, default=0)

  # ---- tokens ----
  p.add_argument("--grids", nargs="+", default=["4x4", "8x8"])
  p.add_argument("--no_tokens", type=_str2bool, default=False)
  p.add_argument("--save_patches", type=_str2bool, default=False,
                 help="Save per-cell RGB/saliency/mask patches alongside tokens (default: False)")
  p.add_argument("--patch_out_hw", type=str, default="(32,32)",
                 help="Patch output size (h,w) for saved patches; parsed like '(32,32)'")

  # ---- outputs / batching ----
  p.add_argument("--save_raw_maps", type=_str2bool, default=True)
  p.add_argument("--save_overlays", type=_str2bool, default=True)
  p.add_argument("--save_tokens", type=_str2bool, default=True)
  p.add_argument("--cam_batch_size", type=int, default=16)
  p.add_argument("--fail_fast", type=_str2bool, default=False)

  # ---- which parts to emit ----
  p.add_argument("--parts", nargs="+", default=["whole", "fg", "bg"])

  args = p.parse_args()
  return args



def print_args(args: argparse.Namespace) -> None:
  log.info("Arguments:")
  for k, v in vars(args).items():
    log.info(f"{k}: {v}")


if __name__ == "__main__":
  args = parse_args()
  print_args(args)
  main(args)
