import os
import sys
import linear_relax as LP_relax_file
import ip_model_whole as ip_model_whole_file
from ip_model_whole import IPOfunc
import numpy as np
import random
import pandas as pd
import math, time
import itertools
from sklearn import preprocessing
from sklearn.preprocessing import MinMaxScaler
import datetime
import torch
from torch import nn, optim
from torch.autograd import Variable
import torch.utils.data as data_utils
from torch.utils.data.dataset import Dataset
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
import gurobipy as gp
import logging
import copy
from collections import defaultdict
import joblib
import gurobipy as gp
from gurobipy import GRB

total_month_num = LP_relax_file.total_month_num
month_num = LP_relax_file.month_num
x_num = LP_relax_file.x_num
y_num = LP_relax_file.y_num
var_num = LP_relax_file.var_num

featureNum = 4096
train_set_size = 70
test_set_size = 30
target_num = 1
warm_start_stop_criterion = 10
stop_epoch_criterion = 20
log_regularizer = 1e-8
warm_start_value = 400
iteration_num = 1

small_or_large = int(sys.argv[1])
startmark = int(sys.argv[2])
endmark = int(sys.argv[3])

dataset_path = os.path.abspath(os.path.dirname(os.getcwd()))
default_path = os.path.join(dataset_path, 'data/month_num=' + str(month_num) + '/')

if small_or_large == 0:
  store_file_path = os.path.join(dataset_path, 'data/month_num=' + str(month_num) + '/small/')
elif small_or_large == 1:
  store_file_path = os.path.join(dataset_path, 'data/month_num=' + str(month_num) + '/large/')

LP_relax_file.mkdir(store_file_path, 'parallel_T_NN (warm_start=' + str(warm_start_stop_criterion) + ')')
LP_relax_file.mkdir(store_file_path, 'prev_profit')
LP_relax_file.mkdir(store_file_path, 'prev_stocking')

global train_TOV_np
train_TOV_np = np.zeros(train_set_size)

train_curr_profit = np.zeros(train_set_size)
train_curr_stocking = np.zeros(train_set_size)

test_curr_profit = np.zeros(test_set_size)
test_curr_stocking = np.zeros(test_set_size)
    
def demand_init(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.demand)
        nn.init.constant_(m.bias, 0)

    elif isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.demand, mode='fan_out', nonlinearity='relu')

    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.demand, 1)
        nn.init.constant_(m.bias, 0)
        
def make_fc(num_layers, num_features, num_targets=target_num,
            activation_fn = nn.ReLU,intermediate_size=512, regularizers = True):
    net_layers = [nn.Linear(num_features, intermediate_size),activation_fn()]
    for hidden in range(num_layers-2):
        net_layers.append(nn.Linear(intermediate_size, intermediate_size))
        net_layers.append(activation_fn())
    net_layers.append(nn.Linear(intermediate_size, num_targets))
    net_layers.append(activation_fn())
    return nn.Sequential(*net_layers)


def make_next_plan(test_num, cur_NN, true_demand, pred_demand, price, cost):
    global test_curr_profit
    global test_curr_stocking
        
    if cur_NN == 0:
        init_x, init_y = LP_relax_file.get_init_plan(price, cost, pred_demand, true_demand)
        test_curr_profit[test_num] = - cost[0] * init_x[0]
        test_curr_stocking[test_num] = init_x[0]
#        print(cur_NN, init_x[0])
        
    else:
#        demand = np.concatenate([true_demand[0], pred_demand[1:]], axis=0)
        demand = np.zeros(LP_relax_file.month_num)
        demand[0] = true_demand[0]
#        print(cur_NN, true_demand)
        for i in range(1, LP_relax_file.month_num):
            demand[i] = pred_demand[i]
        G_t, h_t = LP_relax_file.gen_constraints_latter_days(cur_NN, demand, test_curr_stocking[test_num])
        c_t = LP_relax_file.gen_obj_latter_days(cur_NN, price, cost)
        t_updated_x, t_updated_y = LP_relax_file.get_updated_plan_for_each_day(cur_NN, c_t, G_t, h_t)
        
        # compute current states
        new_profit = price[cur_NN] * t_updated_y[0] - cost[cur_NN] * t_updated_x[0]
        test_curr_profit[test_num] += new_profit
        new_stocking = t_updated_x[0] - t_updated_y[0]
        test_curr_stocking[test_num] += new_stocking
