# use ip_model_whole(logKKT(Gh)_hUn).py
import os
import sys
import numpy as np
import random
import pandas as pd
import math, time
import itertools
from sklearn import preprocessing
from sklearn.preprocessing import MinMaxScaler
import datetime
import torch
from torch import nn, optim
from torch.autograd import Variable
import torch.utils.data as data_utils
from torch.utils.data.dataset import Dataset
from sklearn.preprocessing import StandardScaler
import gurobipy as gp
import logging
import copy
from collections import defaultdict
import joblib
import gurobipy as gp
from gurobipy import GRB
import time, datetime

import sys
import ip_model_whole as ip_model_wholeFile
from ip_model_whole import IPOfunc
import linear_relax as LP_relax_file

total_day_num = LP_relax_file.total_day_num
nurse_num = LP_relax_file.nurse_num
day_num = LP_relax_file.day_num
shift_num = LP_relax_file.shift_num
day_shift_num = LP_relax_file.day_shift_num
day_work_shift_num = LP_relax_file.day_work_shift_num
decision_num = LP_relax_file.decision_num
t_decision_num = LP_relax_file.t_decision_num
penaltyTerm = LP_relax_file.penaltyTerm
extra_serve_patient_num = LP_relax_file.extra_serve_patient_num
minimum_relax_day = LP_relax_file.minimum_relax_day
maximum_relax_day = LP_relax_file.maximum_relax_day

extra_payment = int(sys.argv[1])
startmark = int(sys.argv[2])
endmark = int(sys.argv[3])
LP_relax_file.set_extra_payment(extra_payment)

train_set_num = 70
test_set_num = 30
warm_start_val = 25000
iteration_num = 5
featureNum = 8
warm_start_epoch_criterion = 12
stop_epoch_criterion = 20
retrain_criterion = False
dataset_path = os.path.abspath(os.path.dirname(os.getcwd()))
default_path = os.path.join(dataset_path, 'day_num=' + str(total_day_num) + '/penalty=' + str(penaltyTerm) + ', extra_serve=' + str(extra_serve_patient_num) + ', extra_payment=' + str(extra_payment) + '/')

LP_relax_file.mkdir(default_path, 'sequential_T_NN')
LP_relax_file.mkdir(default_path, 'sequential_prev_cost')
LP_relax_file.mkdir(default_path, 'sequential_prev_penalty')
LP_relax_file.mkdir(default_path, 'sequential_prev_schedule')

train_future_x = np.ones((train_set_num, LP_relax_file.x_num))
train_has_rested = np.zeros((train_set_num, nurse_num))
train_curr_cost = np.zeros(train_set_num)
train_curr_penalty = np.zeros(train_set_num)

test_future_x = np.zeros((test_set_num, LP_relax_file.x_num))
test_has_rested = np.zeros((test_set_num, nurse_num))
test_curr_cost = np.zeros(test_set_num)
test_curr_penalty = np.zeros(test_set_num)

def weight_init(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight)
        nn.init.constant_(m.bias, 0)

    elif isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)

def make_fc(num_layers, num_features, num_targets=1,
            activation_fn = nn.ReLU,intermediate_size=2*featureNum, regularizers = True):
    net_layers = [nn.Linear(num_features, intermediate_size), activation_fn()]
    for hidden in range(num_layers-2):
        net_layers.append(nn.Linear(intermediate_size, intermediate_size))
        net_layers.append(activation_fn())
    net_layers.append(nn.Linear(intermediate_size, num_targets))
    net_layers.append(activation_fn())
    return nn.Sequential(*net_layers)


def make_next_plan(test_num, cur_NN, real_patient, pre_patient, cost, serve_patient_num, penalty=None):
    global test_future_x
    global test_has_rested
    global test_curr_cost
    global test_curr_penalty

    if cur_NN == 0:
        c = LP_relax_file.gen_obj(0, cost)
        A,b,G,h2,h3,h4,h5 = LP_relax_file.gen_matrix(nurse_num,day_num,shift_num,serve_patient_num,decision_num,day_shift_num)

        init_x, init_sigma = LP_relax_file.get_init_plan(c, A, b, G, real_patient, pre_patient, h2, h3, h4, h5)
        test_future_x[test_num] = init_x
#        for i in range(nurse_num):
#            for j in range(shift_num):
#                test_future_x[test_num][i*total_day_num*shift_num+j] = init_x[i*total_day_num*shift_num+j]
#                if j == shift_num and init_x[i*total_day_num*shift_num+j] == 1:
#                    test_has_rested[i] += 1
            
    else:
        remaining_schedule = np.zeros(nurse_num * LP_relax_file.day_num * shift_num)
        for i in range(nurse_num):
            for j in range(cur_NN-1, total_day_num):
                for k in range(shift_num):
                    remaining_schedule[i*LP_relax_file.day_shift_num+(j-cur_NN+1)*shift_num+k] = test_future_x[test_num][i*total_day_num*shift_num+j*shift_num+k]
        
        pre_h1 = np.zeros(LP_relax_file.day_shift_num)
        real_h1 = np.zeros(LP_relax_file.day_shift_num)
        cnt = 0
        for i in range(LP_relax_file.day_shift_num):
            if i % shift_num != 3:
                pre_h1[i] = -pre_patient[cnt]
                real_h1[i] = -real_patient[cnt]
                cnt = cnt + 1
            else:
                pre_h1[i] = 0
                real_h1[i] = 0
        
        c = LP_relax_file.gen_obj_latter_days(cur_NN, cost, penalty)
        A, b, G, h = LP_relax_file.gen_constraints_latter_days(cur_NN, remaining_schedule, test_has_rested[test_num], real_h1, pre_h1, serve_patient_num)
        try:
            x_sol, sigma_sol, incur_penalty = LP_relax_file.get_updated_plan_for_each_day(cur_NN, c, A, b, G, h, penalty)
