from __future__ import annotations

import json
import os
from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
import torch


def _as_3d(x: torch.Tensor) -> torch.Tensor:
    """Ensure x is [B,T,C]."""
    if x.dim() == 2:
        return x.unsqueeze(0)
    if x.dim() == 3:
        return x
    raise ValueError(f"Expected tensor with dim 2 or 3, got shape={tuple(x.shape)}")


def extract_r(x: torch.Tensor, *, M: int = 1) -> torch.Tensor:
    """Extract stable router representation r(x) from input features.

    - If x is [B,T,C]: time-chunk mean pooling with M segments -> [B, M*C]
    - If x is [B,C]: treated as M=1 -> [B, C]
    """
    if not torch.is_tensor(x):
        raise TypeError(f"extract_r expects torch.Tensor, got {type(x)}")
    if x.dim() == 2:
        # [B,C]
        if int(M) != 1:
            raise ValueError(f"extract_r: M must be 1 for vector inputs, got M={M}")
        return x

    x3 = _as_3d(x)
    B, T, C = x3.shape
    M = int(M)
    if M <= 0:
        raise ValueError(f"extract_r: M must be >= 1, got {M}")
    if M == 1:
        return x3.mean(dim=1)

    # Chunk into M approximately equal segments.
    # For simplicity (and determinism), use floor division boundaries.
    seg_means: List[torch.Tensor] = []
    for m in range(M):
        s = (m * T) // M
        e = ((m + 1) * T) // M
        if e <= s:
            # Degenerate chunk; fall back to global mean
            seg_means.append(x3.mean(dim=1))
        else:
            seg_means.append(x3[:, s:e, :].mean(dim=1))
    return torch.cat(seg_means, dim=1)  # [B, M*C]


def _kmeans_pp_init(X: np.ndarray, *, k: int, rng: np.random.RandomState) -> np.ndarray:
    """KMeans++ init for centers.

    Args:
      X: [N, D] float32
      k: number of centers (k >= 1)
    """
    N, D = X.shape
    if N <= 0:
        raise ValueError("kmeans init: empty X")
    k = int(min(int(k), int(N)))
    # pick first center uniformly
    centers = np.empty((k, D), dtype=np.float32)
    first = int(rng.randint(0, N))
    centers[0] = X[first]

    # track squared distance to closest chosen center
    d2 = np.sum((X - centers[0]) ** 2, axis=1).astype(np.float64, copy=False)  # [N]
    for i in range(1, k):
        tot = float(d2.sum())
        if not np.isfinite(tot) or tot <= 0.0:
            # degenerate; fall back to uniform random
            centers[i] = X[int(rng.randint(0, N))]
            d2 = np.minimum(d2, np.sum((X - centers[i]) ** 2, axis=1).astype(np.float64, copy=False))
            continue
        probs = d2 / tot
        idx = int(rng.choice(N, p=probs))
        centers[i] = X[idx]
        d2 = np.minimum(d2, np.sum((X - centers[i]) ** 2, axis=1).astype(np.float64, copy=False))
    return centers


def _kmeans_fit(
    X: np.ndarray,
    *,
    k: int,
    seed: int = 0,
    max_iter: int = 50,
    tol: float = 1e-4,
) -> np.ndarray:
    """Simple numpy KMeans with kmeans++ init.

    Returns:
      centers: [k, D] float32
    """
    if X.ndim != 2:
        raise ValueError(f"kmeans_fit expects X [N,D], got {X.shape}")
    N, D = X.shape
    if N <= 0:
        raise ValueError("kmeans_fit: empty X")
    k = int(max(1, min(int(k), int(N))))
    rng = np.random.RandomState(int(seed))

    centers = _kmeans_pp_init(X, k=k, rng=rng)  # [k,D]
    prev_inertia = None

    for _ in range(int(max_iter)):
        # squared distances [N,k] via (x-c)^2 = x^2 + c^2 - 2 x c
        x2 = np.sum(X * X, axis=1, keepdims=True).astype(np.float64, copy=False)  # [N,1]
        c2 = np.sum(centers * centers, axis=1, keepdims=True).T.astype(np.float64, copy=False)  # [1,k]
        xc = (X.astype(np.float64, copy=False) @ centers.astype(np.float64, copy=False).T)  # [N,k]
        d2 = x2 + c2 - 2.0 * xc
        # numeric safety
        d2 = np.maximum(d2, 0.0)
        assign = np.argmin(d2, axis=1).astype(np.int64, copy=False)  # [N]
        inertia = float(np.min(d2, axis=1).sum())

        # update centers
        new_centers = np.zeros_like(centers)
        counts = np.bincount(assign, minlength=k).astype(np.int64, copy=False)
        for j in range(k):
            if counts[j] <= 0:
                # re-init empty cluster to a random point
                new_centers[j] = X[int(rng.randint(0, N))]
            else:
                new_centers[j] = X[assign == j].mean(axis=0).astype(np.float32, copy=False)
        shift = float(np.sqrt(np.sum((new_centers - centers) ** 2)))
        centers = new_centers

        if prev_inertia is not None:
            if abs(prev_inertia - inertia) <= float(tol) * (prev_inertia + 1e-12):
                break
        prev_inertia = inertia
        if shift <= float(tol):
            break

    return centers.astype(np.float32, copy=False)


