#!/usr/bin/env python3
"""
Usage (simulate stenosis & export box)
-----
python new_gen_copy.py \                                     
  --input /mnt/CAG_Dataset/datasets/gen_seg_dataset/seed_0021_0_-30.png \
  --out_dir /mnt/CAG_Dataset/datasets/stenosis_edit/new_test \
  --simulate_stenosis \
  --shrink_factor 0.3 \
  --half_window 12 \
  --sten_pos_frac 0.5 \
  --save_overlay \
  --save_radius_map \
  --reconstruct
"""
from __future__ import annotations
import argparse, json, math
from pathlib import Path
from typing import Iterable, List, Tuple, Optional, Dict, Any

import cv2
import numpy as np
from skimage.morphology import skeletonize, medial_axis

from vessel_path_finder import VesselPathFinder

NEIGH8 = np.array([[1,1,1],[1,0,1],[1,1,1]], dtype=np.uint8)
OFFSETS: List[Tuple[int,int]] = [(dy,dx) for dy in (-1,0,1) for dx in (-1,0,1) if not (dy==0 and dx==0)]

# ---------- I/O & preprocessing ----------
def read_mask(path: Path):
    img = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE)
    if img is None:
        raise FileNotFoundError(f"Cannot read image: {path}")
    _, bw = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    bw = (bw > 0).astype(np.uint8)
    return img, bw

# ---------- skeleton utilities ----------
def neighbor_count(arr: np.ndarray) -> np.ndarray:
    return cv2.filter2D(arr, -1, NEIGH8, borderType=cv2.BORDER_CONSTANT)

def endpoints(arr: np.ndarray) -> np.ndarray:
    nb = neighbor_count(arr)
    return ((arr==1) & (nb==1)).astype(np.uint8)

def junctions(arr: np.ndarray) -> np.ndarray:
    nb = neighbor_count(arr)
    return ((arr==1) & (nb>=3)).astype(np.uint8)

def prune_spurs(skel_arr: np.ndarray, max_len: int = 15, passes: int = 2) -> np.ndarray:
    sk = skel_arr.copy().astype(np.uint8)
    h, w = sk.shape
    for _ in range(passes):
        junc = junctions(sk)
        end = endpoints(sk)
        to_remove = np.zeros_like(sk, dtype=np.uint8)
        ys, xs = np.where(end==1)
        for y0, x0 in zip(ys, xs):
            path = [(y0, x0)]
            prev = None
            y, x = y0, x0
            for _step in range(max_len):
                found = None
                for dy, dx in OFFSETS:
                    ny, nx = y+dy, x+dx
                    if 0 <= ny < h and 0 <= nx < w and sk[ny, nx]==1 and (prev is None or (ny,nx)!=prev):
                        found = (ny, nx); break
                if found is None: break
                prev = (y, x); y, x = found; path.append((y, x))
                if junc[y, x]==1:
                    for (py,px) in path[:-1]: to_remove[py,px]=1
                    break
                if endpoints(sk)[y,x]==1 and (y,x)!=(y0,x0):
                    for (py,px) in path: to_remove[py,px]=1
                    break
        sk[to_remove==1] = 0
    return sk

# ---------- connectivity repair ----------
def _bfs_shortest_path_on_mask(mask: np.ndarray, src: Tuple[int,int], dst: Tuple[int,int],
                               max_len: int, dilate_r: int = 1) -> Optional[List[Tuple[int,int]]]:
    H, W = mask.shape
    walk = cv2.dilate(mask, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*dilate_r+1, 2*dilate_r+1)), 1) if dilate_r>0 else mask
    y0, x0 = src; y1, x1 = dst
    ymin = max(0, min(y0,y1) - (max_len+2))
    ymax = min(H, max(y0,y1) + (max_len+3))
    xmin = max(0, min(x0,x1) - (max_len+2))
    xmax = min(W, max(x0,x1) + (max_len+3))
    sub = walk[ymin:ymax, xmin:xmax]
    if sub.sum() == 0: return None
    src_s = (y0 - ymin, x0 - xmin); dst_s = (y1 - ymin, x1 - xmin)
    from collections import deque
    q = deque([src_s]); visited = np.zeros_like(sub, dtype=np.uint8); visited[src_s]=1; prev={}
    found=False
    while q:
        cy, cx = q.popleft()
        if (cy,cx)==dst_s: found=True; break
        for dy,dx in OFFSETS:
            ny, nx = cy+dy, cx+dx
            if 0 <= ny < sub.shape[0] and 0 <= nx < sub.shape[1] and not visited[ny,nx] and sub[ny,nx]==1:
                visited[ny,nx]=1; prev[(ny,nx)]=(cy,cx); q.append((ny,nx))
    if not found: return None
    path_s=[]; cur=dst_s
    while cur!=src_s: path_s.append(cur); cur=prev[cur]
    path_s.append(src_s); path_s.reverse()
    if len(path_s)-1 > max_len: return None
    return [(py+ymin, px+xmin) for (py,px) in path_s]

