import numpy as np
try:
    from numba import njit
    NUMBA = True
except ImportError:
    NUMBA = False
    def njit(**k):       
        def wrap(f): return f
        return wrap


@njit(fastmath=True, cache=True)
def _sigmoid_nb(x):
    x = np.minimum(np.maximum(x, -500.0), 500.0)  
    return 1.0 / (1.0 + np.exp(-x))

def sigmoid(x):
    return _sigmoid_nb(x) if NUMBA else 1.0/(1.0+np.exp(-np.clip(x,-500,500)))

@njit(fastmath=True, cache=True)
def _dsigmoid_nb(x):
    s = _sigmoid_nb(x)
    return s * (1.0 - s)

def dsigmoid(x):
    return _dsigmoid_nb(x) if NUMBA else sigmoid(x)*(1.0-sigmoid(x))


@njit(fastmath=True, cache=True)
def _weighted_norm_nb(x, A):
    return np.sqrt(x @ A @ x)

def weighted_norm(x, A):
    return _weighted_norm_nb(x, A) if NUMBA else np.sqrt(x @ A @ x)


@njit(fastmath=True, cache=True)
def _gauss_sample_nb(center, design, radius):
    L = np.linalg.cholesky(design)                 
    z = np.random.normal(0.0, 1.0, center.size)
    delta = np.linalg.solve(L, z) * radius         
    return center + delta

def gaussian_sample_ellipsoid(center, design, radius):
    center = center.astype(np.float64)
    return (_gauss_sample_nb(center, design, radius)
            if NUMBA else
            center + np.linalg.solve(np.linalg.cholesky(design),
                                     np.random.normal(size=center.size))*radius)


@njit(fastmath=True, cache=True)
def _proj_nb(x, center, A, r):
    y  = x - center
    L  = np.linalg.cholesky(A)
    u  = np.linalg.solve(L, y)            
    u2 = np.dot(u, u)
    r2 = r*r
    if u2 <= r2 + 1e-12:
        return x
    lam = 0.0
    for _ in range(50):                   
        f  = u2 / (1.0+lam)**2 - r2
        if abs(f) < 1e-12:
            break
        fp = -2.0 * u2 / (1.0+lam)**3
        lam = max(lam - f/fp, 0.0)
    return center + y / (1.0+lam)

def project_onto_ellipsoid(x, center, A, r):
    return _proj_nb(x, center, A, r) if NUMBA else _proj_nb.py_func(x, center, A, r)


@njit(fastmath=True, cache=True)
def _pgd_nb(arm, theta0, V, V_inv,
            S, steps, mode, reward):            

    L      = np.linalg.cholesky(V)
    L_inv  = np.linalg.cholesky(V_inv)
    z0     = L @ theta0
    z      = z0.copy()
    inv_z_arm = L_inv @ arm
    step   = 0.5

    for _ in range(steps):
        pred  = _sigmoid_nb(np.dot(z, inv_z_arm))
        if mode == 0:
            coef = pred - reward
        else:                   
            coef = 2.0*pred - 1.0
        grad  = z - z0 + coef * inv_z_arm
        z    -= step * grad
        z[:]  = _proj_nb(z, np.zeros_like(arm), V, S)
    return np.linalg.solve(L, z)

def _online_logistic_driver(arm, theta0, V, V_inv,
                            S, precision, mode, reward):
    diam  = S
    steps = int(np.ceil((9/4 + diam/8) * np.log(diam/precision)))

    if NUMBA:
        return _pgd_nb(arm, theta0, V, V_inv, S, steps, mode, reward)

    L      = np.linalg.cholesky(V)
    L_inv  = np.linalg.cholesky(V_inv)
    z0     = L @ theta0
    z      = z0.copy()
    inv_z_arm = L_inv @ arm
    step   = 0.5
    for _ in range(steps):
        pred = sigmoid(np.dot(z, inv_z_arm))
        coef = (pred - reward) if mode == 0 else (2*pred - 1)
        grad = z - z0 + coef * inv_z_arm
        z   -= step * grad
        z    = project_onto_ellipsoid(z, np.zeros_like(arm), V, S)
    return np.linalg.solve(L, z)

def fit_online_logistic_estimate(arm, reward, current_estimate,
                                 vtilde_matrix, vtilde_inv_matrix,
                                 constraint_set_radius, precision=1e-1):
    return _online_logistic_driver(
        arm, current_estimate, vtilde_matrix, vtilde_inv_matrix,
        constraint_set_radius, precision,
        mode=0, reward=reward                    
    )

def fit_online_logistic_estimate_bar(arm, current_estimate,
                                     vtilde_matrix, vtilde_inv_matrix,
                                     constraint_set_radius, precision=1e-1):
    return _online_logistic_driver(
        arm, current_estimate, vtilde_matrix, vtilde_inv_matrix,
        constraint_set_radius, precision,
        mode=1, reward=0.0                       
    )

def mu_dot(scalar):
    return dsigmoid(np.array([[scalar]])).item()