@dataclass
class Subspace:
    mu: np.ndarray  # [d]
    U: np.ndarray  # [d,k] orthonormal columns (PCA components)

    def to_torch(self, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
        mu = torch.from_numpy(self.mu).to(device=device, dtype=torch.float32)
        U = torch.from_numpy(self.U).to(device=device, dtype=torch.float32)
        return mu, U


class TaskSubspaceRouter:
    """Non-parametric task router using per-task PCA subspaces.

    This router does NOT use task_id at inference time; it predicts weights from residual energies.
    """

    def __init__(self, *, M: int = 1, k: int = 32, eps: float = 1e-6) -> None:
        self.M = int(M)
        self.k = int(k)
        self.eps = float(eps)
        self._spaces: Dict[int, Subspace] = {}

    def task_ids(self) -> List[int]:
        return list(sorted(int(t) for t in self._spaces.keys()))

    def num_tasks(self) -> int:
        return int(len(self._spaces))

    def add_task_space(self, task_id: int, *, mu: np.ndarray, U: np.ndarray) -> None:
        tid = int(task_id)
        mu = np.asarray(mu, dtype=np.float32).reshape(-1)
        U = np.asarray(U, dtype=np.float32)
        if U.ndim != 2:
            raise ValueError(f"U must be 2D, got shape={U.shape}")
        if int(U.shape[0]) != int(mu.shape[0]):
            raise ValueError(f"U rows must match mu dim, got U={U.shape} mu={mu.shape}")
        self._spaces[tid] = Subspace(mu=mu, U=U)

    # -------------------------
    # Fitting (task-end, offline)
    # -------------------------
    def fit_from_loader(
        self,
        *,
        task_id: int,
        loader: Iterable,
        device: Union[torch.device, str] = "cpu",
        verbose: bool = False,
    ) -> Subspace:
        """Fit (mu,U) from a task-specific loader. Uses only ego vid1/vid2 tensors.

        Expected batch formats:
          - (input1, input2)
          - (input1, input2, input_exo)  # exo ignored for routing statistics
        """
        dev = torch.device(device)
        tid = int(task_id)
        M = int(self.M)
        k = int(self.k)

        # Pass 1: mean
        n = 0
        sum_r: Optional[torch.Tensor] = None
        for batch in loader:
            if isinstance(batch, (tuple, list)):
                input1 = batch[0]
                input2 = batch[1]
            else:
                raise TypeError(f"Unsupported batch type: {type(batch)}")

            x1 = input1.to(dev, dtype=torch.float32, non_blocking=True)
            x2 = input2.to(dev, dtype=torch.float32, non_blocking=True)
            r1 = extract_r(x1, M=M)  # [B,d]
            r2 = extract_r(x2, M=M)
            r = torch.cat([r1, r2], dim=0)  # [2B,d]
            if sum_r is None:
                sum_r = r.sum(dim=0)
            else:
                sum_r = sum_r + r.sum(dim=0)
            n += int(r.shape[0])
        if n <= 1 or sum_r is None:
            raise RuntimeError(f"fit_from_loader: empty/too-small loader for task_id={tid} (n={n})")

        mu = (sum_r / float(n)).detach()
        d = int(mu.numel())
        if verbose:
            print(f"[router fit] task={tid} n={n} d={d} M={M} k={k}")

        # Pass 2: covariance accumulator S = Z^T Z
        S = torch.zeros((d, d), dtype=torch.float64, device=dev)
        for batch in loader:
            input1 = batch[0]
            input2 = batch[1]
            x1 = input1.to(dev, dtype=torch.float32, non_blocking=True)
            x2 = input2.to(dev, dtype=torch.float32, non_blocking=True)
            r1 = extract_r(x1, M=M).to(torch.float64)
            r2 = extract_r(x2, M=M).to(torch.float64)
            r = torch.cat([r1, r2], dim=0)  # [2B,d]
            z = r - mu.to(torch.float64)
            S = S + (z.transpose(0, 1) @ z)

        # Covariance
        cov = (S / float(max(n - 1, 1))).to(dtype=torch.float32)
        cov_np = cov.detach().cpu().numpy()

        # PCA via eigh (cov symmetric)
        # Get top-k eigenvectors by descending eigenvalues.
        w, V = np.linalg.eigh(cov_np)  # ascending
        idx = np.argsort(w)[::-1]
        idx = idx[: int(min(k, V.shape[1]))]
        U = V[:, idx].astype(np.float32, copy=False)
        # Orthonormalize (numerical safety)
        # QR returns Q with orthonormal columns.
        U, _ = np.linalg.qr(U)

        space = Subspace(mu=mu.detach().cpu().numpy().astype(np.float32, copy=False), U=U.astype(np.float32, copy=False))
        self._spaces[tid] = space
        return space

    # -------------------------
    # Inference
    # -------------------------
    def residuals(self, r: torch.Tensor, *, device: torch.device, normalize: bool = True) -> Tuple[torch.Tensor, List[int]]:
        """Return residual scores e_t for each task: [B, Ttasks]. Lower is better."""
        if self.num_tasks() <= 0:
            raise RuntimeError("residuals called with empty router (no task spaces).")
        r = r.to(device=device, dtype=torch.float32)
        tids = self.task_ids()
        B, d = int(r.shape[0]), int(r.shape[1])
        e_all = torch.empty((B, len(tids)), device=device, dtype=torch.float32)
        eps = float(self.eps)
        for j, tid in enumerate(tids):
            sp = self._spaces[int(tid)]
            mu, U = sp.to_torch(device)
            if int(mu.numel()) != d or int(U.shape[0]) != d:
                raise RuntimeError(f"Router dim mismatch for task_id={tid}: r_dim={d}, mu_dim={mu.numel()}, U={tuple(U.shape)}")
            z = r - mu[None, :]
            z_norm2 = (z * z).sum(dim=1)  # [B]
            # proj = U^T z : [B,k]
            proj = z @ U  # U is [d,k]
            proj_norm2 = (proj * proj).sum(dim=1)
            # residual energy
            res = z_norm2 - proj_norm2
            if normalize:
                e = res / (z_norm2 + eps)
            else:
                e = res
            e_all[:, j] = e
        return e_all, tids

    def infer_weights(
        self,
        r: torch.Tensor,
        *,
        topL: int = 2,
        gamma: float = 10.0,
        device: Optional[torch.device] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, float]]:
        """Infer per-sample soft weights over top-L tasks.

        Returns:
          - task_indices: [B, L] (task ids)
          - task_weights: [B, L] (softmax weights over those L tasks)
          - stats: dict with some summary scalars (for gamma tuning)
        """
        if device is None:
            device = r.device if torch.is_tensor(r) else torch.device("cpu")
        topL = int(topL)
        if topL <= 0:
            raise ValueError(f"topL must be >= 1, got {topL}")
        e_all, tids = self.residuals(r, device=device, normalize=True)
        # pick L smallest residuals
        L = int(min(topL, e_all.shape[1]))
        neg_e = -e_all
        vals, idx = torch.topk(neg_e, k=L, dim=1)  # largest -e == smallest e
        e_sel = -vals  # back to e
        # weights on selected tasks
        w = torch.softmax((-float(gamma) * e_sel), dim=1)
        # map indices to task ids
        tid_tensor = torch.tensor(tids, device=device, dtype=torch.long)
        task_ids = tid_tensor[idx]  # [B,L]

        # stats for gamma tuning (no task labels used)
        with torch.no_grad():
            # margin between best and 2nd best residual, when L>=2
            e_sorted, _ = torch.sort(e_all, dim=1)
            best = e_sorted[:, 0]
            second = e_sorted[:, 1] if e_sorted.shape[1] >= 2 else best
            gap = (second - best).clamp(min=0)
            # entropy of full softmax over all tasks (proxy for ambiguity)
            p_full = torch.softmax((-float(gamma) * e_all), dim=1)
            ent = -(p_full * (p_full.clamp(min=1e-12)).log()).sum(dim=1)
            stats = {
                "res_best_mean": float(best.mean().item()),
                "res_gap_mean": float(gap.mean().item()),
                "entropy_mean": float(ent.mean().item()),
                "num_tasks": float(e_all.shape[1]),
            }
        return task_ids, w, stats

    # -------------------------
    # IO
    # -------------------------
    def save_task(self, *, output_dir: str, task_id: int) -> None:
        os.makedirs(output_dir, exist_ok=True)
        tid = int(task_id)
        if tid not in self._spaces:
            raise KeyError(f"save_task: task_id={tid} not found in router.")
        sp = self._spaces[tid]
        np.savez(
            os.path.join(output_dir, f"router_task_{tid:02d}.npz"),
            mu=sp.mu,
            U=sp.U,
            M=self.M,
            k=self.k,
            eps=self.eps,
        )

    def save_index(self, *, output_dir: str) -> None:
        os.makedirs(output_dir, exist_ok=True)
        payload = {"tasks": self.task_ids(), "M": int(self.M), "k": int(self.k), "eps": float(self.eps)}
        with open(os.path.join(output_dir, "router_index.json"), "w", encoding="utf-8") as f:
            json.dump(payload, f, indent=2, ensure_ascii=False)


