import os
import sys
import linear_relax as LP_relax_file
import ip_model_whole as ip_model_whole_file
from ip_model_whole import IPOfunc
import numpy as np
import random
import pandas as pd
import math, time
import itertools
from sklearn import preprocessing
from sklearn.preprocessing import MinMaxScaler
import datetime
import torch
from torch import nn, optim
from torch.autograd import Variable
import torch.utils.data as data_utils
from torch.utils.data.dataset import Dataset
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
import gurobipy as gp
import logging
import copy
from collections import defaultdict
import joblib
import gurobipy as gp
from gurobipy import GRB

total_month_num = LP_relax_file.month_num
item_num = LP_relax_file.item_num
featureNum = 4096
train_set_size = 70
test_set_size = 30
target_num = 2
warm_start_epoch = 15
stop_epoch_criterion = 20
log_regularizer = 1e-8
warm_start_value = 300
iteration_num = 1

cap = int(sys.argv[1])
trans_fee_percent = float(sys.argv[2])
LP_relax_file.set_capacity(cap)
LP_relax_file.set_trans_fee_percent(trans_fee_percent)
startmark = int(sys.argv[3])
endmark = int(sys.argv[4])

dataset_path = os.path.abspath(os.path.dirname(os.getcwd()))
default_path = os.path.join(dataset_path, 'data/trans_fee=' + str(trans_fee_percent) +'/v' + str(LP_relax_file.version_num) + '(item_num=' + str(item_num) + ',month_num=' + str(total_month_num) + ',trans_fee=' + str(trans_fee_percent) + ',cap=' + str(LP_relax_file.capacity) + ')/')

LP_relax_file.mkdir(default_path, 'parallel_T_NN')
LP_relax_file.mkdir(default_path, 'prev_cost')
LP_relax_file.mkdir(default_path, 'prev_prof')
LP_relax_file.mkdir(default_path, 'prev_owned_prod')
LP_relax_file.mkdir(default_path, 'prev_y')
LP_relax_file.mkdir(default_path, 'prev_z')

train_future_x = np.zeros((train_set_size, LP_relax_file.x_num))
train_future_y = np.zeros((train_set_size, LP_relax_file.y_num))
train_future_z = np.zeros((train_set_size, LP_relax_file.z_num))
prev_prof = np.zeros(train_set_size)
prev_cost = np.zeros(train_set_size)

test_future_x = np.zeros((test_set_size, LP_relax_file.x_num))
test_future_y = np.zeros((test_set_size, LP_relax_file.y_num))
test_future_z = np.zeros((test_set_size, LP_relax_file.z_num))
test_prev_prof = np.zeros(test_set_size)
test_prev_cost = np.zeros(test_set_size)
    
def weight_init(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight)
        nn.init.constant_(m.bias, 0)

    elif isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)
        
def make_fc(num_layers, num_features, num_targets=target_num,
            activation_fn = nn.ReLU,intermediate_size=512, regularizers = True):
    net_layers = [nn.Linear(num_features, intermediate_size),activation_fn()]
    for hidden in range(num_layers-2):
        net_layers.append(nn.Linear(intermediate_size, intermediate_size))
        net_layers.append(activation_fn())
    net_layers.append(nn.Linear(intermediate_size, num_targets))
    net_layers.append(activation_fn())
    return nn.Sequential(*net_layers)


def make_next_plan(test_num, cur_NN, pred_price, pred_weight, true_price, true_weight):
    global test_future_x
    global test_future_y
    global test_future_z
        
    if cur_NN == 0:
        A_0, b_0, G_0, h_0 = LP_relax_file.gen_constraints(0)
        x_sol, y_sol, z_sol = LP_relax_file.get_init_plan(pred_price, pred_weight, true_weight, A_0, b_0, G_0, h_0)
        for i in range(item_num):
            test_future_x[test_num][i] = x_sol[i]
            
    else:
        x_prev_sol = test_future_x[test_num][(cur_NN-1)*item_num:cur_NN*item_num]
        A_0, b_0, G_0, h_0 = LP_relax_file.gen_constraints_latter_months(LP_relax_file.month_num, x_prev_sol)
        x_sol, y_sol, z_sol = LP_relax_file.get_updated_plan_for_each_month(LP_relax_file.month_num, test_prev_prof[test_num], test_prev_cost[test_num], pred_price, pred_weight, true_price, true_weight, A_0, b_0, G_0, h_0)
#                print(prev_cost[num])
        for i in range(item_num):
            test_future_x[test_num][cur_NN*item_num+i] = x_sol[i]
            test_future_y[test_num][(cur_NN-1)*item_num+i] = y_sol[i]
            test_future_z[test_num][(cur_NN-1)*item_num+i] = z_sol[i]



