import os
import glob
import numpy as np
import gurobipy as gp
from gurobipy import GRB
import linear_relax as LP_relax_file

facility_num = LP_relax_file.facility_num
ERU_num = LP_relax_file.ERU_num
var_num = LP_relax_file.var_num


def gen_original_milp(cost, avail_matrices, pred_coverage):
    cost = cost.tolist()
    avail_matrices = avail_matrices.tolist()
    m = gp.Model()
    m.setParam('OutputFlag', 0)
    x = m.addVars(ERU_num, vtype=GRB.INTEGER, name='x')
    
    m.setObjective(x.prod(cost), GRB.MINIMIZE)
    for j in range(facility_num):
        m.addConstr((gp.quicksum((avail_matrices[i*facility_num+j] * x[i]) for i in range(ERU_num))) >= pred_coverage[j])
    
    m.write("original_milp.mps")

#    m.optimize()
#    sol = np.zeros(decision_num)
#    objective = m.OBJ
##    print(objective)
#    try:
#        for i in range(decision_num):
#            sol[i] = x[i].x
#        print(sol)
#    except:
#        print("cannot solve")
#        m.computeIIS()
#        m.write('model.ilp')

    return m


def gen_stage2_milp(s1_x_sol, cost, avail_matrices, true_coverage, penaltyTerm):
    pen_cost = penaltyTerm * cost
    fixed_pen_cost = np.dot(s1_x_sol, pen_cost)
    pen_cost = pen_cost.tolist()
    cost = cost.tolist()
    avail_matrices = avail_matrices.tolist()
    
    m2 = gp.Model()
    m2.setParam('OutputFlag', 0)
    y = m2.addVars(ERU_num, vtype=GRB.INTEGER, name='y')    # x_(2)

    m2.setObjective(y.prod(cost) + y.prod(pen_cost) - fixed_pen_cost, GRB.MINIMIZE)

    for j in range(facility_num):
        m2.addConstr((gp.quicksum((avail_matrices[i*facility_num+j] * y[i]) for i in range(ERU_num))) >= true_coverage[j])
    for i in range(ERU_num):
        m2.addConstr(y[i] >= s1_x_sol[i])
    
    m2.write("stage2_milp.mps")

#    m2.optimize()
#    sol = np.zeros(ERU_num)
#    objective = m2.OBJVal
#    print(objective)
#    try:
#        for i in range(ERU_num):
#            sol[i] = y[i].x
#        print(sol)
#    except:
#        print("cannot solve")
#        m2.computeIIS()
#        m2.write('model.ilp')

    return m2


def lp_parser(filename):
    with gp.Env(empty=True) as env:
        env.setParam('OutputFlag', 0)
        env.start()
    #    m = gp.Model()
    #    m.setParam('OutputFlag', 0)
        m = gp.read(filename, env)
#        m.optimize()
#        objective = m.getObjective()
#        print("objective: ", objective.getValue())
#        # directly get c
        c = m.getAttr("Obj",m.getVars())
        #print(c)  # c
        var_num = len(c)
        
#        cut_var_name = []
#        for var in m.getVars():
##            print(var.VarName, end=", ")
#            cut_var_name.append(var.VarName)
##        print()
#        non_active_var_id = []
#        curr_id = 0
#        original_m = gp.read('original_milp.mps', env)
#        for var in original_m.getVars():
#            if var.VarName not in cut_var_name:
##                print(var.VarName, end=" ")
#                non_active_var_id.append(curr_id)
#            curr_id += 1
##        print(len(non_active_var_id))
            
        # access to A, b, G, h
        sense = m.getAttr('Sense', m.getConstrs())
        #print(sense)
        cons_num = len(sense)
        #print(cons_num)

        eq_num = sense.count('=')
        ineq_num = sense.count('>')