@dataclass
class WhitenedSubspaceStats:
    """Per-task stats for the whitened-subspace router."""

    mu: np.ndarray  # [d]
    var: np.ndarray  # [d] (diagonal variance)
    Bw: np.ndarray  # [d, k+1] orthonormal basis in whitened space

    def to_torch(self, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        mu = torch.from_numpy(self.mu).to(device=device, dtype=torch.float32)
        var = torch.from_numpy(self.var).to(device=device, dtype=torch.float32)
        Bw = torch.from_numpy(self.Bw).to(device=device, dtype=torch.float32)
        return mu, var, Bw


class TaskWhitenedSubspaceRouter:
    """Whitened-subspace router: diagonal whitening + augmented whitened subspace residual.

    For each task t:
      1) Estimate mean mu_t and diagonal variance var_t from r(x) on the task train split.
      2) Define whitening weights w_t = 1/sqrt(var_t + eps).
      3) In whitened space, build an augmented subspace basis:
         - mean direction: m_w = normalize(mu_t ⊙ w_t)
         - variation directions: U_w = top-k PCA directions of (r - mu_t) ⊙ w_t
         - basis: B_w = orth([m_w, U_w])  (QR)
      4) Score sample r by residual ratio to this task basis (lower is better):
         e_t(r) = 1 - ||B_w^T (r ⊙ w_t)||^2 / (||r ⊙ w_t||^2 + eps)
    """

    def __init__(self, *, M: int = 1, k: int = 32, eps: float = 1e-6) -> None:
        self.M = int(M)
        self.k = int(k)
        self.eps = float(eps)
        self._stats: Dict[int, WhitenedSubspaceStats] = {}

    def task_ids(self) -> List[int]:
        return list(sorted(int(t) for t in self._stats.keys()))

    def num_tasks(self) -> int:
        return int(len(self._stats))

    def add_task_stats(self, task_id: int, *, mu: np.ndarray, var: np.ndarray, Bw: np.ndarray) -> None:
        tid = int(task_id)
        mu = np.asarray(mu, dtype=np.float32).reshape(-1)
        var = np.asarray(var, dtype=np.float32).reshape(-1)
        Bw = np.asarray(Bw, dtype=np.float32)
        if Bw.ndim != 2:
            raise ValueError(f"Bw must be 2D, got shape={Bw.shape}")
        if int(Bw.shape[0]) != int(mu.shape[0]):
            raise ValueError(f"Bw rows must match mu dim, got Bw={Bw.shape} mu={mu.shape}")
        if mu.shape != var.shape:
            raise ValueError(f"mu/var shape mismatch: mu={mu.shape} var={var.shape}")
        self._stats[tid] = WhitenedSubspaceStats(mu=mu, var=var, Bw=Bw)

    def fit_from_loader(
        self,
        *,
        task_id: int,
        loader: Iterable,
        device: Union[torch.device, str] = "cpu",
        verbose: bool = False,
    ) -> WhitenedSubspaceStats:
        """Fit (mu, var, Bw) from a task-specific loader. Uses only ego vid1/vid2 tensors."""
        dev = torch.device(device)
        tid = int(task_id)
        M = int(self.M)
        k = int(self.k)

        # Pass 1: mean + diag variance
        n = 0
        sum_r: Optional[torch.Tensor] = None
        sum_r2: Optional[torch.Tensor] = None
        for batch in loader:
            if isinstance(batch, (tuple, list)):
                input1 = batch[0]
                input2 = batch[1]
            else:
                raise TypeError(f"Unsupported batch type: {type(batch)}")

            x1 = input1.to(dev, dtype=torch.float32, non_blocking=True)
            x2 = input2.to(dev, dtype=torch.float32, non_blocking=True)
            r1 = extract_r(x1, M=M)  # [B,d]
            r2 = extract_r(x2, M=M)
            r = torch.cat([r1, r2], dim=0)  # [2B,d]
            if sum_r is None:
                sum_r = r.sum(dim=0)
                sum_r2 = (r * r).sum(dim=0)
            else:
                sum_r = sum_r + r.sum(dim=0)
                sum_r2 = sum_r2 + (r * r).sum(dim=0)  # type: ignore[operator]
            n += int(r.shape[0])
        if n <= 1 or sum_r is None or sum_r2 is None:
            raise RuntimeError(f"fit_from_loader: empty/too-small loader for task_id={tid} (n={n})")

        mu = (sum_r / float(n)).detach()
        ex2 = (sum_r2 / float(n)).detach()
        var = (ex2 - mu * mu).clamp(min=0.0)
        d = int(mu.numel())
        if verbose:
            print(f"[router fit whitened-subspace] task={tid} n={n} d={d} M={M} k={k}")

        # Pass 2: covariance in whitened space (centered) for U_w
        w = 1.0 / torch.sqrt(var + float(self.eps))  # [d]
        S = torch.zeros((d, d), dtype=torch.float64, device=dev)
        for batch in loader:
            input1 = batch[0]
            input2 = batch[1]
            x1 = input1.to(dev, dtype=torch.float32, non_blocking=True)
            x2 = input2.to(dev, dtype=torch.float32, non_blocking=True)
            r1 = extract_r(x1, M=M).to(torch.float64)
            r2 = extract_r(x2, M=M).to(torch.float64)
            r = torch.cat([r1, r2], dim=0)  # [2B,d]
            z = (r - mu.to(torch.float64)) * w.to(torch.float64)  # whitened centered
            S = S + (z.transpose(0, 1) @ z)
        cov = (S / float(max(n - 1, 1))).to(dtype=torch.float32)
        cov_np = cov.detach().cpu().numpy()
        eig, V = np.linalg.eigh(cov_np)  # ascending, V columns orthonormal
        order = np.argsort(eig)[::-1]
        order = order[: int(min(k, V.shape[1]))]
        Uw = V[:, order].astype(np.float32, copy=False)  # [d,k]
        mw = (mu * w).to(torch.float32)
        mw = mw / (mw.norm(p=2).clamp(min=float(self.eps)))
        A = np.concatenate([mw.detach().cpu().numpy().astype(np.float32, copy=False)[:, None], Uw], axis=1)  # [d,1+k]
        Bw, _ = np.linalg.qr(A)  # [d,1+k]

        st = WhitenedSubspaceStats(
            mu=mu.detach().cpu().numpy().astype(np.float32, copy=False),
            var=var.detach().cpu().numpy().astype(np.float32, copy=False),
            Bw=Bw.astype(np.float32, copy=False),
        )
        self._stats[tid] = st
        return st

    def augmented_residual_scores(self, r: torch.Tensor, *, device: torch.device) -> Tuple[torch.Tensor, List[int]]:
        """Return residual ratio to augmented whitened subspace (mw + Uw): [B, Ttasks]. Lower is better."""
        if self.num_tasks() <= 0:
            raise RuntimeError("augmented_residual_scores called with empty router (no task stats).")
        r = r.to(device=device, dtype=torch.float32)
        tids = self.task_ids()
        B, d = int(r.shape[0]), int(r.shape[1])
        out = torch.empty((B, len(tids)), device=device, dtype=torch.float32)
        eps = float(self.eps)
        for j, tid in enumerate(tids):
            sp = self._stats[int(tid)]
            mu, var, Bw = sp.to_torch(device)
            if int(mu.numel()) != d or int(var.numel()) != d or int(Bw.shape[0]) != d:
                raise RuntimeError(f"Router dim mismatch for task_id={tid}: r_dim={d}, mu_dim={mu.numel()}, var_dim={var.numel()}, Bw={tuple(Bw.shape)}")
            w = 1.0 / torch.sqrt(var.clamp(min=0.0) + eps)
            x = r * w[None, :]  # uncentered whitened
            x2 = (x * x).sum(dim=1)
            proj = x @ Bw  # [B,1+k]
            proj2 = (proj * proj).sum(dim=1)
            out[:, j] = 1.0 - (proj2 / (x2 + eps))
        return out, tids

    def save_task(self, *, output_dir: str, task_id: int) -> None:
        os.makedirs(output_dir, exist_ok=True)
        tid = int(task_id)
        if tid not in self._stats:
            raise KeyError(f"save_task: task_id={tid} not found in router.")
        st = self._stats[tid]
        np.savez(
            os.path.join(output_dir, f"router_task_{tid:02d}.npz"),
            mu=st.mu,
            var=st.var,
            Bw=st.Bw,
            M=self.M,
            k=self.k,
            eps=self.eps,
            type="whitened_subspace",
        )

    def save_index(self, *, output_dir: str) -> None:
        os.makedirs(output_dir, exist_ok=True)
        payload = {"tasks": self.task_ids(), "M": int(self.M), "k": int(self.k), "eps": float(self.eps), "type": "whitened_subspace"}
        with open(os.path.join(output_dir, "router_index.json"), "w", encoding="utf-8") as f:
            json.dump(payload, f, indent=2, ensure_ascii=False)


@dataclass
class Prototype:
    mu: np.ndarray  # [d]

    def to_torch(self, device: torch.device) -> torch.Tensor:
        return torch.from_numpy(self.mu).to(device=device, dtype=torch.float32)


class TaskMeanCosineRouter:
    """Non-parametric task router using per-task feature mean prototypes.

    - Fit: store task-wise mean of r(x) over the task training set (using both ego vid1/vid2).
    - Infer: use cosine similarity between r(x) and each task prototype. Higher is better.
    """

    def __init__(self, *, M: int = 1, eps: float = 1e-6, normalize: bool = True) -> None:
        self.M = int(M)
        self.eps = float(eps)
        self.normalize = bool(normalize)
        self._protos: Dict[int, Prototype] = {}

    def task_ids(self) -> List[int]:
        return list(sorted(int(t) for t in self._protos.keys()))

    def num_tasks(self) -> int:
        return int(len(self._protos))

    def add_task_proto(self, task_id: int, *, mu: np.ndarray) -> None:
        tid = int(task_id)
        mu = np.asarray(mu, dtype=np.float32).reshape(-1)
        self._protos[tid] = Prototype(mu=mu)

    def fit_from_loader(
        self,
        *,
        task_id: int,
        loader: Iterable,
        device: Union[torch.device, str] = "cpu",
        verbose: bool = False,
    ) -> Prototype:
        """Fit prototype mean from a task-specific loader. Uses only ego vid1/vid2 tensors."""
        dev = torch.device(device)
        tid = int(task_id)
        M = int(self.M)

        n = 0
        sum_r: Optional[torch.Tensor] = None
        for batch in loader:
            if isinstance(batch, (tuple, list)):
                input1 = batch[0]
                input2 = batch[1]
            else:
                raise TypeError(f"Unsupported batch type: {type(batch)}")

            x1 = input1.to(dev, dtype=torch.float32, non_blocking=True)
            x2 = input2.to(dev, dtype=torch.float32, non_blocking=True)
            r1 = extract_r(x1, M=M)  # [B,d]
            r2 = extract_r(x2, M=M)
            r = torch.cat([r1, r2], dim=0)  # [2B,d]
            if sum_r is None:
                sum_r = r.sum(dim=0)
            else:
                sum_r = sum_r + r.sum(dim=0)
            n += int(r.shape[0])
        if n <= 0 or sum_r is None:
            raise RuntimeError(f"fit_from_loader: empty loader for task_id={tid} (n={n})")

        mu = (sum_r / float(n)).detach()
        if verbose:
            print(f"[router fit mean] task={tid} n={n} d={int(mu.numel())} M={M} normalize={self.normalize}")

        proto = Prototype(mu=mu.detach().cpu().numpy().astype(np.float32, copy=False))
        self._protos[tid] = proto
        return proto

    def cosine_scores(self, r: torch.Tensor, *, device: torch.device) -> Tuple[torch.Tensor, List[int]]:
        """Return cosine similarity scores s_t for each task: [B, Ttasks]. Higher is better."""
        if self.num_tasks() <= 0:
            raise RuntimeError("cosine_scores called with empty router (no prototypes).")
        r = r.to(device=device, dtype=torch.float32)
        tids = self.task_ids()
        B, d = int(r.shape[0]), int(r.shape[1])
        s_all = torch.empty((B, len(tids)), device=device, dtype=torch.float32)
        eps = float(self.eps)

        if self.normalize:
            r_norm = r / (r.norm(p=2, dim=1, keepdim=True).clamp(min=eps))
        else:
            r_norm = r

        for j, tid in enumerate(tids):
            mu = self._protos[int(tid)].to_torch(device)
            if int(mu.numel()) != d:
                raise RuntimeError(f"Router dim mismatch for task_id={tid}: r_dim={d}, mu_dim={mu.numel()}")
            if self.normalize:
                mu_norm = mu / (mu.norm(p=2).clamp(min=eps))
            else:
                mu_norm = mu
            # cosine similarity: [B]
            s_all[:, j] = (r_norm * mu_norm[None, :]).sum(dim=1)
        return s_all, tids

    def save_task(self, *, output_dir: str, task_id: int) -> None:
        os.makedirs(output_dir, exist_ok=True)
        tid = int(task_id)
        if tid not in self._protos:
            raise KeyError(f"save_task: task_id={tid} not found in router.")
        sp = self._protos[tid]
        np.savez(
            os.path.join(output_dir, f"router_task_{tid:02d}.npz"),
            mu=sp.mu,
            M=self.M,
            eps=self.eps,
            normalize=int(self.normalize),
        )

    def save_index(self, *, output_dir: str) -> None:
        os.makedirs(output_dir, exist_ok=True)
        payload = {"tasks": self.task_ids(), "M": int(self.M), "eps": float(self.eps), "normalize": bool(self.normalize), "type": "mean_cosine"}
        with open(os.path.join(output_dir, "router_index.json"), "w", encoding="utf-8") as f:
            json.dump(payload, f, indent=2, ensure_ascii=False)


@dataclass
class PrototypeDiagVar:
    mu: np.ndarray  # [d]
    var: np.ndarray  # [d] (diagonal variance)

    def to_torch(self, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
        mu = torch.from_numpy(self.mu).to(device=device, dtype=torch.float32)
        var = torch.from_numpy(self.var).to(device=device, dtype=torch.float32)
        return mu, var


class TaskWhitenedCosineRouter:
    """Non-parametric task router using task-wise (mean, diag-variance) and weighted cosine similarity.

    Fit (task end): store mean mu_t and diagonal variance var_t of r(x) over the task training set.
    Infer: score_t = cos( normalize(r * w_t), normalize(mu_t * w_t) ), where w_t = 1/sqrt(var_t + eps).
    """

    def __init__(self, *, M: int = 1, eps: float = 1e-6) -> None:
        self.M = int(M)
        self.eps = float(eps)
        self._stats: Dict[int, PrototypeDiagVar] = {}

    def task_ids(self) -> List[int]:
        return list(sorted(int(t) for t in self._stats.keys()))

    def num_tasks(self) -> int:
        return int(len(self._stats))

    def add_task_stats(self, task_id: int, *, mu: np.ndarray, var: np.ndarray) -> None:
        tid = int(task_id)
        mu = np.asarray(mu, dtype=np.float32).reshape(-1)
        var = np.asarray(var, dtype=np.float32).reshape(-1)
        if mu.shape != var.shape:
            raise ValueError(f"mu/var shape mismatch: mu={mu.shape} var={var.shape}")
        self._stats[tid] = PrototypeDiagVar(mu=mu, var=var)

    def fit_from_loader(
        self,
        *,
        task_id: int,
        loader: Iterable,
        device: Union[torch.device, str] = "cpu",
        verbose: bool = False,
    ) -> PrototypeDiagVar:
        """Fit (mu,var) from a task-specific loader. Uses only ego vid1/vid2 tensors."""
        dev = torch.device(device)
        tid = int(task_id)
        M = int(self.M)

        n = 0
        sum_r: Optional[torch.Tensor] = None
        sum_r2: Optional[torch.Tensor] = None
        for batch in loader:
            if isinstance(batch, (tuple, list)):
                input1 = batch[0]
                input2 = batch[1]
            else:
                raise TypeError(f"Unsupported batch type: {type(batch)}")

            x1 = input1.to(dev, dtype=torch.float32, non_blocking=True)
            x2 = input2.to(dev, dtype=torch.float32, non_blocking=True)
            r1 = extract_r(x1, M=M)  # [B,d]
            r2 = extract_r(x2, M=M)
            r = torch.cat([r1, r2], dim=0)  # [2B,d]
            if sum_r is None:
                sum_r = r.sum(dim=0)
                sum_r2 = (r * r).sum(dim=0)
            else:
                sum_r = sum_r + r.sum(dim=0)
                sum_r2 = sum_r2 + (r * r).sum(dim=0)  # type: ignore[operator]
            n += int(r.shape[0])
        if n <= 0 or sum_r is None or sum_r2 is None:
            raise RuntimeError(f"fit_from_loader: empty loader for task_id={tid} (n={n})")

        mu = (sum_r / float(n)).detach()
        ex2 = (sum_r2 / float(n)).detach()
        var = (ex2 - mu * mu).clamp(min=0.0)

        if verbose:
            print(f"[router fit whitened] task={tid} n={n} d={int(mu.numel())} M={M} eps={self.eps}")

        st = PrototypeDiagVar(
            mu=mu.detach().cpu().numpy().astype(np.float32, copy=False),
            var=var.detach().cpu().numpy().astype(np.float32, copy=False),
        )
        self._stats[tid] = st
        return st

    def whitened_cosine_scores(self, r: torch.Tensor, *, device: torch.device) -> Tuple[torch.Tensor, List[int]]:
        """Return weighted cosine similarity scores for each task: [B, Ttasks]. Higher is better."""
        if self.num_tasks() <= 0:
            raise RuntimeError("whitened_cosine_scores called with empty router (no task stats).")
        r = r.to(device=device, dtype=torch.float32)
        tids = self.task_ids()
        B, d = int(r.shape[0]), int(r.shape[1])
        s_all = torch.empty((B, len(tids)), device=device, dtype=torch.float32)
        eps = float(self.eps)

        for j, tid in enumerate(tids):
            mu, var = self._stats[int(tid)].to_torch(device)
            if int(mu.numel()) != d or int(var.numel()) != d:
                raise RuntimeError(f"Router dim mismatch for task_id={tid}: r_dim={d}, mu_dim={mu.numel()}, var_dim={var.numel()}")
            w = 1.0 / torch.sqrt(var.clamp(min=0.0) + eps)  # [d]
            rw = r * w[None, :]
            mw = mu * w
            # normalize then dot
            rw = rw / (rw.norm(p=2, dim=1, keepdim=True).clamp(min=eps))
            mw = mw / (mw.norm(p=2).clamp(min=eps))
            s_all[:, j] = (rw * mw[None, :]).sum(dim=1)
        return s_all, tids

    def save_task(self, *, output_dir: str, task_id: int) -> None:
        os.makedirs(output_dir, exist_ok=True)
        tid = int(task_id)
        if tid not in self._stats:
            raise KeyError(f"save_task: task_id={tid} not found in router.")
        st = self._stats[tid]
        np.savez(
            os.path.join(output_dir, f"router_task_{tid:02d}.npz"),
            mu=st.mu,
            var=st.var,
            M=self.M,
            eps=self.eps,
            type="whitened_cosine",
        )

    def save_index(self, *, output_dir: str) -> None:
        os.makedirs(output_dir, exist_ok=True)
        payload = {"tasks": self.task_ids(), "M": int(self.M), "eps": float(self.eps), "type": "whitened_cosine"}
        with open(os.path.join(output_dir, "router_index.json"), "w", encoding="utf-8") as f:
            json.dump(payload, f, indent=2, ensure_ascii=False)


@dataclass
class KMeansCenters:
    centers: np.ndarray  # [K, d]

    def to_torch(self, device: torch.device) -> torch.Tensor:
        return torch.from_numpy(self.centers).to(device=device, dtype=torch.float32)


class TaskKMeansRouter:
    """Non-parametric task router using per-task KMeans centroids.

    - Fit: run KMeans on r(x) for each task, keep K centers.
    - Infer: for a sample r, compute mean L2 distance to the K centers of each task; choose smallest task.
    """

    def __init__(self, *, M: int = 1, k: int = 32, eps: float = 1e-6, max_iter: int = 50, seed: int = 0) -> None:
        self.M = int(M)
        self.k = int(k)
        self.eps = float(eps)
        self.max_iter = int(max_iter)
        self.seed = int(seed)
        self._centers: Dict[int, KMeansCenters] = {}

    def task_ids(self) -> List[int]:
        return list(sorted(int(t) for t in self._centers.keys()))

    def num_tasks(self) -> int:
        return int(len(self._centers))

    def add_task_centers(self, task_id: int, *, centers: np.ndarray) -> None:
        tid = int(task_id)
        c = np.asarray(centers, dtype=np.float32)
        if c.ndim != 2:
            raise ValueError(f"centers must be 2D [K,d], got shape={c.shape}")
        self._centers[tid] = KMeansCenters(centers=c)

    def fit_from_loader(
        self,
        *,
        task_id: int,
        loader: Iterable,
        device: Union[torch.device, str] = "cpu",
        verbose: bool = False,
    ) -> KMeansCenters:
        """Fit task centroids from a task-specific loader. Uses only ego vid1/vid2 tensors."""
        dev = torch.device(device)
        tid = int(task_id)
        M = int(self.M)
        K = int(self.k)

        chunks: List[np.ndarray] = []
        n = 0
        for batch in loader:
            if isinstance(batch, (tuple, list)):
                input1 = batch[0]
                input2 = batch[1]
            else:
                raise TypeError(f"Unsupported batch type: {type(batch)}")
            x1 = input1.to(dev, dtype=torch.float32, non_blocking=True)
            x2 = input2.to(dev, dtype=torch.float32, non_blocking=True)
            r1 = extract_r(x1, M=M)
            r2 = extract_r(x2, M=M)
            r = torch.cat([r1, r2], dim=0)  # [2B,d]
            rr = r.detach().cpu().numpy().astype(np.float32, copy=False)
            chunks.append(rr)
            n += int(rr.shape[0])

        if n <= 0:
            raise RuntimeError(f"fit_from_loader: empty loader for task_id={tid} (n={n})")
        X = np.concatenate(chunks, axis=0)  # [N,d]
        if verbose:
            print(f"[router fit kmeans] task={tid} N={int(X.shape[0])} d={int(X.shape[1])} M={M} K={K} max_iter={self.max_iter}")

        centers = _kmeans_fit(X, k=K, seed=self.seed + tid, max_iter=self.max_iter)
        km = KMeansCenters(centers=centers)
        self._centers[tid] = km
        return km

    def mean_l2_distances(self, r: torch.Tensor, *, device: torch.device) -> Tuple[torch.Tensor, List[int]]:
        """Return mean L2 distance to K centers for each task: [B, Ttasks]. Lower is better."""
        if self.num_tasks() <= 0:
            raise RuntimeError("mean_l2_distances called with empty router (no centers).")
        r = r.to(device=device, dtype=torch.float32)
        tids = self.task_ids()
        B, d = int(r.shape[0]), int(r.shape[1])
        out = torch.empty((B, len(tids)), device=device, dtype=torch.float32)
        eps = float(self.eps)

        # Precompute r^2 term for squared distances
        r2 = (r * r).sum(dim=1, keepdim=True)  # [B,1]
        for j, tid in enumerate(tids):
            C = self._centers[int(tid)].to_torch(device)  # [K,d]
            if int(C.shape[1]) != d:
                raise RuntimeError(f"Router dim mismatch for task_id={tid}: r_dim={d}, centers_dim={int(C.shape[1])}")
            # squared distances: ||r-c||^2 = r2 + c2 - 2 r c^T
            c2 = (C * C).sum(dim=1).view(1, -1)  # [1,K]
            rc = r @ C.t()  # [B,K]
            d2 = (r2 + c2 - 2.0 * rc).clamp(min=0.0)
            d_l2 = torch.sqrt(d2 + eps)  # [B,K]
            out[:, j] = d_l2.mean(dim=1)  # strict mean over K centers
        return out, tids

    def save_task(self, *, output_dir: str, task_id: int) -> None:
        os.makedirs(output_dir, exist_ok=True)
        tid = int(task_id)
        if tid not in self._centers:
            raise KeyError(f"save_task: task_id={tid} not found in router.")
        km = self._centers[tid]
        np.savez(
            os.path.join(output_dir, f"router_task_{tid:02d}.npz"),
            centers=km.centers,
            M=int(self.M),
            k=int(self.k),
            eps=float(self.eps),
            max_iter=int(self.max_iter),
            seed=int(self.seed),
            type="kmeans",
        )

    def save_index(self, *, output_dir: str) -> None:
        os.makedirs(output_dir, exist_ok=True)
        payload = {"tasks": self.task_ids(), "M": int(self.M), "k": int(self.k), "eps": float(self.eps), "max_iter": int(self.max_iter), "seed": int(self.seed), "type": "kmeans"}
        with open(os.path.join(output_dir, "router_index.json"), "w", encoding="utf-8") as f:
            json.dump(payload, f, indent=2, ensure_ascii=False)


class TaskRandomRouter:
    """Random router for PPCL ablation.

    This router does NOT model task statistics; it only tracks which task_ids exist,
    so inference can uniformly sample a task adapter.
    """

    def __init__(self, *, M: int = 1) -> None:
        self.M = int(M)
        self._task_ids: List[int] = []

    def task_ids(self) -> List[int]:
        return list(sorted(int(t) for t in self._task_ids))

    def num_tasks(self) -> int:
        return int(len(self._task_ids))

    def add_task_id(self, task_id: int) -> None:
        tid = int(task_id)
        if tid not in self._task_ids:
            self._task_ids.append(tid)

    def fit_from_loader(
        self,
        *,
        task_id: int,
        loader: Iterable,
        device: Union[torch.device, str] = "cpu",
        verbose: bool = False,
    ) -> None:
        _ = loader, device, verbose
        self.add_task_id(int(task_id))

    def save_task(self, *, output_dir: str, task_id: int) -> None:
        # No learned statistics; keep a minimal marker for reproducibility.
        os.makedirs(output_dir, exist_ok=True)
        tid = int(task_id)
        self.add_task_id(tid)
        with open(os.path.join(output_dir, f"router_task_{tid:02d}.json"), "w", encoding="utf-8") as f:
            json.dump({"type": "random", "task_id": int(tid)}, f, indent=2, ensure_ascii=False)

    def save_index(self, *, output_dir: str) -> None:
        os.makedirs(output_dir, exist_ok=True)
        with open(os.path.join(output_dir, "router_index.json"), "w", encoding="utf-8") as f:
            json.dump({"type": "random", "tasks": self.task_ids(), "M": int(self.M)}, f, indent=2, ensure_ascii=False)


class TaskOracleRouter:
    """Oracle router for PPCL ablation.

    This router does NOT model task statistics; it only tracks which task_ids exist.
    The GT task-id must be provided at inference time (cheating oracle).
    """

    def __init__(self, *, M: int = 1) -> None:
        self.M = int(M)
        self._task_ids: List[int] = []

    def task_ids(self) -> List[int]:
        return list(sorted(int(t) for t in self._task_ids))

    def num_tasks(self) -> int:
        return int(len(self._task_ids))

    def add_task_id(self, task_id: int) -> None:
        tid = int(task_id)
        if tid not in self._task_ids:
            self._task_ids.append(tid)

    def fit_from_loader(
        self,
        *,
        task_id: int,
        loader: Iterable,
        device: Union[torch.device, str] = "cpu",
        verbose: bool = False,
    ) -> None:
        _ = loader, device, verbose
        self.add_task_id(int(task_id))

    def save_task(self, *, output_dir: str, task_id: int) -> None:
        # No learned statistics; keep a minimal marker for reproducibility.
        os.makedirs(output_dir, exist_ok=True)
        tid = int(task_id)
        self.add_task_id(tid)
        with open(os.path.join(output_dir, f"router_task_{tid:02d}.json"), "w", encoding="utf-8") as f:
            json.dump({"type": "oracle", "task_id": int(tid)}, f, indent=2, ensure_ascii=False)

    def save_index(self, *, output_dir: str) -> None:
        os.makedirs(output_dir, exist_ok=True)
        with open(os.path.join(output_dir, "router_index.json"), "w", encoding="utf-8") as f:
            json.dump({"type": "oracle", "tasks": self.task_ids(), "M": int(self.M)}, f, indent=2, ensure_ascii=False)

