import numpy as np
from scipy.special import expit

_idx_matrix_cache = {}


def build_theta(n, gap):
    theta = ((n - 1) / 2 - np.arange(n)) * gap
    return theta


def build_uniform_theta(n, k, theta_min, theta_max, rng, min_gap=0.02):
    while True:
        theta = rng.uniform(theta_min, theta_max, n)
        theta = theta - theta.mean()
        sorted_theta = np.sort(theta)[::-1]
        if sorted_theta[k - 1] - sorted_theta[k] >= min_gap:
            return theta


def build_sst_matrix_boundary_min(n, k, sst_spread, boundary_min, rng):
    P = np.full((n, n), 0.5)

    for i in range(n - 1):
        if i == k - 1:
            # k,k+1: enforce minimum
            P[i, i + 1] = rng.uniform(0.5 + boundary_min, 0.5 + sst_spread)
        else:
            P[i, i + 1] = rng.uniform(0.5, 0.5 + sst_spread)

    for d in range(2, n):
        for i in range(n - d):
            j = i + d
            left = P[i, j - 1]
            below = P[i + 1, j]
            L = max(left, below)
            U = min(L + sst_spread, 1.0)
            P[i, j] = rng.uniform(L, U)

    for i in range(n):
        for j in range(i + 1, n):
            P[j, i] = 1.0 - P[i, j]

    return P


def build_bt_matrix(theta):
    n = len(theta)
    P = np.zeros((n, n))
    for i in range(n):
        for j in range(n):
            P[i, j] = expit(theta[i] - theta[j])
    return P


def shuffle_matrix(P, perm):
    return P[np.ix_(perm, perm)]


def build_P(n, k, mode, seed, **params):
    rng = np.random.default_rng(seed)
    perm = rng.permutation(n) 

    if mode == "gap":
        theta_base = build_theta(n, params['gap'])
        theta = theta_base[perm]
        P = build_bt_matrix(theta)
        true_top_k = get_top_k(theta, k)

    elif mode == "uniform":
        theta = build_uniform_theta(n, k, params['theta_min'], params['theta_max'], rng)
        P = build_bt_matrix(theta)
        true_top_k = get_top_k(theta, k)

    elif mode == "sst_boundary_min":
        P_base = build_sst_matrix_boundary_min(
            n, k, params['sst_spread'], params['boundary_min'], rng)
        P = shuffle_matrix(P_base, perm)
        theta = None
        true_top_k = set(i for i in range(n) if perm[i] < k)
    else:
        raise ValueError(f"Unknown mode: {mode}")

    return P, true_top_k, perm, theta


def get_idx_matrix(n, pair_i, pair_j):
    if n not in _idx_matrix_cache:
        idx_matrix = np.zeros((n, n), dtype=np.int32)
        k = np.arange(len(pair_i), dtype=np.int32)
        idx_matrix[pair_i, pair_j] = k
        idx_matrix[pair_j, pair_i] = k
        _idx_matrix_cache[n] = idx_matrix
    return _idx_matrix_cache[n]


def get_pairs(n):
    return [(i, j) for i in range(n) for j in range(i + 1, n)]


def build_pair_arrays(n):
    pairs = get_pairs(n)
    pair_i = np.array([p[0] for p in pairs], dtype=np.int32)
    pair_j = np.array([p[1] for p in pairs], dtype=np.int32)
    pair_to_idx = {p: idx for idx, p in enumerate(pairs)}
    return pair_i, pair_j, pair_to_idx


def get_top_k(theta, k):
    return set(np.argsort(theta)[-k:])


def get_boundary_pairs(theta, k):
    top_k = get_top_k(theta, k)
    n = len(theta)
    boundary = []
    for i in range(n):
        for j in range(i + 1, n):
            i_in = i in top_k
            j_in = j in top_k
            if i_in != j_in:
                u, v = (i, j) if i_in else (j, i)
                boundary.append((i, j, u, v))
    return boundary
