import itertools, math
from typing import Dict, List, Tuple

# ---------- core builders ----------
def build_config_lut(R: int, C: int) -> Dict:
    row_sets: List[Tuple[int, ...]] = list(itertools.combinations(range(R), C))  # S
    perms:    List[Tuple[int, ...]] = list(itertools.permutations(range(C)))     # P
    S, P = len(row_sets), len(perms)

    lut = []
    for s, rows in enumerate(row_sets):
        for p, perm in enumerate(perms):
            k = s * P + p
            lut.append({
                "config_id": k,
                "rowset_id": s,
                "perm_id":   p,
                "rows":      rows,
                "perm":      perm,
            })
    return {"R": R, "C": C, "S": S, "P": P, "row_sets": row_sets, "perms": perms, "lut": lut}

def decode_config_id(lut: Dict, k: int) -> Tuple[Tuple[int, ...], Tuple[int, ...], int, int]:
    P = lut["P"]
    s, p = divmod(int(k), P)
    return lut["row_sets"][s], lut["perms"][p], s, p

# ---------- pretty print ----------
def print_config_lut(lut: Dict, max_lines: int | None = None) -> None:
    R, C, S, P = lut["R"], lut["C"], lut["S"], lut["P"]
    total = S * P
    print(f"Config ID mapping for (R={R}, C={C})  S=comb(R,C)={S}, P=C!={P}, total={total}")
    print(" id | rowset_id rows        | perm_id perm")
    print("----+------------------------+----------------")
    for i, rec in enumerate(lut["lut"]):
        if max_lines is not None and i >= max_lines:
            print(f"... ({total - max_lines} more)")
            break
        k = rec["config_id"]; s = rec["rowset_id"]; p = rec["perm_id"]
        rows = rec["rows"]; perm = rec["perm"]
        print(f"{k:3d} | {s:9d} {str(rows):12s} | {p:7d} {str(perm)}")

# ---------- small cache convenience (optional) ----------
class ConfigIdLUTCache:
    def __init__(self, R_fixed: int):
        self.R = int(R_fixed)
        self._cache: Dict[int, Dict] = {}

    def get(self, C: int) -> Dict:
        if C not in self._cache:
            self._cache[C] = build_config_lut(self.R, C)
        return self._cache[C]

    def print_once(self, C_list: List[int], max_lines: int | None = None) -> None:
        for C in C_list:
            print_config_lut(self.get(C), max_lines=max_lines)
            print()  # spacer

# ---------- merging logic ----------
def merge_configs_until_unique_max(
    lut: Dict,
    values: List[int],
    *,
    verbose_every: int = 0  # set >0 to print every k merges
):
    """
    Merge row->col pairs from configs in descending value order.
    Stop at the first time exactly one row has a strictly largest union size.
    Returns a dict with the current state at stop (or end if never unique).
    """
    R, C, S, P = lut["R"], lut["C"], lut["S"], lut["P"]
    total = S * P
    if len(values) != total:
        raise ValueError(f"values length {len(values)} != number of configs {total}")

    # Sort config ids by value desc, tiebreak by smaller id for determinism
    order = sorted(range(total), key=lambda k: (values[k], -k), reverse=True)

    # Union storage: per row, which unique column ids have appeared
    unions = [set() for _ in range(R)]
    selected = []  # sequence of (config_id, value)

    prev_counts = [0]*R

    for t, k in enumerate(order, start=1):
        rows, perm, s_id, p_id = decode_config_id(lut, k)
        # apply this config's row->col matches
        for r_local, c in zip(rows, perm):
            unions[r_local].add(c)
        selected.append((k, values[k]))

        # compute counts and check for unique max
        counts = [len(s) for s in unions]
        max_count = max(counts)
        num_with_max = counts.count(max_count)

        if verbose_every and t % verbose_every == 0:
            print(f"[step {t}] added config {k} (value={values[k]}), "
                  f"counts={counts}")

        # Break when exactly one row strictly leads
        if max_count > 0 and num_with_max == 1:
            leader_row = counts.index(max_count)
            print("\n=== Unique leader detected ===")
            print(f"After adding config {k} (value={values[k]}):")
            print(f"rows={rows}, perm={perm}  (rowset_id={s_id}, perm_id={p_id})")
            print("Union sizes per row:", counts)
            print("Current unions (row -> sorted cols):")
            for r in range(R):
                print(f"  row {r}: {sorted(unions[r])}")
            print("Selected so far (config_id, value) [top->current]:")
            print(selected, "... (total:", len(selected), ")")

            # drop into debugger here so you can inspect state
            breakpoint()

            # If you continue from the debugger, the loop will go on merging.
            # Remove 'breakpoint()' and uncomment the next line to *stop* automatically:
            # break

        prev_counts = counts

    # If we never hit a unique leader, return the final state
    return {
        "selected": selected,
        "unions": {r: sorted(list(s)) for r, s in enumerate(unions)},
        "counts": [len(s) for s in unions],
    }