#                print(prev_cost[num])
        except:
            np.savetxt('G.txt', G, fmt="%.2f")
            np.savetxt('h.txt', h, fmt="%.2f")
        test_curr_penalty[test_num] += incur_penalty
        for i in range(nurse_num):
            for j in range(shift_num-1):
                test_curr_cost[test_num] += cost[i] * x_sol[i*LP_relax_file.day_shift_num+j]
        for i in range(shift_num-1):
            test_curr_cost[test_num] += extra_payment * sigma_sol[i]
#        print(sigma_sol)
        
        for i in range(nurse_num):
            for j in range(cur_NN-1, total_day_num):
                for k in range(shift_num):
                    full_index = i*total_day_num*shift_num+j*shift_num+k
                    curr_index = i*LP_relax_file.day_shift_num+(j-cur_NN+1)*shift_num+k
                    test_future_x[test_num][full_index] = x_sol[curr_index]
        for i in range(nurse_num):
            if x_sol[i*LP_relax_file.day_shift_num+shift_num-1] == 1:
                test_has_rested[test_num][i] += 1
        
#        future_plan = x_sol[:item_num]
#        num = num + 1
#    return future_plan


class MyCustomDataset():
    def __init__(self, feature, value):
        self.feature = feature
        self.value = value

    def __len__(self):
        return len(self.value)

    def __getitem__(self, idx):
        return self.feature[idx], self.value[idx]
        

class Intopt:
    def __init__(self, cost, penalty, serve_patient_num, n_features, batch_size, cur_NN, num_layers=5, smoothing=False, thr=0.1, max_iter=None, method=1, mu0=None, damping=1e-7, target_size=1, epochs=8, optimizer=optim.Adam, **hyperparams):
    
        self.cost = cost
        self.penalty = penalty
        self.serve_patient_num = serve_patient_num
        self.target_size = target_size
        self.n_features = n_features
        self.damping = damping
        self.num_layers = num_layers
        self.cur_NN = cur_NN

        self.smoothing = smoothing
        self.thr = thr
        self.max_iter = max_iter
        self.method = method
        self.mu0 = mu0

        self.optimizer = optimizer
        self.batch_size = batch_size
        self.hyperparams = hyperparams
        self.epochs = epochs
        # print("embedding size {} n_features {}".format(embedding_size, n_features))

#        self.model = Net(n_features=n_features, target_size=target_size)
        self.model = make_fc(num_layers=self.num_layers,num_features=n_features)
        #self.model.apply(weight_init)
#        w1 = self.model[0].weight
#        print(w1)

        self.optimizer = optimizer(self.model.parameters(), **hyperparams)

    def fit(self, feature, value):
        logging.info("Intopt")
        train_df = MyCustomDataset(feature, value)

        criterion = nn.MSELoss(reduction='mean')  # nn.MSELoss(reduction='mean')
        grad_list = np.zeros(self.epochs)
#        TOV_train = LP_relax_file.actual_obj(self.c, self.A, self.b, self.G, value, self.h2, self.h3, self.h4, self.h5, train_set_num)
        for e in range(self.epochs):
            cur_loss = 0
            cur_loss_IP = 0
            total_loss = 0
            #          for parameters in self.model.parameters():
            #            print(parameters)
            if e < warm_start_epoch_criterion:
                #print('stage 1')
                train_dl = data_utils.DataLoader(train_df, batch_size=self.batch_size, shuffle=False)
                for feature, value in train_dl:
                    self.optimizer.zero_grad()
                    op = self.model(feature).squeeze()
                    while torch.min(op) <= 0 or torch.isnan(op).any() or torch.isinf(op).any():
                        self.optimizer.zero_grad()
                #                    self.model.__init__(self.n_features, self.target_size)
                        self.model = make_fc(num_layers=self.num_layers,num_features=self.n_features)
                        op = self.model(feature).squeeze()
                    #print(op)

                    loss = criterion(op, value)
                    total_loss += loss.item()
                    grad_list[e] = total_loss
                    loss.backward()
                    self.optimizer.step()