class MyCustomDataset():
    def __init__(self, feature, value):
        self.feature = feature
        self.value = value

    def __len__(self):
        return len(self.value)

    def __getitem__(self, idx):
        return self.feature[idx], self.value[idx]


class Intopt:
    def __init__(self, n_features, batch_size, cur_NN, num_layers=5, smoothing=False, thr=0.1, max_iter=None, method=1, mu0=None,
                 damping=0.5, target_size=target_num, epochs=8, optimizer=optim.Adam, **hyperparams):
        
        self.target_size = target_size
        self.n_features = n_features
        self.damping = damping
        self.num_layers = num_layers
        self.cur_NN = cur_NN

        self.smoothing = smoothing
        self.thr = thr
        self.max_iter = max_iter
        self.method = method
        self.mu0 = mu0

        self.optimizer = optimizer
        self.batch_size = batch_size
        self.hyperparams = hyperparams
        self.epochs = epochs
        # print("embedding size {} n_features {}".format(embedding_size, n_features))

#        self.model = Net(n_features=n_features, target_size=target_size)
        self.model = make_fc(num_layers=self.num_layers,num_features=n_features)
        #self.model.apply(weight_init)
#        w1 = self.model[0].weight
#        print(w1)

        self.optimizer = optimizer(self.model.parameters(), **hyperparams)

    def fit(self, feature, value):
        logging.info("Intopt")
        train_df = MyCustomDataset(feature, value)

        criterion = nn.L1Loss(reduction='mean')  # nn.L1Loss(reduction='mean')
        grad_list = np.zeros(self.epochs)
        IP_grad_list = np.zeros(self.epochs)
        for i in range(self.epochs):
                IP_grad_list[i] = float("inf")
        for e in range(self.epochs):
            total_loss = 0
#          for parameters in self.model.parameters():
#            print(parameters)
            if e < warm_start_epoch:
#                print('batch_size', self.batch_size)
                train_dl = data_utils.DataLoader(train_df, batch_size=self.batch_size, shuffle=False)
                for feature, value in train_dl:
                    self.optimizer.zero_grad()
                    op = self.model(feature).squeeze()

                    loss = criterion(op, value)
                    total_loss += loss.item()
                    loss.backward()
                    self.optimizer.step()

                grad_list[e] = total_loss
                global stop_epoch
                stop_epoch = e
                # print("Epoch{} ::loss {} ->".format(e,total_loss))
                if e < warm_start_epoch - 1:
                  print("{} ->".format(total_loss), end=" ")
                else:
                  print("{} ->".format(total_loss))
                
                global warm_start_value
                if e == warm_start_epoch - 1 and grad_list[e] < warm_start_value:
                    self.model.eval()
                    criterion = nn.L1Loss(reduction='mean')  # nn.MSELoss(reduction='sum')
                    valid_df = MyCustomDataset(feature, value)
                    valid_dl = data_utils.DataLoader(train_df, batch_size=self.batch_size, shuffle=False)
                    corr_obj_list = []
#                    true_obj_list = []
                    num = 0
#                    print("LP_relax_file.x_num: ", LP_relax_file.x_num)
                    for feature, value in valid_dl:
                        op = self.model(feature).squeeze()
                        # print(op)
                        loss = criterion(op, value)

                        true_price = np.zeros(LP_relax_file.x_num)
                        pred_price = np.zeros(LP_relax_file.x_num)
                        true_weight = np.zeros(LP_relax_file.x_num)
                        pred_weight = np.zeros(LP_relax_file.x_num)
                        for i in range(LP_relax_file.x_num):
                            true_price[i] = value[i][0]
                            pred_price[i] = op[i][0]
                            true_weight[i] = value[i][1]
                            pred_weight[i] = op[i][1]

#                        true_obj = LP_relax_file.actual_obj(true_price, true_weight, n_instance=1)
#                        true_obj_list.append(true_obj)
                        if self.cur_NN == 0:
                            corrrlst = LP_relax_file.correction_single_obj(pred_price, pred_weight, true_price, true_weight)
                        else:
                            corrrlst = LP_relax_file.correction_single_for_latter_month(self.cur_NN, pred_price, pred_weight, true_price, true_weight, train_future_x[num], train_future_y[num], train_future_z[num], prev_cost[num], prev_prof[num])
#                            print(prev_cost[num])
                        corr_obj_list.append(corrrlst)
                        num = num + 1

#                    true_obj_np = np.array(true_obj_list)
                    corr_obj_np = np.array(corr_obj_list)
                    IP_grad_list[e] = np.mean(corr_obj_np)