def _connect_skeleton_gaps(bw: np.ndarray, skel: np.ndarray,
                           max_bridge_len: int = 10, dilate_r: int = 1) -> np.ndarray:
    sk = skel.copy().astype(np.uint8)
    def comp_labels(arr): return cv2.connectedComponents(arr, 8)
    def endpoints_of_comp(mask_skel, comp_lbl, comp_id):
        ep = endpoints(mask_skel).astype(bool)
        ys, xs = np.where((comp_lbl==comp_id) & ep)
        return list(zip(ys, xs))
    changed=True
    while changed:
        changed=False
        num, lbl = comp_labels(sk)
        if num <= 2: break
        comp_eps = {cid: endpoints_of_comp(sk, lbl, cid) for cid in range(1, num)}
        cids = list(range(1, num)); cands=[]
        for i in range(len(cids)):
            for j in range(i+1, len(cids)):
                c1, c2 = cids[i], cids[j]; eps1, eps2 = comp_eps[c1], comp_eps[c2]
                if not eps1 or not eps2: continue
                best=None
                for p1 in eps1:
                    for p2 in eps2:
                        d = math.hypot(p1[0]-p2[0], p1[1]-p2[1])
                        if d <= max_bridge_len and (best is None or d<best[0]): best=(d,p1,p2)
                if best is not None: cands.append((best[0], best[1], best[2]))
        if not cands: break
        cands.sort(key=lambda x: x[0])
        for _d, p1, p2 in cands:
            num_now, lbl_now = comp_labels(sk)
            if num_now <= 2: break
            if lbl_now[p1]==lbl_now[p2]: continue
            path = _bfs_shortest_path_on_mask(bw, p1, p2, max_bridge_len, dilate_r)
            if path is None: continue
            for (y,x) in path: sk[y,x]=1
            changed=True
    return sk

# ---------- centerline + radius ----------
def skeleton_and_radius(bw: np.ndarray, spur_len: int = 15, passes: int = 2,
                        connect_gaps: bool = True, max_bridge_len: int = 10, bridge_dilate: int = 1):
    skel = skeletonize(bw.astype(bool)).astype(np.uint8)
    dist = cv2.distanceTransform((bw*255).astype(np.uint8), cv2.DIST_L2, 5)
    radius_map = (dist * skel).astype(np.float32)
    return skel, radius_map