#                print("Epoch{} ::loss {} ->".format(e,total_loss))
                if e < warm_start_epoch_criterion - 1:
                    print("{} ->".format(total_loss), end=" ")
                else:
                    print("{} ->".format(total_loss))
                    
                global stop_epoch
                stop_epoch = e
                if grad_list[e] <= warm_start_val and e == warm_start_epoch_criterion - 1:
                    train_dl = data_utils.DataLoader(train_df, batch_size=self.batch_size, shuffle=False)
                    instance_num = 0
                    batchCnt = 0
                #            train_loss = np.zeros(1)
                    for feature, value in train_dl:
                        self.optimizer.zero_grad()
                        op = self.model(feature).squeeze()
                        while torch.min(op) <= 0 or torch.isnan(op).any() or torch.isinf(op).any():
                            self.optimizer.zero_grad()
                #                    self.model.__init__(self.n_features, self.target_size)
                            self.model = make_fc(num_layers=self.num_layers,num_features=self.n_features)
                            op = self.model(feature).squeeze()

                        penaltyVector = np.zeros(decision_num)
                        for i in range(decision_num):
                            penaltyVector[i] = self.penalty[i+instance_num*decision_num]

                        if self.cur_NN == 0:
                            c = LP_relax_file.gen_obj(0, self.cost)
                            A,b,G,h2,h3,h4,h5 = LP_relax_file.gen_matrix(nurse_num,day_num,shift_num,self.serve_patient_num,decision_num,day_shift_num)
                            loss_IP = LP_relax_file.correction_single_obj(c, A, b, G, value, op, h2, h3, h4, h5, penaltyVector)
                        else:
                            pre_h1 = np.zeros(LP_relax_file.day_shift_num)
                            real_h1 = np.zeros(LP_relax_file.day_shift_num)
                            cnt = 0
#                            print(LP_relax_file.day_shift_num)
                            for i in range(LP_relax_file.day_shift_num):
                                if i % shift_num != 3:
                                    pre_h1[i] = -op[cnt]
                                    real_h1[i] = -value[cnt]
                                    cnt = cnt + 1
                                else:
                                    pre_h1[i] = 0
                                    real_h1[i] = 0

                            remaining_schedule = np.zeros(nurse_num * LP_relax_file.day_num * shift_num)
                            for i in range(nurse_num):
                                for j in range(self.cur_NN-1, total_day_num):
                                    for k in range(shift_num):
                                        remaining_schedule[i*LP_relax_file.day_shift_num+(j-self.cur_NN+1)*shift_num+k] = train_future_x[batchCnt][i*total_day_num*shift_num+j*shift_num+k]
                            
#                            has_rested_temp = train_has_rested[batchCnt]
#                            print(train_curr_cost[batchCnt], train_curr_penalty[batchCnt], train_has_rested[batchCnt],has_rested_temp)
#                            print(pre_h1.shape, LP_relax_file.day_shift_num, LP_relax_file.day_num)
                            loss_IP = LP_relax_file.correction_single_for_latter_days(self.cur_NN, pre_h1, real_h1, self.cost, penaltyVector, self.serve_patient_num, remaining_schedule, train_curr_cost[batchCnt], train_curr_penalty[batchCnt], train_has_rested[batchCnt])
#                            train_has_rested[batchCnt] = has_rested_temp
#                            print(train_curr_cost[batchCnt], train_curr_penalty[batchCnt], train_has_rested[batchCnt],has_rested_temp)

                        cur_loss_IP = cur_loss_IP + loss_IP
                #                    if batchCnt % 30 == 0:
                #                        print(loss_IP)

                        batchCnt = batchCnt + 1
                #                cur_loss = cur_loss / train_set_num
                #                print("cur_loss_LP: ", cur_loss)
                    cur_loss_IP = cur_loss_IP / train_set_num
                    print("EOV: ", cur_loss_IP)
                    grad_list[e] = cur_loss_IP


            else:
#                if e == warm_start_epoch_criterion:
                lr = 5e-7
                for param_group in self.optimizer.param_groups:
                   param_group['lr'] = lr
                #            print('stage 2')
                train_dl = data_utils.DataLoader(train_df, batch_size=self.batch_size, shuffle=False)
                instance_num = 0
                batchCnt = 0
                #            train_loss = np.zeros(1)
                global retrain_criterion
                retrain_criterion = False
                for feature, value in train_dl:
                    self.optimizer.zero_grad()
                    op = self.model(feature).squeeze()
  
                    while torch.min(op) <= 0 or torch.isnan(op).any() or torch.isinf(op).any():
                        self.optimizer.zero_grad()
                #                    self.model.__init__(self.n_features, self.target_size)
                        self.model = make_fc(num_layers=self.num_layers,num_features=self.n_features)
                        op = self.model(feature).squeeze()

                    penaltyVector = np.zeros(decision_num)
                    for i in range(decision_num):
                        penaltyVector[i] = self.penalty[i+instance_num*decision_num]

                    penalty_torch = torch.from_numpy(penaltyVector).float()
                #                print(h2_torch.shape,h3_torch.shape,op.shape)
                    if self.cur_NN == 0:
                        remaining_schedule = train_future_x[batchCnt]
                    else:
                        remaining_schedule = np.zeros(nurse_num * LP_relax_file.day_num * shift_num)
                        for i in range(nurse_num):
                            for j in range(self.cur_NN-1, total_day_num):
                                for k in range(shift_num):
                                    remaining_schedule[i*LP_relax_file.day_shift_num+(j-self.cur_NN+1)*shift_num+k] = train_future_x[batchCnt][i*total_day_num*shift_num+j*shift_num+k]
