# assp/core/eval.py
from __future__ import annotations

from collections import OrderedDict
from typing import Dict, List, Sequence, Tuple

import numpy as np
import torch
import torch.nn.functional as F
from tqdm.auto import tqdm

# 你已有的工具函数（路径按你的工程组织来调）
from assp.core.utils import generate_point, tensor_to_device

# SAM 官方工具：用于把低分辨率 mask 还原到原图尺寸
from assp.models.sammed2d.utils.transforms import ResizeLongestSide

_EPS = 1e-12


# ---------- helpers ----------
def _get_medsz(args) -> int:
    """兼容两种字段名：优先 medsam_image_size，其次 sammed2d_image_size。"""
    if hasattr(args, "medsam_image_size") and getattr(args, "medsam_image_size") is not None:
        return int(getattr(args, "medsam_image_size"))
    return int(getattr(args, "sammed2d_image_size"))


def _get_metrics_list(args) -> List[str]:
    if not hasattr(args, "metrics") or args.metrics is None:
        return ["iou", "dice"]
    return list(args.metrics)


def ddp_weighted_mean(
    vec: torch.Tensor,  # shape [K]
    count: torch.Tensor,  # shape [1]
    accelerator,
) -> torch.Tensor:
    """跨卡加权平均：sum(vec) / sum(count)。只同步标量（稳定且轻量）。"""
    vec = vec.to(accelerator.device, dtype=torch.float32)
    count = count.to(accelerator.device, dtype=torch.float32)
    vec_sum = accelerator.reduce(vec, reduction="sum")
    cnt_sum = accelerator.reduce(count, reduction="sum")
    denom = torch.clamp(cnt_sum, min=1.0)
    return vec_sum / denom


# ---------- 纯张量度量，保持你的逻辑 ----------
def iou_score(pred: np.ndarray | torch.Tensor, gt: np.ndarray | torch.Tensor) -> float:
    p = (pred > 0).astype(np.bool_) if isinstance(pred, np.ndarray) else (torch.sigmoid(pred) > 0.5).cpu().numpy()
    g = (gt > 0).astype(np.bool_) if isinstance(gt, np.ndarray) else (gt > 0.5).cpu().numpy()
    inter = np.logical_and(p, g).sum()
    union = np.logical_or(p, g).sum()
    return (inter + _EPS) / (union + _EPS)


def dice_score(pred: np.ndarray | torch.Tensor, gt: np.ndarray | torch.Tensor) -> float:
    p = (pred > 0).astype(np.bool_) if isinstance(pred, np.ndarray) else (torch.sigmoid(pred) > 0.5).cpu().numpy()
    g = (gt > 0).astype(np.bool_) if isinstance(gt, np.ndarray) else (gt > 0.5).cpu().numpy()
    inter = np.logical_and(p, g).sum()
    return (2 * inter + _EPS) / (p.sum() + g.sum() + _EPS)


def SegMetrics(pred: torch.Tensor, gt: torch.Tensor, metrics: List[str]) -> np.ndarray:
    scores = []
    if isinstance(metrics, str):
        metrics = [metrics]
    for m in metrics:
        if m == "iou":
            scores.append(iou_score(pred, gt).mean())
        elif m == "dice":
            scores.append(dice_score(pred, gt).mean())
        else:
            raise ValueError(f"Unknown metric {m}")
    return np.array(scores, dtype=np.float64)


def postprocess_masks_pre(
    masks: torch.Tensor,
    padded_size: Tuple[int, int],
    original_size: Tuple[int, int],
    model_img_size: int,
) -> torch.Tensor:
    """Remove padding and upscale to original image size."""
    masks = F.interpolate(masks, (model_img_size, model_img_size), mode="bilinear", align_corners=False)
    masks = masks[..., : padded_size[0], : padded_size[1]]
    return F.interpolate(masks, original_size, mode="bilinear", align_corners=False)


