from extract_data import get_sbj_data
import numpy as np
from collections import defaultdict
from itertools import product
import json
from pathlib import Path

def balanced_two_splits(
    x: np.ndarray,
    k: int,
    cols=(0, 1),
    categories=([0, 1], [0, 1, 2]),
    random_state: int | None = None,
    require_all_combos: bool = False,
) -> tuple[np.ndarray, np.ndarray]:
    """
    Return two disjoint index lists (idx_split_1, idx_split_2), each of length k,
    such that:
      • both splits are as balanced as possible over combos of (x[:, cols[0]], x[:, cols[1]])
      • the combo distributions are also balanced *between* the two splits
      • no overlap between the two sets of indices

    If some combos are missing or too small, the function distributes as evenly as possible
    without replacement and within the available capacities.

    Parameters
    ----------
    x : np.ndarray
        Array of shape (N, D).
    k : int
        Target size of each split (total of 2k indices will be drawn).
    cols : tuple[int, int]
        Columns to use for the combo grid (default (0, 1)).
    categories : tuple[list, list]
        Allowed categories per column, defining the target grid.
    random_state : int | None
        Seed for reproducibility.
    require_all_combos : bool
        If True, error if any combo in the grid has zero available rows.

    Returns
    -------
    (np.ndarray, np.ndarray)
        idx_split_1, idx_split_2 — disjoint arrays of length k each.
    """
    rng = np.random.default_rng(random_state)
    N = x.shape[0]
    if k < 0:
        raise ValueError("k must be non-negative.")
    if 2 * k > N:
        raise ValueError(f"Need 2k <= N, got 2k={2*k} > N={N}.")

    c0, c1 = cols
    cats0, cats1 = categories
    a_col, b_col = x[:, c0], x[:, c1]

    # Build pools (indices) per combo over the specified grid
    combo_to_pool = defaultdict(list)
    all_combos = list(product(cats0, cats1))
    for (a, b) in all_combos:
        pool = np.flatnonzero((a_col == a) & (b_col == b))
        combo_to_pool[(a, b)] = pool.tolist()

    if require_all_combos:
        missing = [(a, b) for (a, b) in all_combos if len(combo_to_pool[(a, b)]) == 0]
        if missing:
            raise ValueError(f"Missing combos in data: {missing}")

    # Keep only non-empty combos (can't sample from empty ones)
    existing = [(g, combo_to_pool[g]) for g in all_combos if len(combo_to_pool[g]) > 0]
    if not existing:
        raise ValueError("No non-empty combos available.")
    groups, pools = zip(*existing)
    capacities = np.array([len(p) for p in pools], dtype=int)
    total_available = capacities.sum()
    need = 2 * k
    if need > total_available:
        raise ValueError(f"Requested 2k={need} but only {total_available} rows available across combos.")

    # -------- Stage A: allocate totals per combo (a_g) via water-filling under capacities --------
    lo, hi = 0, capacities.max()
    while lo < hi:
        mid = (lo + hi) // 2
        if np.minimum(capacities, mid).sum() >= need:
            hi = mid
        else:
            lo = mid + 1
    t = lo
    alloc_total = np.minimum(capacities, t).astype(int)
    excess = alloc_total.sum() - need
    if excess > 0:
        # Prefer trimming from non-cap-limited groups (alloc == t)
        candidates = np.where((alloc_total > 0) & (alloc_total == t))[0].tolist()
        rng.shuffle(candidates)
        take = min(excess, len(candidates))
        for i in candidates[:take]:
            alloc_total[i] -= 1
        excess -= take
        if excess > 0:
            # Trim any remaining randomly from groups with alloc > 0
            candidates2 = np.where(alloc_total > 0)[0].tolist()
            rng.shuffle(candidates2)
            for i in candidates2[:excess]:
                alloc_total[i] -= 1

    assert alloc_total.sum() == need

    # -------- Stage B: split each combo's total between split1 and split2 as evenly as possible --------
    a1 = alloc_total // 2
    a2 = alloc_total - a1
    s1 = int(a1.sum())
    if s1 < k:
        # Distribute +1 to split1 for exactly (k - s1) combos with odd totals
        odd_idx = np.where(alloc_total % 2 == 1)[0].tolist()
        rng.shuffle(odd_idx)
        for i in odd_idx[: (k - s1)]:
            a1[i] += 1
            a2[i] -= 1
    elif s1 > k:
        # (Unlikely with floor, but handle anyway)
        # Move 1 from split1 to split2 for some combos with a1 > 0
        candidates = np.where(a1 > 0)[0].tolist()
        rng.shuffle(candidates)
        for i in candidates[: (s1 - k)]:
            a1[i] -= 1
            a2[i] += 1

    assert a1.sum() == k and a2.sum() == k
    assert np.all(a1 >= 0) and np.all(a2 >= 0)

    # -------- Stage C: sample without replacement and assign to splits (no overlap) --------
    idx1, idx2 = [], []
    for pool, n_tot, n1, n2 in zip(pools, alloc_total, a1, a2):
        if n_tot == 0:
            continue
        chosen = rng.choice(pool, size=n_tot, replace=False)
        rng.shuffle(chosen)  # random order before splitting
        if n1 > 0:
            idx1.extend(chosen[:n1].tolist())
        if n2 > 0:
            idx2.extend(chosen[n1:n1+n2].tolist())

    # Final shuffle within each split
    rng.shuffle(idx1)
    rng.shuffle(idx2)

    # Safety checks
    assert len(idx1) == k and len(idx2) == k
    if set(idx1).intersection(idx2):
        raise RuntimeError("Internal error: overlap detected between splits.")
    return np.array(idx1, dtype=int), np.array(idx2, dtype=int)