#                    print(remaining_schedule.shape)
                    x_sol = IPOfunc(cur_NN=self.cur_NN, real_patient=value, penalty=penalty_torch, cost=self.cost, serve_patient_num=self.serve_patient_num, remaining_schedule=remaining_schedule, has_rested=train_has_rested[batchCnt], curr_cost=train_curr_cost[batchCnt], curr_penalty=train_curr_penalty[batchCnt],max_iter=self.max_iter, thr=self.thr, damping=self.damping,smoothing=self.smoothing)(op)
                    incur_penalty = ip_model_wholeFile.incur_penalty
                    
                    c = LP_relax_file.gen_obj(0, self.cost)
                    c_torch = torch.from_numpy(c).float()
                    if self.cur_NN == 0:
                        loss = (x_sol[:decision_num] * c_torch[:decision_num]).sum() + incur_penalty
                        
                    else:
#                        print(x_sol.shape)
                        loss = (x_sol[:LP_relax_file.x_num+LP_relax_file.sigma_num] * c_torch[:LP_relax_file.x_num+LP_relax_file.sigma_num]).sum() + incur_penalty + train_curr_cost[batchCnt]
                    
                    for i in range(self.batch_size):
                        if torch.isnan(op[i]):
                            retrain_criterion = True
                            print(op)
                            break
                        
                    if batchCnt % 30 == 0:
                      # print(loss)
                        print("LP_loss: ", loss, end=" ")
                #                train_loss[instance_num] = loss.detach().numpy()
                
                    if self.cur_NN == 0:
                        c = LP_relax_file.gen_obj(0, self.cost)
                        A,b,G,h2,h3,h4,h5 = LP_relax_file.gen_matrix(nurse_num,day_num,shift_num,self.serve_patient_num,decision_num,day_shift_num)
                        loss_IP = LP_relax_file.correction_single_obj(c, A, b, G, value, op, h2, h3, h4, h5, penaltyVector)
                    else:
                        pre_h1 = np.zeros(LP_relax_file.day_shift_num)
                        real_h1 = np.zeros(LP_relax_file.day_shift_num)
                        cnt = 0
                        for i in range(LP_relax_file.day_shift_num):
                            if i % shift_num != 3:
                                pre_h1[i] = -op[cnt]
                                real_h1[i] = -value[cnt]
                                cnt = cnt + 1
                            else:
                                pre_h1[i] = 0
                                real_h1[i] = 0
                        
#                        print(LP_relax_file.day_num)
                        remaining_schedule = np.zeros(nurse_num * LP_relax_file.day_num * shift_num)
                        for i in range(nurse_num):
                            for j in range(self.cur_NN-1, total_day_num):
                                for k in range(shift_num):
                                    remaining_schedule[i*LP_relax_file.day_shift_num+(j-self.cur_NN+1)*shift_num+k] = train_future_x[batchCnt][i*total_day_num*shift_num+j*shift_num+k]
#                        print(train_has_rested[batchCnt])
                        loss_IP = LP_relax_file.correction_single_for_latter_days(self.cur_NN, pre_h1, real_h1, self.cost, penaltyVector, self.serve_patient_num, remaining_schedule, train_curr_cost[batchCnt], train_curr_penalty[batchCnt], train_has_rested[batchCnt])
                            
                    cur_loss_IP = cur_loss_IP + loss_IP
                    if batchCnt % 30 == 0:
                        print("IP_loss: ", cur_loss_IP)


                    batchCnt += 1
                    total_loss += loss.item()
                    loss = loss / decision_num
                    # if batchCnt % 3 == 0:
                    loss.backward()
                    self.optimizer.step()

                total_loss = total_loss / train_set_num
                logging.info("EPOCH Ends")
                #print("Epoch{}".format(e))
                #          print(train_loss)
                print("Epoch{} ::loss {} ->".format(e,total_loss), end=" ")
                cur_loss_IP = cur_loss_IP / train_set_num
                print("cur_loss_IP: ", cur_loss_IP)
                stop_epoch = e
                grad_list[e] = cur_loss_IP
                recordNow[self.cur_NN] = cur_loss_IP

            global train_loss
            train_loss = total_loss
            if retrain_criterion:
                break
            if e > 0 and grad_list[e] == grad_list[e-1]:
                break
            if e > warm_start_epoch_criterion - 1 and grad_list[e] >= grad_list[e-1]:
                break
            if grad_list[warm_start_epoch_criterion-2] > warm_start_val:
                break
            # print(self.val_loss(valid_econ, valid_prop))
            # print("______________")


    def val_loss(self, feature, value):
        valueTemp = value.numpy()
#        c_list = self.c.tolist()
#        G_list = self.G.tolist()
        test_instance = len(valueTemp) / self.batch_size
#        test_instance = 1
        real_obj = LP_relax_file.actual_obj(self.c, self.A, self.b, self.G, value, self.h2, self.h3, self.h4, self.h5, n_instance=int(test_instance))