#        print(cur_NN, t_updated_x[0], t_updated_y[0])
    
    test_curr_profit[test_num] = round(test_curr_profit[test_num], 2)
    test_curr_stocking[test_num] = round(test_curr_stocking[test_num], 2)



class MyCustomDataset():
    def __init__(self, feature, value):
        self.feature = feature
        self.value = value

    def __len__(self):
        return len(self.value)

    def __getitem__(self, idx):
        return self.feature[idx], self.value[idx]


class Intopt:
    def __init__(self, price, cost, n_features, batch_size, cur_NN, num_layers=5, smoothing=False, thr=0.1, max_iter=None, method=1, mu0=None, damping=0.5, target_size=target_num, epochs=8, optimizer=optim.Adam, **hyperparams):
        
        self.price = price
        self.cost = cost
        self.target_size = target_size
        self.n_features = n_features
        self.damping = damping
        self.num_layers = num_layers
        self.cur_NN = cur_NN

        self.smoothing = smoothing
        self.thr = thr
        self.max_iter = max_iter
        self.method = method
        self.mu0 = mu0

        self.optimizer = optimizer
        self.batch_size = batch_size
        self.hyperparams = hyperparams
        self.epochs = epochs
        # print("embedding size {} n_features {}".format(embedding_size, n_features))

#        self.model = Net(n_features=n_features, target_size=target_size)
        self.model = make_fc(num_layers=self.num_layers,num_features=n_features)
        #self.model.apply(demand_init)
#        w1 = self.model[0].demand
#        print(w1)

        self.optimizer = optimizer(self.model.parameters(), **hyperparams)

    def fit(self, feature, value):
        logging.info("Intopt")
        train_df = MyCustomDataset(feature, value)

        criterion = nn.L1Loss(reduction='mean')  # nn.L1Loss(reduction='mean')
        grad_list = np.zeros(self.epochs)
        IP_grad_list = np.zeros(self.epochs)
        for i in range(self.epochs):
                IP_grad_list[i] = float("inf")
        for e in range(self.epochs):
            total_loss = 0
#          for parameters in self.model.parameters():
#            print(parameters)
            if e < warm_start_stop_criterion:
            #print('stage 1')
                train_dl = data_utils.DataLoader(train_df, batch_size=self.batch_size, shuffle=False)
                for feature, value in train_dl:
                    self.optimizer.zero_grad()
                    op = self.model(feature).squeeze()

                    loss = criterion(op, value)
                    total_loss += loss.item()
                    loss.backward()
                    self.optimizer.step()

                grad_list[e] = total_loss
                global stop_epoch
                stop_epoch = e
#                print("Epoch{} ::loss {} ->".format(e,total_loss))
                if e < warm_start_stop_criterion - 1:
                  print("{} ->".format(total_loss), end=" ")
                else:
                  print("{} ->".format(total_loss))
                
                global warm_start_value
                if e == warm_start_stop_criterion - 1 and grad_list[e] < warm_start_value:
                    self.model.eval()
                    criterion = nn.L1Loss(reduction='mean')  # nn.MSELoss(reduction='sum')
                    valid_df = MyCustomDataset(feature, value)
                    valid_dl = data_utils.DataLoader(train_df, batch_size=self.batch_size, shuffle=False)
                    corr_obj_list = []
#                    global train_TOV_np
                    num = 0
                    for feature, value in valid_dl:
                        op = self.model(feature).squeeze()
                        # print(op)
                        loss = criterion(op, value)

                        true_demand = np.zeros(LP_relax_file.month_num)
                        pred_demand = np.zeros(LP_relax_file.month_num)
                        for i in range(LP_relax_file.month_num):
                            true_demand[i] = value[i]
                            pred_demand[i] = op[i]
                        
                        if self.cur_NN == 0:
#                            true_obj = LP_relax_file.actual_obj(self.price, self.cost, true_demand, n_instance=1)
#                            train_TOV_np[num] = true_obj
                            corrrlst = LP_relax_file.correction_single_obj(self.price, self.cost, pred_demand, true_demand)
                        else:
                            corrrlst = LP_relax_file.correction_single_for_latter_days(self.cur_NN, self.price, self.cost, pred_demand, true_demand, train_curr_profit[num], train_curr_stocking[num])
                            
                        corr_obj_list.append(corrrlst)
                        num = num + 1
                    
                    true_obj_np = train_TOV_np
                    corr_obj_np = np.array(corr_obj_list)
                    IP_grad_list[e] = np.mean(corr_obj_np)
