import numpy as np
import pandas as pd
from typing import List, Tuple, Dict


def _packed_error_bytes(X: pd.DataFrame) -> pd.Series:
    err = 1 - X.values.astype(np.uint8)
    packed = np.packbits(err, axis=1)
    sig = pd.Series([row.tobytes() for row in packed], index=X.index)
    return sig


def build_subset_by_attributes(
    X: pd.DataFrame,
    attrs: pd.DataFrame,
    k: int = 100,
    box_count: int = 10,
    attr_pools: Tuple[str, ...] = ('surprise','risk','bridge','uniqueness','typicality'),
    rep_keep_power: float = 0.5,
    rep_keep_min: int = 1,
    box_order: str = 'asc',
) -> Tuple[List[str], pd.DataFrame, pd.DataFrame]:

    # check before launch
    attrs = attrs.reindex(X.index).copy()
    need = {'difficulty','uniqueness','risk','surprise','typicality','bridge','cluster'}
    if not need.issubset(attrs.columns):
        missing = sorted(list(need - set(attrs.columns)))
        raise KeyError(f"Missing columns: {missing}")

    # remove duplicates
    sig = _packed_error_bytes(X)
    keep_mask = ~sig.duplicated(keep='first')
    X = X.loc[keep_mask]
    attrs = attrs.loc[keep_mask]

    # set up per-cluster quota for equal selection
    cluster = attrs['cluster']
    quota = (cluster.value_counts() ** rep_keep_power).apply(np.ceil).astype(int).clip(lower=rep_keep_min)
    rep_rank = attrs.groupby(cluster)['typicality'].rank(method='first', ascending=False)
    within_quota = rep_rank <= cluster.map(quota)
    X = X.loc[within_quota]
    attrs = attrs.loc[within_quota]

    # difficulty bins allocation
    diff = attrs['difficulty'].astype(float)
    try:
        box = pd.qcut(diff, q=box_count, labels=False, duplicates='drop')
        box_count_eff = int(box.max() + 1) if len(box) else 0
        if box_count_eff <= 0:
            raise ValueError
    except Exception:
        r = diff.rank(method="average", pct=True)
        box = (np.floor(r * box_count)).astype(int).clip(0, box_count - 1)
        box_count_eff = box_count

    counts = box.value_counts().reindex(range(box_count_eff), fill_value=0)
    base, rem = divmod(k, box_count_eff if box_count_eff > 0 else 1)
    box_quota = pd.Series(base, index=range(box_count_eff))
    if rem:
        box_quota.iloc[:rem] += 1
    box_quota = np.minimum(box_quota, counts).astype(int)

    # sort indices per (box, attribute) in descending order
    by_box_attr: Dict[int, Dict[str, List[str]]] = {
        b: {a: attrs.loc[box == b, a].sort_values(ascending=False).index.tolist() for a in attr_pools}
        for b in range(box_count_eff)
    }

    selected: List[str] = []
    in_selected = set()
    box_seq = list(range(box_count_eff))[::-1] if box_order == 'desc' else list(range(box_count_eff))
    ptr: Dict[int, Dict[str, int]] = {b: {a: 0 for a in attr_pools} for b in box_seq}

    for b in box_seq:
        target = int(box_quota.get(b, 0))
        added_b = 0
        if target <= 0:
            continue
        while added_b < target:
            added_this_round = 0
            for a in attr_pools:
                lst = by_box_attr[b][a]
                p = ptr[b][a]
                while p < len(lst) and lst[p] in in_selected:
                    p += 1
                ptr[b][a] = p
                if p >= len(lst):
                    continue
                qid = lst[p]
                selected.append(qid)
                in_selected.add(qid)
                ptr[b][a] += 1
                added_b += 1
                added_this_round += 1
                if added_b >= target:
                    break
            if added_this_round == 0:
                break

    if len(selected) < k:
        print("less than target, now trying global fallback")
        remaining = k - len(selected)
        global_lists = {a: attrs[a].sort_values(ascending=False).index.tolist() for a in attr_pools}
        gptr = {a: 0 for a in attr_pools}
        while remaining > 0:
            progressed = 0
            for a in attr_pools:
                lst = global_lists[a]; p = gptr[a]
                while p < len(lst) and lst[p] in in_selected:
                    p += 1
                gptr[a] = p
                if p >= len(lst):
                    continue
                qid = lst[p]
                selected.append(qid); in_selected.add(qid)
                gptr[a] += 1; remaining -= 1; progressed += 1
                if remaining <= 0:
                    break
            if progressed == 0:
                break

    # output
    sel = pd.Index(selected, name='question_id')
    manifest = attrs.loc[sel, ['cluster','difficulty','uniqueness','risk','surprise',
                               'typicality','bridge']].copy()
    manifest['difficulty_box'] = box.loc[sel].values

    if len(sel) > 0:
        subset_acc = X.loc[sel].mean(axis=0).astype(float)
        acc_table = (pd.DataFrame({'subset_acc': subset_acc, 'n_items': len(sel)})
                       .sort_values('subset_acc', ascending=False))
    else:
        acc_table = pd.DataFrame({'subset_acc': X.mean(axis=0).astype(float), 'n_items': 0})

    return selected, manifest, acc_table