if __name__ == "__main__":
    # Instruction: run this to get the mapping. it'll break tie and you can look up "current unions"
    # to construct the learned matchings
    lut_cache = ConfigIdLUTCache(R_fixed=6)
    lut_C4 = lut_cache.get(4)

    # NOTE extract the accumulated counts from csv file and put it here
    values = [0, 6, 26, 13, 46, 11, 22, 8, 6, 22, 10, 17, 27, 8, 1, 9, 4, 48, 35, 6, 2, 6, 7, 24, 3, 7, 7, 6, 42, 1, 10, 16, 0, 2, 20, 0, 17, 7, 3, 2, 2, 1, 56, 4, 2, 0, 5, 0, 2, 6, 8, 6, 45, 2, 13, 32, 5, 1, 9, 1, 16, 3, 4, 0, 4, 2, 79, 1, 2, 0, 5, 1, 2, 8, 3, 15, 36, 15, 4, 99, 4, 0, 109, 1, 10, 5, 4, 0, 14, 0, 36, 6, 22, 0, 7, 1, 4, 20, 7, 20, 33, 11, 15, 92, 10, 0, 90, 1, 20, 7, 8, 0, 15, 0, 41, 13, 13, 0, 13, 1, 13, 32, 14, 4, 7, 17, 108, 139, 3, 4, 3, 2, 8, 7, 2, 1, 5, 6, 4, 24, 0, 0, 8, 20, 10, 70, 4, 4, 11, 14, 1, 28, 18, 0, 34, 0, 0, 8, 12, 0, 10, 0, 22, 15, 90, 0, 6, 0, 14, 108, 6, 3, 14, 11, 2, 14, 16, 0, 34, 0, 12, 2, 14, 1, 3, 1, 7, 17, 116, 0, 1, 0, 180, 256, 9, 4, 18, 73, 57, 76, 1, 2, 2, 11, 29, 6, 3, 3, 5, 29, 27, 79, 1, 0, 5, 14, 27, 56, 16, 6, 13, 33, 113, 135, 0, 2, 0, 5, 17, 8, 0, 3, 6, 5, 13, 46, 0, 0, 2, 5, 4, 61, 7, 0, 15, 5, 0, 7, 5, 1, 7, 1, 1, 17, 3, 0, 4, 6, 3, 6, 64, 0, 2, 0, 11, 72, 9, 1, 18, 9, 1, 7, 6, 1, 12, 0, 2, 8, 8, 2, 6, 1, 5, 0, 72, 0, 0, 0, 165, 163, 3, 3, 10, 9, 14, 5, 1, 2, 1, 2, 4, 2, 0, 2, 19, 25, 12, 19, 2, 2, 1, 4, 55, 114, 15, 3, 21, 24, 22, 29, 0, 0, 1, 0, 26, 6, 4, 1, 3, 7, 12, 6, 1, 0, 2, 17, 6, 13, 14, 2, 14, 16, 157, 193, 1, 1, 1, 0, 4, 2, 1, 0, 0, 2, 11, 32, 0, 0, 6, 19]

    state = merge_configs_until_unique_max(lut_C4, values, verbose_every=0)

    # If you continued past the breakpoint (or if no unique leader ever occurs),
    # you still get the final state here:
    print("\nFinal (or post-continue) unions:")
    for r, cols in state["unions"].items():
        print(f"  row {r}: {cols}")
    print("Counts per row:", state["counts"])

