import math
from typing import Dict, List, Optional, Sequence
import torch
import torch.nn.functional as F
from typing import Dict, List, Any
import re
from collections import defaultdict, Counter
import math

_TOP_RE = re.compile(r"^top(\d+)$")


@torch.no_grad()
def evaluate_avsync(
    visual: Optional[torch.Tensor] = None,  # [T, C] 或 None
    audio: Optional[torch.Tensor] = None,  # [T, C] 或 None
    *,
    # 跳过特征->距离的路径，直接给窗口距离矩阵（小=更相似）
    distances: Optional[torch.Tensor] = None,  # [T, win_size]
    valid_mask: Optional[
        torch.Tensor
    ] = None,  # [T, win_size]，True=有效；给 distances 时可选
    # 仅用于 k-smooth（窗口匹配）评估
    win_size: int = 31,
    k_list: Sequence[int] = (5, 7, 9, 11, 13, 15),
    accept_radius: int = 1,
    # 若未给 distances，需要从特征计算距离矩阵 D 的设置
    distance: str = "euclidean_normed",  # "euclidean_normed" | "cosine" | "euclidean" | "p"
    p: float = 2.0,
    normalize: bool = True,
    # 检索指标设置（批内 Top-k）
    batch_size: int = 32,
    topk: int = 3,
    num_batches: Optional[int] = None,
    seed: int = 0,
    # 额外：直接给“检索用”的全时距阵（小=更相似），即可计算检索指标
    retrieval_dist: Optional[torch.Tensor] = None,  # [T, T]
    device: Optional[torch.device] = None,
) -> Dict[str, object]:
    """
    返回:
      {
        "k_smooth": {
            K: {"acc": float, "mean_offset": float, "mean_abs_offset": float}, ...
        },
        "retrieval": {
            "top1": float, ..., f"top{topk}": float,
            "k": int, "batch_size": int, "num_batches": int
        }  # 若既无 visual/audio 也无 retrieval_dist，则该字段为空 {}
      }
    """
    # --------------------------
    # 1) 距离矩阵 D:[T, win_size]（用于 k-smooth）
    # --------------------------
    if distances is not None:
        assert distances.dim() == 2, "distances 必须是 [T, win_size]"
        D = distances
        T, W = D.shape
        if win_size != W:
            win_size = W
        vshift = win_size // 2
        if device is None:
            device = D.device
        D = D.to(device)

        if valid_mask is not None:
            assert valid_mask.shape == D.shape and valid_mask.dtype == torch.bool
            D = torch.where(valid_mask.to(device), D, torch.full_like(D, float("inf")))

        # 特征路径不可用
        v = a = None
        can_build_retrieval_from_feats = False

    else:
        # 用 visual/audio 构造 D
        assert (
            visual is not None and audio is not None
        ), "需提供 (visual, audio) 或 distances"
        assert visual.dim() == 2 and audio.dim() == 2 and visual.shape == audio.shape
        T, C = visual.shape
        assert win_size % 2 == 1 and win_size >= 3
        vshift = win_size // 2
        if device is None:
            device = visual.device

        v = visual.to(device)
        a = audio.to(device)

        if normalize and distance in ("euclidean_normed", "cosine"):
            v = F.normalize(v, p=2, dim=-1)
            a = F.normalize(a, p=2, dim=-1)

        a_pad = F.pad(a, (0, 0, vshift, vshift))  # [T+2v, C]
        a_win = torch.stack(
            [a_pad[i : i + win_size, :] for i in range(T)], dim=0
        )  # [T, W, C]

        if distance == "euclidean_normed":
            sims = (v.unsqueeze(1) * a_win).sum(dim=-1)  # [T, W]
            D = torch.sqrt(torch.clamp(2.0 - 2.0 * sims, min=0.0) + 1e-12)
        elif distance == "cosine":
            sims = (v.unsqueeze(1) * a_win).sum(dim=-1)
            D = 1.0 - sims
        elif distance == "euclidean":
            D = torch.cdist(v.unsqueeze(1), a_win, p=2.0).squeeze(1)
        elif distance == "p":
            D = torch.cdist(v.unsqueeze(1), a_win, p=float(p)).squeeze(1)
        else:
            raise ValueError(f"unknown distance='{distance}'")

        # 边界非法 Δ -> +inf
        deltas = torch.arange(-vshift, vshift + 1, device=device)  # [W]
        t_idx = torch.arange(T, device=device).unsqueeze(1)  # [T,1]
        valid = (t_idx + deltas).clamp(min=0, max=T - 1).eq(t_idx + deltas)
        D = torch.where(valid, D, torch.full_like(D, float("inf")))

        can_build_retrieval_from_feats = True

    # --------------------------
    # 2) k 帧平滑 + 偏移/准确率
    # --------------------------
    def moving_average_time(x: torch.Tensor, K: int) -> torch.Tensor:
        if K <= 1:
            return x
        left = K // 2
        right = K - 1 - left
        x_hw = x.t().unsqueeze(0)  # [1, W, T]
        w = torch.ones((x_hw.size(1), 1, K), device=x.device) / K  # [W, 1, K]
        y = F.conv1d(
            F.pad(x_hw, (left, right), mode="replicate"), w, groups=x_hw.size(1)
        )
        return y.squeeze(0).t()  # [T, W]

    k_smooth_out: Dict[int, Dict[str, float]] = {}
    for K in k_list:
        Dk = moving_average_time(D, int(K))
        min_idx = torch.argmin(Dk, dim=1)  # [T]
        pred_offset = (vshift - min_idx).to(torch.int32)  # 正=视频领先音频
        hit = pred_offset.abs() <= accept_radius
        k_smooth_out[int(K)] = {
            "acc": float(hit.float().mean().item()),
            "mean_offset": float(pred_offset.float().mean().item()),
            "mean_abs_offset": float(pred_offset.float().abs().mean().item()),
        }

    # --------------------------
    # 3) 批内检索 Top-1..Top-K
    # --------------------------
    retrieval_out: Dict[str, float] = {}
    have_retrieval_matrix = retrieval_dist is not None

    if have_retrieval_matrix or can_build_retrieval_from_feats:
        if num_batches is None:
            num_batches = max(1, math.ceil(T / batch_size))
        if device is None:
            device = retrieval_dist.device if have_retrieval_matrix else v.device

        g = torch.Generator(device=device).manual_seed(seed)

        # (A) 若提供 retrieval_dist，直接抽子矩阵；否则 (B) 用特征计算批内距离
        if have_retrieval_matrix:
            RD = retrieval_dist.to(device)  # 期待 [T, T]，数值越小越相似
            assert (
                RD.dim() == 2 and RD.shape[0] == RD.shape[1] == T
            ), "retrieval_dist 必须是 [T, T] 距离矩阵"

            def batch_topk_once(idx: torch.Tensor) -> List[float]:
                # 子矩阵 [B, B]：行是视频索引，列是音频索引；“正确配对”是同索引
                dist_mat = RD.index_select(0, idx).index_select(1, idx)  # [B, B]
                order = torch.argsort(dist_mat, dim=1, stable=True)  # 小到大
                ranks = (
                    order == torch.arange(order.size(0), device=device).unsqueeze(1)
                ).nonzero()[:, 1]
                return [(ranks < m).float().mean().item() for m in range(1, topk + 1)]

        else:
            # 从特征构建批内距离
            def pairwise_dist_mat(Vb: torch.Tensor, Ab: torch.Tensor) -> torch.Tensor:
                if distance == "euclidean_normed":
                    sims_b = Vb @ Ab.t()
                    return 2.0 - 2.0 * sims_b
                elif distance == "cosine":
                    sims_b = Vb @ Ab.t()
                    return 1.0 - sims_b
                elif distance == "euclidean":
                    return torch.cdist(Vb, Ab, p=2.0)
                elif distance == "p":
                    return torch.cdist(Vb, Ab, p=float(p))
                else:
                    raise ValueError

            def batch_topk_once(idx: torch.Tensor) -> List[float]:
                Vb = v.index_select(0, idx)  # [B, C]
                Ab = a.index_select(0, idx)  # [B, C]
                dist_mat = pairwise_dist_mat(Vb, Ab)  # [B, B]
                order = torch.argsort(dist_mat, dim=1, stable=True)
                ranks = (
                    order == torch.arange(order.size(0), device=device).unsqueeze(1)
                ).nonzero()[:, 1]
                return [(ranks < m).float().mean().item() for m in range(1, topk + 1)]

        agg = torch.zeros(topk, dtype=torch.float64)
        for _ in range(num_batches):
            idx = (
                torch.randperm(T, device=device, generator=g)[:batch_size]
                if T >= batch_size
                else torch.randint(0, T, (batch_size,), device=device, generator=g)
            )
            agg += torch.tensor(
                batch_topk_once(idx), dtype=torch.float64, device=agg.device
            )
        agg /= num_batches

        retrieval_out = {
            f"top{m}": float(agg[m - 1].item()) for m in range(1, topk + 1)
        }
        retrieval_out.update(
            {
                "k": int(topk),
                "batch_size": int(batch_size),
                "num_batches": int(num_batches),
            }
        )

    return {"k_smooth": k_smooth_out, "retrieval": retrieval_out}


