import os
import sys
from collections import defaultdict
import numpy as np
import gurobipy as gp
from gurobipy import GRB
from sklearn.metrics import mean_squared_error
from numpy import inf
import torch
import math
import time

global nurse_num
global day_num
global shift_num
global day_shift_num
global day_work_shift_num
global decision_num
global penaltyTerm
global extra_serve_patient_num
global extra_payment
global minimum_relax_day
global maximum_relax_day
global t_decision_num

nurse_num = 10
day_num = 7
shift_num = 4
day_shift_num = day_num * shift_num
day_work_shift_num = day_num * (shift_num - 1)
decision_num = nurse_num * day_num * shift_num + day_shift_num
penaltyTerm = 0.1
extra_serve_patient_num = 1
extra_payment = 0
t_decision_num = decision_num + nurse_num * day_num * shift_num
if day_num == 7:
    minimum_relax_day = 1
    maximum_relax_day = 2
elif day_num == 2:
    minimum_relax_day = 0
    maximum_relax_day = 1

def mkdir(default_path, folder_name):
    path = os.path.join(default_path, folder_name)
    folder = os.path.exists(path)
    if not folder:
        os.makedirs(path)


def set_extra_payment(ep):
    global extra_payment
    extra_payment = ep


def gen_matrix(nurse_num, day_num, shift_num, serve_patient_num, decision_num, day_shift_num):
    # Each nurse must be scheduled for exactly one shift per day
    A = np.zeros((nurse_num*day_num, decision_num))
    for i in range(nurse_num):
        for j in range(day_num):
            for q in range(shift_num):
                A[i*day_num+j][i*day_shift_num+shift_num*j+q] = 1
    #print(A)
    b = np.ones(nurse_num*day_num)

    # Each schedule must satisfy the patients' need (include relax shifts)
    G1 = np.zeros((day_shift_num, decision_num))
    for j in range(day_shift_num):
        for i in range(nurse_num):
            G1[j][i*day_shift_num+j] = -serve_patient_num[i]
            G1[j][nurse_num * day_num * shift_num + j] = -extra_serve_patient_num

    # No nurse may be scheduled to work a night shift followed immendiately by a morning shift
    G2 = np.zeros((nurse_num*(day_num-1), decision_num))
    for i in range(nurse_num):
        for j in range(day_num-1):
            G2[i*(day_num-1)+j][i*day_shift_num+shift_num*(j+1)-2] = 1
            G2[i*(day_num-1)+j][i*day_shift_num+shift_num*(j+1)] = 1
    #print(G2)
    h2 = np.ones(nurse_num*(day_num-1))
    
    # Each nurse gets one or two day-off shift per week
    G3 = np.zeros((nurse_num, decision_num))
    G4 = np.zeros((nurse_num, decision_num))
    for i in range(nurse_num):
        for j in range(1, day_num+1):
            G3[i][i*day_shift_num+shift_num*j-1] = -1
            G4[i][i*day_shift_num+shift_num*j-1] = 1
#    print(G4)
    h3 = np.empty(nurse_num, dtype=int)
    h4 = np.empty(nurse_num, dtype=int)
    for i in range(nurse_num):
        h3[i] = -minimum_relax_day
        h4[i] = maximum_relax_day
    
    # x range
    G5 = np.zeros((nurse_num * day_num * shift_num,decision_num))
    for i in range(nurse_num * day_num * shift_num):
        G5[i][i] = 1
    h5 = np.ones(nurse_num * day_num * shift_num)
    
    G = np.concatenate([G1, G2, G3, G4, G5], axis=0)
    
    return A,b,G,h2,h3,h4,h5


def gen_t_matrix(prev_sol, real_h1, pre_h1, serve_patient_num):
    # Each nurse must be scheduled for exactly one shift per day
    A = np.zeros((nurse_num*day_num, t_decision_num))
    for i in range(nurse_num):
        for j in range(day_num):
            for q in range(shift_num):
                A[i*day_num+j][i*day_shift_num+shift_num*j+q] = 1
    #print(A)
    b = np.ones(nurse_num*day_num)

    # Each schedule must satisfy the patients' need (include relax shifts)
    G1 = np.zeros((day_shift_num, t_decision_num))
    for j in range(day_shift_num):
        for i in range(nurse_num):
            G1[j][i*day_shift_num+j] = -serve_patient_num[i]
            G1[j][nurse_num * day_num * shift_num + j] = -extra_serve_patient_num
    h1 = real_h1