#        print(real_obj)

        self.model.eval()
        criterion = nn.L1Loss(reduction='mean')  # nn.MSELoss(reduction='sum')
        valid_df = MyCustomDataset(feature, value)
        valid_dl = data_utils.DataLoader(valid_df, batch_size=self.batch_size, shuffle=False)
        prediction_loss = 0
        corr_obj_list = []
        num = 0
        n_instance=int(test_instance)
        pred_val = np.zeros(day_work_shift_num*n_instance)

        for feature, value in valid_dl:
            op = self.model(feature).squeeze()
#            print(op)
            loss = criterion(op, value)
            prediction_loss += loss.item()

            real_patient = {}
            pre_patient = {}
            for i in range(day_work_shift_num):
                real_patient[i] = value[i]
                pre_patient[i] = op[i]
                pred_val[num*day_work_shift_num+i] = op[i]

            penaltyVector = np.zeros(decision_num)
            for i in range(decision_num):
                penaltyVector[i] = self.penalty[i+num*decision_num]

            corrrlst = LP_relax_file.correction_single_obj(self.c, self.A, self.b, self.G, real_patient, pre_patient, self.h2, self.h3, self.h4, self.h5, penaltyVector)
            corr_obj_list.append(corrrlst)
            num = num + 1

        self.model.train()
#        print("corr_obj_list: ", corr_obj_list)
#        print("2SReg: ", real_obj - np.array(corr_obj_list))
#        return prediction_loss, abs(np.array(obj_list) - real_obj)
        return abs(real_obj - np.array(corr_obj_list)), pred_val
    
    
    def get_pred_val(self, feature, value):
        valueTemp = value.numpy()
        self.model.eval()
        criterion = nn.L1Loss(reduction='mean')  # nn.MSELoss(reduction='sum')
        valid_df = MyCustomDataset(feature, value)
        valid_dl = data_utils.DataLoader(valid_df, batch_size=self.batch_size, shuffle=False)

        len = np.size(valueTemp, 0)
        predVal = torch.zeros(len)

        num = 0
        for feature, value in valid_dl:
            op = self.model(feature).squeeze()
            loss = criterion(op, value)

            for i in range(LP_relax_file.day_work_shift_num):
                predVal[i+num*LP_relax_file.day_work_shift_num] = op[i]
                
            num = num + 1

        return predVal
        

    def make_future_plan(self, feature, value):
        valueTemp = value.numpy()
        instance_num = np.size(valueTemp, 0) / self.batch_size
        self.model.eval()
        criterion = nn.L1Loss(reduction='mean')  # nn.MSELoss(reduction='sum')
        valid_df = MyCustomDataset(feature, value)
        valid_dl = data_utils.DataLoader(valid_df, batch_size=self.batch_size, shuffle=False)

        len = np.size(valueTemp, 0)
        predVal = torch.zeros(len)
        real_patient = np.zeros(LP_relax_file.day_work_shift_num)
        pre_patient = np.zeros(LP_relax_file.day_work_shift_num)
        future_plan = np.zeros((int(instance_num), LP_relax_file.x_num))
        global train_future_x
        global train_has_rested
        global train_curr_cost
        global train_curr_penalty

        num = 0
        for feature, value in valid_dl:
            op = self.model(feature).squeeze()
            loss = criterion(op, value)

            for i in range(LP_relax_file.day_work_shift_num):
                real_patient[i] = value[i]
                if math.isnan(op[i]):
                    pre_patient[i] = 0
                    predVal[i+num*LP_relax_file.day_work_shift_num] = 0
                else:
                    pre_patient[i] = op[i]
                    predVal[i+num*LP_relax_file.day_work_shift_num] = op[i]
            
            if self.cur_NN == 0:
                c = LP_relax_file.gen_obj(0, self.cost)
                A,b,G,h2,h3,h4,h5 = LP_relax_file.gen_matrix(nurse_num,day_num,shift_num,self.serve_patient_num,decision_num,day_shift_num)

                init_x, init_sigma = LP_relax_file.get_init_plan(c, A, b, G, real_patient, pre_patient, h2, h3, h4, h5)
                train_future_x[num] = init_x
        #        for i in range(nurse_num):
        #            for j in range(shift_num):
        #                test_future_x[test_num][i*total_day_num*shift_num+j] = init_x[i*total_day_num*shift_num+j]
        #                if j == shift_num and init_x[i*total_day_num*shift_num+j] == 1:
        #                    test_has_rested[i] += 1

                    
            else:
                remaining_schedule = np.zeros(nurse_num * LP_relax_file.day_num * shift_num)
                for i in range(nurse_num):
                    for j in range(self.cur_NN-1, total_day_num):
                        for k in range(shift_num):
                            remaining_schedule[i*LP_relax_file.day_shift_num+(j-self.cur_NN+1)*shift_num+k] = train_future_x[num][i*total_day_num*shift_num+j*shift_num+k]
                
                pre_h1 = np.zeros(LP_relax_file.day_shift_num)
                real_h1 = np.zeros(LP_relax_file.day_shift_num)
                cnt = 0
                for i in range(LP_relax_file.day_shift_num):
                    if i % shift_num != 3:
                        pre_h1[i] = -pre_patient[cnt]
                        real_h1[i] = -real_patient[cnt]
                        cnt = cnt + 1
                    else:
                        pre_h1[i] = 0
                        real_h1[i] = 0
                
                c = LP_relax_file.gen_obj_latter_days(self.cur_NN, self.cost, self.penalty)
                A, b, G, h = LP_relax_file.gen_constraints_latter_days(self.cur_NN, remaining_schedule, train_has_rested[num], real_h1, pre_h1, self.serve_patient_num)
                try:
                    x_sol, sigma_sol, incur_penalty = LP_relax_file.get_updated_plan_for_each_day(self.cur_NN, c, A, b, G, h, self.penalty)
                except:
                    np.savetxt('G.txt', G, fmt="%.2f")
                    np.savetxt('h.txt', h, fmt="%.2f")
        #                print(prev_cost[num])
                train_curr_penalty[num] += incur_penalty
                for i in range(nurse_num):
                    for j in range(shift_num-1):
                        train_curr_cost[num] += self.cost[i] * x_sol[i*LP_relax_file.day_shift_num+j]
                for i in range(shift_num-1):
                    train_curr_cost[num] += extra_payment * sigma_sol[i]
                
                for i in range(nurse_num):
                    for j in range(cur_NN-1, total_day_num):
                        for k in range(shift_num):
                            full_index = i*total_day_num*shift_num+j*shift_num+k
                            curr_index = i*LP_relax_file.day_shift_num+(j-cur_NN+1)*shift_num+k
                            train_future_x[num][full_index] = x_sol[curr_index]
                for i in range(nurse_num):
                    if x_sol[i*LP_relax_file.day_shift_num+shift_num-1] == 1:
                        train_has_rested[num][i] += 1
        
                future_plan[num] = x_sol
            num = num + 1
