from __future__ import annotations

from dataclasses import dataclass
from typing import Optional, Tuple

import torch

# Reuse the exact PPCL algorithmic components from skill_benchmark.
from skill_benchmark.adapters import AdapterBank, MixtureSpec
from skill_benchmark.task_router import (
    TaskKMeansRouter,
    TaskMeanCosineRouter,
    TaskOracleRouter,
    TaskRandomRouter,
    TaskSubspaceRouter,
    TaskWhitenedCosineRouter,
    TaskWhitenedSubspaceRouter,
    extract_r,
)

def _canonical_router_type(rt: str) -> str:
    r = str(rt or "subspace").strip().lower()
    if r in ("ppcl_random", "rand"):
        return "random"
    if r in ("ppcl_oracle", "gt"):
        return "oracle"
    if r in ("mean-cosine", "mean"):
        return "mean_cosine"
    if r in ("whitened-cosine", "wc"):
        return "whitened_cosine"
    if r in ("k-means", "k_means"):
        return "kmeans"
    if r in ("whitened-subspace", "ws"):
        return "whitened_subspace"
    return r


def ppcl_router_posterior_from_r(
    *,
    router,
    router_type: str,
    r: torch.Tensor,
    gamma: float,
) -> Tuple[list[int], torch.Tensor, torch.Tensor]:
    """Compute a dense posterior over tasks from router representation r.

    Returns:
      task_ids: list[int] of length T (column ids for posterior)
      e: (B, T) "distance-like" scores (lower is better)
      p: (B, T) posterior probabilities over tasks

    Notes:
    - For cosine routers we use e = 1 - cosine (lower is better), and p = softmax(-gamma*e) == softmax(gamma*cos).
    - For subspace routers we use normalized residual energies.
    - For kmeans we use mean L2 distances to per-task centroids.
    """
    if r.dim() != 2:
        raise ValueError(f"ppcl_router_posterior_from_r expects r [B,D], got shape={tuple(r.shape)}")
    rt = _canonical_router_type(router_type)
    dev = r.device
    if rt == "subspace":
        e, tids = router.residuals(r, device=dev, normalize=True)
    elif rt == "whitened_subspace":
        e, tids = router.augmented_residual_scores(r, device=dev)
    elif rt == "mean_cosine":
        s, tids = router.cosine_scores(r, device=dev)
        e = 1.0 - s
    elif rt == "whitened_cosine":
        s, tids = router.whitened_cosine_scores(r, device=dev)
        e = 1.0 - s
    elif rt == "kmeans":
        e, tids = router.mean_l2_distances(r, device=dev)
    else:
        raise ValueError(f"Unsupported ppcl_router_type={router_type}")

    p = torch.softmax((-float(gamma) * e), dim=1)
    # tids is an ordered list corresponding to columns of e/p
    task_ids = [int(x) for x in tids]
    return task_ids, e, p


def ppcl_router_batch_stats(
    *,
    e: torch.Tensor,
    p: torch.Tensor,
) -> dict:
    """Compute router confidence stats for a batch (averaged over samples).

    Returns keys aligned with skill_benchmark:
      - res_best_mean: mean(min e)
      - res_gap_mean: mean(second_min e - min e)
      - entropy_mean: mean(H(p))
    """
    if e.dim() != 2 or p.dim() != 2 or e.shape != p.shape:
        raise ValueError(f"ppcl_router_batch_stats expects e,p [B,T] same shape, got e={tuple(e.shape)} p={tuple(p.shape)}")
    e_sorted, _ = torch.sort(e, dim=1)
    best = e_sorted[:, 0]
    second = e_sorted[:, 1] if e_sorted.shape[1] >= 2 else best
    gap = (second - best).clamp(min=0)
    ent = -(p * (p.clamp(min=1e-12)).log()).sum(dim=1)
    return {
        "res_best_mean": float(best.mean().item()),
        "res_gap_mean": float(gap.mean().item()),
        "entropy_mean": float(ent.mean().item()),
    }


