import os
import sys
from collections import defaultdict
import numpy as np
import gurobipy as gp
from gurobipy import GRB
from sklearn.metrics import mean_squared_error
from numpy import inf
import torch
import math
import time

global nurse_num
global day_num
global shift_num
global day_shift_num
global day_work_shift_num
global decision_num
global penaltyTerm
global extra_serve_patient_num
global extra_payment
global minimum_relax_day
global maximum_relax_day
global t_decision_num

nurse_num = 10
day_num = 7
shift_num = 4
day_shift_num = day_num * shift_num
day_work_shift_num = day_num * (shift_num - 1)
decision_num = nurse_num * day_num * shift_num + day_shift_num
penaltyTerm = 0.01
extra_serve_patient_num = 1
extra_payment = 0
t_decision_num = decision_num + nurse_num * day_num * shift_num
minimum_relax_day = 1
maximum_relax_day = 2


def mkdir(default_path, folder_name):
    path = os.path.join(default_path, folder_name)
    folder = os.path.exists(path)
    if not folder:
        os.makedirs(path)


def set_extra_payment(extra_payment_num):
    global extra_payment
    extra_payment = extra_payment_num


def gen_matrix(nurse_num, day_num, shift_num, serve_patient_num, decision_num, day_shift_num):
    # Each nurse must be scheduled for exactly one shift per day
    A = np.zeros((nurse_num*day_num, decision_num))
    for i in range(nurse_num):
        for j in range(day_num):
            for q in range(shift_num):
                A[i*day_num+j][i*day_shift_num+shift_num*j+q] = 1
    #print(A)
    b = np.ones(nurse_num*day_num)

    # Each schedule must satisfy the patients' need (include relax shifts)
    G1 = np.zeros((day_shift_num, decision_num))
    for j in range(day_shift_num):
        for i in range(nurse_num):
            G1[j][i*day_shift_num+j] = -serve_patient_num[i]
            G1[j][nurse_num * day_num * shift_num + j] = -extra_serve_patient_num

    # No nurse may be scheduled to work a night shift followed immendiately by a morning shift
    G2 = np.zeros((nurse_num*(day_num-1), decision_num))
    for i in range(nurse_num):
        for j in range(day_num-1):
            G2[i*(day_num-1)+j][i*day_shift_num+shift_num*(j+1)-2] = 1
            G2[i*(day_num-1)+j][i*day_shift_num+shift_num*(j+1)] = 1
    #print(G2)
    h2 = np.ones(nurse_num*(day_num-1))
    
    # Each nurse gets one or two day-off shift per week
    G3 = np.zeros((nurse_num, decision_num))
    G4 = np.zeros((nurse_num, decision_num))
    for i in range(nurse_num):
        for j in range(1, day_num+1):
            G3[i][i*day_shift_num+shift_num*j-1] = -1
            G4[i][i*day_shift_num+shift_num*j-1] = 1
#    print(G4)
    h3 = np.empty(nurse_num, dtype=int)
    h4 = np.empty(nurse_num, dtype=int)
    for i in range(nurse_num):
        h3[i] = -minimum_relax_day
        h4[i] = maximum_relax_day
    
    # x range
    G5 = np.zeros((nurse_num * day_num * shift_num,decision_num))
    for i in range(nurse_num * day_num * shift_num):
        G5[i][i] = 1
    h5 = np.ones(nurse_num * day_num * shift_num)
    
    G = np.concatenate([G1, G2, G3, G4, G5], axis=0)
    
    return A,b,G,h2,h3,h4,h5