def prompt_and_decoder_eval(args, batched_input, model, image_embeddings):
    if batched_input["point_coords"] is not None:
        points = (batched_input["point_coords"], batched_input["point_labels"])
    else:
        points = None

    with torch.no_grad():
        sparse_embeddings, dense_embeddings = model.prompt_encoder(
            points=points,
            boxes=batched_input.get("boxes", None),
            masks=batched_input.get("mask_inputs", None),
        )

        low_res_masks, iou_predictions = model.mask_decoder(
            image_embeddings=image_embeddings,
            image_pe=model.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=args.multimask,
        )

    if args.multimask:
        max_values, max_indexs = torch.max(iou_predictions, dim=1)
        max_values = max_values.unsqueeze(1)
        iou_predictions = max_values
        low_res = []
        for i, idx in enumerate(max_indexs):
            low_res.append(low_res_masks[i : i + 1, idx])
        low_res_masks = torch.stack(low_res, 0)
    masks = F.interpolate(
        low_res_masks,
        (args.medsam_image_size, args.medsam_image_size),
        mode="bilinear",
        align_corners=False,
    )
    return masks, low_res_masks, iou_predictions


# ---------- 上游评测（单正点） ----------
def evaluate_pretrain(
    student: torch.nn.Module,
    switcher,  # SAMPartSwitcher
    loader,
    args,
    *,
    accelerator,  # Accelerate 统一处理
    tag: str = "pre",
) -> np.ndarray:
    """
    • 对 SA-1B 验证集做 single-positive-point 分割
    • 返回全局平均的指标数组（如 [IoU, Dice]）
    """
    student = accelerator.unwrap_model(student)
    student.eval()

    # 上游路径：禁 adapter、用 orig prompt
    if hasattr(student.image_encoder, "set_skip_adapter_mod"):
        student.image_encoder.set_skip_adapter_mod()
    if hasattr(switcher, "to_orig_prompt"):
        switcher.to_orig_prompt()

    # 与你原来的实现保持一致：用 1024 做坐标缩放 & mask 还原到原图
    transform = ResizeLongestSide(1024)
    thr = getattr(student, "mask_threshold", 0.0)
    metrics_list = _get_metrics_list(args)

    metric_sum = torch.zeros(len(metrics_list), device=accelerator.device, dtype=torch.float32)
    samp_total = torch.zeros(1, device=accelerator.device, dtype=torch.float32)

    with accelerator.autocast(), torch.no_grad():
        pbar = tqdm(loader, disable=not accelerator.is_main_process, desc=f"val_{tag}")
        for batch in pbar:
            batch = tensor_to_device(batch, accelerator.device)
            img_emb = student.image_encoder(batch["input_image"])[0]  # [B, C, H', W']

            sub_scores = torch.zeros(len(metrics_list), device=accelerator.device, dtype=torch.float32)
            sub_cnt = 0

            for k, annos in enumerate(batch["annotation_path"]):
                if isinstance(annos, str):
                    import json

                    from pycocotools import mask as mask_utils

                    with open(annos, "r", encoding="utf-8") as f:
                        annos = json.load(f)["annotations"]

                for anno in annos:
                    # GT mask
                    gt = anno["segmentation"]
                    if isinstance(gt, dict) and "counts" in gt:
                        from pycocotools import mask as mask_utils

                        gt = mask_utils.decode(gt)
                    gt = torch.as_tensor(gt, device=accelerator.device).unsqueeze(0).unsqueeze(0)  # [1,1,H,W]

                    # 单正点 prompt（保持 1024 坐标缩放一致性）
                    orig_img_sz = tuple(batch["original_image_size"][k].tolist())
                    input_img_sz = tuple(batch["input_image_size"][k].tolist())
                    coords = np.array(anno["point_coords"])
                    lbls = np.array([1], dtype=np.int64)
                    pt_xy = transform.apply_coords(coords, orig_img_sz)
                    points = (
                        torch.as_tensor(pt_xy, device=accelerator.device)[None],
                        torch.as_tensor(lbls, device=accelerator.device)[None],
                    )

                    # 前向：prompt_encoder + mask_decoder
                    sparse, dense = student.prompt_encoder(points=points, boxes=None, masks=None)
                    low_res, _ = student.mask_decoder(
                        image_embeddings=img_emb[k][None],
                        image_pe=student.prompt_encoder.get_dense_pe(),
                        sparse_prompt_embeddings=sparse,
                        dense_prompt_embeddings=dense,
                        multimask_output=False,
                    )

                    # 还原到原图尺寸 —— 改回你提供的实现（与原流程一致）
                    pred = postprocess_masks_pre(
                        low_res,
                        padded_size=input_img_sz,
                        original_size=orig_img_sz,
                        model_img_size=1024,  # 与上面的 ResizeLongestSide(1024) 保持一致
                    )
                    pred = (pred > thr).float()

                    vals = SegMetrics(pred, gt, metrics_list)  # numpy array
                    sub_scores += torch.tensor(vals, device=accelerator.device, dtype=torch.float32)
                    sub_cnt += 1

            if sub_cnt > 0:
                sub_scores /= float(sub_cnt)
                metric_sum += sub_scores
                samp_total += 1

    metric_mean = ddp_weighted_mean(metric_sum, samp_total, accelerator).cpu().numpy()
    accelerator.wait_for_everyone()
    student.train()
    return metric_mean


