import os
import numpy as np
import gurobipy as gp
from gurobipy import GRB

item_num = 10

def gen_original_milp(c, A=None, b=None, G=None, h=None):
    c = c.tolist()
    decision_num = np.size(c)
    m = gp.Model()
    m.setParam('OutputFlag', 0)
    x = m.addVars(decision_num, vtype=GRB.INTEGER, name='x')
    OBJ = x.prod(c)
    m.setObjective(OBJ, GRB.MINIMIZE)
    
    if A is not None:
        rowSizeA = A.shape[0]
        A = A.tolist()
        b = b.tolist()
        for i in range(rowSizeA):
            m.addConstr(x.prod(A[i]) == b[i])
    if G is not None:
        rowSizeG = G.shape[0]
        G = G.tolist()
        h = h.tolist()
        for j in range(rowSizeG):
            m.addConstr(x.prod(G[j]) <= h[j])
    
    m.write("original_milp.mps")

#    m.optimize()
#    sol = np.zeros(decision_num)
#    objective = m.OBJ
##    print(objective)
#    try:
#        for i in range(decision_num):
#            sol[i] = x[i].x
#        print(sol)
#    except:
#        print("cannot solve")
#        m.computeIIS()
#        m.write('model.ilp')

    return m


def gen_stage2_milp(varNum, purchase_fee, compensation_fee, rowSizeG, c_list, G_list, h_list, x_sol):
    m2 = gp.Model()
    m2.setParam('OutputFlag', 0)
#                x_m2 = m2.addVars(varNum, vtype=GRB.CONTINUOUS, lb=0, ub=1, name='x')
    y_m2 = m2.addVars(varNum, vtype=GRB.INTEGER, lb=0, ub=1, name='y')
#                x_m2 = m2.addVars(varNum, vtype=GRB.BINARY, name='x')
#                y_m2 = m2.addVars(varNum, vtype=GRB.BINARY, name='y')

    OBJ = 0
    for i in range(varNum):
#        OBJ = OBJ + purchase_fee * c_list[i] * x_sol[i] - (compensation_fee - purchase_fee) * c_list[i] * y_m2[i]
        OBJ = OBJ - (compensation_fee - purchase_fee) * c_list[i] * y_m2[i]
    m2.setObjective(-OBJ, GRB.MINIMIZE)

    has_selected_weight = 0
    for i in range(varNum):
        has_selected_weight = has_selected_weight + x_sol[i] * G_list[rowSizeG-1][i]
    m2.addConstr(has_selected_weight - (y_m2.prod(G_list[rowSizeG-1])) <= h_list[rowSizeG-1])
#                m2.addConstrs(x_m2[i] == x_sol[i] for i in range(varNum))
    m2.addConstrs(y_m2[i] <= x_sol[i] for i in range(varNum))
    
    m2.write("stage2_milp.mps")

#    m.optimize()
#    sol = np.zeros(decision_num)
#    objective = m.OBJ
##    print(objective)
#    try:
#        for i in range(decision_num):
#            sol[i] = x[i].x
#        print(sol)
#    except:
#        print("cannot solve")
#        m.computeIIS()
#        m.write('model.ilp')

    return m2


def lp_parser(filename):
    with gp.Env(empty=True) as env:
        env.setParam('OutputFlag', 0)
        env.start()
    #    m = gp.Model()
    #    m.setParam('OutputFlag', 0)
        m = gp.read(filename, env)
#        m.optimize()
#        objective = m.getObjective()
#        print("objective: ", objective.getValue())
#        # directly get c
        c = m.getAttr("Obj",m.getVars())
        #print(c)  # c
        var_num = len(c)
