import cvxpy as cp
from COCO import *

class Solver():
    def __init__(self, problem):
        assert isinstance(problem, COCO)
        self.problem = problem
        self.x_dim = problem.x_dim
        self.constraint_dim = problem.constraint_dim
        self.loss = 0
        self.constraint_violation = 0
        self.loss_list = []
        self.constraint_violation_list = []

    # update cumulative loss and constraint violation
    def update_all(self,theta,x):
        self.loss += np.dot(theta, x)
        for i in range(self.constraint_dim):
            self.constraint_violation += max(0, np.dot(self.problem.A[i], x) - self.problem.b[i])
        self.loss_list.append(self.loss)
        self.constraint_violation_list.append(self.constraint_violation)

    def update_observe(self,theta):
        self.theta_estimate = theta

    @property
    def run_one_step(self):
        raise NotImplementedError

    # run an experiment with 5000 rounds
    def run(self):
        for t in range(5000):
           x = self.run_one_step
           theta = self.problem.observe(t)
           self.update_observe(theta)
           self.update_all(theta,x)

class RECOO(Solver):
    def __init__(self,problem):
        super().__init__(problem)
        self.count = 1
        self.vq = np.zeros(self.constraint_dim)
        self.choice = []

    @property
    def run_one_step(self):
        if self.count == 1:
            self.count += 1
            self.choice.append(np.zeros(self.x_dim))
            return np.zeros(self.x_dim)
        else:
            # Rectified decision
            gamma = 1.5 * math.pow(self.count, 1 / 2 + 0.01)
            eta = math.pow(self.count, 1 / 2)
            alpha = 3 * math.pow(self.count, 1 / 2)
            x = cp.Variable(shape=(self.x_dim,))
            constraints = [x >= 0, x <= 1]
            obj = cp.Minimize(
                self.theta_estimate @ x + gamma * self.vq @ (cp.abs(self.problem.A @ x - self.problem.b) + self.problem.A @ x - self.problem.b)/2 + alpha * cp.square(cp.norm(x - self.choice[-1])))
            prob = cp.Problem(obj, constraints)
            prob.solve()
            self.choice.append(x.value)
            # Rectified penalty update
            for i in range(self.constraint_dim):
                self.vq[i] = max(eta,self.vq[i] + gamma*max(0,np.dot(self.problem.A[i],x.value) - self.problem.b[i]))
            self.count += 1
            return x.value

class Alg1_Yi(Solver):
    def __init__(self,problem,alpha,gamma):
        super().__init__(problem)
        self.count = 1
        self.vq_x = np.zeros(self.constraint_dim)
        self.vq_hat = np.zeros(self.constraint_dim)
        self.choice = []
        self.alpha = alpha
        self.gamma = gamma

    @property
    def run_one_step(self):
        if self.count == 1:
            self.count += 1
            self.choice.append(np.zeros(self.x_dim))
            return np.zeros(self.x_dim)
        else:
            x = cp.Variable(shape=(self.x_dim,))
            constraints = [x >= 0, x <= 1]
            obj = cp.Minimize(
                self.alpha * self.theta_estimate @ x + self.alpha * self.gamma * self.vq_hat @ (cp.abs(self.problem.A @ x - self.problem.b) + self.problem.A @ x - self.problem.b)/2 + cp.square(
                    cp.norm(x - self.choice[-1])))
            prob = cp.Problem(obj, constraints)
            prob.solve()
            self.choice.append(x.value)
            for i in range(self.constraint_dim):
                self.vq_x[i] = self.vq_x[i] + max(0,np.dot(self.problem.A[i],x.value) - self.problem.b[i])
                self.vq_hat[i] = self.vq_x[i] + max(0,np.dot(self.problem.A[i],x.value) - self.problem.b[i])
            return x.value