# ---------- 下游评测（Box + 多点迭代） ----------
def evaluate_downstream(
    student: torch.nn.Module,
    switcher,  # SAMPartSwitcher
    loader,
    args,
    *,
    accelerator,  # 新增：Accelerate 统一处理
    tag: str = "down",
) -> Dict[int | str, np.ndarray]:
    """
    一次遍历同时完成 Box 提示与 Point(1,3,5,7) 的评测。
    返回：
        { 1: np.ndarray, 3: np.ndarray, 5: np.ndarray, 7: np.ndarray, "final": np.ndarray }
    其中 "final" 是 Box 的指标均值。
    """
    student = accelerator.unwrap_model(student)
    student.eval()

    # 切到下游路径：开 adapter、用 interp prompt
    if hasattr(student.image_encoder, "cancel_skip_adapter_mod"):
        student.image_encoder.cancel_skip_adapter_mod()
    if hasattr(switcher, "to_interp_prompt"):
        switcher.to_interp_prompt()

    metrics_list = _get_metrics_list(args)
    medsz = _get_medsz(args)

    points_to_evaluate: Sequence[int] = (1, 3, 5, 7)
    max_points = max(points_to_evaluate)
    num_metrics = len(metrics_list)
    dice_idx = metrics_list.index("dice") if "dice" in metrics_list else None

    # 逐卡累加（稍后做 ddp_weighted_mean）
    pt_metric_sum = {p: torch.zeros(num_metrics, device=accelerator.device, dtype=torch.float32) for p in points_to_evaluate}
    pt_counts = {p: torch.zeros(1, device=accelerator.device, dtype=torch.float32) for p in points_to_evaluate}
    box_metric_sum = torch.zeros(num_metrics, device=accelerator.device, dtype=torch.float32)
    box_counts = torch.zeros(1, device=accelerator.device, dtype=torch.float32)

    def _update_postfix(pbar):
        if not accelerator.is_main_process:
            return
        post = OrderedDict()
        if box_counts.item() > 0:
            box_mean = (box_metric_sum / box_counts.clamp(min=1)).detach().cpu().numpy()
            post["Box_IoU"] = f"{box_mean[0]:.4f}"
            if dice_idx is not None:
                post["Box_Dice"] = f"{box_mean[dice_idx]:.4f}"
        if dice_idx is not None:
            pt_dices = []
            for pt in points_to_evaluate:
                if pt_counts[pt].item() > 0:
                    mean_pt = (pt_metric_sum[pt] / pt_counts[pt].clamp(min=1)).detach().cpu().numpy()
                    pt_dices.append(mean_pt[dice_idx])
                    post[f"P{pt}_Dice"] = f"{mean_pt[dice_idx]:.4f}"
            if pt_dices:
                post["DiceAvg"] = f"{float(np.mean(pt_dices)):.4f}"
        if post:
            pbar.set_postfix(post)

    with accelerator.autocast(), torch.no_grad():
        pbar = tqdm(
            loader,
            desc=f"val_{tag} (Box+Point)",
            disable=not accelerator.is_main_process,
        )

        for batch in pbar:
            batch = tensor_to_device(batch, accelerator.device)
            B, M = batch["label"].shape[:2]
            N = B * M
            labels = batch["label"].reshape(N, 1, medsz, medsz)

            # 一次编码，多次解码
            emb_single = student.image_encoder(batch["input_image"])[0]  # [B,C,H',W']
            image_embeddings = emb_single.unsqueeze(1).expand(-1, M, -1, -1, -1)  # [B,M,C,H',W']
            image_embeddings = image_embeddings.reshape(N, *emb_single.shape[1:])  # [N,C,H',W']

            # Box 评测（如果 batch 有 box）
            if "boxes" in batch and batch["boxes"] is not None:
                box_prompts = {
                    "boxes": batch["boxes"].reshape(N, -1),
                    "point_coords": None,
                    "point_labels": None,
                }
                masks_box, _, _ = prompt_and_decoder_eval(args, box_prompts, student, image_embeddings)
                box_batch_metrics = SegMetrics(masks_box, labels, metrics_list)  # numpy
                box_metric_sum += torch.tensor(box_batch_metrics, device=accelerator.device, dtype=torch.float32) * float(N)
                box_counts += float(N)

            # Point 评测（如果 batch 有 seed points）
            if "point_coords" in batch and batch["point_coords"] is not None:
                point_prompts = {
                    "point_coords": batch["point_coords"].reshape(N, -1, 2),
                    "point_labels": batch["point_labels"].reshape(N, -1),
                    "boxes": None,
                }
                pc_hist = [point_prompts["point_coords"]]
                pl_hist = [point_prompts["point_labels"]]

                for it in range(max_points):
                    cur_pts = it + 1
                    masks_pt, low_res, _ = prompt_and_decoder_eval(args, point_prompts, student, image_embeddings)
                    if cur_pts in points_to_evaluate:
                        pt_batch_metrics = SegMetrics(masks_pt, labels, metrics_list)
                        pt_metric_sum[cur_pts] += torch.tensor(
                            pt_batch_metrics,
                            device=accelerator.device,
                            dtype=torch.float32,
                        ) * float(N)
                        pt_counts[cur_pts] += float(N)

                    if it < max_points - 1:
                        point_prompts = generate_point(
                            masks_pt,
                            labels,
                            low_res,
                            point_prompts,
                            getattr(args, "point_num", 9),
                        )
                        point_prompts = tensor_to_device(point_prompts, accelerator.device)
                        pc_hist.append(point_prompts["point_coords"])
                        pl_hist.append(point_prompts["point_labels"])
                        point_prompts["point_coords"] = torch.cat(pc_hist, dim=1)
                        point_prompts["point_labels"] = torch.cat(pl_hist, dim=1)

            _update_postfix(pbar)

    # 跨卡聚合（加权平均）
    out: Dict[int | str, np.ndarray] = {}
    for pt in points_to_evaluate:
        out[pt] = ddp_weighted_mean(pt_metric_sum[pt], pt_counts[pt], accelerator).cpu().numpy()
    out["final"] = ddp_weighted_mean(box_metric_sum, box_counts, accelerator).cpu().numpy()

    accelerator.wait_for_everyone()
    student.train()
    return out