#    for i in range(t*shift_num):
#        h1[i] = real_h1[i]
#    for i in range(t*shift_num, day_shift_num):
#        h1[i] = pre_h1[i]

    # No nurse may be scheduled to work a night shift followed immendiately by a morning shift
    G2 = np.zeros((nurse_num*(day_num-1), t_decision_num))
    for i in range(nurse_num):
        for j in range(day_num-1):
            G2[i*(day_num-1)+j][i*day_shift_num+shift_num*(j+1)-2] = 1
            G2[i*(day_num-1)+j][i*day_shift_num+shift_num*(j+1)] = 1
    #print(G2)
    
    # Each nurse gets one or two day-off shift per week
    G3 = np.zeros((nurse_num, t_decision_num))
    G4 = np.zeros((nurse_num, t_decision_num))
    for i in range(nurse_num):
        for j in range(1, day_num+1):
            G3[i][i*day_shift_num+shift_num*j-1] = -1
            G4[i][i*day_shift_num+shift_num*j-1] = 1
#    print(G4)
    h2 = np.ones(nurse_num*(day_num-1))
    h3 = np.empty(nurse_num, dtype=int)
    h4 = np.empty(nurse_num, dtype=int)
    for i in range(nurse_num):
        h3[i] = -minimum_relax_day
        h4[i] = maximum_relax_day
    
    # x range
    G5 = np.zeros((nurse_num * day_num * shift_num,t_decision_num))
    for i in range(nurse_num * day_num * shift_num):
        G5[i][i] = 1
    h5 = np.ones(nurse_num * day_num * shift_num)
    
    # x - gamma <= x_(t-1)
    G6 = np.zeros((nurse_num*day_shift_num, t_decision_num))
    h6 = prev_sol[:nurse_num*day_shift_num]
    for i in range(nurse_num):
        for j in range(day_shift_num):
            G6[i*day_shift_num+j][i*day_shift_num+j] = 1
            G6[i*day_shift_num+j][decision_num+i*day_shift_num+j] = -1
#    print(h6)
#    np.savetxt('G6.txt', G6, fmt="%.0f")

    # gamma range
    G7 = np.zeros((nurse_num*day_shift_num,t_decision_num))
    for i in range(nurse_num*day_shift_num):
        G7[i][decision_num+i] = 1
    h7 = np.ones(nurse_num*day_shift_num)
    
#    # fix hard commitments
#    A2 = np.zeros((nurse_num*t*shift_num+(t-1)*shift_num,t_decision_num))
#    b2 = np.zeros(nurse_num*t*shift_num+(t-1)*shift_num)
#    cnt = 0
#    for i in range(nurse_num):
#        for j in range(t):
#            for k in range(shift_num):
##                print("fix cons")
#                A2[cnt][i*day_shift_num+j*shift_num+k] = 1
#                cnt = cnt + 1
#    for j in range(t-1):
#        for k in range(shift_num):
#            A2[cnt][nurse_num*day_shift_num+j*shift_num+k] = 1
#            cnt = cnt + 1
##        print(G6)
#    cnt = 0
#    for i in range(nurse_num):
#        for j in range(t):
#            for k in range(shift_num):
##                print(i, j, k)
##                print(i*day_shift_num+j*day_shift_num+k)
#                b2[cnt] = prev_sol[i*day_shift_num+j*shift_num+k]
#                cnt = cnt + 1
#    for j in range(t-1):
#        for k in range(shift_num):
#            b2[cnt] = prev_sol[nurse_num*day_shift_num+j*shift_num+k]
#            cnt = cnt + 1
##        print(h6)

    G = np.concatenate([G1, G2, G3, G4, G5, G6, G7], axis=0)
    h = np.concatenate([h1, h2, h3, h4, h5, h6, h7], axis=0)