def gen_t_matrix(t, prev_sol, real_h1, pre_h1, serve_patient_num):
    # Each nurse must be scheduled for exactly one shift per day
    A = np.zeros((nurse_num*day_num, t_decision_num))
    for i in range(nurse_num):
        for j in range(day_num):
            for q in range(shift_num):
                A[i*day_num+j][i*day_shift_num+shift_num*j+q] = 1
    #print(A)
    b = np.ones(nurse_num*day_num)

    # Each schedule must satisfy the patients' need (include relax shifts)
    G1 = np.zeros((day_shift_num, t_decision_num))
    for j in range(day_shift_num):
        for i in range(nurse_num):
            G1[j][i*day_shift_num+j] = -serve_patient_num[i]
            G1[j][nurse_num * day_num * shift_num + j] = -extra_serve_patient_num
    h1 = np.zeros(day_shift_num)
    for i in range(t*shift_num):
        h1[i] = real_h1[i]
    for i in range(t*shift_num, day_shift_num):
        h1[i] = pre_h1[i]

    # No nurse may be scheduled to work a night shift followed immendiately by a morning shift
    G2 = np.zeros((nurse_num*(day_num-1), t_decision_num))
    for i in range(nurse_num):
        for j in range(day_num-1):
            G2[i*(day_num-1)+j][i*day_shift_num+shift_num*(j+1)-2] = 1
            G2[i*(day_num-1)+j][i*day_shift_num+shift_num*(j+1)] = 1
    #print(G2)
    
    # Each nurse gets one or two day-off shift per week
    G3 = np.zeros((nurse_num, t_decision_num))
    G4 = np.zeros((nurse_num, t_decision_num))
    for i in range(nurse_num):
        for j in range(1, day_num+1):
            G3[i][i*day_shift_num+shift_num*j-1] = -1
            G4[i][i*day_shift_num+shift_num*j-1] = 1
#    print(G4)
    h2 = np.ones(nurse_num*(day_num-1))
    h3 = np.empty(nurse_num, dtype=int)
    h4 = np.empty(nurse_num, dtype=int)
    for i in range(nurse_num):
        h3[i] = -minimum_relax_day
        h4[i] = maximum_relax_day
    
    # x range
    G5 = np.zeros((nurse_num * day_num * shift_num,t_decision_num))
    for i in range(nurse_num * day_num * shift_num):
        G5[i][i] = 1
    h5 = np.ones(nurse_num * day_num * shift_num)
    
    # x - gamma <= x_(t-1)
    G6 = np.zeros((nurse_num*day_shift_num, t_decision_num))
    h6 = prev_sol[:nurse_num*day_shift_num]
    for i in range(nurse_num):
        for j in range(day_shift_num):
            G6[i*day_shift_num+j][i*day_shift_num+j] = 1
            G6[i*day_shift_num+j][decision_num+i*day_shift_num+j] = -1
#    print(h6)
#    np.savetxt('G6.txt', G6, fmt="%.0f")

    # gamma range
    G7 = np.zeros((nurse_num*day_shift_num,t_decision_num))
    for i in range(nurse_num*day_shift_num):
        G7[i][decision_num+i] = 1
    h7 = np.ones(nurse_num*day_shift_num)
    
    # fix hard commitments
    A2 = np.zeros((nurse_num*t*shift_num+(t-1)*shift_num,t_decision_num))
    b2 = np.zeros(nurse_num*t*shift_num+(t-1)*shift_num)
    cnt = 0
    for i in range(nurse_num):
        for j in range(t):
            for k in range(shift_num):
#                print("fix cons")
                A2[cnt][i*day_shift_num+j*shift_num+k] = 1
                cnt = cnt + 1
    for j in range(t-1):
        for k in range(shift_num):
            A2[cnt][nurse_num*day_shift_num+j*shift_num+k] = 1
            cnt = cnt + 1
#        print(G6)
    cnt = 0
    for i in range(nurse_num):
        for j in range(t):
            for k in range(shift_num):
#                print(i, j, k)
#                print(i*day_shift_num+j*day_shift_num+k)
                b2[cnt] = prev_sol[i*day_shift_num+j*shift_num+k]
                cnt = cnt + 1
    for j in range(t-1):
        for k in range(shift_num):
            b2[cnt] = prev_sol[nurse_num*day_shift_num+j*shift_num+k]
            cnt = cnt + 1
#        print(h6)
    G = np.concatenate([G1, G2, G3, G4, G5, G6, G7], axis=0)
    h = np.concatenate([h1, h2, h3, h4, h5, h6, h7], axis=0)
    A = np.concatenate([A, A2], axis=0)
    b = np.concatenate([b, b2], axis=0)
        