#        lineq_num = sense.count('>')
#        print(eq_num, ineq_num)
        A = np.zeros((eq_num, var_num))
        b = np.zeros(eq_num)
        G = np.zeros((ineq_num, var_num))
        h = np.zeros(ineq_num)
        rhs = m.getAttr("RHS",m.getConstrs())

        eq_cnt = 0
        ineq_cnt = 0
        for cons_cnt in range(cons_num):
            if sense[cons_cnt] == '=':
                for var_cnt in range(var_num):
                    A[eq_cnt][var_cnt] = m.getCoeff(m.getConstrs()[cons_cnt], m.getVars()[var_cnt])
                b[eq_cnt] = rhs[cons_cnt]
                eq_cnt = eq_cnt + 1
            elif sense[cons_cnt] == '>':
                for var_cnt in range(var_num):
                    G[ineq_cnt][var_cnt] = m.getCoeff(m.getConstrs()[cons_cnt], m.getVars()[var_cnt])
                h[ineq_cnt] = rhs[cons_cnt]
                ineq_cnt = ineq_cnt + 1
        G = -G
        h = -h
        #print(A, b, G, h)
                    
        #print(m.getCoeff(m.getConstrs()[0], m.getVars()[0]))

        #lhs = m.getA()
        #rhs = m.getAttr("RHS",m.getConstrs())
        #
        ##print(lhs.shape)    # (6, 5)
        ##print(lhs[0])
        ###  (0, 0)    1.0
        ###  (0, 1)    1.0
        ###  (0, 2)    1.0
        #print(rhs)

        # record lower and upper bounds
        lb = m.getAttr("LB",m.getVars())
        ub = m.getAttr("UB",m.getVars())
#        print(lb)
#        print(ub)
        #zero_num = lb.count(0)
        #inf_num = ub.count(inf)
        #print(zero_num, inf_num)
        lb = np.array(lb)
        ub = np.array(ub)
        nonzero_num = np.count_nonzero(lb)
        #print(nonzero_num)
        if nonzero_num > 0:
            lb_G = np.zeros((nonzero_num, var_num))
            lb_h = np.zeros(nonzero_num)
            lb_cnt = 0
            for i in range(var_num):
                if lb[i] > 0:
                    lb_G[lb_cnt][i] = -1
                    lb_h[lb_cnt] = -lb[i]
                    lb_cnt += 1
        #lb_G = -1 * np.eye(var_num)
        #lb_h = -lb
        noninf_num = 0
        ub_h = []
        noninf_index = []
        for i in range(var_num):
            if ub[i] < 1000000:
                noninf_num += 1
                ub_h.append(ub[i])
                noninf_index.append(i)

        if noninf_num > 0:
            ub_h = np.array(ub_h)
            noninf_index = np.array(noninf_index)
            ub_G = np.zeros((noninf_num, var_num))
            for i in range(noninf_num):
                ub_G[i][noninf_index[i]] = 1
        #ub_G = np.eye(var_num)
        #ub_h = ub

        if nonzero_num > 0:
            G = np.concatenate([G, lb_G], axis=0)
            h = np.concatenate([h, lb_h], axis=0)
        #    print(lb_G, lb_h)
        if noninf_num > 0:
            G = np.concatenate([G, ub_G], axis=0)
            h = np.concatenate([h, ub_h], axis=0)
    
    return c, A, b, G, h


def lp_parser_and_count(filename):
    with gp.Env(empty=True) as env:
        env.setParam('OutputFlag', 0)
        env.start()
    #    m = gp.Model()
    #    m.setParam('OutputFlag', 0)
        m = gp.read(filename, env)
#        m.optimize()
#        objective = m.getObjective()
#        print("objective: ", objective.getValue())
#        # directly get c
        c = m.getAttr("Obj",m.getVars())
        #print(c)  # c
        var_num = len(c)

        # access to A, b, G, h
        sense = m.getAttr('Sense', m.getConstrs())
        #print(sense)
        cons_num = len(sense)
        #print(cons_num)

        eq_num = sense.count('=')
        ineq_num = sense.count('<')
        A = np.zeros((eq_num, var_num))
        b = np.zeros(eq_num)
        G = np.zeros((ineq_num, var_num))
        h = np.zeros(ineq_num)
        rhs = m.getAttr("RHS",m.getConstrs())

        eq_cnt = 0
        ineq_cnt = 0
        for cons_cnt in range(cons_num):
            if sense[cons_cnt] == '=':
                for var_cnt in range(var_num):
                    A[eq_cnt][var_cnt] = m.getCoeff(m.getConstrs()[cons_cnt], m.getVars()[var_cnt])
                b[eq_cnt] = rhs[cons_cnt]
                eq_cnt = eq_cnt + 1
            elif sense[cons_cnt] == '<':
                for var_cnt in range(var_num):
                    G[ineq_cnt][var_cnt] = m.getCoeff(m.getConstrs()[cons_cnt], m.getVars()[var_cnt])
                h[ineq_cnt] = rhs[cons_cnt]
                ineq_cnt = ineq_cnt + 1

        #print(A, b, G, h)
                    
        #print(m.getCoeff(m.getConstrs()[0], m.getVars()[0]))

        #lhs = m.getA()
        #rhs = m.getAttr("RHS",m.getConstrs())
        #
        ##print(lhs.shape)    # (6, 5)
        ##print(lhs[0])
        ###  (0, 0)    1.0
        ###  (0, 1)    1.0
        ###  (0, 2)    1.0
        #print(rhs)

        # record lower and upper bounds
        lb = m.getAttr("LB",m.getVars())
        ub = m.getAttr("UB",m.getVars())