#        for var in m.getVars():
#            print(var.VarName, end=", ")
#        print()
            
        # access to A, b, G, h
        sense = m.getAttr('Sense', m.getConstrs())
        #print(sense)
        cons_num = len(sense)
        #print(cons_num)

        eq_num = sense.count('=')
        ineq_num = sense.count('<')
        A = np.zeros((eq_num, var_num))
        b = np.zeros(eq_num)
        G = np.zeros((ineq_num, var_num))
        h = np.zeros(ineq_num)
        rhs = m.getAttr("RHS",m.getConstrs())

        eq_cnt = 0
        ineq_cnt = 0
        for cons_cnt in range(cons_num):
            if sense[cons_cnt] == '=':
                for var_cnt in range(var_num):
                    A[eq_cnt][var_cnt] = m.getCoeff(m.getConstrs()[cons_cnt], m.getVars()[var_cnt])
                b[eq_cnt] = rhs[cons_cnt]
                eq_cnt = eq_cnt + 1
            elif sense[cons_cnt] == '<':
                for var_cnt in range(var_num):
                    G[ineq_cnt][var_cnt] = m.getCoeff(m.getConstrs()[cons_cnt], m.getVars()[var_cnt])
                h[ineq_cnt] = rhs[cons_cnt]
                ineq_cnt = ineq_cnt + 1

        #print(A, b, G, h)
                    
        #print(m.getCoeff(m.getConstrs()[0], m.getVars()[0]))

        #lhs = m.getA()
        #rhs = m.getAttr("RHS",m.getConstrs())
        #
        ##print(lhs.shape)    # (6, 5)
        ##print(lhs[0])
        ###  (0, 0)    1.0
        ###  (0, 1)    1.0
        ###  (0, 2)    1.0
        #print(rhs)

        # record lower and upper bounds
        lb = m.getAttr("LB",m.getVars())
        ub = m.getAttr("UB",m.getVars())
#        print(lb)
#        print(ub)
        #zero_num = lb.count(0)
        #inf_num = ub.count(inf)
        #print(zero_num, inf_num)
        lb = np.array(lb)
        ub = np.array(ub)
        nonzero_num = np.count_nonzero(lb)
        #print(nonzero_num)
        if nonzero_num > 0:
            lb_G = np.zeros((nonzero_num, var_num))
            lb_h = np.zeros(nonzero_num)
            lb_cnt = 0
            for i in range(var_num):
                if lb[i] > 0:
                    lb_G[lb_cnt][i] = -1
                    lb_h[lb_cnt] = -lb[i]
                    lb_cnt += 1
        #lb_G = -1 * np.eye(var_num)
        #lb_h = -lb
        noninf_num = 0
        ub_h = []
        noninf_index = []
        for i in range(var_num):
            if ub[i] < 1000000:
                noninf_num += 1
                ub_h.append(ub[i])
                noninf_index.append(i)

        if noninf_num > 0:
            ub_h = np.array(ub_h)
            noninf_index = np.array(noninf_index)
            ub_G = np.zeros((noninf_num, var_num))
            for i in range(noninf_num):
                ub_G[i][noninf_index[i]] = 1
        #ub_G = np.eye(var_num)
        #ub_h = ub

        if nonzero_num > 0:
            G = np.concatenate([G, lb_G], axis=0)
            h = np.concatenate([h, lb_h], axis=0)
        #    print(lb_G, lb_h)
        if noninf_num > 0:
            G = np.concatenate([G, ub_G], axis=0)
            h = np.concatenate([h, ub_h], axis=0)
    
    return c, A, b, G, h


def lp_parser_and_count(filename):
    with gp.Env(empty=True) as env:
        env.setParam('OutputFlag', 0)
        env.start()
    #    m = gp.Model()
    #    m.setParam('OutputFlag', 0)
        m = gp.read(filename, env)