#        print(h1.shape, h2.shape, h3.shape, h4.shape, h5.shape, h6.shape, h7.shape, h8.shape)
#        (8,) (15,) (15,) (15,) (120,) torch.Size([120]) (120,) (16,)
#
#    else:
#        G = np.concatenate([G1, G2, G3, G4, G5, G6, G7], axis=0)
#        h = np.concatenate([h1, h2, h3, h4, h5, h6, h7], axis=0)
    
    return A,b,G,h
    

def gen_obj(t, cost, penalty=None):
    if penalty is None:
        c = np.zeros(decision_num)
        for i in range(nurse_num):
            for j in range(day_shift_num):
                if j % shift_num != 3:
                    c[i*day_shift_num+j] = cost[i]
                elif j % shift_num == 3:
                    c[i*day_shift_num+j] = 0
        for i in range(nurse_num*day_shift_num, decision_num):
            c[i] = extra_payment
    else:
        c_for_x = np.zeros(nurse_num*day_shift_num)
        for i in range(nurse_num):
            for j in range(day_shift_num):
                if j % shift_num != 3:
                    c_for_x[i*day_shift_num+j] = cost[i]
                elif j % shift_num == 3:
                    c_for_x[i*day_shift_num+j] = 0
        
        c_for_sigma = np.zeros(day_shift_num)
        for i in range(day_shift_num):
            c_for_sigma[i] = extra_payment
        
        c_for_gamma = np.zeros(nurse_num*day_shift_num)
        for i in range(nurse_num):
            for j in range(day_num):
                for k in range(shift_num):
                    c_for_gamma[i*day_shift_num+j*shift_num+k] = (day_num - j + t) * penalty[i*day_shift_num+j*shift_num+k] * c_for_x[i*day_shift_num+j*shift_num+k]
                
        c = np.concatenate([c_for_x, c_for_sigma, c_for_gamma], axis=0)
        
    return c


def actual_obj(c, A, b, G, real_patient_num, h2, h3, h4, h5, n_instance):
    obj_list = []
    rowSizeA = A.shape[0]
    rowSizeG = G.shape[0]
#    x_sol_size = nurse_num * day_shift_num
#    c = c[:x_sol_size]
#    A = A[:, :x_sol_size]
#    G = G[:, :x_sol_size]
    c = c.tolist()
    A = A.tolist()
    b = b.tolist()
    G = G.tolist()
    
    for num in range(n_instance):
        h1 = np.zeros(day_shift_num)
        cnt = num * day_work_shift_num
        for i in range(day_shift_num):
            if i % shift_num != 3:
                h1[i] = -real_patient_num[cnt]
                cnt = cnt + 1
            elif i % shift_num == 3:    # relax shift
                h1[i] = 0
        h = np.concatenate([h1, h2, h3, h4, h5], axis=0)
#        print(h1)
        h = h.tolist()

        m = gp.Model()
        m.setParam('OutputFlag', 0)
        x = m.addVars(decision_num, vtype=GRB.INTEGER, name='x')
#        sigma = m.addVars(day_shift_num, vtype=GRB.INTEGER, name='sigma')
        
        OBJ = x.prod(c)
#        for i in range(day_shift_num):
#            OBJ = OBJ + extra_payment * sigma[i]
        m.setObjective(OBJ, GRB.MINIMIZE)
        for i in range(rowSizeA):
            m.addConstr(x.prod(A[i]) == b[i])
        for j in range(rowSizeG):
            m.addConstr(x.prod(G[j]) <= h[j])
#            m.addConstr(x.prod(G[j]) - extra_serve_patient_num * sigma[j] <= h[j])
#        for j in range(day_shift_num, rowSizeG):
#            m.addConstr(x.prod(G[j]) <= h[j])

        m.optimize()
        sol = []
        sigmaSol = []
        try:
            for i in range(decision_num):
                sol.append(x[i].x)
#            for i in range(day_shift_num):
#                sigmaSol.append(sigma[i].x)
            objective = m.objVal
        except:
            for i in range(decision_num):
                sol.append(0)
#            for i in range(day_shift_num):
#                sigmaSol.append(0)
            objective = 0

        obj_list.append(objective)