#            np.savetxt('MS_train.txt', predVal.detach().numpy(), fmt="%.2f")
            
        return future_plan




stopCriterion = 200
print("*** SCD ****")

testTime = 30
recordBest = np.ones(total_day_num+1)
recordNow = np.zeros(total_day_num+1)
recordBest = recordBest * 1000000
train_each_NN_time = np.zeros(total_day_num+1)
each_iter_time = np.zeros(iteration_num)
print("nurse_num: ", nurse_num, "day_num: ", day_num, "penalty_for_reschedule: ", penaltyTerm, "extra_serve_patient_num: ", extra_serve_patient_num, "extra_payment: ", extra_payment)
    
    
for testmark in range(startmark, endmark):
    print("-------------------------------------------------------------")
    print("Simulation ", testmark)
    cost = np.loadtxt(os.path.join(dataset_path, 'day_num=' + str(total_day_num) + '/payment/payment(' + str(testmark) + ').txt'))
    serve_patient_num = np.loadtxt(os.path.join(dataset_path, 'day_num=' + str(total_day_num) + '/serve_patient_num/serve_patient_num(' + str(testmark) + ').txt'))

    trainData = np.loadtxt(os.path.join(dataset_path, 'day_num=' + str(total_day_num) + '/train/train(' + str(testmark) + ').txt'))
    testData = np.loadtxt(os.path.join(dataset_path, 'day_num=' + str(total_day_num) + '/test/test(' + str(testmark) + ').txt'))
    
    iter_loss = np.zeros(iteration_num)
    for iter_cnt in range(iteration_num):
        cur_NN_start = 0
        for cur_NN in range(cur_NN_start, total_day_num+1):
            if cur_NN == 4:
                warm_start_val = 30000
            elif cur_NN == 5:
                warm_start_val = 40000
            elif cur_NN == 6:
                warm_start_val = 50000
            elif cur_NN == 7:
                warm_start_val = 60000
            start_time = time.time()
            cur_day_num = total_day_num
            if cur_NN == 0:
                LP_relax_file.reset_day_num()
            else:
                cur_day_num = total_day_num - cur_NN + 1
                LP_relax_file.change_day_num(cur_day_num)
            print(cur_NN)

            penalty_train = np.full(LP_relax_file.x_num*train_set_num, penaltyTerm)
            penalty_test = np.full(LP_relax_file.x_num*test_set_num, penaltyTerm)
            if cur_NN == 0:
                x_train = trainData[:, 1:featureNum+1]
                y_train = trainData[:, featureNum+1]
                x_test = testData[:, 1:featureNum+1]
                y_test = testData[:, featureNum+1]
            else:
                x_train = np.zeros((LP_relax_file.day_work_shift_num*train_set_num, featureNum))
                y_train = np.zeros(LP_relax_file.day_work_shift_num*train_set_num)
                for i in range(train_set_num):
                    k = 0
                    for j in range((cur_NN-1)*(shift_num-1), total_day_num*(shift_num-1)):
                        x_train[i*cur_day_num*(shift_num-1)+k] = trainData[i*total_day_num*(shift_num-1)+j, 1:featureNum+1]
                        y_train[i*cur_day_num*(shift_num-1)+k] = trainData[i*total_day_num*(shift_num-1)+j][featureNum+1]
                        k = k + 1

                x_test = np.zeros((LP_relax_file.day_work_shift_num*test_set_num, featureNum))
                y_test = np.zeros(LP_relax_file.day_work_shift_num*test_set_num)
                for i in range(test_set_num):
                    k = 0
                    for j in range((cur_NN-1)*(shift_num-1), total_day_num*(shift_num-1)):
                        x_test[i*cur_day_num*(shift_num-1)+k] = testData[i*total_day_num*(shift_num-1)+j, 1:featureNum+1]
                        y_test[i*cur_day_num*(shift_num-1)+k] = testData[i*total_day_num*(shift_num-1)+j][featureNum+1]
                        k = k + 1
    #        np.savetxt("x_train.txt", x_train, fmt="%.2f")
    #        time.sleep(100)
            feature_train = torch.from_numpy(x_train).float()
            value_train = torch.from_numpy(y_train).float()
            feature_test = torch.from_numpy(x_test).float()
            value_test = torch.from_numpy(y_test).float()

            if cur_NN > 0:
                train_future_x = np.loadtxt(os.path.join(default_path, 'sequential_prev_schedule/sequential_prev_schedule_iter' + str(iter_cnt) + '_NN' + str(cur_NN-1) + '(' + str(testmark) + ').txt'))
                train_curr_cost = np.loadtxt(os.path.join(default_path, 'sequential_prev_cost/sequential_prev_cost_iter' + str(iter_cnt) + '_NN' + str(cur_NN-1) + '(' + str(testmark) + ').txt'))
                train_curr_penalty = np.loadtxt(os.path.join(default_path, 'sequential_prev_penalty/sequential_prev_penalty_iter' + str(iter_cnt) + '_NN' + str(cur_NN-1) + '(' + str(testmark) + ').txt'))


                # compute the has_rested table
                if cur_NN > 1:
                    train_has_rested = train_has_rested * 0
                    for train_cnt in range(train_set_num):
                        for i in range(nurse_num):
                            for j in range(cur_NN-1):
                                if train_future_x[train_cnt][i*total_day_num*shift_num+j*shift_num+shift_num-1] == 1:
                                    train_has_rested[train_cnt][i] += 1

            damping = 1e-7
            thr = 1e-7
            lr = 1e-3
            #lr = 1e-2
            bestTrainCorrReg = float("inf")
            # for j in range(1):
            max_retrain_time = 10
            start_time = time.time()
            stop_epoch = 0
            while stop_epoch < warm_start_epoch_criterion - 1 and max_retrain_time > 0:
                max_retrain_time = max_retrain_time - 1
        #        cost, penalty, serve_patient_num, n_features, batch_size, cur_NN
                clf = Intopt(cost, penalty_train, serve_patient_num, damping=damping, lr=lr, n_features=featureNum, batch_size=LP_relax_file.day_work_shift_num, cur_NN=cur_NN, thr=thr, epochs=stop_epoch_criterion)
                clf.fit(feature_train, value_train)
                # print(stop_epoch)
                if stop_epoch >= warm_start_epoch_criterion or max_retrain_time <= 0:
                    end_time = time.time()
                    cur_decision = clf.make_future_plan(feature_train, value_train)
                    if recordNow[cur_NN] < recordBest[cur_NN]:
                        recordBest[cur_NN] = recordNow[cur_NN]
                        torch.save(clf.model.state_dict(), 'MS_SMS_ep' + str(extra_payment) + 'NN' + str(cur_NN) + '_model.pkl')
                    if cur_NN > 0 and iter_cnt == 1:
                        recordBest[cur_NN] = recordNow[cur_NN]
                        torch.save(clf.model.state_dict(), 'MS_SMS_ep' + str(extra_payment) + 'NN' + str(cur_NN) + '_model.pkl')

            clfBest = Intopt(cost, penalty_train, serve_patient_num, damping=damping, lr=lr, n_features=featureNum, batch_size=LP_relax_file.day_work_shift_num, cur_NN=cur_NN, thr=thr, epochs=stop_epoch_criterion)
            clfBest.model.load_state_dict(torch.load('MS_SMS_ep' + str(extra_payment) + 'NN' + str(cur_NN) + '_model.pkl'))
            print("Simulation " + str(testmark) + " Training NN " + str(cur_NN) + " time: ", end_time - start_time)
            train_each_NN_time[cur_NN] = end_time - start_time


            value = clfBest.model(feature_test).squeeze()
            value = value.detach().numpy()
            predValue = np.zeros((value.size, 3))

            pred_val = clfBest.get_pred_val(feature_test, value_test)

            for i in range(value.size):
                predValue[i][0] = int(i/day_work_shift_num)
                predValue[i][1] = value_test[i]
                predValue[i][2] = pred_val[i]
            np.savetxt(os.path.join(default_path, 'sequential_T_NN/sequential_iter' + str(iter_cnt) + '_NN' + str(cur_NN) + ', warm_start=' + str(warm_start_epoch_criterion) + '(' + str(testmark) + ').txt'), predValue, fmt="%.2f")

            # store current states
            np.savetxt(os.path.join(default_path, 'sequential_prev_cost/sequential_prev_cost_iter' + str(iter_cnt) + '_NN' + str(cur_NN) + '(' + str(testmark) + ').txt'), train_curr_cost, fmt="%.2f")
            np.savetxt(os.path.join(default_path, 'sequential_prev_penalty/sequential_prev_penalty_iter' + str(iter_cnt) + '_NN' + str(cur_NN) + '(' + str(testmark) + ').txt'), train_curr_penalty, fmt="%.2f")
            np.savetxt(os.path.join(default_path, 'sequential_prev_schedule/sequential_prev_schedule_iter' + str(iter_cnt) + '_NN' + str(cur_NN) + '(' + str(testmark) + ').txt'), train_future_x, fmt="%.0f")

