# utils/mmr.py
import numpy as np
from typing import List, Dict, Iterable, Optional, Tuple, Set

def _build_pools_with_mask(S: np.ndarray,
                           allowed_lists: List[Iterable[int]],
                           quotas: List[int],
                           topL_factor: int) -> List[List[int]]:
    M, N = S.shape
    pools: List[List[int]] = []
    for i in range(M):
        allowed = list(allowed_lists[i]) if allowed_lists and allowed_lists[i] is not None else list(range(N))
        if len(allowed)==0 or quotas[i]<=0:
            pools.append([]); continue
        Li = int(min(max(quotas[i]*topL_factor, quotas[i]), len(allowed)))
        scores = S[i, allowed]
        if Li < len(allowed):
            idx_part = np.argpartition(scores, -Li)[-Li:]
            cand = [allowed[k] for k in idx_part]
        else:
            cand = allowed[:]
        cand.sort(key=lambda j: float(S[i, j]), reverse=True)
        pools.append(cand)
    return pools

def mmr_unique_assignment_masked(
    frame_feats: np.ndarray,              # (N, D)
    query_feats: np.ndarray,              # (M, D)
    quotas: List[int],
    allowed_lists: Optional[List[Iterable[int]]] = None,
    lambda_div: float = 0.7,
    topL_factor: int = 3,
    pre_used: Optional[Set[int]] = None,
    pre_selected_feats: Optional[List[np.ndarray]] = None
):
    M, N = query_feats.shape[0], frame_feats.shape[0]
    quotas = quotas[:]
    S = query_feats @ frame_feats.T  # (M, N)

    pools = _build_pools_with_mask(S, allowed_lists, quotas, topL_factor)
    used = set(pre_used) if pre_used else set()
    selected = {i: [] for i in range(M)}
    selected_feats = list(pre_selected_feats) if pre_selected_feats else []

    remaining = sum(quotas)
    while remaining > 0:
        best = None  # (score, i, f)
        for i in range(M):
            if quotas[i] <= 0: continue
            for f in pools[i]:
                if f in used: continue
                rel = float(S[i, f])
                div = max(float(frame_feats[f] @ g) for g in selected_feats) if selected_feats else 0.0
                score = lambda_div * rel - (1.0 - lambda_div) * div
                if (best is None) or (score > best[0]): best = (score, i, f)

        if best is None:
            # Fallback: for each query, pick the best score among allowed-but-unused frames.
            fallback = None
            for i in range(M):
                if quotas[i] <= 0: continue
                allowed = list(allowed_lists[i]) if allowed_lists and allowed_lists[i] is not None else list(range(N))
                cand = [f for f in allowed if f not in used]
                if not cand: continue
                f = max(cand, key=lambda x: float(S[i, x]))
                rel = float(S[i, f])
                div = max(float(frame_feats[f] @ g) for g in selected_feats) if selected_feats else 0.0
                score = lambda_div * rel - (1.0 - lambda_div) * div
                if (fallback is None) or (score > fallback[0]): fallback = (score, i, f)
            if fallback is None: break
            _, i, f = fallback
        else:
            _, i, f = best

        used.add(f)
        quotas[i] -= 1
        remaining -= 1
        selected[i].append(f)
        selected_feats.append(frame_feats[f])

    selected_global = sorted({f for arr in selected.values() for f in arr})
    return selected, selected_global, used, selected_feats