#        print(objective)
##        print("True Opt Sol: ",sol)
#        print("True Opt Schedule: ")
##        print("Day 1 2 3 4 5 6 7")
##        for i in range(decision_num):
##            if (i!=0 and i%day_shift_num == 0):
##                print("")
##            print(sol[i], end=" ")
##        print("\n")
#
#        for i in range(nurse_num*day_shift_num):
#            if (i!=0 and i%day_shift_num == 0):
#                print("")
##                print("N", math.ceil(i/day_shift_num), end=" ")
#            if sol[i] == 1:
#                if i % shift_num == 0:
#                    print("M", end=" ")
#                elif i % shift_num == 1:
#                    print("E", end=" ")
#                elif i % shift_num == 2:
#                    print("N", end=" ")
#                else:
#                    print("-", end=" ")
#        print("\n")
##        print(sigmaSol)
#        print("Extra hired nurses: ")
#        print("Day 1 2 3 4 5 6 7")
#        for j in range(shift_num-1):
#            if j == 0:
#                print(" M ", end=" ")
#            elif j == 1:
#                print(" E ", end=" ")
#            elif j == 2:
#                print(" N ", end=" ")
#            for i in range(day_num):
#                print(math.ceil(sigmaSol[i*shift_num+j]), end=" ")
#            print("")
#        print("\n")

    return np.array(obj_list)


def get_init_plan(c, A, b, G, real_patient_num, pre_patient_num, h2, h3, h4, h5):
    rowSizeA = A.shape[0]
    rowSizeG = G.shape[0]
    x_sol_size = nurse_num * day_shift_num
#    c = c[:x_sol_size]
#    A = A[:, :x_sol_size]
#    G = G[:, :x_sol_size]
    c = c.tolist()
    A = A.tolist()
    b = b.tolist()
    G = G.tolist()
    
    pre_h1 = np.zeros(day_shift_num)
    real_h1 = np.zeros(day_shift_num)
    cnt = 0
    for i in range(day_shift_num):
        if i % shift_num != 3:
            pre_h1[i] = -pre_patient_num[cnt]
            real_h1[i] = -real_patient_num[cnt]
            cnt = cnt + 1
        else:
            pre_h1[i] = 0
            real_h1[i] = 0
    pre_h = np.concatenate([pre_h1, h2, h3, h4, h5], axis=0)
    real_h = np.concatenate([real_h1, h2, h3, h4, h5], axis=0)
    pre_h = pre_h.tolist()
    real_h = real_h.tolist()
#    print(pre_h1, real_h1)

    m = gp.Model()
    m.setParam('OutputFlag', 0)
    x = m.addVars(decision_num, vtype=GRB.INTEGER, name='x')
#    sigma = m.addVars(day_shift_num, vtype=GRB.INTEGER, name='sigma')

    OBJ = x.prod(c)
#    for i in range(day_shift_num):
#        OBJ = OBJ + extra_payment * sigma[i]
    m.setObjective(OBJ, GRB.MINIMIZE)
    for i in range(rowSizeA):
        m.addConstr(x.prod(A[i]) == b[i])
    for j in range(rowSizeG):
        m.addConstr(x.prod(G[j]) <= pre_h[j])
#    for j in range(day_shift_num, rowSizeG):
#        m.addConstr(x.prod(G[j]) <= pre_h[j])
    m.optimize()
    
    try:
        predSol = np.zeros(nurse_num*day_shift_num)
        sigmaSol = np.zeros(day_shift_num)
        for i in range(nurse_num*day_shift_num):
            predSol[i] = x[i].x
        for i in range(nurse_num*day_shift_num, decision_num):
            sigmaSol[i-nurse_num*day_shift_num] = x[i].x
        objective = m.objVal
    except:
        print("cannot solve")
#        m.computeIIS()
#        m.write('model.ilp')
        predSol = np.zeros(nurse_num*day_shift_num)
        sigmaSol = np.zeros(day_shift_num)
        objective = 0
    