# ---------- skeleton path extraction ----------
def _extract_paths(skel: np.ndarray) -> List[List[Tuple[int,int]]]:
    """Decompose skeleton into polylines (endpoint↔node)."""
    sk = (skel>0).astype(np.uint8)
    H, W = sk.shape
    deg = neighbor_count(sk)
    is_node = ((sk==1) & (deg!=2))
    visited = np.zeros_like(sk, dtype=np.uint8)
    paths: List[List[Tuple[int,int]]] = []

    def neighbors(y,x):
        for dy,dx in OFFSETS:
            ny, nx = y+dy, x+dx
            if 0 <= ny < H and 0 <= nx < W and sk[ny,nx]==1:
                yield (ny,nx)

    ys, xs = np.where((sk==1) & (deg==1))
    for y0,x0 in zip(ys,xs):
        if visited[y0,x0]: continue
        path = [(y0,x0)]; visited[y0,x0]=1; prev=None; y,x=y0,x0
        while True:
            nxt=None
            for nb in neighbors(y,x):
                if nb!=prev: nxt=nb; break
            if nxt is None: break
            prev=(y,x); y,x=nxt; path.append((y,x)); visited[y,x]=1
            if is_node[y,x] and (y,x)!=(y0,x0): break
        if len(path)>=2: paths.append(path)

    ys, xs = np.where((sk==1) & (visited==0))
    for y0,x0 in zip(ys,xs):
        if visited[y0,x0]: continue
        path=[(y0,x0)]; visited[y0,x0]=1; prev=None; y,x=y0,x0; seen={(y0,x0)}
        while True:
            nxts=[nb for nb in neighbors(y,x) if nb!=prev]
            nxt = nxts[0] if nxts else None
            if nxt is None: break
            prev=(y,x); y,x=nxt
            if (y,x) in seen: path.append((y,x)); break
            path.append((y,x)); visited[y,x]=1; seen.add((y,x))
            if is_node[y,x]: break
        if len(path)>=2: paths.append(path)
    return paths

def _moving_avg(a: np.ndarray, k: int) -> np.ndarray:
    if k<=1: return a.copy()
    k = int(k)|1; pad=k//2
    aa = np.pad(a, (pad,pad), mode='edge')
    c = np.cumsum(aa, dtype=np.float64)
    return ((c[k:] - c[:-k]) / k).astype(np.float32)

def _fit_rotated_rect(points: List[Tuple[int,int]], width_px: float, pad_px: float):
    """Return OpenCV rotated rect ((cx,cy),(L,W),angle_deg) and box points."""
    P = np.array([[float(x), float(y)] for (y,x) in points], dtype=np.float32)  # (N,2) in (x,y)
    if len(P) < 2:
        p = P.mean(axis=0) if len(P)>0 else np.array([0.0,0.0],dtype=np.float32)
        rect = ((float(p[0]), float(p[1])), (max(1.0,2*pad_px), max(1.0,width_px)), 0.0)
        box = cv2.boxPoints(rect).astype(np.float32)
        return rect, box
    vx, vy, x0, y0 = cv2.fitLine(P, cv2.DIST_L2, 0, 0.01, 0.01)
    v = np.array([float(vx), float(vy)], dtype=np.float32); v/= (np.linalg.norm(v)+1e-8)
    p0 = np.array([float(x0), float(y0)], dtype=np.float32)
    t = (P - p0) @ v
    tmin, tmax = float(t.min()), float(t.max())
    L = max(1.0, (tmax - tmin) + 2*pad_px)
    center_xy = p0 + v * ((tmin + tmax)/2.0)
    angle = math.degrees(math.atan2(v[1], v[0]))
    rect = ((float(center_xy[0]), float(center_xy[1])), (float(L), float(width_px)), float(angle))
    box = cv2.boxPoints(rect).astype(np.float32)
    return rect, box

def draw_and_save_stenosis_vis(gray: np.ndarray, box_pts: np.ndarray, pds: float, out_path: Path):
    vis = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
    box_int = np.round(box_pts).astype(np.int32)
    cv2.polylines(vis, [box_int], True, (0,0,255), 2)
    cx, cy = int(round(box_pts[:,0].mean())), int(round(box_pts[:,1].mean()))
    cv2.putText(vis, f"{pds*100:.0f}%", (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,255), 2, cv2.LINE_AA)
    cv2.imwrite(str(out_path), vis)


