from scipy import linalg
import numpy as np
rng = np.random.default_rng()

from numba import njit, float64, int32


@njit    
def topK_prox(x,gamma, k, copy=True):
    if k == 0:
        return np.zeros_like(x)
    y = x.copy() if copy else x
    indices = np.argpartition(np.abs(y), len(y)- k)[:len(y)-k]
    y[indices] = 0
    return y

def obj_fun(A, b, x, mu):
    return 1/2*linalg.norm(A @ x - b)**2 + mu/4*linalg.norm(x)**2