# train.py
import copy
from contextlib import nullcontext
from pathlib import Path
from typing import Dict, Optional, Tuple

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm

from .eval import evaluate_downstream, evaluate_pretrain
from .loss import FocalDiceloss_IoULoss
from .utils import generate_point, set_trainable_parts, tensor_to_device


def _expand_BM(emb_single: torch.Tensor, M: int) -> torch.Tensor:
    """[B,C,H,W] -> [B*M,C,H,W]（严格等价于参考代码里的展开逻辑）"""
    B = emb_single.shape[0]
    return emb_single.unsqueeze(1).expand(B, M, *emb_single.shape[1:]).reshape(B * M, *emb_single.shape[1:])


def compensate_unified_distillation(
    *,
    args,
    student: nn.Module,  # 学生：SAM-Med2D
    teacher_anchor: nn.Module,  # anchor（meta-teacher）
    switcher,  # SAMPartSwitcher
    loaders: Dict[str, object],  # 需要包含 "train_pre", "train_down"；可选 "val_pre", "val_down"
    accelerator,
    epoch_num: int,
    stage_name: Optional[str] = None,  # 保留占位，不使用
    pruning_mode: Optional[str] = None,  # 保留占位，不使用
    teacher_prev: Optional[nn.Module] = None,
    w_pre: Optional[float] = None,
    w_down: Optional[float] = None,
    lr: Optional[float] = None,
    ddp_no_sync_first_backward: bool = True,  # 保留占位，不使用
    logger=None,  # 显式传入 logger（非主进程会被你的 system 日志器静默）
    best_ckpt_path_unified: Optional[str] = None,  # 新增：最优 state_dict 保存路径（建议 main 里传绝对/规范化路径）
    reload_best_at_end: bool = True,  # 新增：阶段结束是否全员回载最优
    save_full_model: bool = True,  # 新增：是否在 rank0 额外保存完整模型快照
) -> Dict[str, float]:
    """
    统一蒸馏补偿（保持参考逻辑）：
      • PRE 段：禁 adapter + orig prompt，主损失对 anchor（meta），中间对齐优先 mentor（prev），无则回退 anchor
      • DOWN 段：开 adapter + interp prompt，主损失对 anchor（meta），中间对齐优先 mentor（prev），无则回退 anchor
    每个 batch：PRE backward -> DOWN backward -> step（与参考一致）
    """
    # 为了自包含，不动你文件头部的 import，这里本地引入
    import copy
    from pathlib import Path

    device = accelerator.device
    train_pre_dl = loaders.get("train_pre")
    train_down_dl = loaders.get("train_down")
    val_pre_dl = loaders.get("val_pre")
    val_down_dl = loaders.get("val_down")

    # 超参（若未通过参数显式传入，则从 args 读；保持参考默认）
    w_pre = float(w_pre if w_pre is not None else getattr(args, "w_pre", 1.0))
    w_down = float(w_down if w_down is not None else getattr(args, "w_down", 1.0))
    lr = float(lr if lr is not None else getattr(args, "lr", 1e-4))

    # ---- 日志：阶段开始 ----
    if logger is not None:
        logger.info(
            f"[UnifiedDistill] start | epochs={epoch_num}, w_pre={w_pre}, w_down={w_down}, lr={lr}, mentor={'on' if teacher_prev is not None else 'off'}"
        )

    mse = nn.MSELoss()

    # 仅优化 image_encoder（与参考一致）
    enc = student.image_encoder
    optimizer = torch.optim.Adam(enc.parameters(), lr=lr)

    enc_prep, optimizer, train_pre_dl, train_down_dl, val_pre_dl, val_down_dl = accelerator.prepare(
        enc, optimizer, train_pre_dl, train_down_dl, val_pre_dl, val_down_dl
    )

    last_loss_pre, last_loss_down = 0.0, 0.0

    # —— 新增：本阶段最优缓存（以 DiceAvgPts 作为比较基准）——
    best_metric = float("-inf")
    best_epoch = -1

    student.train()
    teacher_anchor.eval()
    if teacher_prev is not None:
        teacher_prev.eval()

    for ep in range(epoch_num):
        # 让采样器可复现（与参考一致）
        if hasattr(train_pre_dl, "sampler") and hasattr(train_pre_dl.sampler, "set_epoch"):
            train_pre_dl.sampler.set_epoch(ep)
        if hasattr(train_down_dl, "sampler") and hasattr(train_down_dl.sampler, "set_epoch"):
            train_down_dl.sampler.set_epoch(ep)

        # zip 两路，取最短（与参考一致）
        num_steps = min(len(train_pre_dl), len(train_down_dl))
        pbar = tqdm(
            zip(train_pre_dl, train_down_dl),
            total=num_steps,
            desc=f"[Unified] Epoch {ep + 1}/{epoch_num}",
            disable=not accelerator.is_main_process,  # 关键：只在主进程显示
        )
        for batch_pre, batch_down in pbar:
            optimizer.zero_grad(set_to_none=True)
            frac = float(ep + 1) / float(max(1, epoch_num))

            # ========= PRE 段：禁 adapter + orig prompt =========
            set_trainable_parts(
                accelerator.unwrap_model(student),
                train_image_encoder_adapter=False,
                train_image_encoder_non_adapter=True,
                train_prompt_encoder=False,
                train_mask_decoder=False,
            )
            student.image_encoder.set_skip_adapter_mod()
            switcher.to_orig_prompt()

            img_pre = (batch_pre if torch.is_tensor(batch_pre) else batch_pre["input_image"]).to(device=device, dtype=torch.float32)

            # anchor / mentor 句柄
            teacher_mentor = teacher_prev if teacher_prev is not None else teacher_anchor

            # 学生前向
            with accelerator.autocast():
                (
                    stu_main_pre,
                    stu_attn_mid_pre,
                    stu_attn_pre,
                    stu_mlp_mid_pre,
                    stu_block_pre,
                    _,
                    _,
                    _,
                ) = enc_prep(img_pre)

            # 老师状态（禁 adapter）
            teacher_anchor.image_encoder.set_skip_adapter_mod()
            if teacher_mentor is not teacher_anchor:
                teacher_mentor.image_encoder.set_skip_adapter_mod()

            # 老师前向：先 anchor，mentor 若与 anchor 相同则直接复用张量
            with torch.no_grad():
                (
                    anchor_main_pre,
                    anchor_attn_mid_pre,
                    anchor_attn_pre,
                    anchor_mlp_mid_pre,
                    anchor_block_pre,
                    _,
                    _,
                    _,
                ) = teacher_anchor.image_encoder(img_pre)

                if teacher_mentor is teacher_anchor:
                    mentor_attn_mid_pre = anchor_attn_mid_pre
                    mentor_attn_pre = anchor_attn_pre
                    mentor_mlp_mid_pre = anchor_mlp_mid_pre
                    mentor_block_pre = anchor_block_pre
                else:
                    (
                        _,
                        mentor_attn_mid_pre,
                        mentor_attn_pre,
                        mentor_mlp_mid_pre,
                        mentor_block_pre,
                        _,
                        _,
                        _,
                    ) = teacher_mentor.image_encoder(img_pre)

            # 主蒸馏 + 早期中间对齐
            loss_pre = mse(stu_main_pre, anchor_main_pre)
            if frac <= 0.25:
                if stu_attn_mid_pre.shape == mentor_attn_mid_pre.shape:
                    loss_pre = loss_pre + 0.2 * mse(stu_attn_mid_pre, mentor_attn_mid_pre)
                if stu_attn_pre.shape == mentor_attn_pre.shape:
                    loss_pre = loss_pre + 0.2 * mse(stu_attn_pre, mentor_attn_pre)
                if stu_mlp_mid_pre.shape == mentor_mlp_mid_pre.shape:
                    loss_pre = loss_pre + 0.2 * mse(stu_mlp_mid_pre, mentor_mlp_mid_pre)
                if stu_block_pre.shape == mentor_block_pre.shape:
                    loss_pre = loss_pre + 0.2 * mse(stu_block_pre, mentor_block_pre)

            with accelerator.no_sync(enc_prep) if accelerator.num_processes > 1 else nullcontext():
                accelerator.backward(w_pre * loss_pre)

            # ========= DOWN 段：开 adapter + interp prompt =========
            set_trainable_parts(
                accelerator.unwrap_model(student),
                train_image_encoder_adapter=True,
                train_image_encoder_non_adapter=False,
                train_prompt_encoder=False,
                train_mask_decoder=False,
            )
            student.image_encoder.cancel_skip_adapter_mod()
            switcher.to_interp_prompt()

            img_down = (batch_down if torch.is_tensor(batch_down) else batch_down["input_image"]).to(device=device, dtype=torch.float32)

            # 老师状态（开 adapter）
            teacher_anchor.image_encoder.cancel_skip_adapter_mod()
            if teacher_mentor is not teacher_anchor:
                teacher_mentor.image_encoder.cancel_skip_adapter_mod()

            # 学生前向
            with accelerator.autocast():
                (
                    stu_main_down,
                    stu_attn_mid_down,
                    stu_attn_down,
                    stu_mlp_mid_down,
                    stu_block_down,
                    _,
                    _,
                    _,
                ) = enc_prep(img_down)

            # 老师前向（先 anchor，再 mentor 视情况）
            with torch.no_grad():
                (
                    anchor_main_down,
                    anchor_attn_mid_down,
                    anchor_attn_down,
                    anchor_mlp_mid_down,
                    anchor_block_down,
                    _,
                    _,
                    _,
                ) = teacher_anchor.image_encoder(img_down)

                if teacher_mentor is teacher_anchor:
                    mentor_attn_mid_down = anchor_attn_mid_down
                    mentor_attn_down = anchor_attn_down
                    mentor_mlp_mid_down = anchor_mlp_mid_down
                    mentor_block_down = anchor_block_down
                else:
                    (
                        _,
                        mentor_attn_mid_down,
                        mentor_attn_down,
                        mentor_mlp_mid_down,
                        mentor_block_down,
                        _,
                        _,
                        _,
                    ) = teacher_mentor.image_encoder(img_down)

            # 主蒸馏 + 早期中间对齐
            loss_down = mse(stu_main_down, anchor_main_down)
            if frac <= 0.25:
                if stu_attn_mid_down.shape == mentor_attn_mid_down.shape:
                    loss_down = loss_down + 0.2 * mse(stu_attn_mid_down, mentor_attn_mid_down)
                if stu_attn_down.shape == mentor_attn_down.shape:
                    loss_down = loss_down + 0.2 * mse(stu_attn_down, mentor_attn_down)
                if stu_mlp_mid_down.shape == mentor_mlp_mid_down.shape:
                    loss_down = loss_down + 0.2 * mse(stu_mlp_mid_down, mentor_mlp_mid_down)
                if stu_block_down.shape == mentor_block_down.shape:
                    loss_down = loss_down + 0.2 * mse(stu_block_down, mentor_block_down)

            accelerator.backward(w_down * loss_down)
            optimizer.step()

            last_loss_pre = float(loss_pre.detach().item())
            last_loss_down = float(loss_down.detach().item())

            pbar.set_postfix(
                loss_pre=f"{last_loss_pre:.4f}",
                loss_down=f"{last_loss_down:.4f}",
            )

        # ===== 每个 epoch 结束：按需验证并记录日志（只做“记录”，不改 eval 内部逻辑）=====
        stu_eval = accelerator.unwrap_model(student) if hasattr(accelerator, "unwrap_model") else student

        # 指标名位置
        metrics_list = list(getattr(args, "metrics", []))
        iou_idx = metrics_list.index("iou") if "iou" in metrics_list else -1
        dice_idx = metrics_list.index("dice") if "dice" in metrics_list else -1

        # 1) 上游验证（单点）
        if val_pre_dl is not None:
            pre_vec = evaluate_pretrain(
                student=stu_eval,
                switcher=switcher,
                loader=val_pre_dl,
                args=args,
                accelerator=accelerator,
                tag="pre",
            )
            if logger is not None:
                parts = []
                if iou_idx != -1 and pre_vec.size > iou_idx:
                    parts.append(f"IoU={float(pre_vec[iou_idx]):.4f}")
                if dice_idx != -1 and pre_vec.size > dice_idx:
                    parts.append(f"Dice={float(pre_vec[dice_idx]):.4f}")
                logger.info(f"[UnifiedDistill][E{ep + 1}/{epoch_num}] Pre-validate: " + (" | ".join(parts) if parts else "metrics=N/A"))

        # 2) 下游验证（Box + 多点）
        if val_down_dl is not None:
            down_dict = evaluate_downstream(
                student=stu_eval,
                switcher=switcher,
                loader=val_down_dl,
                args=args,
                accelerator=accelerator,
                tag="down",
            )
            if logger is not None:
                # Box-only
                box_arr = down_dict.get("final", None)
                box_str = "Box: "
                if isinstance(box_arr, np.ndarray):
                    box_bits = []
                    if iou_idx != -1 and box_arr.size > iou_idx:
                        box_bits.append(f"IoU={float(box_arr[iou_idx]):.4f}")
                    if dice_idx != -1 and box_arr.size > dice_idx:
                        box_bits.append(f"Dice={float(box_arr[dice_idx]):.4f}")
                    box_str += " | ".join(box_bits) if box_bits else "metrics=N/A"
                else:
                    box_str += "metrics=N/A"

                # Points（多点）
                pt_keys = sorted([k for k in down_dict.keys() if isinstance(k, int)])
                pt_bits = []
                pt_dice_vals = []
                for k in pt_keys:
                    arr = down_dict[k]
                    one = []
                    if isinstance(arr, np.ndarray):
                        if iou_idx != -1 and arr.size > iou_idx:
                            one.append(f"IoU={float(arr[iou_idx]):.4f}")
                        if dice_idx != -1 and arr.size > dice_idx:
                            dv = float(arr[dice_idx])
                            one.append(f"Dice={dv:.4f}")
                            pt_dice_vals.append(dv)
                    pt_bits.append(f"P{k}: " + (" | ".join(one) if one else "metrics=N/A"))

                # 平均 Dice（如果有 Dice）
                tail = ""
                if pt_dice_vals:
                    tail = f" | DiceAvgPts={float(sum(pt_dice_vals) / len(pt_dice_vals)):.4f}"

                logger.info(f"[UnifiedDistill][E{ep + 1}/{epoch_num}] Down-validate: {box_str} || " + " ; ".join(pt_bits) + tail)

            # ====== 判优（以多点交互 Dice 的均值为基准）并保存 ======
            if dice_idx != -1:
                pt_keys = sorted([k for k in down_dict.keys() if isinstance(k, int)])
                pt_dice_vals = []
                for k in pt_keys:
                    arr = down_dict[k]
                    if isinstance(arr, np.ndarray) and arr.size > dice_idx:
                        pt_dice_vals.append(float(arr[dice_idx]))
                if pt_dice_vals:
                    dice_avg_pts = float(sum(pt_dice_vals) / len(pt_dice_vals))

                    # 命中新 best：保存 state_dict（全进程同步仍用 state_dict），并在主进程额外保存完整模型
                    if dice_avg_pts > best_metric:
                        best_metric = dice_avg_pts
                        best_epoch = ep + 1
                        if best_ckpt_path_unified is not None and accelerator.is_main_process:
                            p = Path(best_ckpt_path_unified)
                            p.parent.mkdir(parents=True, exist_ok=True)
                            state = accelerator.get_state_dict(student)
                            accelerator.save(state, str(p))
                            if logger is not None:
                                logger.info(f"[UnifiedDistill][E{ep + 1}] New BEST DiceAvgPts={best_metric:.4f} -> saved state_dict to {p}")
                            if save_full_model:
                                full_path = p.with_name(p.stem + "_full" + p.suffix)
                                full_model = copy.deepcopy(accelerator.unwrap_model(student)).cpu()
                                torch.save(full_model, str(full_path))
                                if logger is not None:
                                    logger.info(f"[UnifiedDistill][E{ep + 1}] Full model saved to {full_path}")
                                del full_model
                accelerator.wait_for_everyone()
        student.train()

    # ====== 阶段结束：全员回载 best（仅当指定了路径且命中过 best）======
    if reload_best_at_end and best_ckpt_path_unified is not None and best_metric > float("-inf"):
        accelerator.wait_for_everyone()
        state = torch.load(str(best_ckpt_path_unified), map_location="cpu")
        student.load_state_dict(state, strict=True)
        if logger is not None:
            logger.info(f"[UnifiedDistill] Reloaded BEST (E{best_epoch}) DiceAvgPts={best_metric:.4f} from {best_ckpt_path_unified}")

    return {"loss_pre": last_loss_pre, "loss_down": last_loss_down}