def ppcl_router_topk_from_posterior(
    *,
    task_ids: list[int],
    p: torch.Tensor,
    topL: int,
) -> torch.Tensor:
    """Return topL predicted task ids for each sample given dense posterior p."""
    if p.dim() != 2:
        raise ValueError(f"ppcl_router_topk_from_posterior expects p [B,T], got {tuple(p.shape)}")
    L = int(min(int(topL), int(p.shape[1])))
    vals, idx = torch.topk(p, k=L, dim=1)
    _ = vals  # unused, but kept for clarity
    tid_tensor = torch.tensor(task_ids, device=p.device, dtype=torch.long)
    return tid_tensor[idx]  # [B,L]


def ppcl_eval_router_grouped(
    *,
    router,
    router_type: str,
    x: torch.Tensor,
    gt_task_ids: torch.Tensor,
    M: int,
    topL: int,
    gamma: float,
) -> Tuple[dict, dict]:
    """Evaluate router hit rates + confidence stats grouped by gt task.

    Args:
      x: [B,T,C] or [B,C] (will be converted to r via extract_r)
      gt_task_ids: [B] int tensor (task id for each sample)

    Returns:
      router_stats: {task_id: {res_best_mean,res_gap_mean,entropy_mean}}
      router_hits : {task_id: {top1_hit_rate,topL_hit_rate,topL,n_samples,true_task_prob_mean}}
    """
    if gt_task_ids.dim() != 1:
        raise ValueError(f"ppcl_eval_router_grouped expects gt_task_ids [B], got {tuple(gt_task_ids.shape)}")
    r = extract_r(x, M=int(M))
    task_ids, e, p = ppcl_router_posterior_from_r(router=router, router_type=router_type, r=r, gamma=gamma)
    top_ids = ppcl_router_topk_from_posterior(task_ids=task_ids, p=p, topL=topL)  # [B,L]

    # Accumulate per task id
    stats_sum = {}
    stats_cnt = {}
    top1_hits = {}
    topL_hits = {}
    prob_sum = {}
    n_samp = {}

    # precompute per-sample stats scalars
    e_sorted, _ = torch.sort(e, dim=1)
    best = e_sorted[:, 0]
    second = e_sorted[:, 1] if e_sorted.shape[1] >= 2 else best
    gap = (second - best).clamp(min=0)
    ent = -(p * (p.clamp(min=1e-12)).log()).sum(dim=1)

    # map gt task id to posterior column index
    col_index = {int(t): int(i) for i, t in enumerate(task_ids)}

    gt_list = gt_task_ids.detach().to(device=top_ids.device).to(dtype=torch.long)
    for i in range(int(gt_list.shape[0])):
        gt = int(gt_list[i].item())
        n_samp[gt] = n_samp.get(gt, 0) + 1

        # hit rates
        if int(top_ids[i, 0].item()) == gt:
            top1_hits[gt] = top1_hits.get(gt, 0) + 1
        # any in topL
        if bool((top_ids[i] == gt).any().item()):
            topL_hits[gt] = topL_hits.get(gt, 0) + 1

        # true task prob (0 if gt not in router tasks)
        if gt in col_index:
            prob_sum[gt] = prob_sum.get(gt, 0.0) + float(p[i, col_index[gt]].item())
        else:
            prob_sum[gt] = prob_sum.get(gt, 0.0)

        # stats (per-sample)
        st = stats_sum.get(gt, {"res_best_mean": 0.0, "res_gap_mean": 0.0, "entropy_mean": 0.0})
        st["res_best_mean"] += float(best[i].item())
        st["res_gap_mean"] += float(gap[i].item())
        st["entropy_mean"] += float(ent[i].item())
        stats_sum[gt] = st
        stats_cnt[gt] = stats_cnt.get(gt, 0) + 1

    router_stats = {}
    router_hits = {}
    for gt, n in sorted(n_samp.items(), key=lambda kv: kv[0]):
        cnt = int(stats_cnt.get(gt, 0))
        if cnt > 0:
            router_stats[int(gt)] = {k: float(v) / float(cnt) for k, v in stats_sum[int(gt)].items()}
        else:
            router_stats[int(gt)] = {}
        router_hits[int(gt)] = {
            "top1_hit_rate": float(top1_hits.get(gt, 0)) / float(n),
            "topL_hit_rate": float(topL_hits.get(gt, 0)) / float(n),
            "topL": int(min(int(topL), int(len(task_ids)))),
            "n_samples": int(n),
            "true_task_prob_mean": float(prob_sum.get(gt, 0.0)) / float(n),
        }
    return router_stats, router_hits


