from gurobipy import GRB
import gurobipy as gp
import numpy as np
import time

'''
Defines and solves the pricing problem as an integer program
Used for exact MINLP formulation in column generation
'''

class pricing_IP:

    def __init__(self, d, i, duals, C_set, regu_Lambda):
        self.d = d
        self.i = i
        self.duals = duals
        self.c_set = C_set
        self.regu_Lambda = regu_Lambda

    def define_pricing_problem(self):
        data = self.d.data
        n = self.d.n
        ndata = self.d.ndata

        i = self.i
        duals = self.duals
        C_set = self.c_set
        regu_Lambda = self.regu_Lambda

        C_len = len(C_set)
        lambda_i = duals[i]
        lambda_c = duals[n:]
        xi = data[:,i]
        X_mtx = data

        '''Define variables'''
        pricing_problem = gp.Model('Pricing Problem')
        yj = pricing_problem.addVars(n, vtype=GRB.BINARY, name='yj')
        beta = pricing_problem.addVars(n, lb=-float("inf"), vtype=GRB.CONTINUOUS,name='beta')
        yc = pricing_problem.addVars(C_len, vtype=GRB.BINARY, name='yc')
        # for x_sigma_sq: can set ub=10 to mitiate numerical issue in nonlinear approximation
        x_sigma_sq = pricing_problem.addVar(lb=-float('inf'), vtype=GRB.CONTINUOUS,name="sigmasq")
        x_log_sigma_sq = pricing_problem.addVar(lb=-float("inf"),vtype=GRB.CONTINUOUS,name="log_sigmasq")

        '''Define objective'''
        linear = 0
        for j in range(n):
            linear += beta[j] * X_mtx[:,j]

        obj = ndata*x_log_sigma_sq/2
        obj += ndata/2 * (np.log(2*np.pi) + 1)
        for j in range(n):
            obj += regu_Lambda*yj[j]
        for c in range(C_len):
            obj += lambda_c[c]*yc[c]
        obj = obj - lambda_i
        pricing_problem.setObjective(obj)

        pricing_problem.ModelSense = GRB.MINIMIZE

        '''Define constraints'''
        pricing_problem.addConstr(sum((linear-xi)**2)<=ndata*x_sigma_sq)
        con = pricing_problem.addGenConstrLog(x_sigma_sq, x_log_sigma_sq) # obj_spare_2 = log(obj_spare_1)
        pricing_problem.update()
        con.FuncNonlinear=1
        # con.FuncPieces = -1
        # con.FuncPieceError = 1e-3 # some problem here

        pricing_problem.addConstr(yj[i]==0)

        y_spare_1 = pricing_problem.addVars(n,name='spare1')
        for j in range(n):
            pricing_problem.addConstr(y_spare_1[j]==1-yj[j])
        for j in range(n):
            pricing_problem.addSOS(GRB.SOS_TYPE1,[beta[j],y_spare_1[j]])

        for c in range(C_len):
            C = C_set[c]
            for j in range(n):
                if j in C:
                    pricing_problem.addConstr(yc[c]>=yj[j])

        return pricing_problem

    def solve_pricing(self): # optimize the pricing problem and return pattern
        n = self.d.n
        start_time = time.time()
        pricing_problem = self.define_pricing_problem()
        pricing_problem.Params.OutputFlag = 1
        pricing_problem.optimize()
        pricing_obj = pricing_problem.ObjVal
        new_pattern = []
        for j in range(n):
            pricing_sol = pricing_problem.getVars()[j].X # yj in pricing problem, whether j in the pattern
            new_pattern.append(pricing_sol)
        end_time = time.time()
        pricing_time = end_time - start_time
        return new_pattern, pricing_obj, pricing_time