# assp/vis/vis.py
from __future__ import annotations

import json
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import matplotlib

matplotlib.use("Agg")  # 服务器无显示环境时可安全绘图
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.patches import Rectangle

# ====== 与 hook 一致的可视化风格（反归一化 / 颜色 / 叠加方式）======

_IM_MEAN = np.array([123.675, 116.28, 103.53], dtype=np.float32)
_IM_STD = np.array([58.395, 57.12, 57.375], dtype=np.float32)


def _to_numpy(x: torch.Tensor | np.ndarray) -> np.ndarray:
    if isinstance(x, torch.Tensor):
        x = x.detach().cpu().numpy()
    return x


def _denorm_image(chw: torch.Tensor | np.ndarray) -> np.ndarray:
    """CHW(float32, normalized) -> HWC(float in [0,1])，与 hook 完全一致"""
    arr = _to_numpy(chw).astype(np.float32)  # C,H,W
    if arr.ndim != 3:
        raise ValueError(f"expect CHW image, got shape={arr.shape}")
    c, h, w = arr.shape
    arr = arr.transpose(1, 2, 0) * _IM_STD + _IM_MEAN  # H,W,C
    arr = np.clip(arr, 0.0, 255.0) / 255.0
    return arr  # HWC in [0,1]


def _bin_from_logit(logit_hw: torch.Tensor | np.ndarray) -> np.ndarray:
    """把logit阈值到二值（>0）。输入 HxW 或 1xHxW 或 Nx1xHxW 都能处理（取单通道）"""
    arr = _to_numpy(logit_hw)
    if arr.ndim == 4:
        arr = arr[:, 0]  # N,H,W
    if arr.ndim == 3:
        if arr.shape[0] == 1:
            arr = arr[0]  # H,W
        else:
            return (arr > 0).astype(np.uint8)  # N,H,W -> 逐样本上层处理
    elif arr.ndim != 2:
        raise ValueError(f"unexpected pred mask shape={arr.shape}")
    return (arr > 0).astype(np.uint8)


def _to_bool_mask(x: torch.Tensor | np.ndarray) -> np.ndarray:
    arr = _to_numpy(x).astype(np.uint8)
    if arr.ndim == 4:  # N,1,H,W
        arr = arr[:, 0]
    if arr.ndim == 3 and arr.shape[0] == 1:
        arr = arr[0]
    if arr.ndim != 2:
        raise ValueError(f"unexpected gt mask shape={arr.shape}")
    return (arr > 0).astype(np.uint8)


def _dice_from_bin(pred_bin: np.ndarray, gt_bin: np.ndarray, eps: float = 1e-7) -> float:
    inter = (pred_bin & gt_bin).sum(dtype=np.int64)
    denom = pred_bin.sum(dtype=np.int64) + gt_bin.sum(dtype=np.int64)
    return float((2.0 * inter + eps) / (denom + eps))


def _parse_modality_from_name(name: str) -> str:
    """
    解析命名规则：
      images/{mod}--{dataset}--{ori}--{slice}.png
    取第一个 -- 之前作为 modality key。
    """
    stem = Path(name).stem
    if "--" in stem:
        return stem.split("--", 1)[0]
    return stem  # 兜底


@dataclass
class StageView:
    dice: float
    pred_bin: np.ndarray  # H x W, uint8
    gt_bin: np.ndarray  # H x W, uint8
    points: Optional[List[Tuple[float, float, int]]] = None  # [(x,y,label), ...]
    box: Optional[Tuple[float, float, float, float]] = None  # (x0,y0,x1,y1)


@dataclass
class SampleRecord:
    key: str
    name: str
    modality: str
    image: np.ndarray  # H x W x 3, float in [0,1]
    stages: Dict[str, StageView] = field(default_factory=dict)

    def combined_score(self, want: Tuple[str, ...]) -> float:
        vals = [self.stages[s].dice for s in want if s in self.stages]
        if not vals:
            return -1.0
        return float(np.mean(vals))