class Alg_Yuan(Solver):
    def __init__(self,problem,eta,sigma):
        super().__init__(problem)
        self.choice = []
        self.eta = eta
        self.sigma = sigma
        self.count = 1

    @property
    def run_one_step(self):
        if self.count == 1:
            self.count += 1
            self.choice.append(np.zeros(self.x_dim))
            return np.zeros(self.x_dim)
        else:
            self.dual = []
            last_choice = self.choice[-1]
            self.congra = []
            self.gradient = self.theta_estimate
            for i in range(self.constraint_dim):
                if np.dot(self.problem.A[i], last_choice) - self.problem.b[i] <= 0:
                    self.congra.append(np.zeros(self.x_dim))
                else:
                    self.congra.append(self.problem.A[i])
            for i in range(self.constraint_dim):
                self.dual.append(max(0, np.dot(self.problem.A[i], last_choice) - self.problem.b[i]) / (self.sigma * self.eta))
            for i in range(self.x_dim):
                for j in range(self.constraint_dim):
                    self.gradient[i] += self.dual[j] * self.congra[j][i]
            x = cp.Variable(shape=(self.x_dim,))
            constraints = [x >= -1, x <= 1]
            obj = cp.Minimize(
                self.gradient @ x + cp.square(cp.norm(x - self.choice[-1]))) / (2 * self.eta)
            prob = cp.Problem(obj, constraints)
            prob.solve()
            self.choice.append(x.value)
            return x.value