#        m.optimize()
#        objective = m.getObjective()
#        print("objective: ", objective.getValue())
#        # directly get c
        c = m.getAttr("Obj",m.getVars())
        #print(c)  # c
        var_num = len(c)

        # access to A, b, G, h
        sense = m.getAttr('Sense', m.getConstrs())
        #print(sense)
        cons_num = len(sense)
        #print(cons_num)

        eq_num = sense.count('=')
        ineq_num = sense.count('<')
        A = np.zeros((eq_num, var_num))
        b = np.zeros(eq_num)
        G = np.zeros((ineq_num, var_num))
        h = np.zeros(ineq_num)
        rhs = m.getAttr("RHS",m.getConstrs())

        eq_cnt = 0
        ineq_cnt = 0
        for cons_cnt in range(cons_num):
            if sense[cons_cnt] == '=':
                for var_cnt in range(var_num):
                    A[eq_cnt][var_cnt] = m.getCoeff(m.getConstrs()[cons_cnt], m.getVars()[var_cnt])
                b[eq_cnt] = rhs[cons_cnt]
                eq_cnt = eq_cnt + 1
            elif sense[cons_cnt] == '<':
                for var_cnt in range(var_num):
                    G[ineq_cnt][var_cnt] = m.getCoeff(m.getConstrs()[cons_cnt], m.getVars()[var_cnt])
                h[ineq_cnt] = rhs[cons_cnt]
                ineq_cnt = ineq_cnt + 1

        #print(A, b, G, h)
                    
        #print(m.getCoeff(m.getConstrs()[0], m.getVars()[0]))

        #lhs = m.getA()
        #rhs = m.getAttr("RHS",m.getConstrs())
        #
        ##print(lhs.shape)    # (6, 5)
        ##print(lhs[0])
        ###  (0, 0)    1.0
        ###  (0, 1)    1.0
        ###  (0, 2)    1.0
        #print(rhs)

        # record lower and upper bounds
        lb = m.getAttr("LB",m.getVars())
        ub = m.getAttr("UB",m.getVars())
#        print(lb)
#        print(ub)
        #zero_num = lb.count(0)
        #inf_num = ub.count(inf)
        #print(zero_num, inf_num)
        lb = np.array(lb)
        ub = np.array(ub)
        FX_num = 0
        num = min(item_num, var_num)
        for i in range(num):
            if lb[i] == ub[i]:
                FX_num = FX_num + 1
        
        nonzero_num = np.count_nonzero(lb)
        #print(nonzero_num)
        if nonzero_num > 0:
            lb_G = np.zeros((nonzero_num, var_num))
            lb_h = np.zeros(nonzero_num)
            lb_cnt = 0
            for i in range(var_num):
                if lb[i] > 0:
                    lb_G[lb_cnt][i] = -1
                    lb_h[lb_cnt] = -lb[i]
                    lb_cnt += 1
        #lb_G = -1 * np.eye(var_num)
        #lb_h = -lb
        noninf_num = 0
        ub_h = []
        noninf_index = []
        for i in range(var_num):
            if ub[i] < 1000000:
                noninf_num += 1
                ub_h.append(ub[i])
                noninf_index.append(i)

        if noninf_num > 0:
            ub_h = np.array(ub_h)
            noninf_index = np.array(noninf_index)
            ub_G = np.zeros((noninf_num, var_num))
            for i in range(noninf_num):
                ub_G[i][noninf_index[i]] = 1
        #ub_G = np.eye(var_num)
        #ub_h = ub

        if nonzero_num > 0:
            G = np.concatenate([G, lb_G], axis=0)
            h = np.concatenate([h, lb_h], axis=0)
        #    print(lb_G, lb_h)
        if noninf_num > 0:
            G = np.concatenate([G, ub_G], axis=0)
            h = np.concatenate([h, ub_h], axis=0)
    
    return c, A, b, G, h, eq_num, ineq_num, FX_num


#def count_FX(filename):
#    with gp.Env(empty=True) as env:
#        env.setParam('OutputFlag', 0)
#        env.start()
#    #    m = gp.Model()
#    #    m.setParam('OutputFlag', 0)
#        m = gp.read(filename, env)
##        m.optimize()
##        objective = m.getObjective()
##        print("objective: ", objective.getValue())
##        # directly get c
#        c = m.getAttr("Obj",m.getVars())
#        #print(c)  # c
#        var_num = len(c)
#        
#        # record lower and upper bounds
#        lb = m.getAttr("LB",m.getVars())
#        ub = m.getAttr("UB",m.getVars())
##        print(lb)
##        print(ub)
#        #zero_num = lb.count(0)
#        #inf_num = ub.count(inf)
#        #print(zero_num, inf_num)
#        lb = np.array(lb)
#        ub = np.array(ub)
#        FX_num = 0
#        num = min(item_num, var_num)
#        for i in range(num):
#            if lb[i] == ub[i]:
#                FX_num = FX_num + 1
#    
#    return FX_num