#        print(lb)
#        print(ub)
        #zero_num = lb.count(0)
        #inf_num = ub.count(inf)
        #print(zero_num, inf_num)
        lb = np.array(lb)
        ub = np.array(ub)
        FX_num = 0
        num = min(NRP_decision_num, var_num)
        for i in range(num):
            if lb[i] == ub[i]:
                FX_num = FX_num + 1
        
        nonzero_num = np.count_nonzero(lb)
        #print(nonzero_num)
        if nonzero_num > 0:
            lb_G = np.zeros((nonzero_num, var_num))
            lb_h = np.zeros(nonzero_num)
            lb_cnt = 0
            for i in range(var_num):
                if lb[i] > 0:
                    lb_G[lb_cnt][i] = -1
                    lb_h[lb_cnt] = -lb[i]
                    lb_cnt += 1
        #lb_G = -1 * np.eye(var_num)
        #lb_h = -lb
        noninf_num = 0
        ub_h = []
        noninf_index = []
        for i in range(var_num):
            if ub[i] < 1000000:
                noninf_num += 1
                ub_h.append(ub[i])
                noninf_index.append(i)

        if noninf_num > 0:
            ub_h = np.array(ub_h)
            noninf_index = np.array(noninf_index)
            ub_G = np.zeros((noninf_num, var_num))
            for i in range(noninf_num):
                ub_G[i][noninf_index[i]] = 1
        #ub_G = np.eye(var_num)
        #ub_h = ub

        if nonzero_num > 0:
            G = np.concatenate([G, lb_G], axis=0)
            h = np.concatenate([h, lb_h], axis=0)
        #    print(lb_G, lb_h)
        if noninf_num > 0:
            G = np.concatenate([G, ub_G], axis=0)
            h = np.concatenate([h, ub_h], axis=0)
    
    return c, A, b, G, h, eq_num, ineq_num, FX_num


#def count_FX(filename):
#    with gp.Env(empty=True) as env:
#        env.setParam('OutputFlag', 0)
#        env.start()
#    #    m = gp.Model()
#    #    m.setParam('OutputFlag', 0)
#        m = gp.read(filename, env)
##        m.optimize()
##        objective = m.getObjective()
##        print("objective: ", objective.getValue())
##        # directly get c
#        c = m.getAttr("Obj",m.getVars())
#        #print(c)  # c
#        var_num = len(c)
#        
#        # record lower and upper bounds
#        lb = m.getAttr("LB",m.getVars())
#        ub = m.getAttr("UB",m.getVars())
##        print(lb)
##        print(ub)
#        #zero_num = lb.count(0)
#        #inf_num = ub.count(inf)
#        #print(zero_num, inf_num)
#        lb = np.array(lb)
#        ub = np.array(ub)
#        FX_num = 0
#        num = min(item_num, var_num)
#        for i in range(num):
#            if lb[i] == ub[i]:
#                FX_num = FX_num + 1
#    
#    return FX_num


#input_c = np.loadtxt("c.txt")
#input_A = np.loadtxt("A.txt")
#input_b = np.loadtxt("b.txt")
#
#model = gen_original_milp(input_c,input_A,input_b,G=None,h=None)
#
#get_cutting_exe = r'./wrtnode'
#get_cutting_para = r'original_milp.mps'
#r_v = os.system(get_cutting_exe+' '+get_cutting_para)
#
#output_c, output_A, output_b, output_G, output_h = lp_parser("cutting.mps")
##np.savetxt('output_c.txt', output_c, fmt="%.2f")
##np.savetxt('output_A.txt', output_A, fmt="%.2f")
##np.savetxt('output_b.txt', output_b, fmt="%.2f")
##np.savetxt('output_G.txt', output_G, fmt="%.2f")
##np.savetxt('output_h.txt', output_h, fmt="%.2f")
#output_c = np.array(output_c)
#print(output_c.shape, output_A.shape, output_b.shape, output_G.shape, output_h.shape)
##FX_num = count_FX("cutting.mps")
##print(FX_num)