class Alg2_Yi(Solver):
    def __init__(self,problem,alpha,gamma,beta):
        super().__init__(problem)
        self.beta = beta
        self.count = 1
        self.vq_x_1 = np.zeros(self.constraint_dim)
        self.vq_hat_1 = np.zeros(self.constraint_dim)
        self.vq_x_2 = np.zeros(self.constraint_dim)
        self.vq_hat_2 = np.zeros(self.constraint_dim)
        self.vq_x_3 = np.zeros(self.constraint_dim)
        self.vq_hat_3 = np.zeros(self.constraint_dim)
        self.vq_x_4 = np.zeros(self.constraint_dim)
        self.vq_hat_4 = np.zeros(self.constraint_dim)
        self.vq_x_5 = np.zeros(self.constraint_dim)
        self.vq_hat_5 = np.zeros(self.constraint_dim)
        self.choice = []
        self.alpha = alpha
        self.gamma = gamma
        self.weight = [ 3 / 5,1 / 5 ,1 / 10,3 / 50,1 / 25]

    def update_observe(self,theta):
        self.theta_estimate = theta
        # update weight
        weight = []
        for i in range(5):
            weight.append(self.weight[i] * np.exp(- self.beta * np.dot(self.theta_estimate, self.x_list[i] - self.choice[-1])))
        self.weight = weight / np.sum(weight)

    @property
    def run_one_step(self):
        if self.count == 1:
            self.count += 1
            self.choice.append(np.zeros(self.x_dim))
            self.x_list = []
            for _ in range(5):
                self.x_list.append(np.zeros(self.x_dim))
            return np.zeros(self.x_dim)
        else:
            x_1 = cp.Variable(shape=(self.x_dim,))
            x_2 = cp.Variable(shape=(self.x_dim,))
            x_3 = cp.Variable(shape=(self.x_dim,))
            x_4 = cp.Variable(shape=(self.x_dim,))
            x_5 = cp.Variable(shape=(self.x_dim,))

            constraints1 = [x_1 >= 0, x_1 <= 1]
            constraints2 = [x_2 >= 0, x_2 <= 1]
            constraints3 = [x_3 >= 0, x_3 <= 1]
            constraints4 = [x_4 >= 0, x_4 <= 1]
            constraints5 = [x_5 >= 0, x_5 <= 1]
            obj1 = cp.Minimize(
                self.alpha[0] * self.theta_estimate @ x_1 + self.alpha[0] * self.gamma * self.vq_hat_1 @ (cp.abs(self.problem.A @ x_1 - self.problem.b) + self.problem.A @ x_1 - self.problem.b)/2 + cp.square(
                    cp.norm(x_1 - self.choice[-1])))
            obj2 = cp.Minimize(
                self.alpha[1] * self.theta_estimate @ x_2 + self.alpha[1] * self.gamma * self.vq_hat_2 @ (cp.abs(self.problem.A @ x_2 - self.problem.b) + self.problem.A @ x_2 - self.problem.b)/2 + cp.square(
                    cp.norm(x_2 - self.choice[-1])))
            obj3 = cp.Minimize(
                self.alpha[2] * self.theta_estimate @ x_3 + self.alpha[2] * self.gamma * self.vq_hat_3 @ (cp.abs(self.problem.A @ x_3 - self.problem.b) + self.problem.A @ x_3 - self.problem.b)/2 + cp.square(
                    cp.norm(x_3 - self.choice[-1])))
            obj4 = cp.Minimize(
                self.alpha[3] * self.theta_estimate @ x_4 + self.alpha[3] * self.gamma * self.vq_hat_4 @ (cp.abs(self.problem.A @ x_4 - self.problem.b) + self.problem.A @ x_4 - self.problem.b)/2 + cp.square(
                    cp.norm(x_4 - self.choice[-1])))
            obj5 = cp.Minimize(
                self.alpha[4] * self.theta_estimate @ x_5 + self.alpha[4] * self.gamma * self.vq_hat_5 @ (cp.abs(self.problem.A @ x_5 - self.problem.b) + self.problem.A @ x_5 - self.problem.b)/2 + cp.square(
                    cp.norm(x_5 - self.choice[-1])))
            prob1 = cp.Problem(obj1, constraints1)
            prob2 = cp.Problem(obj2, constraints2)
            prob3 = cp.Problem(obj3, constraints3)
            prob4 = cp.Problem(obj4, constraints4)
            prob5 = cp.Problem(obj5, constraints5)

            prob1.solve()
            prob2.solve()
            prob3.solve()
            prob4.solve()
            prob5.solve()

            self.x_list = [x_1.value ,x_2.value,x_3.value,x_4.value,x_5.value]

            # self.choice.append(x.value)
            for i in range(self.constraint_dim):
                self.vq_x_1[i] = self.vq_x_1[i] + max(0,np.dot(self.problem.A[i],x_1.value) - self.problem.b[i])
                self.vq_hat_1[i] = self.vq_x_1[i] + max(0,np.dot(self.problem.A[i],x_1.value) - self.problem.b[i])
                self.vq_x_2[i] = self.vq_x_2[i] + max(0,np.dot(self.problem.A[i],x_2.value) - self.problem.b[i])
                self.vq_hat_2[i] = self.vq_x_2[i] + max(0,np.dot(self.problem.A[i],x_2.value) - self.problem.b[i])
                self.vq_x_3[i] = self.vq_x_3[i] + max(0,np.dot(self.problem.A[i],x_3.value) - self.problem.b[i])
                self.vq_hat_3[i] = self.vq_x_3[i] + max(0,np.dot(self.problem.A[i],x_3.value) - self.problem.b[i])
                self.vq_x_4[i] = self.vq_x_4[i] + max(0,np.dot(self.problem.A[i],x_4.value) - self.problem.b[i])
                self.vq_hat_4[i] = self.vq_x_4[i] + max(0,np.dot(self.problem.A[i],x_4.value) - self.problem.b[i])
                self.vq_x_5[i] = self.vq_x_5[i] + max(0,np.dot(self.problem.A[i],x_5.value) - self.problem.b[i])
                self.vq_hat_5[i] = self.vq_x_5[i] + max(0,np.dot(self.problem.A[i],x_5.value) - self.problem.b[i])
            x = np.zeros(2)
            for i in range(5):
                x += self.x_list[i] * self.weight[i]
            self.choice.append(x)

            return self.choice[-1]

        # # experiment K times
        # def experiment(K):




