# MARK: - 两阶段路径查找
def get_paths_from_two_stage(bw: np.ndarray, min_path_length: int = 0) -> Tuple[List[List[Tuple[int,int]]], Dict[str, Any]]:
    """
    使用两阶段路径查找获取从第一分支点开始的所有路径
    Args:
        bw: 二值化掩码
        min_path_length: 最小路径长度筛选
    Returns:
        (paths, info): 路径列表和详细信息
    """
    # 创建临时图像文件
    import tempfile
    with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp:
        tmp_path = tmp.name
        cv2.imwrite(tmp_path, (bw * 255).astype(np.uint8))
    # 使用两阶段模式
    finder = VesselPathFinder(tmp_path)
    finder.run(use_two_stage=True)
    
    # 获取路径数据
    paths_data = finder.get_paths_from_first_branch()
    
    # 筛选路径长度
    filtered_paths = []
    filtered_info = []
    
    for info in paths_data['path_info']:
        if info['length'] >= min_path_length:
            filtered_paths.append(info['path'])
            filtered_info.append(info)
    
    result_info = {
        'main_start_point': paths_data['main_start_point'],
        'leftmost_point': paths_data['leftmost_point'],
        'total_paths': len(filtered_paths),
        'filtered_from': paths_data['num_paths'],
        'path_info': filtered_info
    }
    
    return filtered_paths, result_info
        
    # finally:
    #     # 清理临时文件
    #     import os
    #     if os.path.exists(tmp_path):
    #         os.unlink(tmp_path)




# ---------- simulate stenosis & export box ----------
def simulate_stenosis_and_export(
    gray: np.ndarray,
    skel: np.ndarray,
    radius_map: np.ndarray,
    sten_pos_frac: float = 0.5,
    half_window: int = 12,
    shrink_factor: float = 0.5,
    box_width_scale: float = 2.2,
    box_pad: float = 2.0,
    out_json: Optional[Path] = None,
    out_png: Optional[Path] = None,
    out_png_no_box: Optional[Path] = None,
    given_path: Optional[List[Tuple[int,int]]] = None
) -> Dict[str, Any]:
    """
    返回字典包含框、狭窄范围、PDS 等；
    同时生成带框和不带框的两个版本
    """
    path = given_path

    n = len(path)
    if n < 3:
        raise RuntimeError("Main path too short for stenosis simulation.")
    idx_center = int(np.clip(round(sten_pos_frac * (n-1)), 0, n-1))
    i0 = max(0, idx_center - half_window)
    i1 = min(n-1, idx_center + half_window)
    seg_pts = path[i0:i1+1]

    ys = np.array([p[0] for p in path], dtype=np.int32)
    xs = np.array([p[1] for p in path], dtype=np.int32)
    r = radius_map[ys, xs].astype(np.float32)
    r_smooth = _moving_avg(r, 9)

    # 参考半径：窗口外的局部中位数（避免被缩窄影响）
    ref_ctx = np.concatenate([r_smooth[max(0,i0-25):max(0,i0-5)], r_smooth[min(n, i1+5):min(n, i1+25)]]) if n>50 else r_smooth
    ref_radius = float(np.median(ref_ctx[ref_ctx>0])) if np.any(ref_ctx>0) else float(np.median(r_smooth[r_smooth>0]))
    if not np.isfinite(ref_radius) or ref_radius <= 0: ref_radius = float(np.maximum(1.0, r_smooth.max()))

    # 应用缩窄（直接改 radius_map）
    seg_idx = (ys[i0:i1+1], xs[i0:i1+1])
    orig_seg_r = radius_map[seg_idx].copy()
    radius_map[seg_idx] = np.maximum(0.0, orig_seg_r * float(shrink_factor))

    # 计算直径狭窄百分比（用半径近似）
    min_r = float(np.min(radius_map[seg_idx])) if len(seg_pts)>0 else 0.0
    print(min_r)
    pds = max(0.0, 1.0 - (min_r / (ref_radius + 1e-6)))  # ≈ 直径狭窄比例

    # 旋转框（沿中心线方向）
    width = max(2.0, 2.0 * ref_radius * box_width_scale)
    rect, box = _fit_rotated_rect(seg_pts, width_px=width, pad_px=float(box_pad))

    det = {
        "range_indices": [int(i0), int(i1)],
        "center": [float(rect[0][0]), float(rect[0][1])],   # (x,y)
        "size": [float(rect[1][0]), float(rect[1][1])],     # (L_along, W_across)
        "angle_deg": float(rect[2]),
        "ref_radius_px": float(ref_radius),
        "min_radius_px": float(min_r),
        "percent_diameter_stenosis": float(pds),
        "polygon": np.asarray(box, dtype=np.float32).reshape(-1,2).astype(float).tolist(),
        "shrink_factor": float(shrink_factor),
        "half_window": int(half_window),
        "sten_pos_frac": float(sten_pos_frac)
    }

    if out_json is not None:
        out_json.parent.mkdir(parents=True, exist_ok=True)
        json.dump({
            "image_width": int(gray.shape[1]),
            "image_height": int(gray.shape[0]),
            "stenosis": det
        }, open(out_json, "w"), ensure_ascii=False, indent=2)

    recon = reconstruct_from_centerline(radius_map)
    recon = (recon*255).astype(np.uint8)

    # 保存不带框的纯分割图
    if out_png_no_box is not None:
        out_png_no_box.parent.mkdir(parents=True, exist_ok=True)
        cv2.imwrite(str(out_png_no_box), recon)

    # 保存带框的可视化图
    if out_png is not None:
        draw_and_save_stenosis_vis(recon, np.asarray(det["polygon"], dtype=np.float32), pds, out_png)

    return det