def verify_no_overlap(old_idx, new_idx, return_overlap: bool = False):
    """
    Check that two index collections share no elements.

    Parameters
    ----------
    old_idx, new_idx : array-like of ints
        Existing indices and newly sampled indices.
    return_overlap : bool
        If True, also return the sorted overlapping indices as a NumPy array.

    Returns
    -------
    bool
        True if there is no overlap, False otherwise.
    (optional) np.ndarray
        Sorted unique overlapping indices (only if return_overlap=True).
    """
    a = np.asarray(old_idx, dtype=int).ravel()
    b = np.asarray(new_idx, dtype=int).ravel()
    overlap = np.intersect1d(a, b)
    ok = overlap.size == 0
    if return_overlap:
        return ok, overlap
    return ok


def print_combo_counts(
    x: np.ndarray,
    cols=(0, 1),
    categories=([0, 1], [0, 1, 2]),
    return_counts: bool = False,
):
    """
    Print the number of occurrences for each combo of x[:, cols[0]] and x[:, cols[1]].
    By default, counts combos in {0,1} × {0,1,2}. Missing combos print as 0.

    Parameters
    ----------
    x : np.ndarray
        Array of shape (N, D).
    cols : tuple[int, int]
        Columns to form combos from (default: (0, 1)).
    categories : tuple[list, list]
        Categories to enumerate per column (default: [0,1] × [0,1,2]).
    return_counts : bool
        If True, also return a dict mapping (a,b) -> count.
    """
    c0, c1 = cols
    cats0, cats1 = categories

    counts = {}
    a_col, b_col = x[:, c0], x[:, c1]
    for a, b in product(cats0, cats1):
        counts[(a, b)] = int(np.sum((a_col == a) & (b_col == b)))

    # Pretty print
    print(f"Combo counts for x[:, {cols}]:")
    for a in cats0:
        for b in cats1:
            print(f"  ({a}, {b}): {counts[(a, b)]}")
    total = sum(counts.values())
    print(f"Total across listed combos: {total}")

    if return_counts:
        return counts