@dataclass
class PPCLState:
    enabled: bool
    adapter_bank: Optional[AdapterBank]
    router: Optional[object]
    router_type: str
    router_M: int
    topL: int
    gamma: float
    eps: float
    apply_to_target: bool
    train_backbone_after_task1: bool


def build_ppcl_router(
    *,
    router_type: str,
    router_M: int,
    subspace_k: int,
    eps: float,
    kmeans_k: Optional[int] = None,
    kmeans_max_iter: int = 50,
    kmeans_seed: int = 0,
):
    rt = str(router_type or "subspace").strip().lower()
    if rt == "subspace":
        return TaskSubspaceRouter(M=int(router_M), k=int(subspace_k), eps=float(eps))
    if rt in ("whitened_subspace", "whitened-subspace", "ws"):
        return TaskWhitenedSubspaceRouter(M=int(router_M), k=int(subspace_k), eps=float(eps))
    if rt in ("mean_cosine", "mean-cosine", "mean"):
        return TaskMeanCosineRouter(M=int(router_M), eps=float(eps), normalize=True)
    if rt in ("whitened_cosine", "whitened-cosine", "wc"):
        return TaskWhitenedCosineRouter(M=int(router_M), eps=float(eps))
    if rt in ("kmeans", "k-means", "k_means"):
        km_k = int(subspace_k if kmeans_k is None else kmeans_k)
        return TaskKMeansRouter(
            M=int(router_M),
            k=int(km_k),
            eps=float(eps),
            max_iter=int(kmeans_max_iter),
            seed=int(kmeans_seed),
        )
    if rt in ("random", "ppcl_random", "rand"):
        return TaskRandomRouter(M=int(router_M))
    if rt in ("oracle", "ppcl_oracle", "gt"):
        return TaskOracleRouter(M=int(router_M))
    raise ValueError(f"Unsupported ppcl_router_type={router_type}")