#                    print(num)
                    print("EOV: ", np.mean(corr_obj_np))

            
            else:
                if e == warm_start_epoch:
                    lr = 5e-7
                    for param_group in self.optimizer.param_groups:
                        param_group['lr'] = lr
                # print(lr)
                train_dl = data_utils.DataLoader(train_df, batch_size=self.batch_size, shuffle=False)
                
                num = 0
                batchCnt = 0
                loss = Variable(torch.tensor(0.0, dtype=torch.double), requires_grad=True)
                for feature, value in train_dl:
                    self.optimizer.zero_grad()
                    op = self.model(feature).squeeze()
                    while torch.min(op) < 0 or torch.isnan(op).any() or torch.isinf(op).any():
                        self.optimizer.zero_grad()
    #                    self.model.__init__(self.n_features, self.target_size)
                        self.model = make_fc(num_layers=self.num_layers,num_features=self.n_features)
                        op = self.model(feature).squeeze()
      

                    true_price = value[:, 0]
                    true_weight = value[:, 1]
                    
                    sol_cur = IPOfunc(cur_NN=self.cur_NN, true_price=true_price, true_weight=true_weight, fixed_x=train_future_x[num], fixed_y=train_future_y[num], fixed_z=train_future_z[num], prev_prof=prev_prof[num], prev_cost=prev_cost[num], max_iter=self.max_iter, thr=self.thr, damping=self.damping, smoothing=self.smoothing)(op)
                    
                    x_sol_cur = sol_cur[:LP_relax_file.x_num]
                    y_sol_cur = sol_cur[LP_relax_file.x_num:LP_relax_file.x_num+LP_relax_file.y_num]
                    z_sol_cur = sol_cur[LP_relax_file.x_num+LP_relax_file.y_num:]
                    if cur_NN == 0:
                        newLoss = - (((true_price * x_sol_cur).sum() + (true_weight[item_num*(LP_relax_file.month_num-1):] * x_sol_cur[item_num*(LP_relax_file.month_num-1):]).sum() - (true_weight[:item_num] * x_sol_cur[:item_num]).sum() - (true_weight[item_num:] * y_sol_cur).sum() - (trans_fee_percent * true_weight[item_num:] * z_sol_cur).sum()))
                    else:
#                        print(x_sol_cur)
                        final_prof = prev_prof[num] + (true_price*x_sol_cur).sum() + (true_weight[(LP_relax_file.month_num-1)*item_num:]*x_sol_cur[(LP_relax_file.x_num-item_num):]).sum()
                        final_cost = prev_cost[num] + trans_fee_percent*(true_weight*z_sol_cur).sum() + (true_weight*y_sol_cur).sum()

                        newLoss = final_cost - final_prof
                    EOV_IP_value = newLoss.item()
                    total_loss += EOV_IP_value
                    newLoss.backward()
                    self.optimizer.step()
                    

                    batchCnt = batchCnt + 1
                    
                    # when training size is large
#                    if batchCnt % 30 == 0:
#                        print(EOV_IP_value)
                    num = num + 1

                grad_list[e] = total_loss / num
                stop_epoch = e

                # compute IP_grad
                valid_dl = data_utils.DataLoader(train_df, batch_size=self.batch_size, shuffle=False)
                corr_obj_list = []
                true_obj_list = []
                num = 0
                for feature, value in valid_dl:
                    op = self.model(feature).squeeze()
                    #            print(op)
                    loss = criterion(op, value)

                    true_price = np.zeros(LP_relax_file.x_num)
                    pred_price = np.zeros(LP_relax_file.x_num)
                    true_weight = np.zeros(LP_relax_file.x_num)
                    pred_weight = np.zeros(LP_relax_file.x_num)
                    for i in range(LP_relax_file.x_num):
                        true_price[i] = value[i][0]
                        pred_price[i] = op[i][0]
                        true_weight[i] = value[i][1]
                        pred_weight[i] = op[i][1]

#                    true_obj = LP_relax_file.actual_obj(true_price, true_weight, n_instance=1)
#                    true_obj_list.append(true_obj)
                    if self.cur_NN == 0:
                        corrrlst = LP_relax_file.correction_single_obj(pred_price, pred_weight, true_price, true_weight)
                    else:
                        corrrlst = LP_relax_file.correction_single_for_latter_month(self.cur_NN, pred_price, pred_weight, true_price, true_weight, train_future_x[num], train_future_y[num], train_future_z[num], prev_cost[num], prev_prof[num])
                    corr_obj_list.append(corrrlst)
                    num = num + 1

#                true_obj_np = np.array(true_obj_list)
                corr_obj_np = np.array(corr_obj_list)
                IP_grad_list[e] = np.mean(corr_obj_np)
                recordNow[self.cur_NN] = np.mean(corr_obj_np)

            logging.info("EPOCH Ends")
            if e >= warm_start_epoch:
                print("Epoch{} ::EOV {} ".format(e, IP_grad_list[e]))
            if grad_list[6] > warm_start_value:
                break
            if e >= 1 and abs(grad_list[e] - grad_list[e-1]) <= 0.001:
                break
            if e >= warm_start_epoch and abs(IP_grad_list[e] - IP_grad_list[e-1]) <= 0.001:
                break
            if e >= warm_start_epoch and abs(IP_grad_list[e]) < abs(IP_grad_list[e-1]):
                break
            if e >= warm_start_epoch and abs(IP_grad_list[e] - IP_grad_list[e-1]) >= 50:
                stop_epoch = 1
                break


    def val_loss(self, feature, value):
        valueTemp = value.numpy()
