# assp/vis/vis_runner.py
from __future__ import annotations

from typing import Dict, Optional, Sequence, Tuple

import torch
from torch import nn
from tqdm.auto import tqdm

from assp.core.eval import prompt_and_decoder_eval  # 直接复用你的推理子流程
from assp.core.utils import generate_point, tensor_to_device

from .vis import vis_session


def _get_medsz(args) -> int:
    # 与 eval.py 的行为一致：优先 medsam_image_size，其次 sammed2d_image_size
    if hasattr(args, "medsam_image_size") and getattr(args, "medsam_image_size") is not None:
        return int(getattr(args, "medsam_image_size"))
    return int(getattr(args, "sammed2d_image_size"))


@torch.no_grad()
def visualize_downstream_once(
    *,
    args,
    student: nn.Module,
    switcher,
    loader,  # val_down loader（或任意下游样本的 DataLoader）
    accelerator,
    tag: str = "down_vis",
    out_root: Optional[str] = None,
    top_k: int = 16,
    top_k_by_modality: Optional[Dict[str, int]] = None,
    want_stages: Tuple[str, ...] = ("box", "point1", "point3", "point5"),
) -> None:
    """
    • 一次遍历 loader，收集 Box + Points(1/3/5) 的预测，按 modality 分桶选 Top-K，六宫格落盘。
    • 只在主进程落盘；推理逻辑完全沿用 evaluate_downstream 的实现。
    """
    medsz = _get_medsz(args)
    device = accelerator.device

    # 切到“下游路径”：开 adapter + interp prompt
    stu = accelerator.unwrap_model(student)
    stu.eval()
    if hasattr(stu.image_encoder, "cancel_skip_adapter_mod"):
        stu.image_encoder.cancel_skip_adapter_mod()
    if hasattr(switcher, "to_interp_prompt"):
        switcher.to_interp_prompt()

    want_points: Sequence[int] = (1, 3, 5)
    max_points = max(want_points)

    sess = vis_session(tag=tag, out_root=out_root, top_k=top_k, want_stages=("box", "point1", "point3", "point5"), top_k_by_modality=top_k_by_modality)

    with accelerator.autocast(), torch.no_grad():
        pbar = tqdm(loader, desc=f"vis_{tag} (Box+Point)", disable=not accelerator.is_main_process)
        for batch in pbar:
            batch = tensor_to_device(batch, device)
            labels_bm = batch["label"]  # [B,M,H,W]
            B, M = labels_bm.shape[:2]
            N = B * M
            labels = labels_bm.reshape(N, 1, medsz, medsz)

            # 一次编码，多次解码（与 evaluate_downstream 一致）
            emb_single = stu.image_encoder(batch["input_image"])[0]  # [B,C,H',W']
            image_embeddings = emb_single.unsqueeze(1).expand(-1, M, -1, -1, -1).reshape(N, *emb_single.shape[1:])

            # ========== Box 可视化 ==========
            if "boxes" in batch and batch["boxes"] is not None:
                prompts_box = {
                    "boxes": batch["boxes"].reshape(N, -1),
                    "point_coords": None,
                    "point_labels": None,
                }
                masks_box, _, _ = prompt_and_decoder_eval(args, prompts_box, stu, image_embeddings)  # logits in [N,1,H,W]
                sess.observe(batch=batch, pred_masks=masks_box, gt_masks=labels, prompts=prompts_box, stage="box", M=M)

            # ========== Points 可视化（迭代）==========
            if "point_coords" in batch and batch["point_coords"] is not None:
                prompts_pt = {
                    "point_coords": batch["point_coords"].reshape(N, -1, 2),
                    "point_labels": batch["point_labels"].reshape(N, -1),
                    "boxes": None,
                }
                pc_hist = [prompts_pt["point_coords"]]
                pl_hist = [prompts_pt["point_labels"]]

                # 逐步解码并在特定步数时收集
                for it in range(max_points):
                    cur_pts = it + 1
                    masks_pt, low_res, _ = prompt_and_decoder_eval(args, prompts_pt, stu, image_embeddings)  # [N,1,H,W]
                    if cur_pts in want_points:
                        stage_name = f"point{cur_pts}"
                        sess.observe(batch=batch, pred_masks=masks_pt, gt_masks=labels, prompts=prompts_pt, stage=stage_name, M=M)

                    if it < max_points - 1:
                        # 生成下一步点
                        prompts_pt = generate_point(masks_pt, labels, low_res, prompts_pt, getattr(args, "point_num", 9))
                        prompts_pt = tensor_to_device(prompts_pt, device)

                        # 累加历史点（与 evaluate_downstream 保持一致）
                        pc_hist.append(prompts_pt["point_coords"])
                        pl_hist.append(prompts_pt["point_labels"])
                        prompts_pt["point_coords"] = torch.cat(pc_hist, dim=1)
                        prompts_pt["point_labels"] = torch.cat(pl_hist, dim=1)

    # 只在主进程落盘
    sess.finalize(accelerator)