def _infer_mix_from_r(
    *,
    router,
    router_type: str,
    r1: torch.Tensor,
    r2: Optional[torch.Tensor],
    topL: int,
    gamma: float,
    gt_task_ids: Optional[torch.Tensor] = None,
) -> MixtureSpec:
    rt = str(router_type or "subspace").strip().lower()
    device = r1.device
    if rt == "subspace":
        e1, tids = router.residuals(r1, device=device, normalize=True)
        if r2 is None:
            p = torch.softmax((-float(gamma) * e1), dim=1)
        else:
            e2, _ = router.residuals(r2, device=device, normalize=True)
            p1 = torch.softmax((-float(gamma) * e1), dim=1)
            p2 = torch.softmax((-float(gamma) * e2), dim=1)
            p = 0.5 * (p1 + p2)
        L = int(min(int(topL), int(p.shape[1])))
        vals, idx = torch.topk(p, k=L, dim=1)
        tid_tensor = torch.tensor(tids, device=device, dtype=torch.long)
        task_ids = tid_tensor[idx]
        weights = vals / (vals.sum(dim=1, keepdim=True).clamp(min=1e-12))
        return MixtureSpec(task_ids=task_ids, weights=weights)
    if rt in ("whitened_subspace", "whitened-subspace", "ws"):
        e1, tids = router.augmented_residual_scores(r1, device=device)
        if r2 is None:
            e = e1
        else:
            e2, _ = router.augmented_residual_scores(r2, device=device)
            e = 0.5 * (e1 + e2)
        p = torch.softmax((-float(gamma) * e), dim=1)
        L = int(min(int(topL), int(p.shape[1])))
        vals, idx = torch.topk(p, k=L, dim=1)
        tid_tensor = torch.tensor(tids, device=device, dtype=torch.long)
        task_ids = tid_tensor[idx]
        weights = vals / (vals.sum(dim=1, keepdim=True).clamp(min=1e-12))
        return MixtureSpec(task_ids=task_ids, weights=weights)
    if rt in ("mean_cosine", "mean-cosine", "mean"):
        s1, tids = router.cosine_scores(r1, device=device)
        if r2 is None:
            s = s1
        else:
            s2, _ = router.cosine_scores(r2, device=device)
            s = 0.5 * (s1 + s2)
        L = int(min(int(topL), int(s.shape[1])))
        vals, idx = torch.topk(s, k=L, dim=1)
        tid_tensor = torch.tensor(tids, device=device, dtype=torch.long)
        task_ids = tid_tensor[idx]
        if L == 1:
            weights = torch.ones_like(vals)
        else:
            weights = torch.softmax((float(gamma) * vals), dim=1)
        return MixtureSpec(task_ids=task_ids, weights=weights)
    if rt in ("whitened_cosine", "whitened-cosine", "wc"):
        s1, tids = router.whitened_cosine_scores(r1, device=device)
        if r2 is None:
            s = s1
        else:
            s2, _ = router.whitened_cosine_scores(r2, device=device)
            s = 0.5 * (s1 + s2)
        L = int(min(int(topL), int(s.shape[1])))
        vals, idx = torch.topk(s, k=L, dim=1)
        tid_tensor = torch.tensor(tids, device=device, dtype=torch.long)
        task_ids = tid_tensor[idx]
        if L == 1:
            weights = torch.ones_like(vals)
        else:
            weights = torch.softmax((float(gamma) * vals), dim=1)
        return MixtureSpec(task_ids=task_ids, weights=weights)
    if rt in ("kmeans", "k-means", "k_means"):
        if int(topL) != 1:
            raise ValueError(f"ppcl_router_type=kmeans requires ppcl_topL=1, got {topL}")
        d1, tids = router.mean_l2_distances(r1, device=device)
        if r2 is None:
            d = d1
        else:
            d2, _ = router.mean_l2_distances(r2, device=device)
            d = 0.5 * (d1 + d2)
        idx = torch.argmin(d, dim=1, keepdim=True)
        tid_tensor = torch.tensor(tids, device=device, dtype=torch.long)
        task_ids = tid_tensor[idx]
        weights = torch.ones((int(task_ids.shape[0]), 1), device=device, dtype=torch.float32)
        return MixtureSpec(task_ids=task_ids, weights=weights)
    if rt in ("random", "ppcl_random", "rand"):
        if int(topL) != 1:
            raise ValueError(f"ppcl_router_type=random requires ppcl_topL=1, got {topL}")
        tids = router.task_ids()
        if len(tids) <= 0:
            raise RuntimeError("ppcl_router_type=random called with empty router (no tasks).")
        B = int(r1.shape[0])
        tid_tensor = torch.tensor(tids, device=device, dtype=torch.long)
        idx = torch.randint(low=0, high=int(tid_tensor.shape[0]), size=(B, 1), device=device)
        task_ids = tid_tensor[idx]
        weights = torch.ones((B, 1), device=device, dtype=torch.float32)
        return MixtureSpec(task_ids=task_ids, weights=weights)
    if rt in ("oracle", "ppcl_oracle", "gt"):
        if int(topL) != 1:
            raise ValueError(f"ppcl_router_type=oracle requires ppcl_topL=1, got {topL}")
        if gt_task_ids is None:
            raise ValueError("ppcl_router_type=oracle requires gt_task_ids to be provided.")
        if gt_task_ids.dim() != 1:
            raise ValueError(f"ppcl_router_type=oracle expects gt_task_ids [B], got {tuple(gt_task_ids.shape)}")
        # NOTE: validity checks (seen task/adapters) are handled at call sites to avoid producing all-zero outputs.
        task_ids = gt_task_ids.to(device=device, dtype=torch.long).unsqueeze(1)
        weights = torch.ones((int(task_ids.shape[0]), 1), device=device, dtype=torch.float32)
        return MixtureSpec(task_ids=task_ids, weights=weights)
    raise ValueError(f"Unsupported ppcl_router_type={router_type}")


def infer_ppcl_mix_from_inputs(
    *,
    router,
    router_type: str,
    x1: torch.Tensor,
    x2: Optional[torch.Tensor],
    M: int,
    topL: int,
    gamma: float,
    gt_task_ids: Optional[torch.Tensor] = None,
) -> MixtureSpec:
    r1 = extract_r(x1, M=int(M))
    r2 = extract_r(x2, M=int(M)) if x2 is not None else None
    return _infer_mix_from_r(
        router=router,
        router_type=router_type,
        r1=r1,
        r2=r2,
        topL=topL,
        gamma=gamma,
        gt_task_ids=gt_task_ids,
    )