#    A = np.concatenate([A, A2], axis=0)
#    b = np.concatenate([b, b2], axis=0)
        
#        print(h1.shape, h2.shape, h3.shape, h4.shape, h5.shape, h6.shape, h7.shape, h8.shape)
#        (8,) (15,) (15,) (15,) (120,) torch.Size([120]) (120,) (16,)
#
#    else:
#        G = np.concatenate([G1, G2, G3, G4, G5, G6, G7], axis=0)
#        h = np.concatenate([h1, h2, h3, h4, h5, h6, h7], axis=0)
    
    return A,b,G,h
    

def gen_obj(t, cost, penalty=None):
    if penalty is None:
        c = np.zeros(decision_num)
        for i in range(nurse_num):
            for j in range(day_shift_num):
                if j % shift_num != 3:
                    c[i*day_shift_num+j] = cost[i]
                elif j % shift_num == 3:
                    c[i*day_shift_num+j] = 0
        for i in range(nurse_num*day_shift_num, decision_num):
            c[i] = extra_payment
    else:
        c_for_x = np.zeros(nurse_num*day_shift_num)
        for i in range(nurse_num):
            for j in range(day_shift_num):
                if j % shift_num != 3:
                    c_for_x[i*day_shift_num+j] = cost[i]
                elif j % shift_num == 3:
                    c_for_x[i*day_shift_num+j] = 0
        
        c_for_sigma = np.zeros(day_shift_num)
        for i in range(day_shift_num):
            c_for_sigma[i] = extra_payment
        
        c_for_gamma = np.zeros(nurse_num*day_shift_num)
        for i in range(nurse_num):
            for j in range(day_num):
                for k in range(shift_num):
                    c_for_gamma[i*day_shift_num+j*shift_num+k] = (day_num - j + t) * penalty[i*day_shift_num+j*shift_num+k] * c_for_x[i*day_shift_num+j*shift_num+k]
                
        c = np.concatenate([c_for_x, c_for_sigma, c_for_gamma], axis=0)
        
    return c


def actual_obj(c, A, b, G, real_patient_num, h2, h3, h4, h5, n_instance):
    obj_list = []
    rowSizeA = A.shape[0]
    rowSizeG = G.shape[0]
#    x_sol_size = nurse_num * day_shift_num
#    c = c[:x_sol_size]
#    A = A[:, :x_sol_size]
#    G = G[:, :x_sol_size]
    c = c.tolist()
    A = A.tolist()
    b = b.tolist()
    G = G.tolist()
    
#    print(real_patient_num.shape)
    for num in range(n_instance):
        h1 = np.zeros(day_shift_num)
        cnt = num * day_work_shift_num
        for i in range(day_shift_num):
            if i % shift_num != 3:
                h1[i] = -real_patient_num[cnt]
                cnt = cnt + 1
            elif i % shift_num == 3:    # relax shift
                h1[i] = 0
        h = np.concatenate([h1, h2, h3, h4, h5], axis=0)
#        print(h1)
        h = h.tolist()

        m = gp.Model()
        m.setParam('OutputFlag', 0)
        x = m.addVars(decision_num, vtype=GRB.INTEGER, name='x')
#        sigma = m.addVars(day_shift_num, vtype=GRB.INTEGER, name='sigma')
        
        OBJ = x.prod(c)
#        for i in range(day_shift_num):
#            OBJ = OBJ + extra_payment * sigma[i]
        m.setObjective(OBJ, GRB.MINIMIZE)
        for i in range(rowSizeA):
            m.addConstr(x.prod(A[i]) == b[i])
        for j in range(rowSizeG):
            m.addConstr(x.prod(G[j]) <= h[j])
#            m.addConstr(x.prod(G[j]) - extra_serve_patient_num * sigma[j] <= h[j])
#        for j in range(day_shift_num, rowSizeG):
#            m.addConstr(x.prod(G[j]) <= h[j])

        m.optimize()
        sol = []
        sigmaSol = []
        try:
            for i in range(nurse_num * day_num * shift_num):
                sol.append(x[i].x)
            for i in range(nurse_num * day_num * shift_num, decision_num):
                sigmaSol.append(x[i].x)
            objective = m.objVal
        except:
            for i in range(nurse_num * day_num * shift_num):
                sol.append(0)
            for i in range(day_shift_num):
                sigmaSol.append(0)
            objective = 0

        obj_list.append(objective)
