import numpy as np
from scipy.special import expit
from bt_utils import get_idx_matrix

NEWTON_TOL = 1e-3
NEWTON_MAX_ITERS = 50
REG_EPS = 1e-8
DELTA_CLIP = 10.0
ETA_CLIP = 30.0
LS_MAX_STEPS = 10
LS_BACKTRACK = 0.5
LS_ACCEPT_TOL = 1e-10


def kl_divergence_bt_vec(eta, eta_prime):
    A_eta = np.logaddexp(0, eta)
    A_eta_prime = np.logaddexp(0, eta_prime)
    A_prime_eta = expit(eta)
    return A_eta_prime - A_eta - A_prime_eta * (eta_prime - eta)


def compute_Dw(theta, theta_prime, pair_i, pair_j, w):
    eta = theta[pair_i] - theta[pair_j]
    eta_prime = theta_prime[pair_i] - theta_prime[pair_j]
    return np.sum(w * kl_divergence_bt_vec(eta, eta_prime))


def kl_projection_bt(theta, pair_i, pair_j, w, u, v, n, tol=NEWTON_TOL, max_iters=NEWTON_MAX_ITERS):
    eta_fixed = theta[pair_i] - theta[pair_j]
    eta_fixed_clip = np.clip(eta_fixed, -ETA_CLIP, ETA_CLIP)
    sig_fixed = expit(eta_fixed_clip)
    A_eta_fixed = np.logaddexp(0, eta_fixed)
    idx_matrix = get_idx_matrix(n, pair_i, pair_j)
    theta_star = _kl_projection_unconstrained(
        theta, pair_i, pair_j, w, n, tol, max_iters,
        eta_fixed, sig_fixed, A_eta_fixed, idx_matrix
    )
    if theta_star[v] >= theta_star[u] - LS_ACCEPT_TOL:
        gamma = _compute_Dw_fast(theta_star, pair_i, pair_j, w, eta_fixed, sig_fixed, A_eta_fixed)
        return theta_star, max(0, gamma)

    theta_star = _kl_projection_constrained(
        theta, pair_i, pair_j, w, u, v, n, tol, max_iters,
        eta_fixed, sig_fixed, A_eta_fixed, idx_matrix
    )
    gamma = _compute_Dw_fast(theta_star, pair_i, pair_j, w, eta_fixed, sig_fixed, A_eta_fixed)
    return theta_star, max(0, gamma)


def _compute_Dw_fast(theta_prime, pair_i, pair_j, w, eta_fixed, sig_fixed, A_eta_fixed):
    eta_prime = theta_prime[pair_i] - theta_prime[pair_j]
    A_eta_prime = np.logaddexp(0, eta_prime)
    return np.sum(w * (A_eta_prime - A_eta_fixed - sig_fixed * (eta_prime - eta_fixed)))


def _kl_projection_unconstrained(theta, pair_i, pair_j, w, n, tol, max_iters,
                                  eta_fixed, sig_fixed, A_eta_fixed, idx_matrix):
    theta_prime = theta.copy()
    theta_prime -= theta_prime.mean()
    for _ in range(max_iters):
        eta_prime = theta_prime[pair_i] - theta_prime[pair_j]
        eta_prime_clip = np.clip(eta_prime, -ETA_CLIP, ETA_CLIP)
        sig_prime = expit(eta_prime_clip)

        diff = w * (sig_prime - sig_fixed)
        grad = np.bincount(pair_i, diff, minlength=n) - np.bincount(pair_j, diff, minlength=n)

        if np.linalg.norm(grad) < tol:
            break

        hess_weights = w * sig_prime * (1 - sig_prime)
        L = -hess_weights[idx_matrix]
        np.fill_diagonal(L, 0)
        degrees = -L.sum(axis=1)
        np.fill_diagonal(L, degrees + REG_EPS)

        if degrees.max() < REG_EPS:
            break

        L_reg = L + REG_EPS / n
        b = -grad
        b_centered = b - b.mean()
        delta = np.linalg.solve(L_reg, b_centered)
        delta = delta - delta.mean()
        delta = np.clip(delta, -DELTA_CLIP, DELTA_CLIP)

        alpha = 1.0
        Dw_old = _compute_Dw_fast(theta_prime, pair_i, pair_j, w, eta_fixed, sig_fixed, A_eta_fixed)
        for _ in range(LS_MAX_STEPS):
            theta_new = theta_prime + alpha * delta
            theta_new -= theta_new.mean()
            if _compute_Dw_fast(theta_new, pair_i, pair_j, w, eta_fixed, sig_fixed, A_eta_fixed) <= Dw_old + LS_ACCEPT_TOL:
                break
            alpha *= LS_BACKTRACK

        theta_prime = theta_new

    return theta_prime


def _kl_projection_constrained(theta, pair_i, pair_j, w, u, v, n, tol, max_iters,
                                eta_fixed, sig_fixed, A_eta_fixed, idx_matrix):
    theta_prime = theta.copy()
    avg = (theta_prime[u] + theta_prime[v]) / 2
    theta_prime[u] = theta_prime[v] = avg
    theta_prime -= theta_prime.mean()

    uu, vv = (u, v) if u < v else (v, u)

    for _ in range(max_iters):
        eta_prime = theta_prime[pair_i] - theta_prime[pair_j]
        eta_prime_clip = np.clip(eta_prime, -ETA_CLIP, ETA_CLIP)
        sig_prime = expit(eta_prime_clip)
        diff = w * (sig_prime - sig_fixed)
        grad = np.bincount(pair_i, diff, minlength=n) - np.bincount(pair_j, diff, minlength=n)
        avg_grad = (grad[u] + grad[v]) / 2
        grad[u] = grad[v] = avg_grad

        if np.linalg.norm(grad) < tol:
            break

        hess_weights = w * sig_prime * (1 - sig_prime)
        L = -hess_weights[idx_matrix]
        np.fill_diagonal(L, 0)
        degrees = -L.sum(axis=1)
        np.fill_diagonal(L, degrees + REG_EPS)

        if degrees.max() < REG_EPS:
            break

        L[uu, :] += L[vv, :]
        L[:, uu] += L[:, vv]
        L_reduced = np.delete(np.delete(L, vv, axis=0), vv, axis=1)

        n_new = n - 1
        L_reduced += REG_EPS / n_new

        b = -grad.copy()
        b[uu] += b[vv]
        b_reduced = np.delete(b, vv)
        b_centered = b_reduced - b_reduced.mean()
        delta_reduced = np.linalg.solve(L_reduced, b_centered)
        delta_reduced = delta_reduced - delta_reduced.mean()
        delta = np.insert(delta_reduced, vv, 0.0)
        delta[vv] = delta[uu]
        delta = np.clip(delta, -DELTA_CLIP, DELTA_CLIP)

        alpha = 1.0
        Dw_old = _compute_Dw_fast(theta_prime, pair_i, pair_j, w, eta_fixed, sig_fixed, A_eta_fixed)
        for _ in range(LS_MAX_STEPS):
            theta_new = theta_prime + alpha * delta
            avg = (theta_new[u] + theta_new[v]) / 2
            theta_new[u] = theta_new[v] = avg
            theta_new -= theta_new.mean()
            if _compute_Dw_fast(theta_new, pair_i, pair_j, w, eta_fixed, sig_fixed, A_eta_fixed) <= Dw_old + LS_ACCEPT_TOL:
                break
            alpha *= LS_BACKTRACK
        theta_prime = theta_new

    avg = (theta_prime[u] + theta_prime[v]) / 2
    theta_prime[u] = theta_prime[v] = avg
    theta_prime -= theta_prime.mean()
    return theta_prime