class Session:
    """
    按 modality 维护 Top-K 可视化池：
      - want_stages: 用于排序的阶段集合（默认 box / point1 / point3 / point5 的平均 Dice）
      - 支持统一 top_k 或 top_k_by_modality 的细粒度设置
    """

    def __init__(
        self,
        tag: str,
        out_root: Optional[str | Path] = None,
        top_k: int = 16,
        want_stages: Tuple[str, ...] = ("box", "point1", "point3", "point5"),
        top_k_by_modality: Optional[Dict[str, int]] = None,
    ) -> None:
        self.tag = tag
        self.top_k = int(top_k)
        self.want = tuple(want_stages)
        self.top_k_by_modality = dict(top_k_by_modality) if top_k_by_modality else {}

        if out_root is None:
            out_root = "./visualizations"
        self.root = Path(out_root) / str(tag)
        self.root.mkdir(parents=True, exist_ok=True)

        # 每个模态一个 TOP 池： modality -> { key: (score, record) }
        self._top_by_mod: Dict[str, Dict[str, Tuple[float, SampleRecord]]] = {}

        # 活跃样本（用于限制常驻内存）：key -> record
        self._active: Dict[str, SampleRecord] = {}

    # ---------- 活跃样本管理 ----------
    def _ensure_active(self, key: str, name: str, modality: str, img_chw: torch.Tensor | np.ndarray) -> SampleRecord:
        rec = self._active.get(key)
        if rec is None:
            rec = SampleRecord(
                key=key,
                name=name,
                modality=modality,
                image=_denorm_image(img_chw),
            )
            self._active[key] = rec
        return rec

    def _commit_record(self, rec: SampleRecord) -> None:
        score = rec.combined_score(self.want)
        if score < 0:
            return
        bucket = self._top_by_mod.setdefault(rec.modality, {})
        if rec.key in bucket:
            prev_score, _ = bucket[rec.key]
            if score > prev_score:
                bucket[rec.key] = (score, rec)
        else:
            k_lim = self.top_k_by_modality.get(rec.modality, self.top_k)
            if len(bucket) < k_lim:
                bucket[rec.key] = (score, rec)
            else:
                worst_key = min(bucket.items(), key=lambda kv: kv[1][0])[0]
                if score > bucket[worst_key][0]:
                    bucket.pop(worst_key, None)
                    bucket[rec.key] = (score, rec)

    # ---------- 外部 API：观察一个 stage 的结果（批量） ----------
    def observe(
        self,
        *,
        batch: Dict,
        pred_masks: torch.Tensor,  # [N,1,H,W] logits
        gt_masks: torch.Tensor,  # [N,1,H,W]
        prompts: Dict,  # 包含 boxes 或 points
        stage: str,  # "box" | "point1" | "point3" | "point5"
        M: int,  # masks per image in batch["label"]: [B,M,H,W]
    ) -> None:
        imgs = batch.get("input_image")  # [B,3,H,W]
        names = batch.get("name", None)  # list[str] or None
        ids = batch.get("sample_id", None)  # optional

        if isinstance(names, list) and len(names) == 1:
            names = names[0]
        if isinstance(ids, list) and len(ids) == 1:
            ids = ids[0]

        N = int(pred_masks.shape[0])
        assert gt_masks.shape[0] == N

        # 解析 prompts（用于画箱/点）
        boxes = None
        if prompts is not None and "boxes" in prompts and prompts["boxes"] is not None:
            boxes = _to_numpy(prompts["boxes"])
        pts_all = None
        if prompts is not None and "point_coords" in prompts and prompts["point_coords"] is not None:
            pts_all = (_to_numpy(prompts["point_coords"]), _to_numpy(prompts["point_labels"]))

        for i in range(N):
            # 对应到 batch 的第几个图（B 维）：
            b_idx = i // M if isinstance(M, int) and M > 0 else 0

            # 取名字/键
            name_i = (names[b_idx] if isinstance(names, list) else names) if names is not None else f"sample_{b_idx}"
            key_i = str((ids[b_idx] if isinstance(ids, list) else ids)) if ids is not None else f"{name_i}"
            modality = _parse_modality_from_name(str(name_i))

            # 激活记录
            img_chw_i = imgs[b_idx] if isinstance(imgs, torch.Tensor) else imgs
            rec = self._ensure_active(key_i, str(name_i), modality, img_chw_i)

            # 准备 mask/gt/提示
            pred_bin = _bin_from_logit(pred_masks[i])  # H,W
            gt_bin = _to_bool_mask(gt_masks[i])  # H,W

            box_one = None
            if boxes is not None:
                if boxes.ndim == 2:  # N,4
                    box_one = tuple(map(float, boxes[i].tolist()))
                elif boxes.ndim == 1 and N == 1:
                    box_one = tuple(map(float, boxes.tolist()))

            pts_one = None
            if pts_all is not None:
                pc_all, pl_all = pts_all
                if pc_all.ndim == 3:  # N,P,2
                    pts_one = (pc_all[i], pl_all[i])
                elif pc_all.ndim == 2 and N == 1:
                    pts_one = (pc_all, pl_all)

            dice = _dice_from_bin(pred_bin, gt_bin)
            st = StageView(
                dice=dice,
                pred_bin=pred_bin,
                gt_bin=gt_bin,
                points=[(float(x), float(y), int(l)) for (x, y), l in zip(pts_one[0], pts_one[1])] if pts_one is not None else None,
                box=box_one,
            )

            old = rec.stages.get(stage)
            if old is None or dice >= old.dice:
                rec.stages[stage] = st

            # 每次更新后尝试推进到 Top 池（这样即使大数据集也能逐步淘汰）
            self._commit_record(rec)

    # ---------- finalize：落盘（每模态各自 TOP-K），六宫格 ----------
    def finalize(self, accelerator) -> None:
        if not getattr(accelerator, "is_main_process", True):
            return

        for modality, bucket in self._top_by_mod.items():
            out_dir = self.root / modality
            out_dir.mkdir(parents=True, exist_ok=True)

            items = sorted(bucket.items(), key=lambda kv: kv[1][0], reverse=True)
            manifest = []
            for rank_idx, (key, (score, rec)) in enumerate(items):
                out_path = out_dir / f"{rank_idx:03d}_{rec.name}_{key}.png"
                self._render_six_panel(rec, out_path)
                manifest.append(
                    dict(
                        key=key,
                        name=rec.name,
                        modality=rec.modality,
                        score=score,
                        stages={s: rec.stages[s].dice for s in rec.stages},
                        file=str(out_path.name),
                    )
                )

            with open(out_dir / "manifest.json", "w", encoding="utf-8") as f:
                json.dump(manifest, f, ensure_ascii=False, indent=2)

    # ---------- 六宫格绘图（Input / GT / Box / 1 / 3 / 5） ----------
    def _render_six_panel(self, rec: SampleRecord, out_png: Path) -> None:
        order = ("box", "point1", "point3", "point5")
        title_map = {"box": "Box", "point1": "1 point", "point3": "3 points", "point5": "5 points"}

        fig, axes = plt.subplots(2, 3, figsize=(12, 8), dpi=150)
        ax_input, ax_gt, ax_box, ax_p1, ax_p3, ax_p5 = axes.reshape(-1)

        # 1) Input
        ax_input.axis("off")
        ax_input.imshow(rec.image)
        ax_input.set_title("Input", fontsize=9)

        # 为 GT 选择一个 stage 拿到 gt_bin（所有 stage 的 gt 应相同，取可用的一个即可）
        any_gt = None
        for s in rec.stages.values():
            any_gt = s.gt_bin
            break

        # 2) GT
        ax_gt.axis("off")
        ax_gt.imshow(rec.image)
        if any_gt is not None:
            self._draw_contour(ax_gt, any_gt)  # 绿色轮廓
        ax_gt.set_title("GT", fontsize=9)

        # 3~6) 预测面板：Box / 1 / 3 / 5
        panes = [ax_box, ax_p1, ax_p3, ax_p5]
        for ax, stage_name in zip(panes, order):
            ax.axis("off")
            ax.imshow(rec.image)
            if stage_name in rec.stages:
                st = rec.stages[stage_name]
                # 叠加预测（半透明）
                ax.imshow(np.ma.masked_where(st.pred_bin == 0, st.pred_bin), alpha=0.35)  # 与 hook 一致
                # 画 GT 轮廓
                self._draw_contour(ax, st.gt_bin)

                # 画点/框（与 hook 一致的配色与样式）
                if st.points is not None:
                    for x, y, lab in st.points:
                        if int(lab) == 1:
                            ax.scatter([x], [y], marker="o", s=14, facecolors="none", edgecolors="yellow", linewidths=1.5)
                        else:
                            ax.scatter([x], [y], marker="x", s=12, c="yellow", linewidths=1.0)
                if st.box is not None:
                    x0, y0, x1, y1 = st.box
                    ax.add_patch(Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor="gold", linewidth=1.5))

                ax.set_title(f"{title_map[stage_name]} | Dice={st.dice:.3f}", fontsize=9)
            else:
                ax.set_title(f"{title_map[stage_name]} | N/A", fontsize=9)

        plt.suptitle(f"{rec.name} [{rec.modality}] | AvgDice={rec.combined_score(self.want):.3f}", fontsize=10)
        plt.tight_layout(rect=[0, 0.03, 1, 0.97])
        fig.savefig(str(out_png))
        plt.close(fig)

    @staticmethod
    def _draw_contour(ax, mask_bin: np.ndarray) -> None:
        try:
            ax.contour(mask_bin.astype(np.float32), levels=[0.5], colors=["lime"], linewidths=1.0)
        except Exception:
            pass


# ---------------- 对外主入口 ----------------
def vis_session(
    *,
    tag: str,
    out_root: Optional[str | Path] = None,
    top_k: int = 16,
    want_stages: Tuple[str, ...] = ("box", "point1", "point3", "point5"),
    top_k_by_modality: Optional[Dict[str, int]] = None,
) -> Session:
    return Session(tag=tag, out_root=out_root, top_k=top_k, want_stages=want_stages, top_k_by_modality=top_k_by_modality)