#        print(objective)
##        print("True Opt Sol: ",sol)
#        print("True Opt Schedule: ")
##        print("Day 1 2 3 4 5 6 7")
##        for i in range(decision_num):
##            if (i!=0 and i%day_shift_num == 0):
##                print("")
##            print(sol[i], end=" ")
##        print("\n")
#
#        for i in range(nurse_num*day_shift_num):
#            if (i!=0 and i%day_shift_num == 0):
#                print("")
##                print("N", math.ceil(i/day_shift_num), end=" ")
#            if sol[i] == 1:
#                if i % shift_num == 0:
#                    print("M", end=" ")
#                elif i % shift_num == 1:
#                    print("E", end=" ")
#                elif i % shift_num == 2:
#                    print("N", end=" ")
#                else:
#                    print("-", end=" ")
#        print("\n")
##        print(sigmaSol)
#        print("Extra hired nurses: ")
#        print("Day 1 2 3 4 5 6 7")
#        for j in range(shift_num-1):
#            if j == 0:
#                print(" M ", end=" ")
#            elif j == 1:
#                print(" E ", end=" ")
#            elif j == 2:
#                print(" N ", end=" ")
#            for i in range(day_num):
#                print(math.ceil(sigmaSol[i*shift_num+j]), end=" ")
#            print("")
#        print("TOV: ", objective)
#        print("\n")
        
    return np.array(obj_list)


def get_init_plan(c, A, b, G, pre_patient_num, h2, h3, h4, h5):
    rowSizeA = A.shape[0]
    rowSizeG = G.shape[0]
    x_sol_size = nurse_num * day_shift_num
#    c = c[:x_sol_size]
#    A = A[:, :x_sol_size]
#    G = G[:, :x_sol_size]
    c = c.tolist()
    A = A.tolist()
    b = b.tolist()
    G = G.tolist()
    
    pre_h1 = np.zeros(day_shift_num)
    cnt = 0
    for i in range(day_shift_num):
        if i % shift_num != 3:
            pre_h1[i] = -pre_patient_num[cnt]
            cnt = cnt + 1
        else:
            pre_h1[i] = 0
    pre_h = np.concatenate([pre_h1, h2, h3, h4, h5], axis=0)
    pre_h = pre_h.tolist()
#    print(pre_h1, real_h1)

    m = gp.Model()
    m.setParam('OutputFlag', 0)
    x = m.addVars(decision_num, vtype=GRB.INTEGER, name='x')
#    sigma = m.addVars(day_shift_num, vtype=GRB.INTEGER, name='sigma')

    OBJ = x.prod(c)
#    for i in range(day_shift_num):
#        OBJ = OBJ + extra_payment * sigma[i]
    m.setObjective(OBJ, GRB.MINIMIZE)
    for i in range(rowSizeA):
        m.addConstr(x.prod(A[i]) == b[i])
    for j in range(rowSizeG):
        m.addConstr(x.prod(G[j]) <= pre_h[j])
#    for j in range(day_shift_num, rowSizeG):
#        m.addConstr(x.prod(G[j]) <= pre_h[j])
    m.optimize()
    
    try:
        predSol = np.zeros(nurse_num*day_shift_num)
        sigmaSol = np.zeros(day_shift_num)
        for i in range(nurse_num*day_shift_num):
            predSol[i] = x[i].x
        for i in range(nurse_num*day_shift_num, decision_num):
            sigmaSol[i-nurse_num*day_shift_num] = x[i].x
        objective = m.objVal
    except:
        print("cannot solve")
#        m.computeIIS()
#        m.write('model.ilp')
        predSol = np.zeros(nurse_num*day_shift_num)
        sigmaSol = np.zeros(day_shift_num)
        objective = 0
    