def compensate_downstream(
    *,
    args,
    student: nn.Module,  # 完整 SAM-Med2D（image_encoder / prompt_encoder / mask_decoder）
    switcher,  # SAMPartSwitcher
    loaders: Dict[str, object],  # 需要包含 "train_down"；可选 "val_down"
    accelerator,
    epoch_num: int,
    stage_name: Optional[str] = None,  # 保留占位，不使用
    pruning_mode: Optional[str] = None,  # 保留占位，不使用
    logger=None,  # 显式传入 logger（非主进程会被你的 system 日志器静默）
    best_ckpt_path_down: Optional[str] = None,  # 新增：最优 state_dict 保存路径
    reload_best_at_end: bool = True,  # 新增：阶段结束是否回载最优
    save_full_model: bool = True,  # 新增：是否在 rank0 额外保存完整模型
) -> float:
    """
    下游补偿（保持参考逻辑）：
      ① 仅训练 adapter（image_encoder 的 adapter 部分）
      ② 仅训练 prompt_encoder + mask_decoder，多次迭代点细化
    每个 epoch 结束做一次下游验证。
    额外新增：以 DiceAvgPts 选最优，保存 state_dict（全员回载用）+ 可选完整模型（仅 rank0）。
    """
    import copy
    from pathlib import Path

    device = accelerator.device
    train_dl = loaders.get("train_down")
    val_down_dl = loaders.get("val_down")

    # 切到下游路径（这些开关对未包裹的原始模块做即可）
    student.train()
    student.image_encoder.cancel_skip_adapter_mod()
    switcher.to_interp_prompt()

    # 只训练需要的部位（先按阶段设 requires_grad，再据此建优化器）
    set_trainable_parts(
        student,
        train_image_encoder_adapter=True,  # 第一段用
        train_image_encoder_non_adapter=False,
        train_prompt_encoder=True,  # 第二段用
        train_mask_decoder=True,  # 第二段用
    )

    base_lr = float(getattr(args, "lr", 1e-4))
    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, student.parameters()),
        lr=base_lr,
    )

    # ⬇⬇⬇ 关键：分别 prepare 子模块 + 优化器 + dataloader ⬇⬇⬇
    image_encoder_prepared, prompt_encoder_prepared, mask_decoder_prepared, optimizer_prepared, train_dl, val_down_dl = accelerator.prepare(
        student.image_encoder, student.prompt_encoder, student.mask_decoder, optimizer, train_dl, val_down_dl
    )
    # ⬆⬆⬆ 注意：开关/冻结仍然对未包裹的 student 做；但凡参与反传的 forward 一定走 *_prepared ⬆⬆⬆

    lr_gamma = float(getattr(args, "lr_gamma", 1.0))
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer_prepared, gamma=lr_gamma) if lr_gamma != 1.0 else None

    seg_loss = FocalDiceloss_IoULoss()

    image_size = int(getattr(args, "sammed2d_image_size", 256))
    multimask = bool(getattr(args, "multimask", False))
    iter_point = int(getattr(args, "iter_point", 7))
    point_num = int(getattr(args, "point_num", 1))

    if logger is not None:
        logger.info(f"[Downstream] start | epochs={epoch_num}, lr={base_lr * 0.1}, iter_point={iter_point}, point_num={point_num}, multimask={multimask}")

    # === 修改后的 prompt+decoder 子例程：显式接收已包裹的两个子模块 ===
    def _prompt_and_decoder_train(
        batched_input: dict,
        prompt_encoder_prepared: nn.Module,
        mask_decoder_prepared: nn.Module,
        image_embeddings: torch.Tensor,
        decoder_iter: bool = False,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        points = None
        if batched_input.get("point_coords") is not None:
            points = (batched_input["point_coords"], batched_input.get("point_labels"))

        if decoder_iter:
            with torch.no_grad():
                sparse_emb, dense_emb = prompt_encoder_prepared(
                    points=points,
                    boxes=batched_input.get("boxes"),
                    masks=batched_input.get("mask_inputs"),
                )
        else:
            sparse_emb, dense_emb = prompt_encoder_prepared(
                points=points,
                boxes=batched_input.get("boxes"),
                masks=batched_input.get("mask_inputs"),
            )

        # 取 dense pe（从被包裹对象里拿到底层 module 再调方法最稳妥）
        pe_owner = prompt_encoder_prepared.module if hasattr(prompt_encoder_prepared, "module") else prompt_encoder_prepared
        image_pe = pe_owner.get_dense_pe()

        low_res_masks, iou_pred = mask_decoder_prepared(
            image_embeddings=image_embeddings,
            image_pe=image_pe,
            sparse_prompt_embeddings=sparse_emb,
            dense_prompt_embeddings=dense_emb,
            multimask_output=multimask,
        )

        if multimask:
            max_vals, max_idxs = torch.max(iou_pred, dim=1, keepdim=True)
            iou_pred = max_vals
            take = []
            for b_idx, idx in enumerate(max_idxs):
                take.append(low_res_masks[b_idx : b_idx + 1, idx])
            low_res_masks = torch.cat(take, dim=0)

        masks = F.interpolate(
            low_res_masks,
            size=(image_size, image_size),
            mode="bilinear",
            align_corners=False,
        )
        return masks, low_res_masks, iou_pred

    # —— 判优缓存 ——
    best_metric = float("-inf")
    best_epoch = -1
    last_avg_loss = 0.0

    for ep in range(epoch_num):
        if hasattr(train_dl, "sampler") and hasattr(train_dl.sampler, "set_epoch"):
            train_dl.sampler.set_epoch(ep)

        student.train()
        ep_loss_sum, num_steps = 0.0, 0

        pbar = tqdm(
            train_dl,
            total=len(train_dl),
            desc=f"[Downstream] Epoch {ep + 1}/{epoch_num}",
            disable=not accelerator.is_main_process,
        )

        for batch in pbar:
            batch = tensor_to_device(batch, device)
            label = batch["label"]  # [B,M,H,W]
            B, M = label.shape[:2]

            # ---- 第一段：仅训练 adapter（编码 + 一次解码）----
            set_trainable_parts(
                student,  # 对未包裹的原始模型做开关
                train_image_encoder_adapter=True,
                train_image_encoder_non_adapter=False,
                train_prompt_encoder=False,
                train_mask_decoder=False,
            )
            student.image_encoder.cancel_skip_adapter_mod()
            switcher.to_interp_prompt()

            with accelerator.autocast():
                enc_out = image_encoder_prepared(batch["input_image"])  # ✅ 走已包裹的 encoder
                emb_single = enc_out[0] if isinstance(enc_out, (tuple, list)) else enc_out
                img_emb = _expand_BM(emb_single, M)

                pc = batch.get("point_coords", None)
                pl = batch.get("point_labels", None)
                bx = batch.get("boxes", None)

                work = {
                    "label": label.view(B * M, 1, image_size, image_size),  # [N,1,H,W]
                    "point_coords": None if pc is None else pc.reshape(B * M, -1, 2),
                    "point_labels": None if pl is None else pl.reshape(B * M, -1),
                    "boxes": None if bx is None else bx.reshape(B * M, -1),
                    "mask_inputs": None,
                }
                use_box_first = (torch.rand(1, device=device) > 0.5).item() and (work["boxes"] is not None)
                if use_box_first:
                    work["point_coords"], work["point_labels"] = None, None
                else:
                    work["boxes"] = None

                masks, low_res, iou_pred = _prompt_and_decoder_train(
                    work,
                    prompt_encoder_prepared,
                    mask_decoder_prepared,
                    img_emb,
                    decoder_iter=False,
                )
                loss = seg_loss(masks, work["label"], iou_pred)

            accelerator.backward(loss)
            optimizer_prepared.step()
            optimizer_prepared.zero_grad(set_to_none=True)

            with torch.no_grad():
                emb_single = emb_single.detach()
                img_emb_p2 = _expand_BM(emb_single, M)
                masks = masks.detach()
                low_res = low_res.detach()
                iou_pred = iou_pred.detach()

            # ---- 第二段：仅训练 prompt + decoder（多次迭代点）----
            set_trainable_parts(
                student,  # 对未包裹的原始模型做开关
                train_image_encoder_adapter=False,
                train_image_encoder_non_adapter=False,
                train_prompt_encoder=True,
                train_mask_decoder=True,
            )

            pts_hist, lbl_hist = [], []
            if work.get("point_coords") is not None:
                pts_hist.append(work["point_coords"])
            if work.get("point_labels") is not None:
                lbl_hist.append(work["point_labels"])

            K = max(iter_point - 1, 0)
            for _ in range(K):
                work = generate_point(masks, work["label"], low_res, work, point_num=point_num)
                work = tensor_to_device(work, device)

                if work.get("point_coords") is not None:
                    pts_hist.append(work["point_coords"])
                if work.get("point_labels") is not None:
                    lbl_hist.append(work["point_labels"])
                if len(pts_hist) > 0:
                    work["point_coords"] = torch.cat(pts_hist, dim=1)
                if len(lbl_hist) > 0:
                    work["point_labels"] = torch.cat(lbl_hist, dim=1)

                with accelerator.autocast():
                    masks, low_res, iou_pred = _prompt_and_decoder_train(
                        work,
                        prompt_encoder_prepared,
                        mask_decoder_prepared,
                        img_emb_p2,
                        decoder_iter=True,
                    )
                    loss = seg_loss(masks, work["label"], iou_pred)

                pbar.set_postfix(loss=f"{loss.item():.6f} (adapter)")
                accelerator.backward(loss)
                optimizer_prepared.step()
                optimizer_prepared.zero_grad(set_to_none=True)

                with torch.no_grad():
                    masks = masks.detach()
                    low_res = low_res.detach()
                    iou_pred = iou_pred.detach()

            ep_loss_sum += float(loss.item())
            num_steps += 1

        if scheduler is not None:
            scheduler.step()

        last_avg_loss = ep_loss_sum / max(num_steps, 1)

        if logger is not None:
            cur_lr = optimizer_prepared.param_groups[0]["lr"]
            logger.info(f"[Downstream][E{ep + 1}/{epoch_num}] train: loss={last_avg_loss:.6f} | lr={cur_lr:.2e}")

        # ===== 验证 & 判优保存 =====
        if val_down_dl is not None:
            # 评估只需要原始 student（evaluate_* 内部会 unwrap / no_grad）
            stu_eval = student
            down_dict = evaluate_downstream(
                student=stu_eval,
                switcher=switcher,
                loader=val_down_dl,
                args=args,
                accelerator=accelerator,
                tag="down",
            )

            metrics_list = list(getattr(args, "metrics", []))
            iou_idx = metrics_list.index("iou") if "iou" in metrics_list else -1
            dice_idx = metrics_list.index("dice") if "dice" in metrics_list else -1

            if logger is not None:
                # Box-only
                box_arr = down_dict.get("final", None)
                if isinstance(box_arr, np.ndarray):
                    bits = []
                    if iou_idx != -1 and box_arr.size > iou_idx:
                        bits.append(f"IoU={float(box_arr[iou_idx]):.4f}")
                    if dice_idx != -1 and box_arr.size > dice_idx:
                        bits.append(f"Dice={float(box_arr[dice_idx]):.4f}")
                    box_str = " | ".join(bits) if bits else "metrics=N/A"
                else:
                    box_str = "metrics=N/A"

                # Points（多点）
                pt_keys = sorted([k for k in down_dict.keys() if isinstance(k, int)])
                pt_bits, pt_dice_vals = [], []
                for k in pt_keys:
                    arr = down_dict[k]
                    if isinstance(arr, np.ndarray):
                        parts = []
                        if iou_idx != -1 and arr.size > iou_idx:
                            parts.append(f"IoU={float(arr[iou_idx]):.4f}")
                        if dice_idx != -1 and arr.size > dice_idx:
                            dv = float(arr[dice_idx])
                            parts.append(f"Dice={dv:.4f}")
                            pt_dice_vals.append(dv)
                        pt_bits.append(f"P{k}: " + (" | ".join(parts) if parts else "metrics=N/A"))
                    else:
                        pt_bits.append(f"P{k}: metrics=N/A")
                tail = f" | DiceAvgPts={float(sum(pt_dice_vals) / len(pt_dice_vals)):.4f}" if pt_dice_vals else ""

                logger.info(f"[Downstream][E{ep + 1}/{epoch_num}] validate: Box: {box_str} || " + " ; ".join(pt_bits) + tail)

            # 判优保存（以多点 Dice 平均）
            if dice_idx != -1:
                pt_keys = sorted([k for k in down_dict.keys() if isinstance(k, int)])
                pt_dice_vals = []
                for k in pt_keys:
                    arr = down_dict[k]
                    if isinstance(arr, np.ndarray) and arr.size > dice_idx:
                        pt_dice_vals.append(float(arr[dice_idx]))
                if pt_dice_vals:
                    dice_avg_pts = float(sum(pt_dice_vals) / len(pt_dice_vals))
                    if dice_avg_pts > best_metric:
                        best_metric = dice_avg_pts
                        best_epoch = ep + 1
                        if best_ckpt_path_down is not None and accelerator.is_main_process:
                            p = Path(best_ckpt_path_down)
                            p.parent.mkdir(parents=True, exist_ok=True)
                            state = accelerator.get_state_dict(student)  # ✅ 用整模（未包裹）取 state_dict
                            accelerator.save(state, str(p))
                            if logger is not None:
                                logger.info(f"[Downstream][E{ep + 1}] New BEST DiceAvgPts={best_metric:.4f} -> saved state_dict to {p}")
                            if save_full_model:
                                full_path = p.with_name(p.stem + "_full" + p.suffix)
                                full_model = copy.deepcopy(accelerator.unwrap_model(student)).cpu()
                                torch.save(full_model, str(full_path))
                                if logger is not None:
                                    logger.info(f"[Downstream][E{ep + 1}] Full model saved to {full_path}")
                                del full_model
        accelerator.wait_for_everyone()

    # 阶段结束：全员回载 best
    if reload_best_at_end and best_ckpt_path_down is not None and best_metric > float("-inf"):
        accelerator.wait_for_everyone()
        state = torch.load(str(best_ckpt_path_down), map_location="cpu")
        student.load_state_dict(state, strict=True)
        if logger is not None:
            logger.info(f"[Downstream] Reloaded BEST (E{best_epoch}) DiceAvgPts={best_metric:.4f} from {best_ckpt_path_down}")

    return last_avg_loss


def compensate_downstream_adapter_only(
    *,
    args,
    student: nn.Module,  # 完整 SAM-Med2D（image_encoder / prompt_encoder / mask_decoder）
    switcher,  # SAMPartSwitcher
    loaders: Dict[str, object],  # 需要包含 "train_down"；可选 "val_down"
    accelerator,
    epoch_num: int,
    stage_name: Optional[str] = None,  # 保留占位，不使用
    pruning_mode: Optional[str] = None,  # 保留占位，不使用
    logger=None,  # 显式传入 logger（非主进程会被你的 system 日志器静默）
    best_ckpt_path_down: Optional[str] = None,  # 最优 state_dict 保存路径
    reload_best_at_end: bool = True,  # 阶段结束是否回载最优
    save_full_model: bool = True,  # 是否在 rank0 额外保存完整模型
) -> float:
    """
    下游补偿（Adapter-Only 版）：
      ① 仅训练 image_encoder 的 adapter（全程 adapter-only）
      ② 保留“多点迭代细化”第二段，但第二段也只更新 adapter（prompt/decoder 冻结）
    评测/判优/保存与原版保持一致（以 DiceAvgPts 选最优）。
    """

    device = accelerator.device
    train_dl = loaders.get("train_down")
    val_down_dl = loaders.get("val_down")

    # 切到下游路径（保持与原版一致）
    student.train()
    student.image_encoder.cancel_skip_adapter_mod()  # 全程走 adapter 路径
    switcher.to_interp_prompt()

    # —— 只训练需要的部位（全程 adapter-only）——
    set_trainable_parts(
        student,
        train_image_encoder_adapter=True,  # 训练 adapter
        train_image_encoder_non_adapter=False,
        train_prompt_encoder=False,  # 冻结 prompt_encoder
        train_mask_decoder=False,  # 冻结 mask_decoder
    )

    # 下游学习率 0.1 * lr（与参考一致；保持原注释与日志一致）
    base_lr = float(getattr(args, "lr", 1e-4))
    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, student.parameters()),
        lr=base_lr * 0.1,
    )

    # 准备子模块与 dataloader（保持原有结构）
    image_encoder_prepared, optimizer_prepared, train_dl, val_down_dl = accelerator.prepare(student.image_encoder, optimizer, train_dl, val_down_dl)

    student.prompt_encoder.to(accelerator.device)
    student.mask_decoder.to(accelerator.device)
    # 然后给本地别名，兼容你现有调用：
    prompt_encoder_prepared = student.prompt_encoder
    mask_decoder_prepared = student.mask_decoder

    lr_gamma = float(getattr(args, "lr_gamma", 1.0))
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer_prepared, gamma=lr_gamma) if lr_gamma != 1.0 else None

    seg_loss = FocalDiceloss_IoULoss()

    image_size = int(getattr(args, "sammed2d_image_size", 256))
    multimask = bool(getattr(args, "multimask", False))
    iter_point = int(getattr(args, "iter_point", 7))
    point_num = int(getattr(args, "point_num", 1))

    if logger is not None:
        logger.info(
            f"[Downstream] start | epochs={epoch_num}, lr={base_lr * 0.1}, "
            f"iter_point={iter_point}, point_num={point_num}, multimask={multimask} | mode=AdapterOnly"
        )

    # === 子例程：prompt+decoder 前向（decoder_iter=True 可对 prompt_encoder 走 no_grad）===
    def _prompt_and_decoder_train(
        batched_input: dict,
        prompt_encoder_prepared: nn.Module,
        mask_decoder_prepared: nn.Module,
        image_embeddings: torch.Tensor,
        decoder_iter: bool = False,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        与原版一致：prompt_encoder(+no_grad) + mask_decoder；multimask 取最大 IoU；上采样回原分辨率。
        注意：为了 adapter-only，建议两段均使用 decoder_iter=True（prompt no_grad，decoder 保持有梯度）。
        """
        points = None
        if batched_input.get("point_coords") is not None:
            points = (batched_input["point_coords"], batched_input.get("point_labels"))

        if decoder_iter:
            with torch.no_grad():  # prompt_encoder 冻结，走 no_grad 更省显存
                sparse_emb, dense_emb = prompt_encoder_prepared(
                    points=points,
                    boxes=batched_input.get("boxes"),
                    masks=batched_input.get("mask_inputs"),
                )
        else:
            # 保留接口兼容（但本函数中我们会统一传 True）
            sparse_emb, dense_emb = prompt_encoder_prepared(
                points=points,
                boxes=batched_input.get("boxes"),
                masks=batched_input.get("mask_inputs"),
            )

        pe_owner = prompt_encoder_prepared.module if hasattr(prompt_encoder_prepared, "module") else prompt_encoder_prepared
        image_pe = pe_owner.get_dense_pe()

        # decoder 不能 no_grad：需要 loss→decoder→image_embeddings 的梯度，才能回传到 adapter
        low_res_masks, iou_pred = mask_decoder_prepared(
            image_embeddings=image_embeddings,
            image_pe=image_pe,
            sparse_prompt_embeddings=sparse_emb,
            dense_prompt_embeddings=dense_emb,
            multimask_output=multimask,
        )

        if multimask:
            max_vals, max_idxs = torch.max(iou_pred, dim=1, keepdim=True)
            iou_pred = max_vals
            take = []
            for b_idx, idx in enumerate(max_idxs):
                take.append(low_res_masks[b_idx : b_idx + 1, idx])
            low_res_masks = torch.cat(take, dim=0)

        masks = F.interpolate(
            low_res_masks,
            size=(image_size, image_size),
            mode="bilinear",
            align_corners=False,
        )
        return masks, low_res_masks, iou_pred

    # —— 判优缓存 ——（与原版一致）
    best_metric = float("-inf")
    best_epoch = -1
    last_avg_loss = 0.0
    for ep in range(epoch_num):
        if hasattr(train_dl, "sampler") and hasattr(train_dl.sampler, "set_epoch"):
            train_dl.sampler.set_epoch(ep)

        student.train()
        ep_loss_sum, num_steps = 0.0, 0
        steps_per_epoch = len(train_dl)
        pbar = tqdm(
            train_dl,
            total=steps_per_epoch,
            desc=f"[Downstream] Epoch {ep + 1}/{epoch_num}",
            disable=not accelerator.is_main_process,
        )

        for step_idx, batch in enumerate(pbar):
            global_step = ep * steps_per_epoch + step_idx
            batch = tensor_to_device(batch, device)
            label = batch["label"]  # [B,M,H,W]
            B, M = label.shape[:2]

            # ==== 第一段：仅训练 adapter（编码 + 一次解码）====
            student.image_encoder.cancel_skip_adapter_mod()
            switcher.to_interp_prompt()

            with accelerator.autocast():
                enc_out = image_encoder_prepared(batch["input_image"])
                emb_single = enc_out[0] if isinstance(enc_out, (tuple, list)) else enc_out
                img_emb = _expand_BM(emb_single, M)

                pc = batch.get("point_coords", None)
                pl = batch.get("point_labels", None)
                bx = batch.get("boxes", None)

                work = {
                    "label": label.view(B * M, 1, image_size, image_size),  # [N,1,H,W]
                    "point_coords": None if pc is None else pc.reshape(B * M, -1, 2),
                    "point_labels": None if pl is None else pl.reshape(B * M, -1),
                    "boxes": None if bx is None else bx.reshape(B * M, -1),
                    "mask_inputs": None,
                }
                use_box_first = ((global_step & 1) == 0) and (work["boxes"] is not None)
                if use_box_first:
                    work["point_coords"], work["point_labels"] = None, None
                else:
                    work["boxes"] = None

                # prompt no_grad，decoder 进图
                masks, low_res, iou_pred = _prompt_and_decoder_train(
                    work,
                    prompt_encoder_prepared,
                    mask_decoder_prepared,
                    img_emb,
                    decoder_iter=True,
                )
                loss_first = seg_loss(masks, work["label"], iou_pred)

            # ❗第一段不 backward/不 step ——> 开始累加
            loss_total = loss_first

            # 生成下一个点时切断 mask/low_res/iou 的梯度链；保留 emb_single 的梯度链
            img_emb_p2 = _expand_BM(emb_single, M)  # 不要对 emb_single/ img_emb detach
            with torch.no_grad():
                masks = masks.detach()
                low_res = low_res.detach()
                iou_pred = iou_pred.detach()

            # ==== 第二段：多点迭代细化（仍然只更新 adapter）====
            set_trainable_parts(
                student,
                train_image_encoder_adapter=True,
                train_image_encoder_non_adapter=False,
                train_prompt_encoder=False,
                train_mask_decoder=False,
            )

            pts_hist, lbl_hist = [], []
            if work.get("point_coords") is not None:
                pts_hist.append(work["point_coords"])
            if work.get("point_labels") is not None:
                lbl_hist.append(work["point_labels"])

            K = max(iter_point - 1, 0)
            for _ in range(K):
                work = generate_point(masks, work["label"], low_res, work, point_num=point_num)
                work = tensor_to_device(work, device)

                if work.get("point_coords") is not None:
                    pts_hist.append(work["point_coords"])
                if work.get("point_labels") is not None:
                    lbl_hist.append(work["point_labels"])
                if len(pts_hist) > 0:
                    work["point_coords"] = torch.cat(pts_hist, dim=1)
                if len(lbl_hist) > 0:
                    work["point_labels"] = torch.cat(lbl_hist, dim=1)

                with accelerator.autocast():
                    masks, low_res, iou_pred = _prompt_and_decoder_train(
                        work,
                        prompt_encoder_prepared,
                        mask_decoder_prepared,
                        img_emb_p2,
                        decoder_iter=True,
                    )
                    loss_k = seg_loss(masks, work["label"], iou_pred)

                # 只累加，不回传
                loss_total = loss_total + loss_k

                with torch.no_grad():
                    masks = masks.detach()
                    low_res = low_res.detach()
                    iou_pred = iou_pred.detach()

            # ==== 两段合并，一次性回传 ====
            loss_total = loss_total / float(K + 1)  # 建议做平均，等效固定步长
            accelerator.backward(loss_total)
            optimizer_prepared.step()
            optimizer_prepared.zero_grad(set_to_none=True)

            # 进度条/统计用这个合并后的 loss
            pbar.set_postfix(loss=f"{loss_total.item():.6f} (adapter accum)")
            ep_loss_sum += float(loss_total.detach().item())
            num_steps += 1

        if scheduler is not None:
            scheduler.step()

        last_avg_loss = ep_loss_sum / max(num_steps, 1)

        if logger is not None:
            cur_lr = optimizer_prepared.param_groups[0]["lr"]
            logger.info(f"[Downstream][E{ep + 1}/{epoch_num}] train: loss={last_avg_loss:.6f} | lr={cur_lr:.2e} | mode=AdapterOnly")

        # ===== 验证 & 判优保存（与原版一致）=====
        if val_down_dl is not None:
            stu_eval = student  # 评测内部自行 no_grad/unwrap
            down_dict = evaluate_downstream(
                student=stu_eval,
                switcher=switcher,
                loader=val_down_dl,
                args=args,
                accelerator=accelerator,
                tag="down",
            )

            metrics_list = list(getattr(args, "metrics", []))
            iou_idx = metrics_list.index("iou") if "iou" in metrics_list else -1
            dice_idx = metrics_list.index("dice") if "dice" in metrics_list else -1

            if logger is not None:
                # Box-only
                box_arr = down_dict.get("final", None)
                if isinstance(box_arr, np.ndarray):
                    bits = []
                    if iou_idx != -1 and box_arr.size > iou_idx:
                        bits.append(f"IoU={float(box_arr[iou_idx]):.4f}")
                    if dice_idx != -1 and box_arr.size > dice_idx:
                        bits.append(f"Dice={float(box_arr[dice_idx]):.4f}")
                    box_str = " | ".join(bits) if bits else "metrics=N/A"
                else:
                    box_str = "metrics=N/A"

                # Points（多点）
                pt_keys = sorted([k for k in down_dict.keys() if isinstance(k, int)])
                pt_bits, pt_dice_vals = [], []
                for k in pt_keys:
                    arr = down_dict[k]
                    if isinstance(arr, np.ndarray):
                        parts = []
                        if iou_idx != -1 and arr.size > iou_idx:
                            parts.append(f"IoU={float(arr[iou_idx]):.4f}")
                        if dice_idx != -1 and arr.size > dice_idx:
                            dv = float(arr[dice_idx])
                            parts.append(f"Dice={dv:.4f}")
                            pt_dice_vals.append(dv)
                        pt_bits.append(f"P{k}: " + (" | ".join(parts) if parts else "metrics=N/A"))
                    else:
                        pt_bits.append(f"P{k}: metrics=N/A")
                tail = f" | DiceAvgPts={float(sum(pt_dice_vals) / len(pt_dice_vals)):.4f}" if pt_dice_vals else ""

                logger.info(f"[Downstream][E{ep + 1}/{epoch_num}] validate: Box: {box_str} || " + " ; ".join(pt_bits) + tail)

            # 判优保存（以多点 Dice 平均）
            if dice_idx != -1:
                pt_keys = sorted([k for k in down_dict.keys() if isinstance(k, int)])
                pt_dice_vals = []
                for k in pt_keys:
                    arr = down_dict[k]
                    if isinstance(arr, np.ndarray) and arr.size > dice_idx:
                        pt_dice_vals.append(float(arr[dice_idx]))
                if pt_dice_vals:
                    dice_avg_pts = float(sum(pt_dice_vals) / len(pt_dice_vals))
                    if dice_avg_pts > best_metric:
                        best_metric = dice_avg_pts
                        best_epoch = ep + 1
                        if best_ckpt_path_down is not None and accelerator.is_main_process:
                            p = Path(best_ckpt_path_down)
                            p.parent.mkdir(parents=True, exist_ok=True)
                            state = accelerator.get_state_dict(student)
                            accelerator.save(state, str(p))
                            if logger is not None:
                                logger.info(f"[Downstream][E{ep + 1}] New BEST DiceAvgPts={best_metric:.4f} -> saved state_dict to {p}")
                            if save_full_model:
                                full_path = p.with_name(p.stem + "_full" + p.suffix)
                                full_model = copy.deepcopy(accelerator.unwrap_model(student)).cpu()
                                torch.save(full_model, str(full_path))
                                if logger is not None:
                                    logger.info(f"[Downstream][E{ep + 1}] Full model saved to {full_path}")
                                del full_model
        accelerator.wait_for_everyone()

    # 阶段结束：全员回载 best（与原版一致）
    if reload_best_at_end and best_ckpt_path_down is not None and best_metric > float("-inf"):
        accelerator.wait_for_everyone()
        state = torch.load(str(best_ckpt_path_down), map_location="cpu")
        student.load_state_dict(state, strict=True)
        if logger is not None:
            logger.info(f"[Downstream] Reloaded BEST (E{best_epoch}) DiceAvgPts={best_metric:.4f} from {best_ckpt_path_down} | mode=AdapterOnly")

    return last_avg_loss


def compensate_distillation_slimsam(
    *,
    args,
    student: nn.Module,  # 学生：SAM-Med2D
    teacher_anchor: nn.Module,  # anchor（meta-teacher）
    switcher,  # SAMPartSwitcher
    loaders: Dict[str, object],  # 需要包含 "train_down"；可选 "val_down"
    accelerator,
    epoch_num: int,
    stage_name: Optional[str] = None,  # 保留占位，不使用
    pruning_mode: Optional[str] = None,  # 保留占位，不使用
    teacher_prev: Optional[nn.Module] = None,  # mentor（上一轮学生）
    w_pre: Optional[float] = None,  # 兼容旧签名：忽略
    w_down: Optional[float] = None,  # 主损失权重（若 None 则从 args 取或默认 1.0）
    lr: Optional[float] = None,  # 学习率（若 None 则从 args 取或默认 1e-4）
    ddp_no_sync_first_backward: bool = True,  # 保留占位，不使用
    logger=None,  # 显式传入 logger（非主进程静默）
    best_ckpt_path_unified: Optional[str] = None,  # 最优 state_dict 保存路径
    reload_best_at_end: bool = True,  # 阶段结束是否回载最优
    save_full_model: bool = True,  # 是否在 rank0 额外保存完整模型快照
    # ===== 新增：辅助蒸馏项权重（若为 None，将从 args 读取；都缺省时使用默认）=====
    lambda_attn_win: Optional[float] = None,
    lambda_attn_glb: Optional[float] = None,
    lambda_mlp_mid: Optional[float] = None,
    lambda_block: Optional[float] = None,
    lambda_adapt_mid_ch: Optional[float] = None,
    lambda_adapt_ch: Optional[float] = None,
    lambda_adapt_sp_mid: Optional[float] = None,
    aux_warmup_frac: Optional[float] = None,  # 前多少训练进度启用辅助项（默认 0.25）
    prefer_mentor: Optional[bool] = None,  # 是否优先对齐 mentor（默认 True）
) -> Dict[str, float]:
    """
    只用下游数据的“统一蒸馏补偿（全调 encoder）”：
      • 数据：仅 DOWN（train_down/val_down）
      • 训练：Adapter 与非-Adapter 同时可训练；prompt/mask-decoder 不训练（保持原逻辑）
      • 损失：主损失对 anchor 主输出；7 路中间输出在 非meta & 形状匹配 时加入辅助 MSE
    """

    device = accelerator.device
    train_down_dl = loaders.get("train_down")
    val_down_dl = loaders.get("val_down")

    # ---- 超参（从参数→args→默认）----
    def _get(hint, argname, default):
        if hint is not None:
            return hint
        if hasattr(args, argname):
            v = getattr(args, argname)
            if v is not None:
                return v
        return default

    w_down = float(_get(w_down, "w_down", 1.0))
    lr = float(_get(lr, "lr", 1e-4))

    lambda_attn_win = float(_get(lambda_attn_win, "lambda_attn_win", 0.2))
    lambda_attn_glb = float(_get(lambda_attn_glb, "lambda_attn_glb", 0.2))
    lambda_mlp_mid = float(_get(lambda_mlp_mid, "lambda_mlp_mid", 0.2))
    lambda_block = float(_get(lambda_block, "lambda_block", 0.2))
    lambda_adapt_mid_ch = float(_get(lambda_adapt_mid_ch, "lambda_adapt_mid_ch", 0.1))
    lambda_adapt_ch = float(_get(lambda_adapt_ch, "lambda_adapt_ch", 0.1))
    lambda_adapt_sp_mid = float(_get(lambda_adapt_sp_mid, "lambda_adapt_sp_mid", 0.1))
    aux_warmup_frac = float(_get(aux_warmup_frac, "aux_warmup_frac", 0.25))
    prefer_mentor = bool(_get(prefer_mentor, "prefer_mentor", True))

    # ---- 日志 ----
    if logger is not None:
        logger.info(
            "[UnifiedDistill-DownOnly] start | "
            f"epochs={epoch_num}, lr={lr}, w_main={w_down}, mentor={'on' if teacher_prev is not None else 'off'} | "
            f"λ(win={lambda_attn_win}, glb={lambda_attn_glb}, mlp={lambda_mlp_mid}, block={lambda_block}, "
            f"ad_mid_ch={lambda_adapt_mid_ch}, ad_ch={lambda_adapt_ch}, ad_sp_mid={lambda_adapt_sp_mid}); "
            f"aux_warmup={aux_warmup_frac}"
        )

    # 仅优化 image_encoder（与原风格一致，但“全调”指 encoder 内 adapter+非adapter 都训）
    enc = student.image_encoder
    optimizer = torch.optim.Adam(enc.parameters(), lr=lr)

    # DDP/AMP 准备：仅下游
    enc_prep, optimizer, train_down_dl, val_down_dl = accelerator.prepare(enc, optimizer, train_down_dl, val_down_dl)

    mse = nn.MSELoss()
    last_loss_down = 0.0

    # —— 最优缓存（以 DiceAvgPts 为准）——
    best_metric = float("-inf")
    best_epoch = -1

    student.train()
    teacher_anchor.eval()
    if teacher_prev is not None:
        teacher_prev.eval()

    # ====== 辅助：安全 MSE 计算 ======
    def _safe_mse(stu_t: torch.Tensor, ref_t: torch.Tensor, w: float) -> torch.Tensor:
        if not (isinstance(stu_t, torch.Tensor) and isinstance(ref_t, torch.Tensor)):
            return torch.tensor(0.0, device=device)
        if stu_t.is_meta or ref_t.is_meta:
            return torch.tensor(0.0, device=device)
        if stu_t.numel() == 0 or ref_t.numel() == 0:
            return torch.tensor(0.0, device=device)
        if stu_t.shape != ref_t.shape:
            return torch.tensor(0.0, device=device)
        # 设备 & dtype 对齐
        if ref_t.device != stu_t.device:
            ref_t = ref_t.to(stu_t.device, non_blocking=True)
        if ref_t.dtype != stu_t.dtype:
            ref_t = ref_t.to(stu_t.dtype)
        # 数值有效性
        if not (torch.isfinite(stu_t).all() and torch.isfinite(ref_t).all()):
            return torch.tensor(0.0, device=device)
        return w * mse(stu_t, ref_t)

    # ====== 训练循环（仅 DOWN）======
    for ep in range(epoch_num):
        # 采样器播种
        if hasattr(train_down_dl, "sampler") and hasattr(train_down_dl.sampler, "set_epoch"):
            train_down_dl.sampler.set_epoch(ep)

        num_steps = len(train_down_dl)
        pbar = tqdm(
            train_down_dl,
            total=num_steps,
            desc=f"[Unified-Down] Epoch {ep + 1}/{epoch_num}",
            disable=not accelerator.is_main_process,
        )

        # 统一切换：下游提示、打开 Adapter，且“全调” encoder
        # （prompt/mask-decoder 保持不训练，符合原“只蒸馏 encoder”的设定）
        switcher.to_interp_prompt()
        student.image_encoder.cancel_skip_adapter_mod()
        teacher_anchor.image_encoder.cancel_skip_adapter_mod()
        if teacher_prev is not None:
            teacher_prev.image_encoder.cancel_skip_adapter_mod()

        for batch_down in pbar:
            optimizer.zero_grad(set_to_none=True)
            frac = float(ep + 1) / float(max(1, epoch_num))

            # —— 设置可训练部分（全调 encoder）——
            set_trainable_parts(
                accelerator.unwrap_model(student),
                train_image_encoder=False,  # 仍使用细粒度开关
                train_image_encoder_adapter=True,  # Adapter 可训练
                train_image_encoder_non_adapter=True,  # 非-Adapter 可训练
                train_prompt_encoder=False,
                train_mask_decoder=False,
            )

            # 准备输入
            img_down = (batch_down if torch.is_tensor(batch_down) else batch_down.get("input_image")).to(device=device, dtype=torch.float32)

            # 前向：学生
            with accelerator.autocast():
                (
                    stu_main,
                    stu_attn_mid_win,  # idx 1
                    stu_attn_mid_glb,  # idx 2
                    stu_mlp_mid,  # idx 3
                    stu_block_emb,  # idx 4
                    stu_ad_mid_ch,  # idx 5
                    stu_ad_ch,  # idx 6
                    stu_ad_sp_mid,  # idx 7
                ) = enc_prep(img_down)

            # 前向：老师（anchor & mentor）
            with torch.no_grad():
                (
                    anch_main,
                    anch_attn_mid_win,
                    anch_attn_mid_glb,
                    anch_mlp_mid,
                    anch_block_emb,
                    anch_ad_mid_ch,
                    anch_ad_ch,
                    anch_ad_sp_mid,
                ) = teacher_anchor.image_encoder(img_down)

                if teacher_prev is not None:
                    (
                        men_main,
                        men_attn_mid_win,
                        men_attn_mid_glb,
                        men_mlp_mid,
                        men_block_emb,
                        men_ad_mid_ch,
                        men_ad_ch,
                        men_ad_sp_mid,
                    ) = teacher_prev.image_encoder(img_down)
                else:
                    # 无 mentor 时，mentor = anchor（避免分支判断）
                    men_main = anch_main
                    men_attn_mid_win = anch_attn_mid_win
                    men_attn_mid_glb = anch_attn_mid_glb
                    men_mlp_mid = anch_mlp_mid
                    men_block_emb = anch_block_emb
                    men_ad_mid_ch = anch_ad_mid_ch
                    men_ad_ch = anch_ad_ch
                    men_ad_sp_mid = anch_ad_sp_mid

            # ====== 损失：主 + 辅助 ======
            # 主损失永远对齐 anchor 主输出
            with accelerator.autocast():
                loss_main = mse(stu_main, anch_main)

                # 是否启用辅助项（早期对齐风格）
                use_aux = (frac <= aux_warmup_frac) if aux_warmup_frac is not None else True

                loss_aux = torch.tensor(0.0, device=device)
                if use_aux:
                    # 参考张量：优先 mentor（可通过 prefer_mentor 控制）
                    ref_attn_win = men_attn_mid_win if prefer_mentor else anch_attn_mid_win
                    ref_attn_glb = men_attn_mid_glb if prefer_mentor else anch_attn_mid_glb
                    ref_mlp_mid = men_mlp_mid if prefer_mentor else anch_mlp_mid
                    ref_block_emb = men_block_emb if prefer_mentor else anch_block_emb
                    ref_ad_mid_ch = men_ad_mid_ch if prefer_mentor else anch_ad_mid_ch
                    ref_ad_ch = men_ad_ch if prefer_mentor else anch_ad_ch
                    ref_ad_sp_mid = men_ad_sp_mid if prefer_mentor else anch_ad_sp_mid

                    loss_aux = (
                        _safe_mse(stu_attn_mid_win, ref_attn_win, lambda_attn_win)
                        + _safe_mse(stu_attn_mid_glb, ref_attn_glb, lambda_attn_glb)
                        + _safe_mse(stu_mlp_mid, ref_mlp_mid, lambda_mlp_mid)
                        + _safe_mse(stu_block_emb, ref_block_emb, lambda_block)
                        + _safe_mse(stu_ad_mid_ch, ref_ad_mid_ch, lambda_adapt_mid_ch)
                        + _safe_mse(stu_ad_ch, ref_ad_ch, lambda_adapt_ch)
                        + _safe_mse(stu_ad_sp_mid, ref_ad_sp_mid, lambda_adapt_sp_mid)
                    )

                loss = w_down * loss_main + loss_aux

            accelerator.backward(loss)
            optimizer.step()

            last_loss_down = float(loss.detach().item())
            pbar.set_postfix(loss_down=f"{last_loss_down:.4f}")

        # ===== 每个 epoch 结束：下游验证+判优 =====
        stu_eval = accelerator.unwrap_model(student) if hasattr(accelerator, "unwrap_model") else student

        if val_down_dl is not None:
            down_dict = evaluate_downstream(
                student=stu_eval,
                switcher=switcher,
                loader=val_down_dl,
                args=args,
                accelerator=accelerator,
                tag="down",
            )
            if logger is not None:
                # 组装日志
                metrics_list = list(getattr(args, "metrics", []))
                iou_idx = metrics_list.index("iou") if "iou" in metrics_list else -1
                dice_idx = metrics_list.index("dice") if "dice" in metrics_list else -1

                # Box-only
                box_arr = down_dict.get("final", None)
                box_str = "Box: "
                if isinstance(box_arr, np.ndarray):
                    box_bits = []
                    if iou_idx != -1 and box_arr.size > iou_idx:
                        box_bits.append(f"IoU={float(box_arr[iou_idx]):.4f}")
                    if dice_idx != -1 and box_arr.size > dice_idx:
                        box_bits.append(f"Dice={float(box_arr[dice_idx]):.4f}")
                    box_str += " | ".join(box_bits) if box_bits else "metrics=N/A"
                else:
                    box_str += "metrics=N/A"

                # Points（多点）
                pt_keys = sorted([k for k in down_dict.keys() if isinstance(k, int)])
                pt_bits = []
                pt_dice_vals = []
                for k in pt_keys:
                    arr = down_dict[k]
                    one = []
                    if isinstance(arr, np.ndarray):
                        if iou_idx != -1 and arr.size > iou_idx:
                            one.append(f"IoU={float(arr[iou_idx]):.4f}")
                        if dice_idx != -1 and arr.size > dice_idx:
                            dv = float(arr[dice_idx])
                            one.append(f"Dice={dv:.4f}")
                            pt_dice_vals.append(dv)
                    pt_bits.append(f"P{k}: " + (" | ".join(one) if one else "metrics=N/A"))

                tail = ""
                if pt_dice_vals:
                    tail = f" | DiceAvgPts={float(sum(pt_dice_vals) / len(pt_dice_vals)):.4f}"

                logger.info(f"[UnifiedDistill-DownOnly][E{ep + 1}/{epoch_num}] Down-validate: {box_str} || " + " ; ".join(pt_bits) + tail)

            # 判优保存
            metrics_list = list(getattr(args, "metrics", []))
            dice_idx = metrics_list.index("dice") if "dice" in metrics_list else -1
            if dice_idx != -1:
                pt_keys = sorted([k for k in down_dict.keys() if isinstance(k, int)])
                pt_dice_vals = []
                for k in pt_keys:
                    arr = down_dict[k]
                    if isinstance(arr, np.ndarray) and arr.size > dice_idx:
                        pt_dice_vals.append(float(arr[dice_idx]))
                if pt_dice_vals:
                    dice_avg_pts = float(sum(pt_dice_vals) / len(pt_dice_vals))

                    if dice_avg_pts > best_metric:
                        best_metric = dice_avg_pts
                        best_epoch = ep + 1
                        if best_ckpt_path_unified is not None and accelerator.is_main_process:
                            p = Path(best_ckpt_path_unified)
                            p.parent.mkdir(parents=True, exist_ok=True)
                            state = accelerator.get_state_dict(student)
                            accelerator.save(state, str(p))
                            if logger is not None:
                                logger.info(f"[UnifiedDistill-DownOnly][E{ep + 1}] New BEST DiceAvgPts={best_metric:.4f} -> saved state_dict to {p}")
                            if save_full_model:
                                full_path = p.with_name(p.stem + "_full" + p.suffix)
                                full_model = copy.deepcopy(accelerator.unwrap_model(student)).cpu()
                                torch.save(full_model, str(full_path))
                                if logger is not None:
                                    logger.info(f"[UnifiedDistill-DownOnly][E{ep + 1}] Full model saved to {full_path}")
                                del full_model
                accelerator.wait_for_everyone()
        student.train()

    # ====== 结束：按需回载最优 ======
    if reload_best_at_end and best_ckpt_path_unified is not None and best_metric > float("-inf"):
        accelerator.wait_for_everyone()
        state = torch.load(str(best_ckpt_path_unified), map_location="cpu")
        student.load_state_dict(state, strict=True)
        if logger is not None:
            logger.info(f"[UnifiedDistill-DownOnly] Reloaded BEST (E{best_epoch}) DiceAvgPts={best_metric:.4f} from {best_ckpt_path_unified}")

    # 兼容旧返回结构
    return {"loss_pre": 0.0, "loss_down": last_loss_down}