def simulate_stenosis_multi_paths(
    gray: np.ndarray,
    skel: np.ndarray,
    radius_map: np.ndarray,
    bw: np.ndarray,
    min_path_length: int = 50,
    sten_positions: Optional[List[float]] = None,
    half_window: int = 12,
    shrink_factor: float = 0.5,
    box_width_scale: float = 2.2,
    box_pad: float = 2.0,
    out_dir: Optional[Path] = None,
    stem: str = "output",
    top_n_paths: Optional[int] = None
) -> List[Dict[str, Any]]:
    if sten_positions is None:
        sten_positions = [0.5]  # 默认中间位置
    
    # 获取路径
    paths, path_info = get_paths_from_two_stage(bw, min_path_length)
    print(f"[TwoStage] Found {path_info['total_paths']} paths (filtered from {path_info['filtered_from']}, min_len={min_path_length})")
    print(f"[TwoStage] Main start point: {path_info['main_start_point']}")

    if not paths:
        print("Warning: No valid paths found for stenosis generation")
        return []
    
    # 如果指定了 top_n_paths，则只选择最长的N条路径
    if top_n_paths is not None and top_n_paths > 0:
        # 按路径长度降序排序
        paths_with_length = [(len(path), i, path) for i, path in enumerate(paths)]
        paths_with_length.sort(key=lambda x: x[0], reverse=True)
        
        # 选择最长的N条
        selected_paths = [path for _, _, path in paths_with_length[:top_n_paths]]
        print(f"[TopN Filter] Selected top {min(top_n_paths, len(paths))} longest paths out of {len(paths)} total paths")
        for i, (length, orig_idx, _) in enumerate(paths_with_length[:top_n_paths]):
            print(f"  Path {i+1}: length={length} pixels (original index={orig_idx})")
        paths = selected_paths
    
    all_detections = []

    original_radius_map = radius_map.copy()  # 保存原始副本
    print("[Independent Mode] Each stenosis will be calculated independently")
    
    # 遍历每条路径
    for path_idx, path in enumerate(paths):
        print(f"\n[Path {path_idx+1}/{len(paths)}] Length: {len(path)} pixels")
        
        # 在当前路径的不同位置生成狭窄
        for pos_idx, pos_frac in enumerate(sten_positions):            
            # 生成输出文件名
            out_json = out_dir / f"{stem}_path{path_idx+1:02d}_pos{pos_frac:.2f}_stenosis.json"
            out_png = out_dir / f"{stem}_path{path_idx+1:02d}_pos{pos_frac:.2f}_stenosis_with_box.png"
            out_png_no_box = out_dir / f"{stem}_path{path_idx+1:02d}_pos{pos_frac:.2f}_stenosis.png"

            radius_map_to_use = original_radius_map.copy()

            try:
                det = simulate_stenosis_and_export(
                    gray, skel, radius_map_to_use,
                    sten_pos_frac=pos_frac,
                    half_window=half_window,
                    shrink_factor=shrink_factor,
                    box_width_scale=box_width_scale,
                    box_pad=box_pad,
                    out_json=out_json,
                    out_png=out_png,
                    out_png_no_box=out_png_no_box,
                    given_path=path
                )
                
                # 添加路径信息
                det['path_index'] = path_idx
                det['path_length'] = len(path)
                det['position_frac'] = pos_frac
                all_detections.append(det)
                
                print(f"  Position {pos_frac:.2f}: PDS={det['percent_diameter_stenosis']*100:.1f}%, "
                      f"ref_r={det['ref_radius_px']:.1f}px, min_r={det['min_radius_px']:.1f}px")
                
            except Exception as e:
                print(f"  Warning: Failed to generate stenosis at position {pos_frac:.2f}: {e}")
                continue
    
    # 保存汇总JSON
    if out_dir and all_detections:
        summary_json = out_dir / f"{stem}_all_stenosis_summary.json"
        with open(summary_json, 'w') as f:
            json.dump({
                'image_width': int(gray.shape[1]),
                'image_height': int(gray.shape[0]),
                'min_path_length': min_path_length,
                'total_stenosis': len(all_detections),
                'stenosis_list': all_detections
            }, f, ensure_ascii=False, indent=2)
        print(f"\n[Summary] Saved {len(all_detections)} stenosis annotations to {summary_json}")
    
    return all_detections



