import numpy as np
import gurobipy as gp
from gurobipy import Model, GRB
import continuous_scores
import discrete_scores
import time
import math

'''
Implements the DC Algorithm (DCA) solver for pricing problems
Includes convex programming using Kelley's algorithm and pruning
'''

class DCASolver:
    def __init__(self, data, data_type, K, i, init_pattern, duals, C_set, maxit1, maxit2, err1, err2, regu_Lambda):
        self.data = data
        self.data_type = data_type
        self.K = K
        self.i = i
        self.init_pattern = init_pattern
        self.duals = duals
        self.C_set = C_set
        self.maxit1 = maxit1
        self.maxit2 = maxit2
        self.err1 = err1
        self.err2 = err2
        self.regu_Lambda = regu_Lambda
        self.evaluation_time = 0
        self.prune_dict = {} # to avoid repeat calculating cost in prune()

    def convex_program(self, d, y, x, i, C_set, lambda_c, maxit2, err2, regu_Lambda, reserved_data):
        '''This function use Kelley's algorithm to solve the convex program in DCA'''
        n = d.n
        if self.data_type=='C':
            g = continuous_scores.g
            constant_g = np.log(np.var(d.data[:,i]))
        else:
            g = discrete_scores.g
            constant_g = discrete_scores.entropy_G(d, [], i)

        x1 = x  # warm-start
        iter = 0
        time1 = time.time()
        _, s1 = g(d, x1, i, C_set, lambda_c, regu_Lambda)
        x2 = (s1 < 0) * 1
        _, s2 = g(d, x2, i, C_set, lambda_c, regu_Lambda)
        time2 = time.time()
        self.evaluation_time += time2 - time1

        if reserved_data is None:

            # Initialize Gurobi model
            model = Model()
            # model.Params.FeasibilityTol = 1e-9
            # model.Params.OptimalityTol = 1e-9
            # model.Params.Method = 1
            
            model.Params.LogToConsole = 0
            # model.Params.LogFile = 'Enter path for Gurobi log file for linear programming'
            # model.Params.OutputFlag = 1
            model.Params.Threads = 1
            model.update()
            
            x = model.addVars(n-1,lb=0, ub=1, name="x")
            z = model.addVar(name="z")
            
            z.setAttr('lb',-GRB.INFINITY)
            z.setAttr('ub', GRB.INFINITY)  # Set last variable's upper bound to infinity

            # Add initial constraints
            model.addConstr(z >= gp.quicksum(s1[k] * x[k] for k in range(n-1)) + constant_g)
            model.addConstr(z >= gp.quicksum(s2[k] * x[k] for k in range(n-1)) + constant_g)
            

        else:
            model, x, z = reserved_data
            
        # set/update objective
        model.setObjective(z - gp.quicksum(y[k] * x[k] for k in range(n-1)), GRB.MINIMIZE)

        while iter < maxit2:
            model.Params.Threads = 1
            model.update()
            model.optimize()
            if model.status != GRB.OPTIMAL:
                raise RuntimeError("Optimization failed")

            LB = model.objVal  # Lower bound as optimal value
            x_new = np.array([x[k].x for k in range(n-1)])  # Get optimal solution for x
            
            time1 = time.time()
            g_new, s_new = g(d, x_new, i, C_set, lambda_c, regu_Lambda)
            time2 = time.time()
            self.evaluation_time += time2 - time1
            UB = g_new - np.dot(x_new, y)  # Upper bound as true function value of x
            # print(LB, UB)
            if abs(UB - LB) < err2:
                # print(iter)
                return x_new, (model,x,z)

            # Add new constraint
            model.addConstr(z >= gp.quicksum(s_new[k] * x[k] for k in range(n-1)) + constant_g)

            iter += 1

        # print(iter)
        return x_new, (model,x,z)

    def pricing_DCA(self, K, d, i, init_pattern, duals, C_set, maxit1, maxit2, err1, err2, regu_Lambda):
        n = d.n
        ndata = d.ndata
        lambda_c = duals[n:]
        if self.data_type=='C':
            g = continuous_scores.g
            h = continuous_scores.h
        else:
            g = discrete_scores.g
            h = discrete_scores.h

        if len(self.K[i])<50: # can tune this to control #candidates
            x = np.random.rand(n)
            # print(x)
        else:
            x = init_pattern
        x = np.delete(x, i)

        iter = 0
        reserved_data = None
        while iter < maxit1:
            ''' Step 1: Get a new permutation and calculate y (subgradient for h)'''
            time1 = time.time()
            _, y = h(d, x, i)
            time2 = time.time()
            self.evaluation_time += time2 - time1

            ''' Step 2: Fix y, solve for a better x through convex programming'''
            new_x, reserved_data = self.convex_program(d, y, x, i, C_set, lambda_c, maxit2, err2, regu_Lambda, reserved_data)
            # print(new_x)
            # add more patterns
            new_x[np.where(new_x > 0.5)] = 1
            new_x[np.where(new_x <=0.5)] = 0
            time1 = time.time()
            cost_x = g(d, new_x, i, C_set, lambda_c, regu_Lambda)[0] - h(d, new_x, i)[0]
            time2 = time.time()
            self.evaluation_time += time2 - time1

            K[i] = np.vstack((K[i], np.insert(new_x, i, 0)))

            time1 = time.time()
            if abs(g(d, x, i, C_set, lambda_c, regu_Lambda)[0] - h(d, x, i)[0] -
                g(d, new_x, i, C_set, lambda_c, regu_Lambda)[0] + h(d, new_x, i)[0]) < err1:
                x = new_x
                time2 = time.time()
                self.evaluation_time += time2 - time1
                break

            x = new_x
            iter += 1

        # print('DCA converge to ', x)
        x[np.where(x > 0.5)] = 1
        x[np.where(x <= 0.5)] = 0
        time1 = time.time()
        cost_x = g(d, x, i, C_set, lambda_c, regu_Lambda)[0] - h(d, x, i)[0]

        time2 = time.time()
        self.evaluation_time += time2 - time1
        if self.data_type=='C':
            obj_x = cost_x * ndata/2 + (np.log(2*np.pi) + 1) * ndata/2 - duals[i]
        else:
            obj_x = cost_x * ndata - duals[i]
 
        x = np.insert(x, i, 0)
        K[i] = np.vstack((K[i], x))
        return K, x, obj_x
    
    def solve(self):
        self.K, self.x, self.obj = self.pricing_DCA(self.K, self.data, self.i, self.init_pattern, self.duals, self.C_set, self.maxit1, self.maxit2, self.err1, self.err2, self.regu_Lambda)
        return self.K, self.x, self.obj, self.evaluation_time
    
    def prune(self, i, admited_J, current_best_cost):
        ''' locally search for optimal substructure using recursion '''
        d = self.data
        J = admited_J
        if len(J)==0:
            return admited_J, current_best_cost
        for j in J:
            pruned_J = J.copy()
            pruned_J.remove(j) # prune locally
            if str(i)+str(pruned_J) in self.prune_dict.keys():
                pruned_cost = self.prune_dict[str(i)+str(pruned_J)]
            else:
                pruned_cost = self.cost_obj(d, pruned_J, i)
                self.prune_dict[str(i)+str(pruned_J)] = pruned_cost
            if pruned_cost < current_best_cost: # have improvement
                # update by recursion
                admited_J, current_best_cost = self.prune(i, pruned_J, pruned_cost)
        return admited_J, current_best_cost
    
    def cost_obj(self, d, J, i): # local cost
        n = d.n
        ndata = d.ndata
        data_type = d.data_type
        if data_type == 'C':
            cost = continuous_scores.cost
            local_cost_obj = np.log(cost(self.data, J, i)) * ndata / 2 + (1 + np.log(2 * np.pi)) * ndata / 2 + self.regu_Lambda * len(J)
        else:
            cost = discrete_scores.cost
            local_cost_obj = cost(self.data, J, i) * ndata + self.regu_Lambda*(self.data.arity[i]-1)*math.prod([self.data.arity[j] for j in J])
        return local_cost_obj