from __future__ import annotations

import heapq
import logging
import time
import random
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Set, Tuple

import numpy as np
import torch

try:
    from tqdm import tqdm  # type: ignore
except Exception:  # pragma: no cover
    tqdm = None

from .base import BaseBatchSampler

logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class _Candidate:
    idx: int
    # Cached exp(w_ix / tau_h) for each seed i in the current batch.
    exp_w: np.ndarray  # shape (m,)


class HOBITBatchSampler(BaseBatchSampler):
    """HOBIT: submodular, seed-based hardness-optimized batching.

    This sampler implements the guaranteed greedy maximization of

        F_C(S) = sum_{i in C} log( sum_{j in S \ {i}} exp(w_ij / tau_h) )

    for a fixed client/seed set C per batch, using lazy greedy on a restricted
    candidate pool.

    Pairwise score (single-active-positive per epoch):

        w_ij = s(q_i, d_j^+) - alpha * s(d_i^+, d_j^+)

    Notes:
    - The (1-1/e) guarantee holds for the fixed C and the candidate universe used
      by greedy. If we restrict candidates to a pool P, the guarantee is relative
      to that P.
    - Multi-positive queries are handled by sampling one active positive per query
      per epoch (or per call) to match the single-positive theory.
        - Optionally, multi-positive queries can use the L2-normalized mean of all
            positive doc embeddings as the per-query positive embedding used by the
            batching objective (see `multi_positive_strategy`).
    """

    def __init__(
        self,
        seed: int = 42,
        batch_size: int = 128,
        topk: int = 32,
        alpha: float = 0.1,
        tau_h: float = 0.05,
        num_batch_seeds: int = 8,
        random_epochs: int = 0,
        seed_selection: str = "weighted_random",
        multi_positive_strategy: str = "random",
        max_positives_per_query: Optional[int] = None,
        # compute knobs
        use_gpu: bool = True,
        topk_batch_size: int = 10000,
        exp_clip: float = 50.0,
        **kwargs,
    ):
        super().__init__(seed=seed, **kwargs)
        self.batch_size = int(batch_size)
        self.topk = int(topk)
        self.alpha = float(alpha)
        self.tau_h = float(tau_h)
        self.num_batch_seeds = int(num_batch_seeds)
        self.random_epochs = int(random_epochs)
        self.seed_selection = seed_selection
        self.multi_positive_strategy = str(multi_positive_strategy).strip().lower()
        self.max_positives_per_query = max_positives_per_query

        assert self.batch_size > 0
        assert self.num_batch_seeds > 0
        assert self.batch_size >= self.num_batch_seeds
        assert self.tau_h > 0.0
        assert self.topk > 0

        assert self.seed_selection in [
            "hardest_first",
            "random",
            "weighted_random",
        ], (
            "seed_selection must be one of ['hardest_first','random','weighted_random'], "
            f"got {self.seed_selection}"
        )

        assert self.multi_positive_strategy in [
            "random",
            "mean",
        ], (
            "multi_positive_strategy must be one of ['random','mean'], "
            f"got {self.multi_positive_strategy}"
        )

        self.use_gpu = bool(use_gpu)
        self.topk_batch_size = int(topk_batch_size)
        self.exp_clip = float(exp_clip)

    def _device(self) -> torch.device:
        device = torch.device(
            "cuda" if self.use_gpu and torch.cuda.is_available() else "cpu"
        )
        if self.use_gpu and device.type != "cuda":
            logger.info(
                "[HOBIT] use_gpu=True but CUDA not available; falling back to CPU"
            )
        return device

    @staticmethod
    def _l2_normalize(x: np.ndarray) -> np.ndarray:
        n = np.linalg.norm(x, axis=1, keepdims=True) + 1e-12
        return x / n

    def _compute_topk_similarity(
        self, query_norm: np.ndarray, pos_emb: np.ndarray
    ) -> Tuple[np.ndarray, np.ndarray]:
        """Top-k similarity between queries and per-query active positives.

        Returns (D, I):
        - D: (n_queries, k) similarities
        - I: (n_queries, k) indices into pos_emb (i.e., query indices)
        """

        t_start = time.perf_counter()
        device = self._device()

        logger.info(
            "[HOBIT] Top-k similarity start | device=%s | q=%s p=%s topk_batch_size=%d",
            str(device),
            str(query_norm.shape),
            str(pos_emb.shape),
            self.topk_batch_size,
        )

        q_t = torch.from_numpy(query_norm.astype(np.float32)).to(device)
        p_t = torch.from_numpy(pos_emb.astype(np.float32)).to(device)

        all_top_scores: List[torch.Tensor] = []
        all_top_indices: List[torch.Tensor] = []
        k = min(self.topk + 1, p_t.shape[0])
        num_chunks = int(np.ceil(q_t.shape[0] / max(1, self.topk_batch_size)))
        starts = list(range(0, q_t.shape[0], self.topk_batch_size))
        chunk_iter = starts
        if tqdm is not None and q_t.shape[0] >= self.topk_batch_size:
            chunk_iter = tqdm(starts, total=num_chunks, desc="[HOBIT] topk chunks")

        log_every = max(1, num_chunks // 10)
        for chunk_idx, start in enumerate(chunk_iter):
            end = min(start + self.topk_batch_size, q_t.shape[0])
            t_chunk = time.perf_counter()
            if tqdm is None or ((chunk_idx + 1) % log_every == 0) or (chunk_idx == 0):
                logger.info(
                    "[HOBIT] Top-k chunk %d/%d | rows=%d:%d",
                    chunk_idx + 1,
                    num_chunks,
                    start,
                    end,
                )
            q_chunk = q_t[start:end]
            sim = torch.matmul(q_chunk, p_t.T)
            top_scores, top_indices = torch.topk(sim, k=k, dim=1)
            all_top_scores.append(top_scores.cpu())
            all_top_indices.append(top_indices.cpu())
            del sim, q_chunk, top_scores, top_indices
            if device.type == "cuda":
                torch.cuda.empty_cache()
            if tqdm is None or ((chunk_idx + 1) % log_every == 0) or (chunk_idx == 0):
                logger.info(
                    "[HOBIT] Top-k chunk %d/%d done in %.3fs",
                    chunk_idx + 1,
                    num_chunks,
                    time.perf_counter() - t_chunk,
                )

        D = torch.cat(all_top_scores, dim=0).numpy()
        I = torch.cat(all_top_indices, dim=0).numpy()
        logger.info(
            "[HOBIT] Top-k similarity done in %.3fs | D=%s I=%s k=%d",
            time.perf_counter() - t_start,
            D.shape,
            I.shape,
            k,
        )
        return D, I

    def _sample_active_positive_ids(
        self,
        query_ids: Sequence[str],
        qrels: Dict[str, Dict[str, int]],
        epoch: Optional[int],
    ) -> Tuple[List[str], Dict[str, str], List[str]]:
        """Return (query_ids_with_pos, qid->active_pos_doc_id, query_ids_no_pos)."""

        rng = np.random.RandomState(self.seed + 1337 * int(epoch or 0))
        query_ids_with_pos: List[str] = []
        qid_to_active_pos: Dict[str, str] = {}
        query_ids_no_pos: List[str] = []

        for qid in query_ids:
            pos_doc_ids = [
                doc_id for doc_id, rel in qrels.get(qid, {}).items() if rel > 0
            ]
            if not pos_doc_ids:
                query_ids_no_pos.append(qid)
                continue
            if self.max_positives_per_query is not None:
                random.shuffle(pos_doc_ids)
                pos_doc_ids = pos_doc_ids[: int(self.max_positives_per_query)]
            active = rng.choice(pos_doc_ids)
            qid_to_active_pos[qid] = str(active)
            query_ids_with_pos.append(qid)

        return query_ids_with_pos, qid_to_active_pos, query_ids_no_pos

    def _collect_positive_doc_ids(
        self,
        query_ids: Sequence[str],
        qrels: Dict[str, Dict[str, int]],
        epoch: Optional[int],
    ) -> Tuple[List[str], Dict[str, List[str]], List[str]]:
        """Return (query_ids_with_pos, qid->positive_doc_ids, query_ids_no_pos)."""

        rng = np.random.RandomState(self.seed + 1337 * int(epoch or 0))
        query_ids_with_pos: List[str] = []
        qid_to_pos_doc_ids: Dict[str, List[str]] = {}
        query_ids_no_pos: List[str] = []

        for qid in query_ids:
            pos_doc_ids = [
                str(doc_id)
                for doc_id, rel in qrels.get(qid, {}).items()
                if rel > 0
            ]
            if not pos_doc_ids:
                query_ids_no_pos.append(qid)
                continue

            if self.max_positives_per_query is not None:
                # Deterministic per epoch.
                if len(pos_doc_ids) > int(self.max_positives_per_query):
                    idxs = rng.choice(
                        len(pos_doc_ids),
                        size=int(self.max_positives_per_query),
                        replace=False,
                    )
                    pos_doc_ids = [pos_doc_ids[int(i)] for i in idxs.tolist()]

            qid_to_pos_doc_ids[qid] = pos_doc_ids
            query_ids_with_pos.append(qid)

        return query_ids_with_pos, qid_to_pos_doc_ids, query_ids_no_pos

    def get_required_positive_doc_ids(
        self,
        query_ids: Sequence[str],
        qrels: Dict[str, Dict[str, int]],
        epoch: Optional[int],
    ) -> Set[str]:
        """Return the set of positive doc IDs HOBIT will require for this epoch.

        Used by the trainer to explicitly encode only missing positive doc IDs
        when forward-pass document embedding caching is enabled.
        """
        if self.multi_positive_strategy == "random":
            query_ids_with_pos, qid_to_active_pos, _ = self._sample_active_positive_ids(
                query_ids=query_ids,
                qrels=qrels,
                epoch=epoch,
            )
            return {qid_to_active_pos[qid] for qid in query_ids_with_pos}

        query_ids_with_pos, qid_to_pos_doc_ids, _ = self._collect_positive_doc_ids(
            query_ids=query_ids,
            qrels=qrels,
            epoch=epoch,
        )
        required: Set[str] = set()
        for qid in query_ids_with_pos:
            for did in qid_to_pos_doc_ids.get(qid, []):
                required.add(str(did))
        return required

    def _compute_w_exp_matrix(
        self,
        seed_qidxs: np.ndarray,
        cand_qidxs: np.ndarray,
        query_norm: np.ndarray,
        pos_emb: np.ndarray,
    ) -> Tuple[np.ndarray, np.ndarray]:
        """Compute w and exp(w/tau_h) for seed-by-candidate.

        Returns:
            w: (m, P)
            exp_w: (m, P)
        """

        device = self._device()
        q_seed = torch.from_numpy(query_norm[seed_qidxs].astype(np.float32)).to(device)
        p_seed = torch.from_numpy(pos_emb[seed_qidxs].astype(np.float32)).to(device)
        p_cand = torch.from_numpy(pos_emb[cand_qidxs].astype(np.float32)).to(device)

        sim_q = torch.matmul(q_seed, p_cand.T)
        sim_p = torch.matmul(p_seed, p_cand.T)
        w = sim_q - (self.alpha * sim_p)

        # exp with clipping for stability
        scaled = torch.clamp(w / self.tau_h, min=-self.exp_clip, max=self.exp_clip)
        exp_w = torch.exp(scaled)

        w_np = w.detach().cpu().numpy()
        exp_np = exp_w.detach().cpu().numpy()
        return w_np, exp_np

    def _select_seeds(
        self,
        available: np.ndarray,
        total_hardness: np.ndarray,
    ) -> np.ndarray:
        """Select seed indices from available according to seed_selection."""

        if available.size == 0:
            return available

        if self.seed_selection == "hardest_first":
            order = np.argsort(-total_hardness[available])
            return available[order]

        if self.seed_selection == "random":
            perm = self.rng.permutation(available.size)
            return available[perm]

        # weighted_random
        scores = total_hardness[available].astype(np.float64)
        scores = scores - np.max(scores)
        probs = np.exp(scores)
        probs_sum = probs.sum()
        if probs_sum <= 0 or not np.isfinite(probs_sum):
            perm = self.rng.permutation(available.size)
            return available[perm]
        probs = probs / probs_sum
        return self.rng.choice(available, size=available.size, replace=False, p=probs)

    def _build_batch_lazy_greedy(
        self,
        remaining: Set[int],
        seed_qidxs: List[int],
        query_norm: np.ndarray,
        pos_emb: np.ndarray,
        topk_indices: np.ndarray,
        *,
        excluded: Optional[Set[int]] = None,
    ) -> List[int]:
        """Build one batch S using lazy greedy for F_C on a restricted candidate pool."""

        m = len(seed_qidxs)
        seed_qidxs_arr = np.array(seed_qidxs, dtype=int)
        S: Set[int] = set(seed_qidxs)

        # Candidate pool P = union of topk neighbors of seeds (excluding already assigned)
        cand_set: Set[int] = set()
        excluded_set = excluded or set()
        for seed in seed_qidxs:
            for nbr in topk_indices[seed]:
                if nbr < 0 or nbr == seed:
                    continue
                if nbr in excluded_set:
                    continue
                if nbr in remaining and nbr not in S:
                    cand_set.add(int(nbr))

        logger.info(
            "[HOBIT] Greedy init | m=%d | remaining=%d | excluded=%d | cand_pool=%d",
            m,
            len(remaining),
            len(excluded_set),
            len(cand_set),
        )

        # If pool is too small to fill the batch, allow sampling from remaining as fallback later.
        cand_qidxs = np.array(sorted(cand_set), dtype=int)
        if cand_qidxs.size == 0:
            return list(S)

        # Compute exp(w/tau_h) for all seed-by-candidate
        _, exp_w = self._compute_w_exp_matrix(seed_qidxs_arr, cand_qidxs, query_norm, pos_emb)

        # Initialize Z_i from seeds-only contributions
        # Z_i = sum_{j in S\{i}} exp(w_ij / tau_h)
        Z = np.zeros((m,), dtype=np.float64)
        if m > 1:
            _, exp_w_seeds = self._compute_w_exp_matrix(seed_qidxs_arr, seed_qidxs_arr, query_norm, pos_emb)
            for i in range(m):
                # exclude diagonal
                Z[i] = float(exp_w_seeds[i].sum() - exp_w_seeds[i, i])

        # Prevent divide-by-zero (if m==1, Z=0); use tiny epsilon
        Z = np.maximum(Z, 1e-12)

        # Build candidates with cached exp vectors
        candidates: List[_Candidate] = []
        for col, qidx in enumerate(cand_qidxs.tolist()):
            candidates.append(_Candidate(idx=qidx, exp_w=exp_w[:, col].astype(np.float64)))

        # Lazy greedy heap: store (-delta, candidate_idx_in_list, version)
        heap: List[Tuple[float, int]] = []
        current_delta: Dict[int, float] = {}

        def compute_delta(exp_vec: np.ndarray) -> float:
            return float(np.log1p(exp_vec / Z).sum())

        for ci, cand in enumerate(candidates):
            d = compute_delta(cand.exp_w)
            current_delta[ci] = d
            heapq.heappush(heap, (-d, ci))

        while len(S) < self.batch_size and heap:
            neg_d, ci = heapq.heappop(heap)
            if ci not in current_delta:
                continue
            # Candidate might have been added already (or removed)
            cand = candidates[ci]
            if cand.idx not in remaining or cand.idx in S:
                current_delta.pop(ci, None)
                continue

            d_cur = compute_delta(cand.exp_w)
            # Peek next best (upper bound) in heap
            while heap and heap[0][1] not in current_delta:
                heapq.heappop(heap)
            next_best = -heap[0][0] if heap else -np.inf

            if d_cur >= next_best:
                # accept
                S.add(cand.idx)
                # update Z
                Z = Z + cand.exp_w
                current_delta.pop(ci, None)
            else:
                # update key and push back
                current_delta[ci] = d_cur
                heapq.heappush(heap, (-d_cur, ci))

        logger.info(
            "[HOBIT] Greedy done | seeds=%d | selected=%d (target=%d)",
            m,
            len(S),
            self.batch_size,
        )

        return list(S)

    def sample(
        self,
        query_ids: List[str],
        query_embeddings: Optional[np.ndarray] = None,
        doc_embeddings: Optional[np.ndarray] = None,
        doc_ids: Optional[List[str]] = None,
        qrels: Optional[Dict[str, Dict[str, int]]] = None,
        epoch: Optional[int] = None,
        **kwargs,
    ) -> List[str]:
        query_embeddings_by_id: Optional[Dict[str, np.ndarray]] = kwargs.get(
            "query_embeddings_by_id", None
        )
        doc_embeddings_by_id: Optional[Dict[str, np.ndarray]] = kwargs.get(
            "doc_embeddings_by_id", None
        )

        if query_embeddings_by_id is not None:
            logger.info(
                "[HOBIT] Using cached query embeddings by id (n=%d)",
                len(query_embeddings_by_id),
            )
        if doc_embeddings_by_id is not None:
            logger.info(
                "[HOBIT] Using cached doc embeddings by id (n=%d)",
                len(doc_embeddings_by_id),
            )

        logger.info(
            "[HOBIT] sample() start | epoch=%s | n_query_ids=%d | multi_positive_strategy=%s",
            str(epoch),
            len(query_ids),
            self.multi_positive_strategy,
        )

        if query_embeddings is None and query_embeddings_by_id is None:
            raise ValueError(
                "Either query_embeddings or query_embeddings_by_id must be provided"
            )
        if doc_embeddings is None and doc_embeddings_by_id is None:
            raise ValueError(
                "Either doc_embeddings or doc_embeddings_by_id must be provided"
            )
        if doc_embeddings is not None and doc_ids is None:
            raise ValueError("doc_ids must be provided when using dense doc_embeddings")
        if qrels is None:
            raise ValueError("qrels must be provided")

        writer = kwargs.get("writer", None)

        # Warm-up: random shuffle
        if epoch is not None and epoch < self.random_epochs:
            valid_qids = [qid for qid in query_ids if qid in qrels and any(rel > 0 for rel in qrels[qid].values())]
            shuffled = valid_qids.copy()
            self.rng.shuffle(shuffled)
            # Append no-positive queries to end
            no_pos = [qid for qid in query_ids if qid not in set(valid_qids)]
            logger.info(
                "[HOBIT] Warmup epoch=%s (< random_epochs=%d): shuffled=%d no_pos=%d",
                str(epoch),
                self.random_epochs,
                len(shuffled),
                len(no_pos),
            )
            return shuffled + no_pos

        # Active positives (multi-positive handling)
        if self.multi_positive_strategy == "random":
            query_ids_with_pos, qid_to_active_pos, query_ids_no_pos = (
                self._sample_active_positive_ids(
                    query_ids=query_ids,
                    qrels=qrels,
                    epoch=epoch,
                )
            )
            qid_to_pos_doc_ids: Optional[Dict[str, List[str]]] = None
        else:
            query_ids_with_pos, qid_to_pos_doc_ids, query_ids_no_pos = (
                self._collect_positive_doc_ids(
                    query_ids=query_ids,
                    qrels=qrels,
                    epoch=epoch,
                )
            )
            qid_to_active_pos = {}

        logger.info(
            "[HOBIT] Positives selected | with_pos=%d no_pos=%d",
            len(query_ids_with_pos),
            len(query_ids_no_pos),
        )

        if not query_ids_with_pos:
            return query_ids_no_pos

        # Subset/build query embeddings aligned to query_ids_with_pos
        if query_embeddings is not None:
            query_id_to_idx = self.create_id_to_index_mapping(query_ids)
            qidxs = np.array(
                [query_id_to_idx[qid] for qid in query_ids_with_pos], dtype=int
            )
            query_emb = query_embeddings[qidxs]
        else:
            assert query_embeddings_by_id is not None
            missing_q = [qid for qid in query_ids_with_pos if qid not in query_embeddings_by_id]
            if missing_q:
                raise ValueError(
                    f"Missing {len(missing_q)} query embeddings in query_embeddings_by_id; "
                    "trainer should fill these explicitly."
                )
            query_emb = np.stack(
                [query_embeddings_by_id[qid] for qid in query_ids_with_pos], axis=0
            )

        # Normalize
        query_norm = self._l2_normalize(query_emb)
        doc_norm = self._l2_normalize(doc_embeddings) if doc_embeddings is not None else None

        def l2_normalize_vec(v: np.ndarray) -> np.ndarray:
            denom = float(np.linalg.norm(v) + 1e-12)
            return (v / denom).astype(np.float32)

        # Build per-query positive embedding array aligned to query_ids_with_pos
        if self.multi_positive_strategy == "random":
            if doc_embeddings is not None:
                assert doc_ids is not None
                doc_id_to_idx = self.create_id_to_index_mapping(doc_ids)
                pos_indices = np.array(
                    [doc_id_to_idx[qid_to_active_pos[qid]] for qid in query_ids_with_pos],
                    dtype=int,
                )
                assert doc_norm is not None
                pos_emb = doc_norm[pos_indices]
            else:
                assert doc_embeddings_by_id is not None
                missing = [
                    qid_to_active_pos[qid]
                    for qid in query_ids_with_pos
                    if qid_to_active_pos[qid] not in doc_embeddings_by_id
                ]
                if missing:
                    raise ValueError(
                        f"Missing {len(missing)} positive doc embeddings in doc_embeddings_by_id; "
                        "trainer should fill these explicitly."
                    )
                pos_emb = np.stack(
                    [l2_normalize_vec(doc_embeddings_by_id[qid_to_active_pos[qid]]) for qid in query_ids_with_pos],
                    axis=0,
                )
        else:
            assert qid_to_pos_doc_ids is not None
            pos_emb_list: List[np.ndarray] = []
            missing_docs = 0
            missing_queries = 0
            t_pos = time.perf_counter()
            for qid in query_ids_with_pos:
                pos_doc_ids = qid_to_pos_doc_ids.get(qid, [])
                vecs: List[np.ndarray] = []
                if doc_embeddings is not None:
                    assert doc_ids is not None
                    assert doc_norm is not None
                    doc_id_to_idx = self.create_id_to_index_mapping(doc_ids)
                    for doc_id in pos_doc_ids:
                        if doc_id in doc_id_to_idx:
                            vecs.append(doc_norm[int(doc_id_to_idx[doc_id])])
                        else:
                            missing_docs += 1
                else:
                    assert doc_embeddings_by_id is not None
                    for doc_id in pos_doc_ids:
                        if doc_id in doc_embeddings_by_id:
                            vecs.append(l2_normalize_vec(doc_embeddings_by_id[doc_id]))
                        else:
                            missing_docs += 1

                if not vecs:
                    missing_queries += 1
                    # Fallback: use zero vector; this query will still participate,
                    # but will likely be deprioritized by similarity.
                    pos_emb_list.append(np.zeros((query_emb.shape[1],), dtype=np.float32))
                    continue

                mean_vec = np.mean(np.stack(vecs, axis=0), axis=0)
                pos_emb_list.append(l2_normalize_vec(mean_vec))

                # Periodic progress so long runs aren't silent
                if len(pos_emb_list) % 5000 == 0:
                    logger.info(
                        "[HOBIT] Building mean pos_emb | done=%d/%d | elapsed=%.1fs",
                        len(pos_emb_list),
                        len(query_ids_with_pos),
                        time.perf_counter() - t_pos,
                    )

            pos_emb = np.stack(pos_emb_list, axis=0)
            if missing_docs > 0 or missing_queries > 0:
                logger.warning(
                    "[HOBIT] multi_positive_strategy=mean: missing_doc_ids=%d missing_queries=%d (pos_emb uses zero fallback)",
                    missing_docs,
                    missing_queries,
                )

        n_queries = query_norm.shape[0]
        logger.info(
            "[HOBIT] Epoch=%s | n_queries=%d batch_size=%d num_batch_seeds=%d topk=%d alpha=%.4f tau_h=%.4f",
            str(epoch),
            n_queries,
            self.batch_size,
            self.num_batch_seeds,
            self.topk,
            self.alpha,
            self.tau_h,
        )
        logger.info(
            "[HOBIT] Active positives | with_pos=%d no_pos=%d (multi_positive_strategy=%s max_positives_per_query=%s)",
            len(query_ids_with_pos),
            len(query_ids_no_pos),
            self.multi_positive_strategy,
            str(self.max_positives_per_query),
        )

        # Precompute topk neighbors for candidate pooling + seed hardness stats
        D, I = self._compute_topk_similarity(query_norm, pos_emb)

        # Compute total hardness per query (approx) for seed ordering
        total_hardness = np.zeros((n_queries,), dtype=np.float64)
        for i in range(n_queries):
            nbrs = I[i]
            sims_q = D[i]
            # pos-pos similarity for those neighbors
            p_i = pos_emb[i]
            p_js = pos_emb[nbrs]
            sims_p = p_js @ p_i
            h = sims_q - float(self.alpha) * sims_p
            # exclude self
            mask = nbrs != i
            total_hardness[i] = float(h[mask].sum())

        hardness_mean = float(np.mean(total_hardness))
        hardness_std = float(np.std(total_hardness))
        hardness_max = float(np.max(total_hardness))
        hardness_min = float(np.min(total_hardness))
        logger.info(
            "[HOBIT] Total hardness stats | mean=%.4f std=%.4f max=%.4f min=%.4f",
            hardness_mean,
            hardness_std,
            hardness_max,
            hardness_min,
        )
        if writer is not None and epoch is not None:
            writer.add_scalar("batching/hardness_mean", hardness_mean, epoch)
            writer.add_scalar("batching/hardness_std", hardness_std, epoch)
            writer.add_scalar("batching/hardness_max", hardness_max, epoch)
            writer.add_scalar("batching/hardness_min", hardness_min, epoch)

        # Seed order over query indices [0..n_queries-1]
        seed_order = self._select_seeds(np.arange(n_queries, dtype=int), total_hardness)

        # Greedy batching across the epoch
        assigned: Set[int] = set()
        batches: List[List[int]] = []
        t_start = time.perf_counter()

        approx_num_batches = int(np.ceil(n_queries / max(1, self.batch_size)))
        log_every_batches = max(1, approx_num_batches // 10)
        num_random_additions = 0
        num_batches_with_random_additions = 0

        seed_ptr = 0
        batch_idx = 0
        while len(assigned) < n_queries:
            # pick next seeds
            seeds: List[int] = []
            while len(seeds) < self.num_batch_seeds and seed_ptr < seed_order.size:
                s = int(seed_order[seed_ptr])
                seed_ptr += 1
                if s in assigned:
                    continue
                seeds.append(s)
            if not seeds:
                break

            remaining = set(range(n_queries)) - assigned
            # Ensure seeds are marked assigned
            for s in seeds:
                assigned.add(s)
                remaining.discard(s)

            if batch_idx % log_every_batches == 0:
                logger.info(
                    "[HOBIT] Batch %d | seeds=%d | assigned=%d/%d",
                    batch_idx + 1,
                    len(seeds),
                    len(assigned),
                    n_queries,
                )

            batch = self._build_batch_lazy_greedy(
                remaining=remaining,
                seed_qidxs=seeds,
                query_norm=query_norm,
                pos_emb=pos_emb,
                topk_indices=I,
            )

            # Mark assigned
            for qidx in batch:
                assigned.add(int(qidx))

            # Fill if needed
            if len(batch) < self.batch_size:
                remaining_list = list(set(range(n_queries)) - assigned)
                if remaining_list:
                    fill_n = min(self.batch_size - len(batch), len(remaining_list))
                    to_add = self.rng.choice(
                        remaining_list,
                        size=fill_n,
                        replace=False,
                    )
                    for qidx in to_add.tolist():
                        batch.append(int(qidx))
                        assigned.add(int(qidx))

                    num_random_additions += int(fill_n)
                    num_batches_with_random_additions += 1

            batches.append(batch)
            batch_idx += 1

        logger.info(
            "[HOBIT] Constructed %d batches in %.3fs (coverage %.2f%%)",
            len(batches),
            time.perf_counter() - t_start,
            100.0 * len(assigned) / n_queries,
        )
        logger.info(
            "[HOBIT] Random fill additions: %d (in %d/%d batches)",
            num_random_additions,
            num_batches_with_random_additions,
            len(batches),
        )

        if writer is not None and epoch is not None:
            writer.add_scalar("batching/random_fill_additions", float(num_random_additions), epoch)
            writer.add_scalar(
                "batching/batches_with_random_fill",
                float(num_batches_with_random_additions),
                epoch,
            )

        # Flatten back to query_ids
        flat_qidxs = [qidx for batch in batches for qidx in batch]
        flat_qids = [query_ids_with_pos[qidx] for qidx in flat_qidxs]
        flat_qids.extend(query_ids_no_pos)
        return flat_qids