##    print(objective)
#    print("init plan: ")
#    for i in range(decision_num):
#        if (i!=0 and i%day_shift_num == 0):
#            print("")
#        print(predSol[i], end=" ")
#    print("\n")
#    print("Init Opt Sol: ")
##        for i in range(decision_num):
##            if (i!=0 and i%day_shift_num == 0):
##                print("")
##            print(sol[i], end=" ")
##        print("\n")
#
#    for i in range(nurse_num*day_shift_num):
#        if (i!=0 and i%day_shift_num == 0):
#            print("")
#        if predSol[i] == 1:
#            if i % shift_num == 0:
#                print("M", end=" ")
#            elif i % shift_num == 1:
#                print("E", end=" ")
#            elif i % shift_num == 2:
#                print("N", end=" ")
#            else:
#                print("-", end=" ")
#    print("\n")
#    print("Extra hired nurses: ")
#    print("Day 1 2 3 4 5 6 7")
#    for j in range(shift_num-1):
#        if j == 0:
#            print(" M ", end=" ")
#        elif j == 1:
#            print(" E ", end=" ")
#        elif j == 2:
#            print(" N ", end=" ")
#        for i in range(day_num):
#            print(math.ceil(sigmaSol[i*shift_num+j]), end=" ")
#        print("")
#    print("\n")
    
    return predSol, sigmaSol


def get_t_updated_plan(t, pre_sol, pre_sigmaSol, c, A, b, G, real_patient_num, pre_patient_num, h2, h3, h4, h5, penalty):
    rowSizeA = A.shape[0]
    rowSizeG = G.shape[0]
    x_sol_size = nurse_num * day_shift_num
    c = c[:x_sol_size]
    A = A[:, :x_sol_size]
    G = G[:, :x_sol_size]
    c = c.tolist()
    A = A.tolist()
    b = b.tolist()
    G = G.tolist()
    
    pre_h1 = np.zeros(day_shift_num)
    real_h1 = np.zeros(day_shift_num)
    cnt = 0
    for i in range(day_shift_num):
        if i % shift_num != 3:
            pre_h1[i] = -pre_patient_num[cnt]
            real_h1[i] = -real_patient_num[cnt]
            cnt = cnt + 1
        else:
            pre_h1[i] = 0
            real_h1[i] = 0
    pre_h = np.concatenate([pre_h1, h2, h3, h4, h5], axis=0)
    real_h = np.concatenate([real_h1, h2, h3, h4, h5], axis=0)
    pre_h = pre_h.tolist()
    real_h = real_h.tolist()
    
    m = gp.Model()
    m.setParam('OutputFlag', 0)
    x = m.addVars(x_sol_size, vtype=GRB.BINARY, name='x')
    gamma = m.addVars(x_sol_size, vtype=GRB.BINARY, name='gamma')
    sigma = m.addVars(day_shift_num, vtype=GRB.INTEGER, name='sigma')
    

    OBJ = x.prod(c)
    for i in range(nurse_num):
        for j in range(day_num):
            for k in range(shift_num):
                OBJ = OBJ + (day_num - j + t) * penalty[i*day_shift_num+j*shift_num+k] * c[i*day_shift_num+j*shift_num+k] * gamma[i*day_shift_num+j*shift_num+k]
    for i in range(day_shift_num):
        OBJ = OBJ + extra_payment * sigma[i]
    m.setObjective(OBJ, GRB.MINIMIZE)

    for i in range(rowSizeA):
        m.addConstr(x.prod(A[i]) == b[i])
    for j in range(t*shift_num):
        m.addConstr(x.prod(G[j]) - extra_serve_patient_num * sigma[j] <= real_h[j])
#    for j in range(t):
#        m.addConstr(x.prod(G[j]) - extra_serve_patient_num * sigma[j] <= real_h[j])
#        print("real_h[", j, "]", real_h[j])
    for j in range(t*shift_num, day_shift_num):
        m.addConstr(x.prod(G[j]) - extra_serve_patient_num * sigma[j] <= pre_h[j])
    for j in range(day_shift_num, rowSizeG):
        m.addConstr(x.prod(G[j]) <= pre_h[j])
#        print("pre_h[", j, "]", pre_h[j])
    for i in range(x_sol_size):
        m.addConstr(gamma[i] >= x[i] - pre_sol[i])