##    print(objective)
#    print("init plan: ")
#    for i in range(decision_num):
#        if (i!=0 and i%day_shift_num == 0):
#            print("")
#        print(predSol[i], end=" ")
#    print("\n")
#    print("Init Opt Sol: ")
##        for i in range(decision_num):
##            if (i!=0 and i%day_shift_num == 0):
##                print("")
##            print(sol[i], end=" ")
##        print("\n")
#
#    for i in range(nurse_num*day_shift_num):
#        if (i!=0 and i%day_shift_num == 0):
#            print("")
#        if predSol[i] == 1:
#            if i % shift_num == 0:
#                print("M", end=" ")
#            elif i % shift_num == 1:
#                print("E", end=" ")
#            elif i % shift_num == 2:
#                print("N", end=" ")
#            else:
#                print("-", end=" ")
#    print("\n")
#    print("Extra hired nurses: ")
#    print("Day 1 2 3 4 5 6 7")
#    for j in range(shift_num-1):
#        if j == 0:
#            print(" M ", end=" ")
#        elif j == 1:
#            print(" E ", end=" ")
#        elif j == 2:
#            print(" N ", end=" ")
#        for i in range(day_num):
#            print(math.ceil(sigmaSol[i*shift_num+j]), end=" ")
#        print("")
#    print("\n")
    
    return predSol, sigmaSol


def get_t_updated_plan(pre_sol, pre_sigmaSol, c, A, b, G, real_patient_num, h2, h3, h4, h5, penalty):
    rowSizeA = A.shape[0]
    rowSizeG = G.shape[0]
    x_sol_size = nurse_num * day_shift_num
    c = c[:x_sol_size]
    A = A[:, :x_sol_size]
    G = G[:, :x_sol_size]
    c = c.tolist()
    A = A.tolist()
    b = b.tolist()
    G = G.tolist()
    

    real_h1 = np.zeros(day_shift_num)
    cnt = 0
    for i in range(day_shift_num):
        if i % shift_num != 3:
            real_h1[i] = -real_patient_num[cnt]
            cnt = cnt + 1
        else:
            real_h1[i] = 0
    real_h = np.concatenate([real_h1, h2, h3, h4, h5], axis=0)
    real_h = real_h.tolist()
    
    m = gp.Model()
    m.setParam('OutputFlag', 0)
    x = m.addVars(x_sol_size, vtype=GRB.BINARY, name='x')
    gamma = m.addVars(x_sol_size, vtype=GRB.BINARY, name='gamma')
    sigma = m.addVars(day_shift_num, vtype=GRB.INTEGER, name='sigma')
    

    OBJ = x.prod(c)
    for i in range(nurse_num):
        for j in range(day_num):
            for k in range(shift_num):
                OBJ = OBJ + (day_num - j) * penalty[i*day_shift_num+j*shift_num+k] * c[i*day_shift_num+j*shift_num+k] * gamma[i*day_shift_num+j*shift_num+k]
    for i in range(day_shift_num):
        OBJ = OBJ + extra_payment * sigma[i]
    m.setObjective(OBJ, GRB.MINIMIZE)

    for i in range(rowSizeA):
        m.addConstr(x.prod(A[i]) == b[i])
    for j in range(day_shift_num):
        m.addConstr(x.prod(G[j]) - extra_serve_patient_num * sigma[j] <= real_h[j])
#    for j in range(t):
#        m.addConstr(x.prod(G[j]) - extra_serve_patient_num * sigma[j] <= real_h[j])
#        print("real_h[", j, "]", real_h[j])
#    for j in range(t*shift_num, day_shift_num):
#        m.addConstr(x.prod(G[j]) - extra_serve_patient_num * sigma[j] <= pre_h[j])
    for j in range(day_shift_num, rowSizeG):
        m.addConstr(x.prod(G[j]) <= real_h[j])
#        print("pre_h[", j, "]", pre_h[j])
    for i in range(x_sol_size):
        m.addConstr(gamma[i] >= x[i] - pre_sol[i])