#input_c = np.loadtxt("c.txt")
#input_A = np.loadtxt("A.txt")
#input_b = np.loadtxt("b.txt")
#
#model = gen_original_milp(input_c,input_A,input_b,G=None,h=None)
#
#get_cutting_exe = r'./wrtnode'
#get_cutting_para = r'original_milp.mps'
#r_v = os.system(get_cutting_exe+' '+get_cutting_para)
#
#output_c, output_A, output_b, output_G, output_h = lp_parser("cutting.mps")
#np.savetxt('output_c.txt', output_c, fmt="%.2f")
#np.savetxt('output_A.txt', output_A, fmt="%.2f")
#np.savetxt('output_b.txt', output_b, fmt="%.2f")
#np.savetxt('output_G.txt', output_G, fmt="%.2f")
#np.savetxt('output_h.txt', output_h, fmt="%.2f")
#output_c = np.array(output_c)
#print(output_c.shape, output_A.shape, output_b.shape, output_G.shape, output_h.shape)
#FX_num = count_FX("cutting.mps")
#print(FX_num)

def solve_LP(filename):
    with gp.Env(empty=True) as env:
        env.setParam('OutputFlag', 0)
        env.start()
    #    m = gp.Model()
    #    m.setParam('OutputFlag', 0)
        m = gp.read(filename, env)
        try:
            m.optimize()
            
            sol = np.zeros(item_num)
            objective = m.getObjective()
#            print("IP: ", objective.getValue(), end=" ")
            # Print the values of all variables
            print("sol: ", end=" ")
            for v in m.getVars():
                print(f"{v.X}", end=", ")
            print()
        except:
            print("cannot solve")
    
    return objective

#objective = solve_LP("original_milp.mps")

def solve_LP_inputMatrix(c, A_temp, b_temp, G, h):
    decision_num = c.shape[0]
    rowSizeA = A_temp.shape[0]
    sizeA = A_temp.size
    if rowSizeA > 0 and sizeA == rowSizeA:
        A = np.zeros((1, decision_num))
        A[0] = A_temp
        b = np.zeros(1)
        b[0] = b_temp
    else:
        A = A_temp
        b = b_temp
#    print(A, b)
    rowSizeA = A.shape[0]
    rowSizeG = G.shape[0]
    c = c.tolist()
    A = A.tolist()
    b = b.tolist()
    G = G.tolist()
    h = h.tolist()
#    print(c, A, b, G, h)
    
#    decision_num = np.size(c)
    m = gp.Model()
    m.setParam('OutputFlag', 0)
    x = m.addVars(decision_num, vtype=GRB.CONTINUOUS, name='x')

    OBJ = x.prod(c)
    m.setObjective(OBJ, GRB.MINIMIZE)
    for i in range(rowSizeA):
        m.addConstr(x.prod(A[i]) == b[i])
    for j in range(rowSizeG):
        m.addConstr(x.prod(G[j]) <= h[j])

    m.optimize()
    sol = np.zeros(decision_num)
#    print(objective)
    try:
#        objective = m.OBJ
#        print("Gurobi c: ", c)
        for i in range(decision_num):
            sol[i] = x[i].x
#        print(sol)
    except:
        print("cannot solve")
        m.computeIIS()
        m.write('model.ilp')
        np.savetxt('c.txt', c, fmt="%.2f")
        np.savetxt('A.txt', A, fmt="%.2f")
        np.savetxt('b.txt', b, fmt="%.2f")
        np.savetxt('G.txt', G, fmt="%.2f")
        np.savetxt('h.txt', h, fmt="%.2f")
        input("Continue?")

    return sol


#original_c, original_A, original_b, original_G, original_h = lp_parser("original_milp.mps")
#original_c = np.array(original_c)
#LP_sol = solve_LP_inputMatrix(original_c, original_A, original_b, original_G, original_h)
#print(LP_sol)
#output_c, output_A, output_b, output_G, output_h = lp_parser("cutting.mps")
#output_c = np.array(output_c)
#IP_sol = solve_LP_inputMatrix(output_c, output_A, output_b, output_G, output_h)
#print(IP_sol)