def merge_eval_result(results: List[Dict[str, Any]]) -> Dict[str, Any]:
    """
    将一组评估结果求平均，返回同结构的汇总结果。
    期望每个元素形如：
    {
        "k_smooth": {
            K: {"acc": float, "mean_offset": float, "mean_abs_offset": float},
            ...
        },
        "retrieval": {
            "top1": float, "top2": float, ..., "k": int, "batch_size": int, "num_batches": int
        }
    }

    规则：
    - 对于 k_smooth：对同一 K 的各指标做算术平均（仅在该 K 存在的样本上平均）。
    - 对于 retrieval：对所有出现的 "topN" 做分别平均；"k" 返回聚合后最大的 N；
      "batch_size"、"num_batches" 若存在则取众数（若众数不唯一则取四舍五入的均值）。
    """
    if not results:
        return {"k_smooth": {}, "retrieval": {}}

    # ---- k_smooth 聚合 ----
    ks_union = set()
    for r in results:
        ks_union.update((r.get("k_smooth") or {}).keys())

    k_smooth_out: Dict[int, Dict[str, float]] = {}
    for K in sorted(int(k) for k in ks_union):
        acc_sum = off_sum = absoff_sum = 0.0
        cnt = 0
        for r in results:
            ks = r.get("k_smooth") or {}
            if K in ks:
                m = ks[K]
                acc_sum += float(m.get("acc", 0.0))
                off_sum += float(m.get("mean_offset", 0.0))
                absoff_sum += float(m.get("mean_abs_offset", 0.0))
                cnt += 1
        if cnt > 0:
            k_smooth_out[K] = {
                "acc": acc_sum / cnt,
                "mean_offset": off_sum / cnt,
                "mean_abs_offset": absoff_sum / cnt,
            }

    # ---- retrieval 聚合 ----
    # 先收集所有出现过的 topN 键
    top_keys_union = set()
    batch_sizes = []
    num_batches_list = []
    for r in results:
        ret = r.get("retrieval") or {}
        for k in ret.keys():
            if _TOP_RE.match(k):
                top_keys_union.add(k)
        if "batch_size" in ret:
            batch_sizes.append(ret["batch_size"])
        if "num_batches" in ret:
            num_batches_list.append(ret["num_batches"])

    # 对每个 topN 分别求平均
    retrieval_out: Dict[str, float] = {}

    # 统一按 N 升序输出
    def _top_idx(k: str) -> int:
        return int(_TOP_RE.match(k).group(1))

    for k in sorted(top_keys_union, key=_top_idx):
        s = 0.0
        cnt = 0
        for r in results:
            ret = r.get("retrieval") or {}
            if k in ret:
                s += float(ret[k])
                cnt += 1
        if cnt > 0:
            retrieval_out[k] = s / cnt

    # k = 最大的 topN
    k_max = max((_top_idx(k) for k in top_keys_union), default=0)
    if k_max > 0:
        retrieval_out["k"] = k_max

    # batch_size / num_batches：优先取众数，若无唯一众数则取四舍五入的均值
    def _mode_or_mean(xs: List[Any]) -> int:
        if not xs:
            return 0
        cnt = Counter(xs)
        most_common = cnt.most_common()
        if len(most_common) == 1 or (
            len(most_common) > 1 and most_common[0][1] > most_common[1][1]
        ):
            return int(most_common[0][0])
        # 无唯一众数
        mean_val = sum(float(x) for x in xs) / len(xs)
        return int(round(mean_val))

    if batch_sizes:
        retrieval_out["batch_size"] = _mode_or_mean(batch_sizes)
    if num_batches_list:
        retrieval_out["num_batches"] = _mode_or_mean(num_batches_list)

    return {"k_smooth": k_smooth_out, "retrieval": retrieval_out}


# --- 简单示例 ---
if __name__ == "__main__":
    lst = [
        {
            "k_smooth": {
                1: {"acc": 0.5, "mean_offset": 0.1, "mean_abs_offset": 3.0},
                5: {"acc": 0.68, "mean_offset": 0.38, "mean_abs_offset": 1.88},
                15: {"acc": 0.84, "mean_offset": 0.76, "mean_abs_offset": 1.01},
            },
            "retrieval": {
                "top1": 0.45625,
                "top2": 0.6229,
                "top3": 0.7458,
                "k": 3,
                "batch_size": 32,
                "num_batches": 15,
            },
        },
        # 可再追加更多同结构元素
    ]
    print(merge_eval_result(lst))