#                    print(num)
#                    np.savetxt('TOV.txt', true_obj_np, fmt="%.2f")
#                    np.savetxt('EOV.txt', corr_obj_np, fmt="%.2f")
                    print("TOV: ", np.mean(true_obj_np), "EOV: ", np.mean(corr_obj_np), "warm_start_PReg: ", np.mean(true_obj_np - corr_obj_np))
            
            else:
                if e == warm_start_stop_criterion:
                    lr = 1e-7
                    for param_group in self.optimizer.param_groups:
                        param_group['lr'] = lr
                # print(lr)
                train_dl = data_utils.DataLoader(train_df, batch_size=self.batch_size, shuffle=False)
                
                num = 0
                batchCnt = 0
                loss = Variable(torch.tensor(0.0, dtype=torch.double), requires_grad=True)
                for feature, value in train_dl:
                    self.optimizer.zero_grad()
                    op = self.model(feature).squeeze()
                    while torch.min(op) < 0 or torch.isnan(op).any() or torch.isinf(op).any():
                        self.optimizer.zero_grad()
    #                    self.model.__init__(self.n_features, self.target_size)
                        self.model = make_fc(num_layers=self.num_layers,num_features=self.n_features)
                        op = self.model(feature).squeeze()

#                    true_demand = value
                    
                    sol_cur = IPOfunc(cur_NN=self.cur_NN, price=self.price, cost=self.cost, true_demand=value, curr_profit=train_curr_profit[batchCnt], curr_stocking=train_curr_stocking[batchCnt], max_iter=self.max_iter, thr=self.thr, damping=self.damping,
                            smoothing=self.smoothing)(op)
                    
                    x_sol_cur = sol_cur[:LP_relax_file.x_num]
                    y_sol_cur = sol_cur[LP_relax_file.x_num:]
                    price_torch = torch.from_numpy(self.price).float()
                    cost_torch = torch.from_numpy(self.cost).float()
                    
                    if self.cur_NN == 0:
                        newLoss = - ((price_torch * y_sol_cur).sum() - (cost_torch * x_sol_cur).sum())
                    else:
                        newLoss = - (train_curr_profit[batchCnt] + (price_torch[self.cur_NN:] * y_sol_cur).sum() - (cost_torch[self.cur_NN:] * x_sol_cur).sum())
                    
                    EOV_IP_value = newLoss.item()
                    total_loss += EOV_IP_value
                    newLoss.backward()
                    self.optimizer.step()
                    

                    batchCnt = batchCnt + 1
                    num = num + 1

                grad_list[e] = total_loss / num
                stop_epoch = e

                # compute IP_grad
                valid_dl = data_utils.DataLoader(train_df, batch_size=self.batch_size, shuffle=False)
                corr_obj_list = []
                num = 0
                for feature, value in valid_dl:
                    op = self.model(feature).squeeze()
                    # print(op)
                    loss = criterion(op, value)

                    true_demand = np.zeros(LP_relax_file.month_num)
                    pred_demand = np.zeros(LP_relax_file.month_num)
                    for i in range(LP_relax_file.month_num):
                        true_demand[i] = value[i]
                        pred_demand[i] = op[i]
                    
                    if self.cur_NN == 0:
                        corrrlst = LP_relax_file.correction_single_obj(self.price, self.cost, pred_demand, true_demand)
                    else:
                        corrrlst = LP_relax_file.correction_single_for_latter_days(self.cur_NN, self.price, self.cost, pred_demand, true_demand, train_curr_profit[num], train_curr_stocking[num])
                        
                    corr_obj_list.append(corrrlst)
                    num = num + 1

                true_obj_np = train_TOV_np
                corr_obj_np = np.array(corr_obj_list)
                IP_grad_list[e] = np.mean(corr_obj_np)
                recordNow[self.cur_NN] = np.mean(corr_obj_np)

            logging.info("EPOCH Ends")
            if e >= warm_start_stop_criterion:
                print("Epoch{} ::EOV {} ->".format(e,IP_grad_list[e]))
                
            if grad_list[6] > warm_start_value:
                break
            if e >= 1 and abs(grad_list[e] - grad_list[e-1]) <= 0.001:
                break
            if e >= warm_start_stop_criterion and abs(IP_grad_list[e] - IP_grad_list[e-1]) <= 0.001:
                break
            if e >= warm_start_stop_criterion and abs(IP_grad_list[e]) < abs(IP_grad_list[e-1]):
                break

            
    def val_loss(self, feature, value):
        valueTemp = value.numpy()
