import numpy as np
from scipy.special import expit
from bt_utils import get_idx_matrix

NEWTON_TOL = 1e-8
NEWTON_MAX_ITERS = 50
REG_EPS = 1e-8
DELTA_CLIP = 10.0
ETA_CLIP = 30.0
LS_MAX_STEPS = 10
LS_BACKTRACK = 0.5
LS_ACCEPT_TOL = 1e-10


def mle_bt(n, N, S, pair_i, pair_j, theta_init=None, tol=NEWTON_TOL, max_iters=NEWTON_MAX_ITERS):
    theta = theta_init.copy() - theta_init.mean() if theta_init is not None else np.zeros(n)
    idx_matrix = get_idx_matrix(n, pair_i, pair_j)

    for _ in range(max_iters):
        theta -= theta.mean()
        eta = np.clip(theta[pair_i] - theta[pair_j], -ETA_CLIP, ETA_CLIP)
        p = expit(eta)
        diff = S - N * p
        grad = np.bincount(pair_i, diff, minlength=n) - np.bincount(pair_j, diff, minlength=n)
        if np.linalg.norm(grad) < tol:
            break

        hess_weights = N * p * (1 - p)
        L = -hess_weights[idx_matrix]
        np.fill_diagonal(L, 0)
        degrees = -L.sum(axis=1)
        np.fill_diagonal(L, degrees + REG_EPS)

        if degrees.max() < REG_EPS:
            break

        L_reg = L + REG_EPS / n
        b_centered = grad - grad.mean()
        delta = np.linalg.solve(L_reg, b_centered)
        delta = delta - delta.mean()
        delta = np.clip(delta, -DELTA_CLIP, DELTA_CLIP)

        alpha = 1.0
        ll_old = log_likelihood(theta, N, S, pair_i, pair_j)
        for _ in range(LS_MAX_STEPS):
            theta_new = theta + alpha * delta
            theta_new -= theta_new.mean()
            if log_likelihood(theta_new, N, S, pair_i, pair_j) >= ll_old - LS_ACCEPT_TOL:
                break
            alpha *= LS_BACKTRACK

        theta = theta_new

        if np.max(np.abs(alpha * delta)) < tol:
            break

    theta -= theta.mean()
    return theta


def log_likelihood(theta, N, S, pair_i, pair_j):
    eta = theta[pair_i] - theta[pair_j]
    return np.sum(S * eta - N * np.logaddexp(0, eta))


def mle_bt_constrained(n, N, S, pair_i, pair_j, u, v, tol=NEWTON_TOL, max_iters=NEWTON_MAX_ITERS):
    if u > v:
        u, v = v, u

    old_to_new = np.arange(n, dtype=np.int32)
    old_to_new[v] = u
    old_to_new[v+1:] -= 1
    n_new = n - 1
    pi = old_to_new[pair_i].copy()
    pj = old_to_new[pair_j].copy()
    N_m, S_m = N.copy(), S.copy()

    mask = pi != pj
    pi, pj, N_m, S_m = pi[mask], pj[mask], N_m[mask], S_m[mask]
    flip = pi > pj
    pi[flip], pj[flip] = pj[flip], pi[flip]
    S_m[flip] = N_m[flip] - S_m[flip]

    pair_id = pi * n_new + pj
    unique_ids, inverse = np.unique(pair_id, return_inverse=True)
    N_new = np.bincount(inverse, N_m).astype(np.float64)
    S_new = np.bincount(inverse, S_m).astype(np.float64)
    pair_i_new = (unique_ids // n_new).astype(np.int32)
    pair_j_new = (unique_ids % n_new).astype(np.int32)

    theta_merged = mle_bt(n_new, N_new, S_new, pair_i_new, pair_j_new, tol=tol, max_iters=max_iters)
    theta = theta_merged[old_to_new]
    return theta - theta.mean()


def glr_statistic(n, N, S, pair_i, pair_j, theta_mle, u, v):
    if theta_mle[u] <= theta_mle[v]:
        return 0.0

    ll_mle = log_likelihood(theta_mle, N, S, pair_i, pair_j)
    theta_constrained = mle_bt_constrained(n, N, S, pair_i, pair_j, u, v)
    ll_constrained = log_likelihood(theta_constrained, N, S, pair_i, pair_j)

    return max(0.0, ll_mle - ll_constrained)
