import numpy as np


class Projection_kkt(object):
	"""docstring for Projection"""
	def __init__(self, dim,B1,v,alpha=0.):
		self.dim = dim
		self.B1 = alpha*B1
		self.v = v
		self.alpha = alpha
		self.lam = np.zeros(dim) # for 01 constraints
		self.theta = 0 # for 1^T s <= B1 constraint, cardinality

	def solution(self):
		x = np.fmax(self.v,0)
		if all(x<=1): # all lam_k = 0
			if x.sum()<=self.B1: # theta = 0 
				print("case 1")
				return x
			else: # theta > 0 
				pi = np.argsort(self.v)[::-1]
				S0 = 0
				for i in range(self.dim):
					S0 = S0 + self.v[pi[i]]
					theta = (S0 - self.B1)/(i+1)
					if theta >= 0 and self.v[pi[i]] - theta >=0 and self.v[pi[i+1]] - theta < 0:
						break
				self.theta = theta
				print("case 2")
				return np.fmax(self.v - theta,0)

		else: # some lam_k > 0
			x[x>1] = 1
			if x.sum()<=self.B1:
				print("case 3")
				return x
			else:
				pi = np.argsort(self.v)[::-1]
				S0 = 0
				S1 = 0 
				theta = 0
				count = 0
				for i in range(self.dim):
					S0 = S0 + 1 if self.v[pi[i]] > 1 else self.v[pi[i]]
					if self.v[pi[i]] <= 1:
						count += 1
						theta = (S0 - self.B1)/count
						if theta >= 0 and self.v[pi[i]] - theta >=0 and self.v[pi[i+1]] - theta < 0:
							break
				self.theta = theta
				print("case 4")
				return np.fmax(np.fmin(self.v - theta,1),0)