def solve_LP(filename):
    with gp.Env(empty=True) as env:
        env.setParam('OutputFlag', 0)
        env.start()
    #    m = gp.Model()
    #    m.setParam('OutputFlag', 0)
        m = gp.read(filename, env)
        try:
            m.optimize()
            sol = np.zeros(var_num)
            objective = m.OBJVal
#            objective = m.getObjective()
#            print("IP: ", objective.getValue(), end=" ")
###            # Print the values of all variables
#            print("sol: ", end=" ")
#            for v in m.getVars():
#                print(f"{v.X}", end=", ")
#            print()
        except:
            objective = 0
#            print("cannot solve", end=" ")
            m.computeIIS()
            m.write('model.ilp')
    
    return objective
    

def solve_LP_inputMatrix(c, A_temp, b_temp, G, h):
    decision_num = c.shape[0]
    rowSizeA = A_temp.shape[0]
    sizeA = A_temp.size
    if rowSizeA > 0 and sizeA == rowSizeA:
        A = np.zeros((1, decision_num))
        A[0] = A_temp
        b = np.zeros(1)
        b[0] = b_temp
    else:
        A = A_temp
        b = b_temp
#    print(A, b)
    rowSizeA = A.shape[0]
    rowSizeG = G.shape[0]
    c = c.tolist()
    A = A.tolist()
    b = b.tolist()
    G = G.tolist()
    h = h.tolist()
#    print(c, A, b, G, h)
    
#    decision_num = np.size(c)
    m = gp.Model()
    m.setParam('OutputFlag', 0)
    x = m.addVars(decision_num, vtype=GRB.CONTINUOUS, name='x')

    OBJ = x.prod(c)
    m.setObjective(OBJ, GRB.MINIMIZE)
    for i in range(rowSizeA):
        m.addConstr(x.prod(A[i]) == b[i])
    for j in range(rowSizeG):
        m.addConstr(x.prod(G[j]) <= h[j])

    m.optimize()
    sol = np.zeros(decision_num)
#    print(objective)
    try:
        objective = m.getObjective()
#        print("Gurobi c: ", c)
        for i in range(decision_num):
            sol[i] = x[i].x
        get_obj = objective.getValue()
#        print(sol)
    except:
#        print("cannot solve", end=" ")
        m.computeIIS()
        m.write('model.ilp')
        np.savetxt('c.txt', c, fmt="%.2f")
        np.savetxt('A.txt', A, fmt="%.2f")
        np.savetxt('b.txt', b, fmt="%.2f")
        np.savetxt('G.txt', G, fmt="%.2f")
        np.savetxt('h.txt', h, fmt="%.2f")
        get_obj = 0
#        input("Continue?")

    return sol, get_obj

def print_schedule(sol_file):
    print(sol_file)
    sol = np.loadtxt(sol_file)
    for i in range(nurse_num*day_shift_num):
        if (i!=0 and i%day_shift_num == 0):
            print("")
        if sol[i] == 1:
            if i % shift_num == 0:
                print("M", end=" ")
            elif i % shift_num == 1:
                print("E", end=" ")
            elif i % shift_num == 2:
                print("N", end=" ")
            else:
                print("-", end=" ")
    print("\n")
    for i in range(nurse_num*day_shift_num, 308):
        print(sol[i])
    if sol.shape[0] > 308:
        change = sol[308:]
        print(change.shape)
        for i in range(nurse_num*day_shift_num):
            if (i!=0 and i%day_shift_num == 0):
                print("")
            print(change[i], end=" ")
#            if change[i] == 0:
#                print(" ", end=" ")
#            elif change[i] == 1:
#                print("C", end=" ")
        print("\n")


def solve_IP_inputMatrix(c, A_temp, b_temp, G, h):
    decision_num = c.shape[0]
    if A_temp is not None:
#        print("A_temp: ", A_temp)
        rowSizeA = A_temp.shape[0]
        sizeA = A_temp.size
    else:
        rowSizeA = 0
        sizeA = 0
    if rowSizeA > 0 and sizeA == rowSizeA:
        A = np.zeros((1, decision_num))
        A[0] = A_temp
        b = np.zeros(1)
        b[0] = b_temp
    else:
        A = A_temp
        b = b_temp
#    print(A, b)

    rowSizeG = G.shape[0]
    c = c.tolist()
    if rowSizeA > 0:
        A = A.tolist()
        b = b.tolist()
    G = G.tolist()
    h = h.tolist()
#    print(c, A, b, G, h)
    
