import numpy as np

def proj_simplex(w, lam):
    sorted_idx = np.argsort(w)[::-1]
    w_sorted = w[sorted_idx]
    cs = np.cumsum(w_sorted) - lam
    comp = cs /np.arange(1, len(w_sorted)+1)
    boolarr = w_sorted > comp
    rho = len(boolarr) - 1 - np.argmax(boolarr[::-1])  # we proceed as such because argmax returns the first index, but we want the last one
    tau = comp[rho]
    sub = w_sorted - tau
    final = np.zeros_like(sub)
    idxpos = sub > 0
    final[idxpos] = sub[idxpos]
    final_reindexed = np.zeros_like(final)
    final_reindexed[sorted_idx] = final
    return final_reindexed

def proj_extended_simplex(w, lam):
    p = len(w)
    sorted_idx = np.argsort(w)[::-1]
    w_sorted = w[sorted_idx]
    tau = 1/p * (np.sum(w_sorted)  - lam  ) 
    final = w_sorted - tau
    final_reindexed = np.zeros_like(final)
    final_reindexed[sorted_idx] = final
    return final_reindexed

def hard_threshold_pos(arr, k):
    top_k_indices = np.argpartition(arr, -k)[-k:]
    thresholded_arr = np.zeros_like(arr)
    thresholded_arr[top_k_indices] = arr[top_k_indices]
    return thresholded_arr, top_k_indices

def hard_threshold(arr, k):
    top_k_indices = np.argpartition(np.abs(arr), -k)[-k:]
    thresholded_arr = np.zeros_like(arr)
    thresholded_arr[top_k_indices] = arr[top_k_indices]
    return thresholded_arr, top_k_indices


def twostepsprojsimplex(w, k, lam):
    hard_t_w, support = hard_threshold(w, k)
    supp_proj = proj_simplex(hard_t_w[support], lam)
    final = np.zeros_like(w)
    final[support] = supp_proj
    return final

def twostepsprojextendedsimplex(w, k, lam):
    hard_t_w, support = hard_threshold(w, k)
    supp_proj = proj_extended_simplex(hard_t_w[support], lam)
    final = np.zeros_like(w)
    final[support] = supp_proj
    return final


def fullprojsimplex(w, k, lam):
    hard_t_w, support = hard_threshold_pos(w, k)
    supp_proj = proj_simplex(hard_t_w[support], lam)
    final = np.zeros_like(w)
    final[support] = supp_proj
    return final

def fullprojextendedsimplex(w, k, lam):
    # we follow the pseudocode from here: https://arxiv.org/pdf/1206.1529.pdf
    l = 1
    init_idx = np.argmax(lam * w)
    S = [init_idx]
    sum_support = w[init_idx]
    while l != k:
        l += 1
        w_masked = w.copy()
        w_masked = np.abs(w - (sum_support - lam)/(l-1))
        w_masked[S] = - np.inf
        j_to_add = np.argmax(w_masked)
        sum_support += w[j_to_add]
        S.append(j_to_add)
    supp_proj = proj_extended_simplex(w[S], lam)
    final = np.zeros_like(w)
    final[S] = supp_proj
    return final    


def twostepsprojbox(x, l, u, k):
  proj = np.zeros(x.shape[0])
  s = np.argsort(np.abs(x))[-k:]
  proj[s] = x[s]
  for i in range(proj.shape[0]):
    if proj[i] > u:
      proj[i] = u
    elif proj[i] < l:
      proj[i] = l
  assert np.max(proj) <= u
  assert np.min(proj) >= l
  assert np.count_nonzero(proj) <= k
  return proj


def fullprojbox(x, l, u, k):
  d = x.shape[0]
  proj = np.zeros(x.shape[0])
  projtmp =  np.zeros(x.shape[0])
  for i in range(x.shape[0]):
    if x[i] > u:
      projtmp[i] = u
    elif x[i] < l:
      projtmp[i] = l
    else:
      projtmp[i] = x[i]
  val = np.zeros(x.shape[0])
  for i in range(x.shape[0]):
    val[i] = projtmp[i]**2 - 2 * projtmp[i] * x[i]
  s = np.argsort(val)[:k]
  proj[s] = projtmp[s]
  assert np.max(proj) <= u
  assert np.min(proj) >= l
  assert np.count_nonzero(proj) <= k
  return proj


def proj_pos(w):
   y = np.zeros_like(w)
   idx_pos = w > 0
   y[idx_pos] = w[idx_pos].copy()
   return y


def tsp_pos(w, k=None):
   arr, indx = hard_threshold(w, k=k)
   return proj_pos(arr)