#        test_instance = len(valueTemp) / self.batch_size
        instance_num = np.size(valueTemp, 0) / self.batch_size
#        print(valueTemp.shape, instance_num)
        true_demand_total = valueTemp
#        print(true_price.shape, true_weight.shape)
        true_obj = LP_relax_file.actual_obj(self.price, self.cost, true_demand_total, n_instance=int(instance_num))
#        print(np.sum(real_obj))

        self.model.eval()
        criterion = nn.L1Loss(reduction='mean')  # nn.MSELoss(reduction='sum')
        valid_df = MyCustomDataset(feature, value)
        valid_dl = data_utils.DataLoader(valid_df, batch_size=self.batch_size, shuffle=False)

        corr_obj_list = []
        len = np.size(valueTemp)
        predVal = torch.zeros(len)
        
        num = 0
        for feature, value in valid_dl:
            op = self.model(feature).squeeze()
#            print(op)
            loss = criterion(op, value)

            true_demand = np.zeros(month_num)
            pred_demand = np.zeros(month_num)
            for i in range(month_num):
                true_demand[i] = value[i]
                pred_demand[i] = op[i]
                predVal[i+num*month_num] = op[i]

            corrrlst = LP_relax_file.correction_single_obj(self.price, self.cost, pred_demand, true_demand)
            corr_obj_list.append(corrrlst)
            num = num + 1
            

        self.model.train()
        print("TOV: ", sum(true_obj)/instance_num, "EOV: ", sum(corr_obj_list)/instance_num, "PReg: ", sum(abs(true_obj - np.array(corr_obj_list)))/instance_num)
#        print(corr_obj_list)
#        print(corr_obj_list-real_obj)
#        print(np.sum(corr_obj_list))
#        return prediction_loss, abs(np.array(obj_list) - real_obj)
        return abs(np.array(corr_obj_list) - true_obj), predVal


    def get_pred_val(self, feature, value):
        valueTemp = value.numpy()
        self.model.eval()
        criterion = nn.L1Loss(reduction='mean')  # nn.MSELoss(reduction='sum')
        valid_df = MyCustomDataset(feature, value)
        valid_dl = data_utils.DataLoader(valid_df, batch_size=self.batch_size, shuffle=False)

        len = np.size(valueTemp)
        predVal = torch.zeros(len)
        
        num = 0
        for feature, value in valid_dl:
            op = self.model(feature).squeeze()
            loss = criterion(op, value)
            
            for i in range(LP_relax_file.month_num):
                predVal[i+num*LP_relax_file.month_num] = op[i]

            num = num + 1
#        print(predVal)
        return predVal


    def make_future_plan(self, feature, value):
        valueTemp = value.numpy()
        self.model.eval()
        criterion = nn.L1Loss(reduction='mean')  # nn.MSELoss(reduction='sum')
        valid_df = MyCustomDataset(feature, value)
        valid_dl = data_utils.DataLoader(valid_df, batch_size=self.batch_size, shuffle=False)

        len = np.size(valueTemp)
        predVal = torch.zeros(len)
        
        global train_curr_profit
        global train_curr_stocking
        
        stage_profit = np.zeros(train_set_size)
        stage_stocking = np.zeros(train_set_size)
        num = 0
        for feature, value in valid_dl:
            op = self.model(feature).squeeze()
            loss = criterion(op, value)

            true_demand = np.zeros(LP_relax_file.month_num)
            pred_demand = np.zeros(LP_relax_file.month_num)
            for i in range(LP_relax_file.month_num):
                true_demand[i] = value[i]
                pred_demand[i] = op[i]
                predVal[num*LP_relax_file.month_num+i] = op[i]
#            print(pred_demand)
            
            if self.cur_NN == 0:
                init_x, init_y = LP_relax_file.get_init_plan(self.price, self.cost, pred_demand, true_demand)
                train_curr_profit[num] = - self.cost[0] * init_x[0]
                train_curr_stocking[num] = init_x[0]
                stage_profit[num] = - self.cost[0] * init_x[0]
                stage_stocking[num] = init_x[0]
                
            else:
#                demand = np.concatenate([true_demand[0], pred_demand[1:]], axis=0)
                demand = np.zeros(LP_relax_file.month_num)
                demand[0] = true_demand[0]
                for i in range(1, LP_relax_file.month_num):
                    demand[i] = pred_demand[i]
                G_t, h_t = LP_relax_file.gen_constraints_latter_days(self.cur_NN, demand, train_curr_stocking[num])
                c_t = LP_relax_file.gen_obj_latter_days(self.cur_NN, self.price, self.cost)
                t_updated_x, t_updated_y = LP_relax_file.get_updated_plan_for_each_day(self.cur_NN, c_t, G_t, h_t)
                
                # compute current states
                new_profit = self.price[self.cur_NN] * t_updated_y[0] - cost[self.cur_NN] * t_updated_x[0]
                train_curr_profit[num] += new_profit
                new_stocking = t_updated_x[0] - t_updated_y[0]
                train_curr_stocking[num] += new_stocking
                stage_profit[num] = new_profit
                stage_stocking[num] = new_stocking
            
            train_curr_profit[num] = round(train_curr_profit[num], 2)
            train_curr_stocking[num] = round(train_curr_stocking[num], 2)
            
            num = num + 1
            
        return predVal


print("*** PCD ****")

simulation_time = 30
recordBest = np.zeros(total_month_num-1)
recordNow = np.zeros(total_month_num-1)
train_each_NN_time = np.zeros(total_month_num-1)
each_iter_time = np.zeros(iteration_num)

if small_or_large == 0:
  print("small price,", end=' ')
elif small_or_large == 1:
  print("large price,", end=' ')
print("month_num: ", month_num, "warm_start_stop_criterion: ", warm_start_stop_criterion, "stop_epoch_criterion: ", stop_epoch_criterion)