#    decision_num = np.size(c)
    m = gp.Model()
    m.setParam('OutputFlag', 0)
    x = m.addVars(decision_num, vtype=GRB.INTEGER, name='x')

    OBJ = x.prod(c)
    m.setObjective(OBJ, GRB.MINIMIZE)
    if rowSizeA > 0:
        for i in range(rowSizeA):
            m.addConstr(x.prod(A[i]) == b[i])
    for j in range(rowSizeG):
        m.addConstr(x.prod(G[j]) <= h[j])

    m.optimize()
    sol = np.zeros(decision_num)
#    print(objective)
    try:
        objective = m.getObjective()
#        print("Gurobi c: ", c)
        for i in range(decision_num):
            sol[i] = x[i].x
        get_obj = objective.getValue()
#        print(sol)
    except:
#        print("cannot solve", end=" ")
        m.computeIIS()
        m.write('model.ilp')
        np.savetxt('c.txt', c, fmt="%.2f")
        np.savetxt('A.txt', A, fmt="%.2f")
        np.savetxt('b.txt', b, fmt="%.2f")
        np.savetxt('G.txt', G, fmt="%.2f")
        np.savetxt('h.txt', h, fmt="%.2f")
        get_obj = 0
#        input("Continue?")

    return sol, get_obj



#print_schedule('cplex_sol.txt')
#
#original_c, original_A, original_b, original_G, original_h = lp_parser("cutting.mps")
#original_c = np.array(original_c)
#
##input_c = np.loadtxt("c.txt")
##input_A = None
##input_b = None
##input_G = np.loadtxt("G.txt")
##input_h = np.loadtxt("h.txt")
#IP_sol, objective = solve_IP_inputMatrix(original_c, original_A, original_b, original_G, original_h)
#np.savetxt('cutting.txt', IP_sol, fmt="%.2f")
#print(IP_sol.shape, objective)
##print_schedule("cutting.txt")
###print(original_c)


##output_c, output_A, output_b, output_G, output_h = lp_parser("./cut/nodelp14_0.mps")
#output_c, output_A, output_b, output_G, output_h = lp_parser("stage2_milp.mps")
#output_c = np.array(output_c)
##np.savetxt('output_c.txt', output_c, fmt="%.2f")
##print(output_c.shape)
#output_c = np.loadtxt("stage1_c.txt")
#output_A = np.loadtxt("stage1_A.txt")
#output_b = np.loadtxt("stage1_b.txt")
#output_G = np.loadtxt("stage1_G.txt")
#output_h = np.loadtxt("stage1_h.txt")
#IP_sol, objective = solve_LP_inputMatrix(output_c, output_A, output_b, output_G, output_h)
#print(objective)
#np.savetxt('IP_sol.txt', IP_sol, fmt="%.2f")
###print(IP_sol)

#objective = solve_LP("stage2_milp.mps")
#print(objective)
#np.savetxt('IP_sol.txt', IP_sol, fmt="%.2f")
#objective = solve_LP("cutting.mps")
#print(objective)

#for i in range(51):
#    objective = solve_LP("./cut/nodelp" + str(i) + "_0.mps")
#    print(i, "Gurobi solves: ", objective, end=" ")
#    output_c, output_A, output_b, output_G, output_h = lp_parser("./cut/nodelp" + str(i) + "_0.mps")
#    output_c = np.array(output_c)
#    IP_sol, objective = solve_LP_inputMatrix(output_c, output_A, output_b, output_G, output_h)
#    print("solve_LP_inputMatrix: ", objective, end=" ")
#    print(output_A.shape, output_G.shape)

## Path to the directory where
## the files reside
#path = r"/Users/huxinyi/Desktop/NSP/cutting/cut/"
#
## Getting the list of files/directories
## in the specified path Filtering the
## list to exclude the directory names
#files = list(filter(os.path.isfile, glob.glob(path + "*")))
#
## Sorting file list based on the
## creation time of the files
#files.sort(key=os.path.getctime)
#
#for filename in files:
##    print(filename, end=" ")
#    objective = solve_LP(filename)
#    print(objective)
##    print("Gurobi solves: ", objective)
##    output_c, output_A, output_b, output_G, output_h = lp_parser(filename)
##    output_c = np.array(output_c)
###    IP_sol, objective = solve_LP_inputMatrix(output_c, output_A, output_b, output_G, output_h)
###    print("solve_LP_inputMatrix: ", objective)
##    print(output_A.shape, output_G.shape)
##    np.savetxt('output_h.txt', output_h, fmt="%.2f")
##    input("Continue?")
#
###print(files)