# ---------- reconstruction ----------
def reconstruct_from_centerline(radius_map: np.ndarray) -> np.ndarray:
    H, W = radius_map.shape
    recon = np.zeros((H, W), dtype=np.uint8)
    ys, xs = np.where(radius_map > 0)
    for y, x in zip(ys, xs):
        r = int(max(0, round(float(radius_map[y, x]))))
        if r > 0:
            cv2.circle(recon, (int(x), int(y)), r, 1, thickness=-1, lineType=cv2.LINE_8)
        else:
            recon[y, x] = 1
    return recon

def dice_iou(pred: np.ndarray, gt: np.ndarray):
    pred = (pred > 0).astype(np.uint8); gt = (gt > 0).astype(np.uint8)
    inter = (pred & gt).sum()
    union = (pred | gt).sum()
    dice = 2*inter / (pred.sum() + gt.sum() + 1e-8)
    iou = inter / (union + 1e-8)
    return float(dice), float(iou)

# ---------- save helpers ----------
def save_centerline_images(gray, skel, out_centerline: Path, out_overlay: Path|None):
    out_centerline.parent.mkdir(parents=True, exist_ok=True)
    cv2.imwrite(str(out_centerline), (skel*255).astype(np.uint8))
    if out_overlay is not None:
        overlay = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
        overlay[skel==1] = (255,255,255)
        cv2.imwrite(str(out_overlay), overlay)

# ---------- CLI ----------
def parse_args():
    p = argparse.ArgumentParser(description="Extract centerline (+radius), reconstruct, and simulate stenosis with oriented rotated box")
    in_grp = p.add_mutually_exclusive_group(required=True)
    in_grp.add_argument("--input", type=Path, help="Path to a single PNG/JPG mask")
    in_grp.add_argument("--input_dir", type=Path, help="Directory containing mask images")
    p.add_argument("--exts", nargs="*", default=["png","jpg","jpeg","bmp"])
    p.add_argument("--out_dir", type=Path, required=True)

    p.add_argument("--open_iter", type=int, default=1)
    p.add_argument("--spur_len", type=int, default=15)
    p.add_argument("--passes", type=int, default=2)
    p.add_argument("--save_overlay", action="store_true")
    p.add_argument("--save_radius_map", action="store_true",
                   help="Save radius map as .npy (float pixels)")
    p.add_argument("--reconstruct", action="store_true",
                   help="Reconstruct mask from (centerline, radius) and report Dice/IoU")

    # connectivity repair
    p.add_argument("--connect_gaps", action="store_true", default=True,
                   help="Repair broken centerlines by bridging endpoints within a short range (default: on)")
    p.add_argument("--max_bridge_len", type=int, default=10,
                   help="Max gap (px) to bridge between endpoint pairs")
    p.add_argument("--bridge_dilate", type=int, default=1,
                   help="Dilate mask by r px before BFS to help crossing 1px holes")

    # simulate stenosis & export rotated box
    p.add_argument("--simulate_stenosis", action="store_true",
                   help="Simulate a stenosis at a location and export an oriented rotated box (JSON + PNG)")
    p.add_argument("--shrink_factor", type=float, default=0.5,
                   help="Radius shrink factor inside stenosis (0~1, smaller = narrower)")
    p.add_argument("--half_window", type=int, default=12,
                   help="Half window length along centerline for stenosis segment")
    p.add_argument("--sten_pos_frac", type=float, default=0.5,
                   help="Relative position (0~1) along the longest path to place stenosis")
    p.add_argument("--box_width_scale", type=float, default=2.2,
                   help="Width scale (multiplies reference radius to set box width)")
    p.add_argument("--box_pad", type=float, default=2.0,
                   help="Padding (px) added to box length along centerline")
    p.add_argument("--save_stenosis_vis", action="store_true", default=True,
                   help="Save stenosis visualization image with rotated box")

    # multi-path stenosis generation
    p.add_argument("--use_two_stage", action="store_true",
                   help="Use two-stage vessel path finding (from first branch point)")
    p.add_argument("--min_path_length", type=int, default=0,
                   help="Minimum path length to consider for stenosis generation")
    p.add_argument("--multi_path", action="store_true",
                   help="Generate stenosis on multiple paths instead of just the longest one")
    p.add_argument("--sten_positions", type=str, default=None,
                   help="Comma-separated stenosis positions (0-1), e.g., '0.3,0.5,0.7'. If not set, uses --sten_pos_frac")
    p.add_argument("--top_n_paths", type=int, default=None,
                   help="Only process the top N longest paths. If not set, process all paths. E.g., --top_n_paths 2 to process only the 2 longest paths")
    return p.parse_args()