for testi in range(startmark, endmark):
    print("-------------------------------------------------------------")
    print("Simulation ", testi)
    cost = np.loadtxt(os.path.join(dataset_path, 'data/month_num='+str(total_month_num)+'/cost/cost(' + str(testi) + ').txt'))
    if small_or_large == 0:
      price = np.loadtxt(os.path.join(dataset_path, 'data/month_num='+str(total_month_num)+'/small_price/price(' + str(testi) + ').txt'))
    elif small_or_large == 1:
      price = np.loadtxt(os.path.join(dataset_path, 'data/month_num='+str(total_month_num)+'/large_price/price(' + str(testi) + ').txt'))

    x_train_full = np.loadtxt(os.path.join(default_path, 'train_features/train_features(' + str(testi) + ').txt'))
    y_train_full = np.loadtxt(os.path.join(default_path, 'train_demands/train_demands(' + str(testi) + ').txt'))
    train_TOV_prev_profit = np.loadtxt(os.path.join(store_file_path, 'true_profit/true_profit(' + str(testi) + ').txt'))
    train_TOV_prev_stocking = np.loadtxt(os.path.join(store_file_path, 'true_stocking/true_stocking(' + str(testi) + ').txt'))
    
    x_test_full = np.loadtxt(os.path.join(default_path, 'test_features/test_features(' + str(testi) + ').txt'))
    y_test_full = np.loadtxt(os.path.join(default_path, 'test_demands/test_demands(' + str(testi) + ').txt'))
    
    train_TOV_np = LP_relax_file.actual_obj(price, cost, y_train_full, train_set_size)
    
    iter_loss = np.zeros(iteration_num)
    for iter_cnt in range(iteration_num):
        cur_NN_start = 0
        for cur_NN in range(cur_NN_start, total_month_num-1):
            cur_month_num = total_month_num
            if cur_NN == 0:
                LP_relax_file.reset_month_num()
            else:
                cur_month_num = total_month_num - cur_NN
                LP_relax_file.change_month_num(cur_month_num)
            print(cur_NN)
            
            if cur_NN == 0:
                x_train = x_train_full
                y_train = y_train_full
                x_test = x_test_full
                y_test = y_test_full
                feature_train = torch.from_numpy(x_train_full).float()
                value_train = torch.from_numpy(y_train_full).float()
                feature_test = torch.from_numpy(x_test_full).float()
                value_test = torch.from_numpy(y_test_full).float()

            else:
                x_train = np.zeros((cur_month_num*train_set_size, featureNum))
                y_train = np.zeros(cur_month_num*train_set_size)
                for i in range(train_set_size):
                    k = 0
                    for j in range(cur_NN, total_month_num):
                        x_train[i*cur_month_num+k] = x_train_full[i*total_month_num+j]
                        y_train[i*cur_month_num+k] = y_train_full[i*total_month_num+j]
                        k = k + 1
                feature_train = torch.from_numpy(x_train).float()
                value_train = torch.from_numpy(y_train).float()

                x_test = np.zeros((cur_month_num*test_set_size, featureNum))
                y_test = np.zeros(cur_month_num*test_set_size)
                for i in range(test_set_size):
                    k = 0
                    for j in range(cur_NN, total_month_num):
                        x_test[i*cur_month_num+k] = x_test_full[i*total_month_num+j]
                        y_test[i*cur_month_num+k] = y_test_full[i*total_month_num+j]
                        k = k + 1
                feature_test = torch.from_numpy(x_test).float()
                value_test = torch.from_numpy(y_test).float()
            
            if cur_NN > 0:
                if iter_cnt == 0:
                    # get current states
                    train_curr_profit = train_TOV_prev_profit[:, cur_NN-1]
                    train_curr_stocking = train_TOV_prev_stocking[:, cur_NN-1]

                else:
                    train_curr_profit = np.loadtxt(os.path.join(store_file_path, 'prev_profit/prev_profit_iter' + str(iter_cnt-1) + '_NN' + str(cur_NN-1) + '(' + str(testi) + ').txt'))
                    train_curr_stocking = np.loadtxt(os.path.join(store_file_path, 'prev_stocking/prev_stocking_iter' + str(iter_cnt-1) + '_NN' + str(cur_NN-1) + '(' + str(testi) + ').txt'))


            damping = 1e-2
            thr = 1e-3
            lr = 1e-5
            stop_epoch = 0
            start_time = time.time()
            bestTrainCorrReg = float("inf")
            max_retrain_time = 10
            while stop_epoch < warm_start_stop_criterion and max_retrain_time > 0:
                max_retrain_time = max_retrain_time - 1
                clf = Intopt(price, cost, batch_size=LP_relax_file.month_num, cur_NN=cur_NN, damping=damping, lr=lr, n_features=featureNum, thr=thr, epochs=stop_epoch_criterion)
                clf.fit(feature_train, value_train)

                if stop_epoch >= warm_start_stop_criterion or max_retrain_time <= 0:
                    end_time = time.time()
                    predTestVal = clf.make_future_plan(feature_train, value_train)
                    if recordNow[cur_NN] > recordBest[cur_NN]:
                        recordBest[cur_NN] = recordNow[cur_NN]
                        torch.save(clf.model.state_dict(), 'MS_'+str(small_or_large)+'_'+str(total_month_num)+'month_NN' + str(cur_NN) + '_model.pkl')
                    if cur_NN > 0 and iter_cnt == 1:
                        recordBest[cur_NN] = recordNow[cur_NN]
                        torch.save(clf.model.state_dict(), 'MS_'+str(small_or_large)+'_'+str(total_month_num)+'month_NN' + str(cur_NN) + '_model.pkl')


            clfBest = Intopt(price, cost, batch_size=LP_relax_file.month_num, cur_NN=cur_NN, damping=damping, lr=lr, n_features=featureNum, thr=thr, epochs=stop_epoch_criterion)
            clfBest.model.load_state_dict(torch.load('MS_'+str(small_or_large)+'_'+str(total_month_num)+'month_NN' + str(cur_NN) + '_model.pkl'))
            print("Simulation " + str(testi) + " Training NN " + str(cur_NN) + " time: ", end_time - start_time)
            train_each_NN_time[cur_NN] = end_time - start_time

    #        val_rslt, predTestVal = clfBest.val_loss(feature_test, value_test)
            predTestVal = clfBest.get_pred_val(feature_test, value_test)

            predTestVal = predTestVal.detach().numpy()
            predTestDemand = np.zeros((predTestVal.size, 2))
            for i in range(predTestVal.size):
        #        predValue[i][0] = int(i/itemNum)
                predTestDemand[i][0] = y_test[i]
                predTestDemand[i][1] = predTestVal[i]
            np.savetxt(os.path.join(store_file_path,'parallel_T_NN (warm_start=' + str(warm_start_stop_criterion) + ')/parallel_iter' + str(iter_cnt) + '_NN' + str(cur_NN) + '(' + str(testi) + ').txt'), predTestDemand, fmt="%.2f")
            
            # store current states
            np.savetxt(os.path.join(store_file_path, 'prev_profit/prev_profit_iter' + str(iter_cnt) + '_NN' + str(cur_NN) + '(' + str(testi) + ').txt'), train_curr_profit, fmt="%.2f")
            np.savetxt(os.path.join(store_file_path, 'prev_stocking/prev_stocking_iter' + str(iter_cnt) + '_NN' + str(cur_NN) + '(' + str(testi) + ').txt'), train_curr_stocking, fmt="%.2f")
            
            
            if cur_NN == total_month_num - 2:
                each_iter_time[iter_cnt] = np.max(train_each_NN_time)
                for NN_cnt in range(total_month_num):
                    cur_month_num = total_month_num
                    if NN_cnt == 0:
                        LP_relax_file.reset_month_num()
                    else:
                        cur_month_num = total_month_num - NN_cnt
                        LP_relax_file.change_month_num(cur_month_num)
                    
                    # get true para and pred para
                    if NN_cnt < total_month_num - 1:
                        test_demand_full = np.loadtxt(os.path.join(store_file_path,'parallel_T_NN (warm_start=' + str(warm_start_stop_criterion) + ')/parallel_iter' + str(iter_cnt) + '_NN' + str(NN_cnt) + '(' + str(testi) + ').txt'))
                        true_demand_full = test_demand_full[:, 0]
                        pred_demand_full = test_demand_full[:, 1]
                    else:
                        test_demand_full = np.loadtxt(os.path.join(store_file_path,'parallel_T_NN (warm_start=' + str(warm_start_stop_criterion) + ')/parallel_iter' + str(iter_cnt) + '_NN' + str(NN_cnt-1) + '(' + str(testi) + ').txt'))
                        true_demand_full = np.zeros(test_set_size)
                        pred_demand_full = np.zeros(test_set_size)
                        cnt = 0
                        for i in range(test_set_size):
                            for j in range(2):
    #                            print(j, i*2+j, test_demand_full[i*2+j][0])
                                if j == 1:
                                    true_demand_full[cnt] = test_demand_full[i*2+j][0]
                                    pred_demand_full[cnt] = test_demand_full[i*2+j][1]
                                    cnt += 1
    #                    print(true_demand_full)
                    
                    # Compute the NN_cnt plans
                    for test_num in range(test_set_size):
                        if NN_cnt < total_month_num - 1:
                            cnt = test_num * LP_relax_file.month_num
                            true_demand = np.zeros(LP_relax_file.month_num)
                            pred_demand = np.zeros(LP_relax_file.month_num)
                            for i in range(LP_relax_file.month_num):
                                true_demand[i] = true_demand_full[cnt]
                                pred_demand[i] = pred_demand_full[cnt]
                                cnt = cnt + 1
                        else:
                            cnt = test_num
                            true_demand = np.zeros(LP_relax_file.month_num)
                            pred_demand = np.zeros(LP_relax_file.month_num)
                            true_demand[0] = true_demand_full[cnt]
                            pred_demand[0] = pred_demand_full[cnt]
                                
    #                    print(NN_cnt, true_demand)
                        make_next_plan(test_num, NN_cnt, true_demand, pred_demand, price, cost)
    #                print(test_curr_profit)
                
                test_obj = test_curr_profit
                LP_relax_file.reset_month_num()
                true_obj = LP_relax_file.actual_obj(price, cost, y_test_full, test_set_size)
                PReg = true_obj - test_obj
                
                print("Simulation ", testi, " Iteration ", iter_cnt, end=" ")
                print("Test: TOV: ", np.sum(true_obj)/test_set_size, "EOV: ", np.sum(test_obj)/test_set_size, "PReg: ", np.sum(PReg)/test_set_size)
                print("Total training time: ", np.sum(each_iter_time))
                iter_loss[iter_cnt] = np.sum(PReg)/test_set_size
                
                # reset
                train_curr_profit = train_curr_profit * 0
                train_curr_stocking = train_curr_stocking * 0

                test_curr_profit = test_curr_profit * 0
                test_curr_stocking = test_curr_stocking * 0
                
        if iter_cnt > 1 and abs(iter_loss[iter_cnt] - iter_loss[iter_cnt-1]) < 0.1:
            break
        if iter_cnt > 1 and iter_loss[iter_cnt] >= iter_loss[iter_cnt-1] and iter_loss[iter_cnt] >= iter_loss[iter_cnt-2]:
            break
            
    print(recordNow, recordBest)