#            cur_NN = total_day_num
            if cur_NN == total_day_num:
                each_iter_time[iter_cnt] = np.max(train_each_NN_time)
                for NN_cnt in range(total_day_num+1):
                    cur_day_num = total_day_num
                    if NN_cnt == 0:
                        LP_relax_file.reset_day_num()
                    else:
                        cur_day_num = total_day_num - NN_cnt + 1
                        LP_relax_file.change_day_num(cur_day_num)
                        
                    # get true para and pred para
#                    if NN_cnt > 0:
                    test_patient_full = np.loadtxt(os.path.join(default_path, 'sequential_T_NN/sequential_iter' + str(iter_cnt) + '_NN' + str(NN_cnt) + ', warm_start=' + str(warm_start_epoch_criterion) + '(' + str(testmark) + ').txt'))
#                    else:
#                        test_patient_full = np.loadtxt(os.path.join(default_path, 'sequential_T_NN/sequential_iter0_NN0, warm_start=' + str(warm_start_epoch_criterion) + 'train(' + str(testmark) + ').txt'))
                    
                #            print(pred_price_full.shape, pred_weight_full.shape)
                    true_patient_full = test_patient_full[:, 1]
                    pred_patient_full = test_patient_full[:, 2]

                    # Compute the NN_cnt plans
                    for test_num in range(test_set_num):
                        cnt = test_num * LP_relax_file.day_work_shift_num
                        real_patient = np.zeros(LP_relax_file.day_work_shift_num)
                        pre_patient = np.zeros(LP_relax_file.day_work_shift_num)
                        penalty = np.full(LP_relax_file.x_num, penaltyTerm)
                        for i in range(LP_relax_file.day_work_shift_num):
                            real_patient[i] = true_patient_full[cnt]
                            pre_patient[i] = pred_patient_full[cnt]
                            cnt = cnt + 1
                        make_next_plan(test_num, NN_cnt, real_patient, pre_patient, cost, serve_patient_num, penalty)


                test_obj = test_curr_cost + test_curr_penalty
                #            np.savetxt('test_curr_cost.txt', test_curr_cost, fmt="%.2f")
                #            np.savetxt('test_curr_penalty.txt', test_curr_penalty, fmt="%.2f")
                LP_relax_file.reset_day_num()

                test_patient_full = np.loadtxt(os.path.join(default_path, 'sequential_T_NN/sequential_iter0_NN0, warm_start=' + str(warm_start_epoch_criterion) + '(' + str(testmark) + ').txt'))
                #            print(pred_price_full.shape, pred_weight_full.shape)
                true_patient_full = test_patient_full[:, 1]
                c = LP_relax_file.gen_obj(0, cost)
                A,b,G,h2,h3,h4,h5 = LP_relax_file.gen_matrix(nurse_num,day_num,shift_num,serve_patient_num,decision_num,day_shift_num)
                true_obj = LP_relax_file.actual_obj(c, A, b, G, true_patient_full, h2, h3, h4, h5, test_set_num)
                PReg = test_obj - true_obj

                print("Simulation ", testmark, " Iteration ", iter_cnt, end=" ")
                print("Test: TOV: ", np.sum(true_obj)/test_set_num, "EOV: ", np.sum(test_obj)/test_set_num, "PReg: ", np.sum(PReg)/test_set_num, end=" ")
#                print("Train: TOV: ", np.sum(train_TOV)/train_set_num, "EOV: ", np.sum(train_EOV)/train_set_num, "PReg: ", np.sum(PReg)/train_set_num, end=" ")
                print("Total training time: ", np.sum(each_iter_time))
                iter_loss[iter_cnt] = np.sum(PReg)/test_set_num
    #            print("Training time: ", end_time - start_time)

                # reset
                train_future_x = train_future_x * 0
                train_has_rested = train_has_rested * 0
                train_curr_cost = train_curr_cost * 0
                train_curr_penalty = train_curr_penalty * 0
                
                test_future_x = test_future_x * 0
                test_has_rested = test_has_rested * 0
                test_curr_cost = test_curr_cost * 0
                test_curr_penalty = test_curr_penalty * 0


#        early_quit = False
        if iter_cnt > 1 and abs(iter_loss[iter_cnt] - iter_loss[iter_cnt-1]) < 0.1:
#            early_quit = True
            break
        if iter_cnt > 1 and iter_loss[iter_cnt] >= iter_loss[iter_cnt-1] and iter_loss[iter_cnt] >= iter_loss[iter_cnt-2]:
#            early_quit = True
            break
    print(recordNow, recordBest)