#    for i in range(day_shift_num):
#        m.addConstr(tau[i] >= sigma[i] - pre_sigmaSol[i])
#        m.addConstr(phi[i] >= pre_sigmaSol[i] - sigma[i])
#    for i in range(nurse_num):
#        for j in range(t):
#            for k in range(shift_num):
##            print(i*day_shift_num+j)
#                m.addConstr(x[i*day_shift_num+j*shift_num+k] == pre_sol[i*day_shift_num+j*shift_num+k])
#    for j in range(t-1):
#        for k in range(shift_num):
#            m.addConstr(sigma[j*shift_num+k] == pre_sigmaSol[j*shift_num+k])

    m.optimize()
    t_updated_sol = np.zeros(x_sol_size)
    t_sigma_sol = np.zeros(day_shift_num)
    t_gamma = np.zeros(x_sol_size)

    try:
        for i in range(x_sol_size):
            t_updated_sol[i] = round(x[i].x)
            t_gamma[i] = gamma[i].x
        for i in range(day_shift_num):
            t_sigma_sol[i] = sigma[i].x
        objective = m.objVal
    except:
#        print("cannot solve")
#        m.computeIIS()
#        m.write('model.ilp')
#        for i in range(decision_num):
#            t_updated_sol[i] = 0
        objective = 0
    
    t_incur_penalty = 0
    for i in range(nurse_num):
        for j in range(day_num):
            for k in range(shift_num):
                t_incur_penalty = t_incur_penalty + (day_num - j) * penalty[i*day_shift_num+j*shift_num+k] * c[i*day_shift_num+j*shift_num+k] * t_gamma[i*day_shift_num+j*shift_num+k]
    
#    print("Stage 1 updated_sol: ")
##    for i in range(decision_num):
##        if (i!=0 and i%day_shift_num == 0):
##            print("")
##        print(t_updated_sol[i], end=" ")
##    cnt = 0
#    for i in range(nurse_num*day_shift_num):
#        if (i!=0 and i%day_shift_num == 0):
#            print("")
#        if t_updated_sol[i] == 1:
#            if i % shift_num == 0:
#                print("M", end=" ")
#            elif i % shift_num == 1:
#                print("E", end=" ")
#            elif i % shift_num == 2:
#                print("N", end=" ")
#            else:
#                print("-", end=" ")
#    print("\n")
##    print(t_sigma_sol)
#    print("Extra hired nurses: ")
#    print("Day 1 2 3 4 5 6 7")
#    for j in range(shift_num-1):
#        if j == 0:
#            print(" M ", end=" ")
#        elif j == 1:
#            print(" E ", end=" ")
#        elif j == 2:
#            print(" N ", end=" ")
#        for i in range(day_num):
#            print(math.ceil(t_sigma_sol[i*shift_num+j]), end=" ")
#        print("")
#    print("\n")
#    print("EOV: ", objective, end=" ")
#    print("Stage 1 incur_penalty: ", t_incur_penalty)
##    print("-----------------------------------")
#
##    print("\n")
#    print("Stage 1 change: ")
#    change = t_updated_sol - pre_sol
##    print(change.shape)
#    for i in range(nurse_num * day_num * shift_num):
#        if (i!=0 and i%day_shift_num == 0):
#            print("")
#        if change[i] == 1:
#            print("W", end=" ")
#        elif change[i] == 0:
#            print("-", end=" ")
#        elif change[i] == -1:
#            print("R", end=" ")
#    print("\n")
    
    return t_updated_sol, t_sigma_sol, t_incur_penalty


def correction_single_obj(c, A, b, G, real_patient, pre_patient, h2, h3, h4, h5, penalty):
    init_plan, init_sigma = get_init_plan(c, A, b, G, pre_patient, h2, h3, h4, h5)
    t_updated_sol = init_plan
    t_sigma_sol = init_sigma
    total_penalty = 0
#    for t in range(1, day_num+1):
    t_updated_sol, t_sigma_sol, t_incur_penalty = get_t_updated_plan(t_updated_sol, t_sigma_sol, c, A, b, G, real_patient, h2, h3, h4, h5, penalty)
    total_penalty = total_penalty + t_incur_penalty
#    print(t_updated_sol, t_sigma_sol)
    total_cost = np.dot(t_updated_sol, c[:nurse_num*day_shift_num]) + np.sum(t_sigma_sol * extra_payment) + total_penalty
#    print("EOV: ", total_cost)
    return total_cost