#        test_instance = len(valueTemp) / self.batch_size
        instance_num = np.size(valueTemp, 0) / self.batch_size
#        print(valueTemp.shape, instance_num)
        true_price = valueTemp[:, 0]
        true_weight = valueTemp[:, 1]
#        print(true_price.shape, true_weight.shape)
        true_obj = LP_relax_file.actual_obj(true_price, true_weight, n_instance=int(instance_num))
#        print(np.sum(real_obj))

        self.model.eval()
        criterion = nn.L1Loss(reduction='mean')  # nn.MSELoss(reduction='sum')
        valid_df = MyCustomDataset(feature, value)
        valid_dl = data_utils.DataLoader(valid_df, batch_size=self.batch_size, shuffle=False)

        corr_obj_list = []
        len = np.size(valueTemp, 0)
        predVal = torch.zeros((len, 2))
        
        num = 0
        for feature, value in valid_dl:
            op = self.model(feature).squeeze()
#            print(op)
            loss = criterion(op, value)

            true_price = np.zeros(LP_relax_file.x_num)
            pred_price = np.zeros(LP_relax_file.x_num)
            true_weight = np.zeros(LP_relax_file.x_num)
            pred_weight = np.zeros(LP_relax_file.x_num)
            for i in range(LP_relax_file.x_num):
                true_price[i] = value[i][0]
                pred_price[i] = op[i][0]
                true_weight[i] = value[i][1]
                pred_weight[i] = op[i][1]
                predVal[i+num*LP_relax_file.x_num][0] = op[i][0]
                predVal[i+num*LP_relax_file.x_num][1] = op[i][1]

            corrrlst = LP_relax_file.correction_single_obj(pred_price, pred_weight, true_price, true_weight)
            corr_obj_list.append(corrrlst)
            num = num + 1
            
#        print(instance_num, num)
#        self.model.train()
        print("TOV: ", sum(true_obj)/num, "EOV: ", sum(corr_obj_list)/num, "PReg: ", sum(abs(true_obj) - np.array(corr_obj_list))/num)
#        print(corr_obj_list)
#        print(corr_obj_list-real_obj)
#        print(np.sum(corr_obj_list))
#        return prediction_loss, abs(np.array(obj_list) - real_obj)
        return abs(np.array(corr_obj_list) - true_obj), predVal


    def get_pred_val(self, feature, value):
        valueTemp = value.numpy()
        self.model.eval()
        criterion = nn.L1Loss(reduction='mean')  # nn.MSELoss(reduction='sum')
        valid_df = MyCustomDataset(feature, value)
        valid_dl = data_utils.DataLoader(valid_df, batch_size=self.batch_size, shuffle=False)

        len = np.size(valueTemp, 0)
        predVal = torch.zeros((len, 2))

        num = 0
        for feature, value in valid_dl:
            op = self.model(feature).squeeze()
            loss = criterion(op, value)

            for i in range(LP_relax_file.x_num):
                predVal[i+num*LP_relax_file.x_num][0] = op[i][0]
                predVal[i+num*LP_relax_file.x_num][1] = op[i][1]
                
            num = num + 1

        return predVal
        

    def make_future_plan(self, feature, value):
        valueTemp = value.numpy()
        instance_num = np.size(valueTemp, 0) / self.batch_size
        self.model.eval()
        criterion = nn.L1Loss(reduction='mean')  # nn.MSELoss(reduction='sum')
        valid_df = MyCustomDataset(feature, value)
        valid_dl = data_utils.DataLoader(valid_df, batch_size=self.batch_size, shuffle=False)

        len = np.size(valueTemp, 0)
        predVal = torch.zeros((len, 2))
        true_price = np.zeros(LP_relax_file.x_num)
        pred_price = np.zeros(LP_relax_file.x_num)
        true_weight = np.zeros(LP_relax_file.x_num)
        pred_weight = np.zeros(LP_relax_file.x_num)
        future_plan = np.zeros((int(instance_num), item_num))
        global train_future_x
        global train_future_y
        global train_future_z

        num = 0
        for feature, value in valid_dl:
            op = self.model(feature).squeeze()
            loss = criterion(op, value)

            for i in range(LP_relax_file.x_num):
                true_price[i] = value[i][0]
                pred_price[i] = op[i][0]
                true_weight[i] = value[i][1]
                pred_weight[i] = op[i][1]
                predVal[i+num*LP_relax_file.x_num][0] = op[i][0]
                predVal[i+num*LP_relax_file.x_num][1] = op[i][1]
            
            if self.cur_NN == 0:
                A_0, b_0, G_0, h_0 = LP_relax_file.gen_constraints(0)
                # current decision: next month weights revealed, profits are unknown
    #            true_weight_cur = true_weight[:item_num]
    #            pred_weight_cur = pred_weight[item_num:]
                x_sol, y_sol, z_sol = LP_relax_file.get_init_plan(pred_price, pred_weight, true_weight, A_0, b_0, G_0, h_0)
                for i in range(item_num):
                    train_future_x[num][i] = x_sol[i]
                    
            else:
#                print("num: ", num)
                x_prev_sol = train_future_x[num][(self.cur_NN-1)*item_num:self.cur_NN*item_num]
                A_0, b_0, G_0, h_0 = LP_relax_file.gen_constraints_latter_months(LP_relax_file.month_num, x_prev_sol)
                # current decision: next month weights revealed, profits are unknown
    #            true_weight_cur = true_weight[:item_num]
    #            pred_weight_cur = pred_weight[item_num:]
                x_sol, y_sol, z_sol = LP_relax_file.get_updated_plan_for_each_month(LP_relax_file.month_num, prev_prof[num], prev_cost[num], pred_price, pred_weight, true_price, true_weight, A_0, b_0, G_0, h_0)
#                print(prev_cost[num])
                for i in range(item_num):
                    train_future_x[num][self.cur_NN*item_num+i] = x_sol[i]
                    train_future_y[num][(self.cur_NN-1)*item_num+i] = y_sol[i]
                    train_future_z[num][(self.cur_NN-1)*item_num+i] = z_sol[i]
#                print(train_future_x[num])
            future_plan[num] = x_sol[:item_num]
                
            num = num + 1

        return future_plan


print("*** PCD ****")

simulation_time = 30
recordBest = np.zeros(total_month_num+1)
recordNow = np.zeros(total_month_num+1)
train_each_NN_time = np.zeros(total_month_num)
each_iter_time = np.zeros(iteration_num)
print("item_num: ", item_num, " month_num: ", LP_relax_file.month_num, " trans_fee_percent: ", trans_fee_percent, " capacity: ", LP_relax_file.capacity, " warm_start_epoch： ", warm_start_epoch)