#    for i in range(day_shift_num):
#        m.addConstr(tau[i] >= sigma[i] - pre_sigmaSol[i])
#        m.addConstr(phi[i] >= pre_sigmaSol[i] - sigma[i])
    for i in range(nurse_num):
        for j in range(t):
            for k in range(shift_num):
#            print(i*day_shift_num+j)
                m.addConstr(x[i*day_shift_num+j*shift_num+k] == pre_sol[i*day_shift_num+j*shift_num+k])
    for j in range(t-1):
        for k in range(shift_num):
            m.addConstr(sigma[j*shift_num+k] == pre_sigmaSol[j*shift_num+k])

    m.optimize()
    t_updated_sol = np.zeros(x_sol_size)
    t_sigma_sol = np.zeros(day_shift_num)
    t_gamma = np.zeros(x_sol_size)

    try:
        for i in range(x_sol_size):
            t_updated_sol[i] = round(x[i].x)
            t_gamma[i] = gamma[i].x
        for i in range(day_shift_num):
            t_sigma_sol[i] = sigma[i].x
        objective = m.objVal
    except:
#        print("cannot solve")
#        m.computeIIS()
#        m.write('model.ilp')
#        for i in range(decision_num):
#            t_updated_sol[i] = 0
        objective = 0
    
    t_incur_penalty = 0
    for i in range(nurse_num):
        for j in range(day_num):
            for k in range(shift_num):
                t_incur_penalty = t_incur_penalty + (day_num - j + t) * penalty[i*day_shift_num+j*shift_num+k] * c[i*day_shift_num+j*shift_num+k] * t_gamma[i*day_shift_num+j*shift_num+k]
    
#    print(t, "updated_sol: ")
##    for i in range(decision_num):
##        if (i!=0 and i%day_shift_num == 0):
##            print("")
##        print(t_updated_sol[i], end=" ")
##    cnt = 0
#    for i in range(nurse_num*day_shift_num):
#        if (i!=0 and i%day_shift_num == 0):
#            print("")
#        if t_updated_sol[i] == 1:
#            if i % shift_num == 0:
#                print("M", end=" ")
#            elif i % shift_num == 1:
#                print("E", end=" ")
#            elif i % shift_num == 2:
#                print("N", end=" ")
#            else:
#                print("-", end=" ")
#    print("\n")
##    print(t_sigma_sol)
#    print("Extra hired nurses: ")
#    print("Day 1 2 3 4 5 6 7")
#    for j in range(shift_num-1):
#        if j == 0:
#            print(" M ", end=" ")
#        elif j == 1:
#            print(" E ", end=" ")
#        elif j == 2:
#            print(" N ", end=" ")
#        for i in range(day_num):
#            print(math.ceil(t_sigma_sol[i*shift_num+j]), end=" ")
#        print("")
#    print("\n")
#    print(t, "incur_penalty: ", t_incur_penalty)
#    print("-----------------------------------")
#
#    print("\n")
#    print(t, "change: ")
#    change = t_updated_sol - pre_sol
#    for i in range(decision_num):
#        if (i!=0 and i%day_shift_num == 0):
#            print("")
#        print(change[i], end=" ")
#    print("\n")
    
    return t_updated_sol, t_sigma_sol, t_incur_penalty


def correction_single_obj(c, A, b, G, real_patient, pre_patient, h2, h3, h4, h5, penalty):
    init_plan, init_sigma = get_init_plan(c, A, b, G, real_patient, pre_patient, h2, h3, h4, h5)
    t_updated_sol = init_plan
    t_sigma_sol = init_sigma
    total_penalty = 0
    for t in range(1, day_num+1):
        t_updated_sol, t_sigma_sol, t_incur_penalty = get_t_updated_plan(t, t_updated_sol, t_sigma_sol, c, A, b, G, real_patient, pre_patient, h2, h3, h4, h5, penalty)
        total_penalty = total_penalty + t_incur_penalty
#    print(t_updated_sol, t_sigma_sol)
    total_cost = np.dot(t_updated_sol, c[:nurse_num*day_shift_num]) + np.sum(t_sigma_sol * extra_payment) + total_penalty
