# core/samplers/scope.py
"""
SCOPE (Semantic Cloud-Orchestrated Perception at Edge).

This sampler performs:
- Query decomposition into sub-queries with importance scores.
- Budget allocation across sub-queries.
- One-pass global matching over candidate frames.
"""

from __future__ import annotations

import os
import re
import time
from glob import glob
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, cast

import numpy as np

from .base import QueryBasedSampler
from .topk import CLIPTopKSampler
from config.settings import CLIP_CONFIG, SCOPE_CONFIG
from services.enhanced_query_planner import QueryGraph, build_enhanced_query_graph
from utils.video_utils import extract_frames_decord


class ScopeSampler(QueryBasedSampler):
    """
    Scope sampler.

    One-pass global matching:
    - extract frames at match_fps
    - encode frame features once (handler-level caching)
    - compute a similarity matrix for all sub-queries
    - select top-k frames per sub-query
    """

    mode_prefix = "Scope"
    metadata_mode = "scope"

    def __init__(
        self,
        clip_model_path: Optional[str] = None,
        device: Optional[str] = None,
        use_api: Optional[bool] = None,
        api_key: Optional[str] = None,
        planner_api_model: Optional[str] = None,
        meta_root: Optional[str] = None,
        reuse_meta_frame_allocation: bool = True,
        match_fps: Optional[float] = None,
        **kwargs: Any,
    ):
        super().__init__(**kwargs)

        self.meta_root = Path(meta_root) if meta_root else None
        self.reuse_meta_frame_allocation = bool(reuse_meta_frame_allocation)

        self.clip_model_path = clip_model_path or CLIP_CONFIG["default_model_path"]
        self.device = device or CLIP_CONFIG["default_device"]

        self.use_api = use_api
        self.api_key = api_key
        self.planner_api_model = planner_api_model

        self.match_fps = 1.0 if match_fps is None else float(match_fps)
        if self.match_fps <= 0:
            self.match_fps = 1.0

        self.config = SCOPE_CONFIG

        self.allowed_allocation_modes = {"importance", "uniform", "random", "dirichlet", "winner_take_all"}
        self.frame_allocation_mode = (kwargs.get("frame_allocation_mode") or "importance").lower()
        if self.frame_allocation_mode not in self.allowed_allocation_modes:
            self._log(
                f"Unknown frame allocation mode: {self.frame_allocation_mode}; falling back to importance"
            )
            self.frame_allocation_mode = "importance"

        self.clip_sampler = CLIPTopKSampler(clip_model_path=self.clip_model_path, device=self.device)

        from collections import defaultdict

        self.timing_log = defaultdict(float)
        self.intermediate_results: Dict[str, Any] = {}

    def _log(self, message: str) -> None:
        print(f"[{self.mode_prefix}] {message}")

    def _allocate_frames(self, query_graph: QueryGraph, num_keyframes: int) -> Dict[str, int]:
        """Allocate per-query frame budgets according to the selected mode."""
        t0 = time.time()
        mode = self.frame_allocation_mode

        if mode == "uniform":
            allocation = self._allocate_uniform_frames(query_graph, num_keyframes)
        elif mode == "random":
            allocation = self._allocate_random_frames(query_graph, num_keyframes)
        elif mode == "dirichlet":
            allocation = self._allocate_high_variance_random(query_graph, num_keyframes)
        elif mode == "winner_take_all":
            allocation = self._allocate_winner_take_all(query_graph, num_keyframes)
        else:
            allocation = self._allocate_importance_frames_local(query_graph, num_keyframes)

        self.timing_log["frame_allocation"] = time.time() - t0
        self.intermediate_results["frame_allocation"] = allocation
        self.intermediate_results["frame_allocation_mode"] = mode
        return allocation

    def _allocate_importance_frames_local(self, query_graph: QueryGraph, total_budget: int) -> Dict[str, int]:
        """Allocate frames by importance (local implementation)."""
        nodes = list(getattr(query_graph, "nodes", []) or [])
        n = len(nodes)
        if n == 0 or total_budget <= 0:
            return {}

        min_frames = max(0, int(self.config["frame_allocation"].get("min_frames_per_query", 0)))
        if min_frames * n > total_budget:
            base = total_budget // n
            remainder = total_budget % n
            return {node.id: base + (1 if i < remainder else 0) for i, node in enumerate(nodes)}

        base_allocation = {node.id: min_frames for node in nodes}
        remaining = total_budget - min_frames * n
        if remaining <= 0:
            return base_allocation

        s_sum = sum(int(getattr(node, "importance", 0) or 0) for node in nodes)
        if s_sum <= 0:
            avg_extra = remaining // n
            remainder = remaining % n
            for i, node in enumerate(nodes):
                base_allocation[node.id] += avg_extra + (1 if i < remainder else 0)
            return base_allocation

        float_alloc = {
            node.id: remaining * (int(getattr(node, "importance", 0) or 0) / s_sum) for node in nodes
        }
        int_alloc = {k: int(v) for k, v in float_alloc.items()}
        assigned = sum(int_alloc.values())
        final_remainder = remaining - assigned

        if final_remainder > 0:
            frac = [(float_alloc[k] - int_alloc[k], k) for k in float_alloc.keys()]
            frac.sort(key=lambda x: x[0], reverse=True)
            for i in range(final_remainder):
                _, k = frac[i]
                int_alloc[k] += 1

        final = {node.id: base_allocation[node.id] + int_alloc[node.id] for node in nodes}

        diff = total_budget - sum(final.values())
        if diff != 0:
            ordered = [node.id for node in nodes]
            idx = 0
            while diff != 0 and ordered:
                k = ordered[idx % len(ordered)]
                if diff > 0:
                    final[k] += 1
                    diff -= 1
                else:
                    if final[k] > 0:
                        final[k] -= 1
                        diff += 1
                idx += 1

        return final

    def _allocate_uniform_frames(self, query_graph: QueryGraph, total_budget: int) -> Dict[str, int]:
        """Uniform allocation with an optional minimum guarantee."""
        nodes = list(getattr(query_graph, "nodes", []) or [])
        n = len(nodes)
        if n == 0 or total_budget <= 0:
            return {}

        min_frames = max(0, int(self.config["frame_allocation"].get("min_frames_per_query", 0)))
        if min_frames * n > total_budget:
            base = total_budget // n
            remainder = total_budget % n
            return {node.id: base + (1 if i < remainder else 0) for i, node in enumerate(nodes)}

        allocation = {node.id: min_frames for node in nodes}
        remaining = total_budget - n * min_frames
        if remaining <= 0:
            return allocation

        even_extra = remaining // n
        remainder = remaining % n
        for i, node in enumerate(nodes):
            allocation[node.id] += even_extra + (1 if i < remainder else 0)
        return allocation

    def _allocate_random_frames(self, query_graph: QueryGraph, total_budget: int) -> Dict[str, int]:
        """Random allocation: satisfy minimum first, then randomly distribute remaining budget."""
        nodes = list(getattr(query_graph, "nodes", []) or [])
        n = len(nodes)
        if n == 0 or total_budget <= 0:
            return {}

        min_frames = max(0, int(self.config["frame_allocation"].get("min_frames_per_query", 0)))
        rng = np.random.default_rng()

        if min_frames * n > total_budget:
            base_alloc = [0 for _ in range(n)]
            for i in range(total_budget):
                base_alloc[i % n] += 1
        else:
            base_alloc = [min_frames for _ in range(n)]
            remaining = total_budget - n * min_frames
            if remaining > 0:
                probs = np.ones(n) / n
                extra = rng.multinomial(remaining, probs)
                base_alloc = [b + int(e) for b, e in zip(base_alloc, extra)]

        return {nodes[i].id: base_alloc[i] for i in range(n)}

    def _allocate_high_variance_random(self, query_graph: QueryGraph, total_budget: int) -> Dict[str, int]:
        """High-variance random allocation: Dirichlet sampling then multinomial draw."""
        nodes = list(getattr(query_graph, "nodes", []) or [])
        n = len(nodes)
        if n == 0 or total_budget <= 0:
            return {}

        min_frames = max(1, int(self.config["frame_allocation"].get("min_frames_per_query", 1)))
        if min_frames * n > total_budget:
            base_alloc = [0 for _ in range(n)]
            for i in range(total_budget):
                base_alloc[i % n] += 1
            return {nodes[i].id: base_alloc[i] for i in range(n)}

        base_alloc = [min_frames for _ in range(n)]
        remaining = total_budget - n * min_frames

        if remaining > 0:
            rng = np.random.default_rng()
            probs = rng.dirichlet(np.ones(n) * 0.1)  # alpha=0.1 -> high variance
            extra = rng.multinomial(remaining, probs)
            base_alloc = [b + int(e) for b, e in zip(base_alloc, extra)]

        return {nodes[i].id: base_alloc[i] for i in range(n)}

    def _allocate_winner_take_all(self, query_graph: QueryGraph, total_budget: int) -> Dict[str, int]:
        """Winner-take-all allocation: randomly pick one node to take remaining budget."""
        nodes = list(getattr(query_graph, "nodes", []) or [])
        n = len(nodes)
        if n == 0 or total_budget <= 0:
            return {}

        min_frames = max(1, int(self.config["frame_allocation"].get("min_frames_per_query", 1)))
        if min_frames * n > total_budget:
            base_alloc = [0 for _ in range(n)]
            for i in range(total_budget):
                base_alloc[i % n] += 1
            return {nodes[i].id: base_alloc[i] for i in range(n)}

        base_alloc = [min_frames for _ in range(n)]
        remaining = total_budget - n * min_frames

        if remaining > 0:
            rng = np.random.default_rng()
            winner_idx = int(rng.integers(0, n))
            base_alloc[winner_idx] += remaining

        return {nodes[i].id: base_alloc[i] for i in range(n)}

    def _topk_global_frames(self, sim_row: np.ndarray, frame_indices: List[int], k: int) -> List[int]:
        """Select top-k frames from a similarity vector."""
        if k <= 0 or not frame_indices:
            return []

        scores = np.asarray(sim_row).reshape(-1)
        n = int(scores.shape[0])
        if n <= 0:
            return []

        frame_ids = np.asarray(frame_indices, dtype=np.int64)
        scores_safe = np.nan_to_num(scores, nan=-1e9)

        if k >= n:
            order = np.lexsort((frame_ids, -scores_safe))
            return frame_ids[order].tolist()

        part = np.argpartition(scores_safe, -k)[-k:]
        part_scores = scores_safe[part]
        part_frames = frame_ids[part]
        order = np.lexsort((part_frames, -part_scores))
        top_idx = part[order]
        return frame_ids[top_idx].tolist()

    def _process_queries_global(
        self, video_path: str, query_graph: QueryGraph, frame_allocation: Dict[str, int]
    ) -> Tuple[List[int], float, Dict[Any, Any], float]:
        """One-pass global matching for all sub-queries."""
        frames, frame_indices = extract_frames_decord(video_path, fps=self.match_fps)

        active_nodes = [n for n in query_graph.nodes if int(frame_allocation.get(n.id, 0) or 0) > 0]
        if not active_nodes or not frames or not frame_indices:
            details: Dict[Any, Any] = {
                "summary": {"total_parallel_time": 0.0, "total_sequential_time": 0.0, "query_count": 0}
            }
            return [], 0.0, details, 0.0

        t_shared = time.time()
        try:
            frame_feats = self.clip_sampler.get_or_build_image_feats(video_path, frames, frame_indices=frame_indices)
        except Exception as e:  # noqa: BLE001
            self._log(f"Failed to build frame features: {e}")
            details = {"summary": {"total_parallel_time": 0.0, "total_sequential_time": 0.0, "query_count": 0}}
            return [], 0.0, details, 0.0
        shared_time = time.time() - t_shared
        self.timing_log["shared_frame_features_time"] = shared_time

        active_nodes_sorted = sorted(
            active_nodes, key=lambda n: (int(getattr(n, "layer", 0) or 0), str(getattr(n, "id", "")))
        )
        texts = [n.text for n in active_nodes_sorted]

        t_score = time.time()
        try:
            text_feats = self.clip_sampler.encode_texts(texts)  # (M, D)
            sims = text_feats @ frame_feats.T  # (M, N)
        except Exception as e:  # noqa: BLE001
            self._log(f"Batch matching failed: {e}")
            details = {"summary": {"total_parallel_time": 0.0, "total_sequential_time": 0.0, "query_count": 0}}
            return [], 0.0, details, 0.0
        score_time = time.time() - t_score
        self.timing_log["matching_score_time"] = score_time

        fill_scores: Optional[np.ndarray]
        try:
            fill_scores = np.nanmax(sims, axis=0).astype(np.float32, copy=False)
            fill_scores = np.nan_to_num(fill_scores, nan=-1e9)
        except Exception:
            fill_scores = None

        selected: List[int] = []
        t_select = time.time()
        for i, node in enumerate(active_nodes_sorted):
            allocated = int(frame_allocation.get(node.id, 0) or 0)
            chosen = self._topk_global_frames(sims[i], frame_indices, allocated)
            if chosen:
                selected.extend(chosen)

            try:
                top_score = float(np.nanmax(sims[i])) if sims.shape[1] > 0 else None
            except Exception:
                top_score = None

            self.intermediate_results[f"query_{node.id}"] = {
                "text": node.text,
                "importance": getattr(node, "importance", None),
                "search_space": len(frame_indices),
                "final_candidates": len(frame_indices),
                "matching_top_score": top_score,
            }
        select_time = time.time() - t_select
        self.timing_log["matching_select_time"] = select_time

        parallel_total = score_time + select_time
        sequential_total = parallel_total
        details = {
            "summary": {
                "shared_frame_features_time": shared_time,
                "matching_score_time": score_time,
                "matching_select_time": select_time,
                "total_parallel_time": parallel_total,
                "total_sequential_time": sequential_total,
                "query_count": len(active_nodes_sorted),
            }
        }

        if fill_scores is not None and len(frame_indices) == int(fill_scores.shape[0]):
            self.intermediate_results["_fill_candidates"] = {
                "video_path": video_path,
                "frame_indices": frame_indices,
                "scores": fill_scores,
            }

        return selected, parallel_total, details, sequential_total

    def _load_cached_meta(
        self,
        video_path: str,
        query: Optional[str] = None,
        num_keyframes: Optional[int] = None,
        task_index: Optional[int] = None,
        task_id: Optional[Any] = None,
    ) -> Tuple[Optional[Dict[str, Any]], Optional[Dict[str, int]]]:
        """Best-effort loader for per-task meta_*.json under meta_root."""
        if self.meta_root is None:
            return None, None

        def _safe_norm(p: str) -> str:
            try:
                return os.path.normcase(os.path.normpath(p))
            except Exception:
                return p

        def _norm_text(s: Optional[str]) -> str:
            s = (s or "").strip().lower()
            s = re.sub(r"\s+", " ", s)
            return s

        needle_video = _safe_norm(video_path)
        needle_query = _norm_text(query) if query else ""

        patterns = [str(self.meta_root / "**" / "meta_*.json"), str(self.meta_root / "meta_*.json")]
        paths: List[str] = []
        for pat in patterns:
            paths.extend(glob(pat, recursive=True))

        best_score: Optional[int] = None
        best_obj: Optional[Dict[str, Any]] = None

        for p in paths:
            try:
                with open(p, "r", encoding="utf-8") as f:
                    obj = cast(Dict[str, Any], __import__("json").load(f))
            except Exception:
                continue

            vp = obj.get("video_path")
            if not isinstance(vp, str) or _safe_norm(vp) != needle_video:
                continue

            score = 0
            if task_id is not None and str(obj.get("task_id")) == str(task_id):
                score += 50
            if task_index is not None:
                bn = os.path.basename(p)
                m = re.match(r"meta_(\d+)_", bn)
                if m and int(m.group(1)) == int(task_index):
                    score += 20

            if needle_query:
                q_meta = _norm_text(obj.get("query_original") or obj.get("question"))
                if q_meta:
                    if q_meta == needle_query:
                        score += 20
                    else:
                        toks_in = set(needle_query.split())
                        toks_meta = set(q_meta.split())
                        inter = len(toks_in & toks_meta)
                        union = len(toks_in | toks_meta) or 1
                        score += int(10 * (inter / union))

            if best_score is None or score > best_score:
                best_score = score
                best_obj = obj
                self.intermediate_results["__loaded_meta_path__"] = p

        if not isinstance(best_obj, dict):
            return None, None

        qd = best_obj.get("query_decomposition")
        if not isinstance(qd, dict):
            return None, None
        nodes = qd.get("nodes")
        if not isinstance(nodes, list):
            return None, None

        qd_meta = {
            "nodes": [
                {
                    "id": str(n.get("id")),
                    "text": str(n.get("text")),
                    "importance": int(n.get("importance", 0) or 0),
                    "layer": int(n.get("layer", 0) or 0),
                }
                for n in nodes
                if isinstance(n, dict) and "id" in n and "text" in n
            ],
            "edges": qd.get("edges", []),
            "total_importance": qd.get("total_importance", 0),
            "decomposition_time": qd.get("decomposition_time", 0.0),
            "model_diag": qd.get("model_diag", {}),
        }

        fa_meta: Optional[Dict[str, int]] = None
        if self.reuse_meta_frame_allocation:
            fa = best_obj.get("frame_allocation")
            if isinstance(fa, dict):
                budget_ok = True
                if num_keyframes is not None:
                    try:
                        budget_ok = sum(int(v) for v in fa.values()) == int(num_keyframes)
                    except Exception:
                        budget_ok = False
                if budget_ok:
                    fa_meta = {str(k): int(v) for k, v in fa.items() if v is not None}

        return qd_meta, fa_meta

    def select_keyframes(self, video_path: str, num_keyframes: int, query: str, **kwargs: Any) -> List[int]:
        """Select keyframes for a (video, query) pair."""
        self._log(f"Start processing query: {query[:100]}...")
        total_start = time.time()

        self.timing_log.clear()
        self.intermediate_results.clear()

        try:
            qd_meta, fa_meta = self._load_cached_meta(
                video_path,
                query,
                num_keyframes=num_keyframes,
                task_index=kwargs.get("task_index"),
                task_id=kwargs.get("task_id"),
            )

            if qd_meta is not None:
                from types import SimpleNamespace

                Node = SimpleNamespace

                class _QG:
                    def __init__(self, meta: Dict[str, Any]):
                        self.nodes = [
                            Node(
                                id=n["id"],
                                text=n["text"],
                                importance=n.get("importance", 0),
                                layer=n.get("layer", 0),
                            )
                            for n in meta.get("nodes", [])
                        ]
                        self.edges = meta.get("edges", [])
                        self.total_importance = meta.get("total_importance", 0)
                        self.decomposition_time = meta.get("decomposition_time", 0.0)
                        self.model_diag = meta.get("model_diag", {})

                query_graph = cast(QueryGraph, _QG(qd_meta))
                self.timing_log["query_decomposition"] = 0.0
                self.intermediate_results["query_graph"] = query_graph
                self.intermediate_results["query_decomposition_reused"] = True
                self.intermediate_results["query_decomposition_source"] = self.intermediate_results.get(
                    "__loaded_meta_path__"
                )

                if fa_meta is not None:
                    frame_allocation = dict(fa_meta)
                    self.timing_log["frame_allocation"] = 0.0
                    self.intermediate_results["frame_allocation"] = frame_allocation
                    self.intermediate_results["frame_allocation_mode"] = self.frame_allocation_mode
                else:
                    frame_allocation = self._allocate_frames(query_graph, num_keyframes)
            else:
                t_decomp = time.time()
                query_graph = build_enhanced_query_graph(
                    query,
                    use_api=self.use_api,
                    api_key=self.api_key,
                    api_model_name=self.planner_api_model,
                )
                self.timing_log["query_decomposition"] = time.time() - t_decomp
                self.intermediate_results["query_graph"] = query_graph

                if not getattr(query_graph, "nodes", None):
                    self._log("Query decomposition returned no nodes; falling back to TopK sampling")
                    return self._fallback_topk_sampling(video_path, num_keyframes, query)

                frame_allocation = self._allocate_frames(query_graph, num_keyframes)

            t_proc = time.time()
            selected, parallel_t, details, sequential_t = self._process_queries_global(
                video_path, query_graph, frame_allocation
            )
            actual_proc = time.time() - t_proc

            self.timing_log["layer_processing"] = float(parallel_t or 0.0)
            self.timing_log["layer_processing_sequential"] = float(sequential_t or 0.0)
            self.timing_log["layer_processing_actual"] = float(actual_proc or 0.0)
            self.intermediate_results["layer_processing_details"] = details

            t_post = time.time()
            final_frames = sorted(set(int(x) for x in selected))
            if len(final_frames) > int(num_keyframes):
                final_frames = final_frames[: int(num_keyframes)]
            self.timing_log["postprocess_dedup_budget_time"] = time.time() - t_post

            if len(final_frames) < int(num_keyframes):
                t_fill = time.time()
                final_frames = self._fill_remaining_frames(video_path, query, final_frames, int(num_keyframes))
                self.timing_log["fill_remaining_frames_time"] = time.time() - t_fill

            total_time = time.time() - total_start
            self.timing_log["total_time_actual"] = total_time
            self.timing_log["total_time"] = total_time

            decom_time = float(self.timing_log.get("query_decomposition", 0.0) or 0.0)
            alloc_time = float(self.timing_log.get("frame_allocation", 0.0) or 0.0)
            shared_feat_time = float(self.timing_log.get("shared_frame_features_time", 0.0) or 0.0)
            match_time = float(self.timing_log.get("layer_processing", 0.0) or 0.0)
            score_time = float(self.timing_log.get("matching_score_time", 0.0) or 0.0)
            select_time = float(self.timing_log.get("matching_select_time", 0.0) or 0.0)
            post_time = float(self.timing_log.get("postprocess_dedup_budget_time", 0.0) or 0.0)
            fill_time = float(self.timing_log.get("fill_remaining_frames_time", 0.0) or 0.0)
            other_time = max(
                0.0,
                total_time - (decom_time + alloc_time + shared_feat_time + match_time + post_time + fill_time),
            )
            self.timing_log["post_layer_time"] = other_time
            self.timing_log["timing_breakdown_compact"] = {
                "query_decomposition": decom_time,
                "frame_allocation": alloc_time,
                "frame_feature_extraction": shared_feat_time,
                "matching": match_time,
                "matching_score": score_time,
                "matching_select": select_time,
                "postprocess_dedup_budget": post_time,
                "fill_remaining_frames": fill_time,
                "other": other_time,
                "total_time_actual": total_time,
            }

            self._log(
                f"Sampling finished: selected {len(final_frames)} frames, total time {total_time:.2f}s"
            )
            return final_frames
        except Exception as e:  # noqa: BLE001
            error_msg = f"Sampling error: {str(e)}"
            self._log(error_msg)
            self.intermediate_results.setdefault("errors", []).append(error_msg)
            self.timing_log["total_time_actual"] = time.time() - total_start
            self.timing_log["total_time"] = self.timing_log["total_time_actual"]
            return self._fallback_topk_sampling(video_path, num_keyframes, query)

    def _fill_remaining_frames(self, video_path: str, query: str, selected_frames: List[int], target_count: int) -> List[int]:
        """Fill the remaining frames if under budget."""
        remaining_needed = int(target_count) - len(selected_frames)
        if remaining_needed <= 0:
            return selected_frames

        fc = self.intermediate_results.get("_fill_candidates")
        if isinstance(fc, dict) and fc.get("video_path") == video_path:
            frame_indices = fc.get("frame_indices")
            scores = fc.get("scores")
            if (
                isinstance(frame_indices, list)
                and isinstance(scores, np.ndarray)
                and scores.ndim == 1
                and len(frame_indices) == int(scores.shape[0])
            ):
                selected_set = set(int(x) for x in selected_frames)
                frame_ids = np.asarray(frame_indices, dtype=np.int64)
                scores_safe = np.nan_to_num(scores.astype(np.float32, copy=False), nan=-1e9)
                mask = np.array([int(f) not in selected_set for f in frame_ids], dtype=bool)
                if mask.any():
                    remaining_needed = min(int(remaining_needed), int(mask.sum()))
                    cand_scores = np.where(mask, scores_safe, -1e9)
                    part = np.argpartition(cand_scores, -remaining_needed)[-remaining_needed:]
                    part_scores = cand_scores[part]
                    part_frames = frame_ids[part]
                    order = np.lexsort((part_frames, -part_scores))
                    extra = part_frames[order].tolist()
                    filled = selected_frames + extra[:remaining_needed]
                    return sorted(set(int(x) for x in filled))

        try:
            frames, frame_indices = extract_frames_decord(video_path, fps=self.match_fps)
            if not frames or not frame_indices:
                return selected_frames
            extra = self.clip_sampler.select_keyframes(
                frames, frame_indices, query, remaining_needed, video_key=video_path
            )
            extra = [f for f in extra if f not in selected_frames]
            return sorted((selected_frames + extra)[:target_count])
        except Exception as e:  # noqa: BLE001
            self._log(f"Error filling remaining frames: {e}")
            return selected_frames

    def _fallback_topk_sampling(self, video_path: str, num_keyframes: int, query: str) -> List[int]:
        """Fallback TopK sampling."""
        try:
            self._log("Using TopK as a fallback sampler")
            frames, frame_indices = extract_frames_decord(video_path, fps=self.match_fps)
            return self.clip_sampler.select_keyframes(
                frames, frame_indices, query, int(num_keyframes), video_key=video_path
            )
        except Exception as e:  # noqa: BLE001
            self._log(f"Fallback TopK sampling also failed: {e}")
            return []

    def get_sampling_metadata(self, video_path: str, query: str, num_keyframes: int) -> Dict[str, Any]:
        """Return detailed metadata for the last sampling call."""
        query_graph = self.intermediate_results.get("query_graph")
        frame_allocation = self.intermediate_results.get("frame_allocation", {})

        errors: List[Dict[str, Any]] = []

        if query_graph is None:
            errors.append(
                {
                    "type": "sampling_failure",
                    "detail": "No intermediate results available - sampling may have failed",
                }
            )
            return {
                "mode": self.metadata_mode,
                "config": self.config,
                "errors": errors,
                "timing_breakdown": dict(self.timing_log),
                "query": query,
                "num_keyframes": num_keyframes,
            }

        model_diag = getattr(query_graph, "model_diag", None)
        if isinstance(model_diag, dict):
            if model_diag.get("fallback_used", False):
                errors.append(
                    {
                        "type": "query_decomposition_fallback",
                        "detail": f"Query decomposition failed; using fallback. Attempts: {model_diag.get('attempts', 1)}",
                    }
                )
            elif model_diag.get("retry_errors") and model_diag.get("final_success", True):
                retry_count = model_diag.get("attempts", 1)
                if int(retry_count) > 1:
                    errors.append(
                        {
                            "type": "query_decomposition_retry",
                            "detail": f"Query decomposition succeeded after {retry_count} attempts",
                        }
                    )

        for err in self.intermediate_results.get("errors", []) or []:
            errors.append({"type": "processing_error", "detail": str(err)})

        metadata = {
            "mode": self.metadata_mode,
            "config": self.config,
            "frame_allocation_mode": self.frame_allocation_mode,
            "match_fps": self.match_fps,
            "planner_api_model": self.planner_api_model,
            "query_original": query,
            "num_keyframes_requested": num_keyframes,
            "query_decomposition": {
                "total_nodes": len(getattr(query_graph, "nodes", []) or []),
                "total_importance": getattr(query_graph, "total_importance", None),
                "decomposition_time": getattr(query_graph, "decomposition_time", None),
                "model_diag": getattr(query_graph, "model_diag", None),
                "nodes": [
                    {
                        "id": getattr(node, "id", None),
                        "text": getattr(node, "text", None),
                        "importance": getattr(node, "importance", None),
                        "layer": getattr(node, "layer", None),
                        "allocated_frames": frame_allocation.get(getattr(node, "id", ""), 0),
                    }
                    for node in getattr(query_graph, "nodes", []) or []
                ],
                "edges": getattr(query_graph, "edges", None),
            },
            "frame_allocation": frame_allocation,
            "timing_breakdown": dict(self.timing_log),
            "intermediate_results": {
                "layer_processing_details": self.intermediate_results.get("layer_processing_details", {}),
            },
            "errors": errors,
        }
        return metadata

    def clear_cache(self) -> None:
        """Clear sampler-level caches."""
        self.timing_log.clear()
        self.intermediate_results.clear()

    def get_metadata(self) -> Dict[str, Any]:
        """Return sampler metadata."""
        metadata = super().get_metadata()
        metadata.update(
            {
                "description": "Scope sampling with semantic importance and one-pass global matching (SCPOE)",
                "requires_content": True,
                "requires_query": True,
                "clip_model_path": self.clip_model_path,
                "device": self.device,
                "config": self.config,
            }
        )
        return metadata