for testi in range(startmark, endmark):
    print("-------------------------------------------------------------")
    print("Simulation ", testi)
    x_train_temp = np.loadtxt(os.path.join(dataset_path, 'data/item_num=5, month_num=' + str(total_month_num) + '/train_features/train_features(' + str(testi) + ').txt'))
    y_train1 = np.loadtxt(os.path.join(dataset_path, 'data/item_num=5, month_num=' + str(total_month_num) + '/rescale_train_prices/rescale_train_prices(' + str(testi) + ').txt'))
    y_train2 = np.loadtxt(os.path.join(dataset_path, 'data/item_num=5, month_num=' + str(total_month_num) + '/train_weights/train_weights(' + str(testi) + ').txt'))
    train_TOV_x = np.loadtxt(os.path.join(default_path, 'true_owned_prod/true_owned_prod(' + str(testi) + ').txt'))
    train_TOV_y = np.loadtxt(os.path.join(default_path, 'true_y/true_y(' + str(testi) + ').txt'))
    train_TOV_z = np.loadtxt(os.path.join(default_path, 'true_z/true_z(' + str(testi) + ').txt'))
    train_TOV_prev_cost = np.loadtxt(os.path.join(default_path, 'true_cost/true_cost(' + str(testi) + ').txt'))
    train_TOV_prev_prof = np.loadtxt(os.path.join(default_path, 'true_prof/true_prof(' + str(testi) + ').txt'))
    
    x_test_temp = np.loadtxt(os.path.join(dataset_path, 'data/item_num=5, month_num=' + str(total_month_num) + '/test_features/test_features(' + str(testi) + ').txt'))
    y_test1 = np.loadtxt(os.path.join(dataset_path, 'data/item_num=5, month_num=' + str(total_month_num) + '/rescale_test_prices/rescale_test_prices(' + str(testi) + ').txt'))
    y_test2 = np.loadtxt(os.path.join(dataset_path, 'data/item_num=5, month_num=' + str(total_month_num) + '/test_weights/test_weights(' + str(testi) + ').txt'))
    

    damping = 1e-2
    thr = 1e-3
    lr = 1e-5
    bestTrainCorrReg = float("inf")
    
    iter_loss = np.zeros(iteration_num)
    for iter_cnt in range(iteration_num):
        cur_NN_start = 0
        if testi == 22:
          cur_NN_start = 5
        for cur_NN in range(cur_NN_start, total_month_num):
            start_time = time.time()
            cur_month_num = total_month_num - cur_NN
            if cur_NN == 0:
                LP_relax_file.reset_month_num()
            if cur_NN > 0:
                LP_relax_file.change_month_num(cur_month_num)
    #        cur_month_num = LP_relax_file.month_num
            print(cur_month_num)
            x_train = np.zeros((cur_month_num*item_num*train_set_size, featureNum))
            y_train = np.zeros((cur_month_num*item_num*train_set_size, 2))
            for i in range(train_set_size):
                k = 0
                for j in range(cur_NN*item_num, total_month_num*item_num):
                    x_train[i*cur_month_num*item_num+k] = x_train_temp[i*total_month_num*item_num+j]
                    y_train[i*cur_month_num*item_num+k][0] = y_train1[i*total_month_num*item_num+j]
                    y_train[i*cur_month_num*item_num+k][1] = y_train2[i*total_month_num*item_num+j]
                    k = k + 1
            feature_train = torch.from_numpy(x_train).float()
            value_train = torch.from_numpy(y_train).float()
            
            x_test = np.zeros((cur_month_num*item_num*test_set_size, featureNum))
            y_test = np.zeros((cur_month_num*item_num*test_set_size, 2))
            for i in range(test_set_size):
                k = 0
                for j in range(cur_NN*item_num, total_month_num*item_num):
                    x_test[i*cur_month_num*item_num+k] = x_test_temp[i*total_month_num*item_num+j]
                    y_test[i*cur_month_num*item_num+k][0] = y_test1[i*total_month_num*item_num+j]
                    y_test[i*cur_month_num*item_num+k][1] = y_test2[i*total_month_num*item_num+j]
                    k = k + 1
            feature_test = torch.from_numpy(x_test).float()
            value_test = torch.from_numpy(y_test).float()
            
            if cur_NN > 0:
                if iter_cnt == 0:
                    # get current states
                    train_future_x = train_TOV_x
                    train_future_y = train_TOV_y
                    train_future_z = train_TOV_z
                    prev_prof = train_TOV_prev_prof[:, cur_NN]
                    prev_cost = train_TOV_prev_cost[:, cur_NN]
                else:
                    # get current states
                    train_future_x = np.loadtxt(os.path.join(default_path, 'prev_owned_prod/prev_owned_prod_iter' + str(iter_cnt-1) + '_NN' + str(cur_NN-1) + '(' + str(testi) + ').txt'))
                    train_future_y = np.loadtxt(os.path.join(default_path, 'prev_y/prev_y_iter' + str(iter_cnt-1) + '_NN' + str(cur_NN-1) + '(' + str(testi) + ').txt'))
                    train_future_z = np.loadtxt(os.path.join(default_path, 'prev_z/prev_z_iter' + str(iter_cnt-1) + '_NN' + str(cur_NN-1) + '(' + str(testi) + ').txt'))
                    prev_cost = np.loadtxt(os.path.join(default_path, 'prev_cost/prev_cost_iter' + str(iter_cnt-1) + '_NN' + str(cur_NN-1) + '(' + str(testi) + ').txt'))
                    prev_prof = np.loadtxt(os.path.join(default_path, 'prev_prof/prev_prof_iter' + str(iter_cnt-1) + '_NN' + str(cur_NN-1) + '(' + str(testi) + ').txt'))

            
            stop_epoch = 0
            while stop_epoch < warm_start_epoch:
                clf = Intopt(damping=damping, lr=lr, n_features=featureNum, batch_size=LP_relax_file.x_num, cur_NN=cur_NN, thr=thr, epochs=stop_epoch_criterion)
                clf.fit(feature_train, value_train)

                if stop_epoch >= warm_start_epoch:
                    end_time = time.time()
                    cur_decision = clf.make_future_plan(feature_train, value_train)
                    if recordNow[cur_NN] > recordBest[cur_NN]:
                        recordBest[cur_NN] = recordNow[cur_NN]
                        torch.save(clf.model.state_dict(), 'parallel_cap' + str(LP_relax_file.capacity) + '_trans' + str(LP_relax_file.trans_fee_percent) + '_NN' + str(cur_NN) + '_model.pkl')
                    if cur_NN > 0 and iter_cnt == 1:
                        recordBest[cur_NN] = recordNow[cur_NN]
                        torch.save(clf.model.state_dict(), 'parallel_cap' + str(LP_relax_file.capacity) + '_trans' + str(LP_relax_file.trans_fee_percent) + '_NN' + str(cur_NN) + '_model.pkl')
        #            print(trainHSD_rslt)

            clfBest = Intopt(damping=damping, lr=lr, n_features=featureNum, cur_NN=cur_NN, batch_size=LP_relax_file.x_num, thr=thr, epochs=stop_epoch_criterion)
            clfBest.model.load_state_dict(torch.load('parallel_cap' + str(LP_relax_file.capacity) + '_trans' + str(LP_relax_file.trans_fee_percent) + '_NN' + str(cur_NN) + '_model.pkl'))
            print("Simulation " + str(testi) + " Training NN " + str(cur_NN) + " time: ", end_time - start_time)
            train_each_NN_time[cur_NN] = end_time - start_time

            predTestVal = clfBest.get_pred_val(feature_test, value_test)
            predTestVal = predTestVal.detach().numpy()
            predTestVal1 = predTestVal[:, 0]
            predTestVal2 = predTestVal[:, 1]
            predValuePrice = np.zeros((predTestVal1.size, 2))
            for i in range(predTestVal1.size):
        #        predValue[i][0] = int(i/itemNum)
                predValuePrice[i][0] = y_test1[i]
                predValuePrice[i][1] = predTestVal1[i]
            np.savetxt(os.path.join(default_path, 'parallel_T_NN/parallel_iter' + str(iter_cnt) + '_NN' + str(cur_NN) + ', warm_start=' + str(warm_start_epoch) + ', MS_prices(' + str(testi) + ').txt'), predValuePrice, fmt="%.2f")
            
            predValueWeight = np.zeros((predTestVal2.size, 2))
            for i in range(predTestVal2.size):
        #        predValue[i][0] = int(i/itemNum)
                predValueWeight[i][0] = y_test2[i]
                predValueWeight[i][1] = predTestVal2[i]
            np.savetxt(os.path.join(default_path, 'parallel_T_NN/parallel_iter' + str(iter_cnt) + '_NN' + str(cur_NN) + ', warm_start=' + str(warm_start_epoch) + ', MS_weights(' + str(testi) + ').txt'), predValueWeight, fmt="%.2f")

            # compute current states
            for num in range(train_set_size):
                true_weight = np.zeros(item_num * total_month_num)
                true_value = np.zeros(item_num * total_month_num)
                cnt = num * item_num * total_month_num
                for i in range(item_num * total_month_num):
                    true_weight[i] = y_train2[cnt]
                    true_value[i] = y_train1[cnt]
                    cnt = cnt + 1
                x_prev_sol = train_future_x[num]
                y_prev_sol = train_future_y[num]
                z_prev_sol = train_future_z[num]
                prev_prof[num] = np.dot(true_value[:(cur_NN+1)*item_num], x_prev_sol[:(cur_NN+1)*item_num])
                prev_cost[num] = np.dot(true_weight[:item_num], x_prev_sol[:item_num]) + trans_fee_percent*np.dot(true_weight[item_num:(cur_NN+1)*item_num], z_prev_sol[:cur_NN*item_num]) + np.dot(true_weight[item_num:(cur_NN+1)*item_num], y_prev_sol[:cur_NN*item_num])
            np.savetxt(os.path.join(default_path, 'prev_cost/prev_cost_iter' + str(iter_cnt) + '_NN' + str(cur_NN) + '(' + str(testi) + ').txt'), prev_cost, fmt="%.2f")
            np.savetxt(os.path.join(default_path, 'prev_prof/prev_prof_iter' + str(iter_cnt) + '_NN' + str(cur_NN) + '(' + str(testi) + ').txt'), prev_prof, fmt="%.2f")
            np.savetxt(os.path.join(default_path, 'prev_owned_prod/prev_owned_prod_iter' + str(iter_cnt) + '_NN' + str(cur_NN) + '(' + str(testi) + ').txt'), train_future_x, fmt="%.2f")
            np.savetxt(os.path.join(default_path, 'prev_y/prev_y_iter' + str(iter_cnt) + '_NN' + str(cur_NN) + '(' + str(testi) + ').txt'), train_future_y, fmt="%.2f")
            np.savetxt(os.path.join(default_path, 'prev_z/prev_z_iter' + str(iter_cnt) + '_NN' + str(cur_NN) + '(' + str(testi) + ').txt'), train_future_z, fmt="%.2f")
            
            
            if cur_NN == total_month_num - 1:
                each_iter_time[iter_cnt] = np.max(train_each_NN_time)
                # Evaluate all NNs qualities
                for NN_cnt in range(total_month_num):
                    # print(NN_cnt)
                    cur_month_num = total_month_num - NN_cnt
                    if NN_cnt == 0:
                        LP_relax_file.reset_month_num()
                    if NN_cnt > 0:
                        LP_relax_file.change_month_num(cur_month_num)
                    
                    test_price_full = np.loadtxt(os.path.join(default_path, 'parallel_T_NN/parallel_iter' + str(iter_cnt) + '_NN' + str(NN_cnt) + ', warm_start=' + str(warm_start_epoch) + ', MS_prices(' + str(testi) + ').txt'))
                    test_weight_full = np.loadtxt(os.path.join(default_path, 'parallel_T_NN/parallel_iter' + str(iter_cnt) + '_NN' + str(NN_cnt) + ', warm_start=' + str(warm_start_epoch) + ', MS_weights(' + str(testi) + ').txt'))
                    pred_price_full = test_price_full[:, 1]
                    pred_weight_full = test_weight_full[:, 1]
                    # print(pred_price_full.shape, pred_weight_full.shape)
                    true_price_full = np.zeros(cur_month_num*item_num*test_set_size)
                    true_weight_full = np.zeros(cur_month_num*item_num*test_set_size)
                    for i in range(test_set_size):
                        k = 0
                        for j in range(NN_cnt*item_num, total_month_num*item_num):
                            true_price_full[i*cur_month_num*item_num+k] = y_test1[i*total_month_num*item_num+j]
                            true_weight_full[i*cur_month_num*item_num+k] = y_test2[i*total_month_num*item_num+j]
                            k = k + 1
                    
                    # compute current states
                    if NN_cnt > 0:
                        for test_num in range(test_set_size):
                            true_weight = np.zeros(item_num * total_month_num)
                            true_value = np.zeros(item_num * total_month_num)
                            cnt = test_num * item_num * total_month_num
                            for i in range(item_num * total_month_num):
                                true_weight[i] = y_test2[cnt]
                                true_value[i] = y_test1[cnt]
                                cnt = cnt + 1
                            x_prev_sol = test_future_x[test_num]
                            y_prev_sol = test_future_y[test_num]
                            z_prev_sol = test_future_z[test_num]
                            test_prev_prof[test_num] = np.dot(true_value[:NN_cnt*item_num], x_prev_sol[:NN_cnt*item_num])
                            test_prev_cost[test_num] = np.dot(true_weight[:item_num], x_prev_sol[:item_num]) + trans_fee_percent*np.dot(true_weight[item_num:cur_NN*item_num], z_prev_sol[:(cur_NN-1)*item_num]) + np.dot(true_weight[item_num:cur_NN*item_num], y_prev_sol[:(cur_NN-1)*item_num])
                    
                    # Compute the NN_cnt plans
                    for test_num in range(test_set_size):
                        cnt = test_num * LP_relax_file.x_num
                        true_price = np.zeros(LP_relax_file.x_num)
                        true_weight = np.zeros(LP_relax_file.x_num)
                        pred_price = np.zeros(LP_relax_file.x_num)
                        pred_weight = np.zeros(LP_relax_file.x_num)
                        for i in range(LP_relax_file.x_num):
                            true_price[i] = true_price_full[cnt]
                            pred_price[i] = pred_price_full[cnt]
                            true_weight[i] = true_weight_full[cnt]
                            pred_weight[i] = pred_weight_full[cnt]
                            cnt = cnt + 1
                        make_next_plan(test_num, NN_cnt, pred_price, pred_weight, true_price, true_weight)
                    
                        if NN_cnt == total_month_num - 1:
                            test_prev_prof[test_num] +=  np.dot(true_price, test_future_x[test_num][(total_month_num-1)*item_num:]) + np.dot(true_weight, test_future_x[test_num][(total_month_num-1)*item_num:])
                            test_prev_cost[test_num] += trans_fee_percent*np.dot(true_weight, test_future_z[test_num][(total_month_num-2)*item_num:]) + np.dot(true_weight, test_future_y[test_num][(total_month_num-2)*item_num:])
                
                
                test_obj = test_prev_prof - test_prev_cost
                LP_relax_file.reset_month_num()
                true_obj = LP_relax_file.actual_obj(y_test1, y_test2, n_instance=test_set_size)
                PReg = true_obj - test_obj
                print("Simulation ", testi, " Iteration ", iter_cnt, end=" ")
                print("Test: TOV: ", np.sum(true_obj)/test_set_size, "EOV: ", np.sum(test_obj)/test_set_size, "PReg: ", np.sum(PReg)/test_set_size, end=" ")
                print("Total training time: ", np.sum(each_iter_time))
                iter_loss[iter_cnt] = np.sum(PReg)/test_set_size
                
                # reset
                train_each_NN_time = train_each_NN_time * 0
                train_future_x = train_future_x * 0
                train_future_y = train_future_y * 0
                train_future_z = train_future_z * 0
                prev_prof = prev_prof * 0
                prev_cost = prev_cost * 0

                test_future_x = test_future_x * 0
                test_future_y = test_future_y * 0
                test_future_z = test_future_z * 0
                test_prev_prof = test_prev_prof * 0
                test_prev_cost = test_prev_cost * 0

        if iter_cnt > 0 and abs(iter_loss[iter_cnt] - iter_loss[iter_cnt-1]) < 0.1:
            break
        if iter_cnt > 1 and iter_loss[iter_cnt] >= iter_loss[iter_cnt-1] and iter_loss[iter_cnt] >= iter_loss[iter_cnt-2]:
            break
    print(recordNow, recordBest)