def iter_images(input_dir: Path, exts: Iterable[str]) -> Iterable[Path]:
    for ext in exts:
        yield from sorted(input_dir.rglob(f"*.{ext}"))

def process_one(img_path: Path, out_dir: Path, args):
    gray, bw = read_mask(img_path)

    skel, radius_map = skeleton_and_radius(
        bw, spur_len=args.spur_len, passes=args.passes,
        connect_gaps=args.connect_gaps,
        max_bridge_len=args.max_bridge_len,
        bridge_dilate=args.bridge_dilate
    )

    stem = img_path.stem
    out_centerline = out_dir / f"{stem}_centerline.png"
    out_overlay    = (out_dir / f"{stem}_overlay.png") if args.save_overlay else None
    save_centerline_images(gray, skel, out_centerline, out_overlay)

    if args.save_radius_map:
        np.save(out_dir / f"{stem}_radius_map.npy", radius_map.astype(np.float32))

    # ---- simulate stenosis & export rotated box (JSON + PNG) ----
    if args.simulate_stenosis:
        sten_positions = [float(x.strip()) for x in args.sten_positions.split(',')]
        all_dets = simulate_stenosis_multi_paths(
            gray, skel, radius_map, bw,
            min_path_length=args.min_path_length,
            sten_positions=sten_positions,
            half_window=int(args.half_window),
            shrink_factor=float(args.shrink_factor),
            box_width_scale=float(args.box_width_scale),
            box_pad=float(args.box_pad),
            out_dir=out_dir,
            stem=stem,
            top_n_paths=args.top_n_paths
        )
        print(f"[MultiPath] Generated {len(all_dets)} stenosis annotations")

    if args.reconstruct:
        recon = reconstruct_from_centerline(radius_map)
        cv2.imwrite(str(out_dir / f"{stem}_recon.png"), (recon*255).astype(np.uint8))
        d, i = dice_iou(recon, bw)
        with open(out_dir / f"{stem}_metrics.txt", "w") as f:
            f.write(f"Dice: {d:.6f}\nIoU: {i:.6f}\n")
        print(f"Reconstruction Dice={d:.4f}, IoU={i:.4f}")

def main():
    args = parse_args()
    args.out_dir.mkdir(parents=True, exist_ok=True)
    img_paths = [args.input] if args.input else list(iter_images(args.input_dir, args.exts))
    if not img_paths:
        raise FileNotFoundError("No input images found.")
    for p in img_paths:
        try:
            process_one(p, args.out_dir, args)
            print(f"[OK] {p}")
        except Exception as e:
            print(f"[FAIL] {p}: {e}")

if __name__ == "__main__":
    main()