#    print("EOV: ", total_cost)
    return total_cost


def check_IP_t_updated(t, c, A, b, G, h, penalty):
    rowSizeA = A.shape[0]
    rowSizeG = G.shape[0]
#    x_sol_size = nurse_num * day_shift_num
#    c = c[:x_sol_size]
#    A = A[:, :x_sol_size]
#    G = G[:, :x_sol_size]
    c = c.tolist()
    A = A.tolist()
    b = b.tolist()
    G = G.tolist()
    h = h.tolist()
    
#    pre_h1 = np.zeros(day_shift_num)
#    real_h1 = np.zeros(day_shift_num)
#    cnt = 0
#    for i in range(day_shift_num):
#        if i % shift_num != 3:
#            pre_h1[i] = -pre_patient_num[cnt]
#            real_h1[i] = -real_patient_num[cnt]
#            cnt = cnt + 1
#        else:
#            pre_h1[i] = 0
#            real_h1[i] = 0
#    pre_h = np.concatenate([pre_h1, h2, h3, h4, h5], axis=0)
#    real_h = np.concatenate([real_h1, h2, h3, h4, h5], axis=0)
#    pre_h = pre_h.tolist()
#    real_h = real_h.tolist()
    
    m = gp.Model()
    m.setParam('OutputFlag', 0)
    x = m.addVars(t_decision_num, vtype=GRB.INTEGER, name='x')
#    gamma = m.addVars(x_sol_size, vtype=GRB.BINARY, name='gamma')
#    sigma = m.addVars(day_shift_num, vtype=GRB.INTEGER, name='sigma')
    

    OBJ = x.prod(c)
#    for i in range(x_sol_size):
#        OBJ = OBJ + penalty[i] * c[i] * gamma[i]
#    for i in range(day_shift_num):
#        OBJ = OBJ + extra_payment * sigma[i]
    m.setObjective(OBJ, GRB.MINIMIZE)

    for i in range(rowSizeA):
        m.addConstr(x.prod(A[i]) == b[i])
    for j in range(rowSizeG):
        m.addConstr(x.prod(G[j]) <= h[j])
#    for j in range(t*shift_num):
#        m.addConstr(x.prod(G[j]) - extra_serve_patient_num * sigma[j] <= real_h[j])
##    for j in range(t):
##        m.addConstr(x.prod(G[j]) - extra_serve_patient_num * sigma[j] <= real_h[j])
##        print("real_h[", j, "]", real_h[j])
#    for j in range(t*shift_num, day_shift_num):
#        m.addConstr(x.prod(G[j]) - extra_serve_patient_num * sigma[j] <= pre_h[j])
#    for j in range(day_shift_num, rowSizeG):
#        m.addConstr(x.prod(G[j]) <= pre_h[j])
##        print("pre_h[", j, "]", pre_h[j])
#    for i in range(x_sol_size):
#        m.addConstr(gamma[i] >= x[i] - pre_sol[i])
##    for i in range(day_shift_num):
##        m.addConstr(tau[i] >= sigma[i] - pre_sigmaSol[i])
##        m.addConstr(phi[i] >= pre_sigmaSol[i] - sigma[i])
#    for i in range(nurse_num):
#        for j in range(t):
#            for k in range(shift_num):
##            print(i*day_shift_num+j)
#                m.addConstr(x[i*day_shift_num+j*shift_num+k] == pre_sol[i*day_shift_num+j*shift_num+k])
#    for j in range(t-1):
#        for k in range(shift_num):
#            m.addConstr(sigma[j*shift_num+k] == pre_sigmaSol[j*shift_num+k])

    m.optimize()
    t_updated_sol = np.zeros(nurse_num*day_shift_num)
    t_sigma_sol = np.zeros(day_shift_num)
    t_gamma = np.zeros(nurse_num*day_shift_num)

    try:
        for i in range(nurse_num*day_shift_num):
            t_updated_sol[i] = round(x[i].x)
        for i in range(nurse_num*day_shift_num, decision_num):
            t_sigma_sol[i-nurse_num*day_shift_num] = round(x[i].x)
        for i in range(decision_num, t_decision_num):
            t_gamma[i-decision_num] = round(x[i].x)
        objective = m.objVal
    except:
#        print("cannot solve")
#        m.computeIIS()
#        m.write('model.ilp')
#        for i in range(decision_num):
#            t_updated_sol[i] = 0
        objective = 0
    
    t_incur_penalty = 0
    for i in range(nurse_num*day_shift_num):
        t_incur_penalty = t_incur_penalty + penalty[i] * c[i] * t_gamma[i]

    
#    print(t, "updated_sol: ")
##    for i in range(decision_num):
##        if (i!=0 and i%day_shift_num == 0):
##            print("")
##        print(t_updated_sol[i], end=" ")
##    cnt = 0
#    for i in range(nurse_num*day_shift_num):
#        if (i!=0 and i%day_shift_num == 0):
#            print("")
#        if t_updated_sol[i] == 1:
#            if i % shift_num == 0:
#                print("M", end=" ")
#            elif i % shift_num == 1:
#                print("E", end=" ")
#            elif i % shift_num == 2:
#                print("N", end=" ")
#            else:
#                print("-", end=" ")
#    print("\n")
##    print(t_sigma_sol)
#    print("Extra hired nurses: ")
#    print("Day 1 2 3 4 5 6 7")
#    for j in range(shift_num-1):
#        if j == 0:
#            print(" M ", end=" ")
#        elif j == 1:
#            print(" E ", end=" ")
#        elif j == 2:
#            print(" N ", end=" ")
#        for i in range(day_num):
#            print(math.ceil(t_sigma_sol[i*shift_num+j]), end=" ")
#        print("")
#    print("\n")
#    print(t, "incur_penalty: ", t_incur_penalty)
#    print("-----------------------------------")
#
#    print("\n")
#    print(t, "change: ")
#    change = t_updated_sol - pre_sol
#    for i in range(decision_num):
#        if (i!=0 and i%day_shift_num == 0):
#            print("")
#        print(change[i], end=" ")
#    print("\n")
    
    return t_updated_sol, t_sigma_sol, t_incur_penalty


def check_intOpt(cost, serve_patient_num, real_patient, pre_patient, penalty):
    c = gen_obj(0, cost)
    A,b,G,h2,h3,h4,h5 = gen_matrix(nurse_num,day_num,shift_num,serve_patient_num,decision_num,day_shift_num)
    init_plan, init_sigma = get_init_plan(c, A, b, G, real_patient, pre_patient, h2, h3, h4, h5)
    
#    init_obj = np.dot(init_plan, c[:nurse_num*day_shift_num])
#    print("IP: ", init_obj)

    t_updated_x = init_plan
    t_sigma_sol = init_sigma
    t_updated_sol = np.concatenate([init_plan, init_sigma], axis=0)
    total_penalty = 0
    
    pre_h1 = np.zeros(day_shift_num)
    real_h1 = np.zeros(day_shift_num)
    cnt = 0
    for i in range(day_shift_num):
        if i % shift_num != 3:
            pre_h1[i] = -pre_patient[cnt]
            real_h1[i] = -real_patient[cnt]
            cnt = cnt + 1
        else:
            pre_h1[i] = 0
            real_h1[i] = 0
    
    for t in range(1, day_num+1):
#        if t % 4 != 0:
        c_t = gen_obj(t, cost, penalty)
        A_t, b_t, G_t, h_t = gen_t_matrix(t, t_updated_sol, real_h1, pre_h1, serve_patient_num)
        t_updated_x, t_sigma_sol, t_incur_penalty = check_IP_t_updated(t, c_t, A_t, b_t, G_t, h_t, penalty)
        t_updated_sol = np.concatenate([t_updated_x, t_sigma_sol], axis=0)
        total_penalty = total_penalty + t_incur_penalty
        c = c[:nurse_num*day_shift_num]
        total_cost = np.dot(t_updated_x, c) + np.sum(t_sigma_sol * extra_payment) + total_penalty
#    print("EOV: ", total_cost)
#    time.sleep(10)
    return total_cost
