import os
import sys
from collections import defaultdict
import numpy as np
import gurobipy as gp
from gurobipy import GRB
from sklearn.metrics import mean_squared_error
from numpy import inf
import torch
import math
import time
import linear_relax as LP_relax_file

total_day_num = LP_relax_file.total_day_num
nurse_num = LP_relax_file.nurse_num
day_num = LP_relax_file.day_num
shift_num = LP_relax_file.shift_num
day_shift_num = LP_relax_file.day_shift_num
day_work_shift_num = LP_relax_file.day_work_shift_num
decision_num = LP_relax_file.decision_num
t_decision_num = LP_relax_file.t_decision_num
penaltyTerm = LP_relax_file.penaltyTerm
extra_serve_patient_num = LP_relax_file.extra_serve_patient_num
minimum_relax_day = LP_relax_file.minimum_relax_day
maximum_relax_day = LP_relax_file.maximum_relax_day


extra_payment = int(sys.argv[1])
startmark = int(sys.argv[2])
endmark = int(sys.argv[3])
LP_relax_file.set_extra_payment(extra_payment)

train_set_num = 70
featureNum = 8
train_future_x = np.zeros((train_set_num, LP_relax_file.x_num))
train_has_rested = np.zeros((train_set_num, nurse_num))
train_curr_cost = np.zeros((train_set_num, total_day_num+1))
train_curr_penalty = np.zeros((train_set_num, total_day_num+1))


simulation_time = 10
dataset_path = os.path.abspath(os.path.dirname(os.getcwd()))
default_path = os.path.join(dataset_path, 'day_num=' + str(total_day_num) + '/penalty=' + str(penaltyTerm) + ', extra_serve=' + str(extra_serve_patient_num) + ', extra_payment=' + str(LP_relax_file.extra_payment) + '/')

LP_relax_file.mkdir(default_path, 'true_schedule')
LP_relax_file.mkdir(default_path, 'true_cost')
LP_relax_file.mkdir(default_path, 'true_penalty')


def make_next_plan(test_num, cur_NN, real_patient, pre_patient, cost, serve_patient_num, penalty=None):
    global train_future_x
    global train_has_rested
    global train_curr_cost
    global train_curr_penalty

    if cur_NN == 0:
        c = LP_relax_file.gen_obj(0, cost)
        A,b,G,h2,h3,h4,h5 = LP_relax_file.gen_matrix(nurse_num,day_num,shift_num,serve_patient_num,decision_num,day_shift_num)

        init_x, init_sigma = LP_relax_file.get_init_plan(c, A, b, G, real_patient, pre_patient, h2, h3, h4, h5)
        train_future_x[test_num] = init_x
#        print(init_x)
#        for i in range(nurse_num):
#            for j in range(shift_num):
#                train_future_x[test_num][i*total_day_num*shift_num+j] = init_x[i*total_day_num*shift_num+j]
#                if j == shift_num and init_x[i*total_day_num*shift_num+j] == 1:
#                    test_has_rested[i] += 1
            
    else:
        remaining_schedule = np.zeros(nurse_num * LP_relax_file.day_num * shift_num)
        for i in range(nurse_num):
            for j in range(cur_NN-1, total_day_num):
                for k in range(shift_num):
                    remaining_schedule[i*LP_relax_file.day_shift_num+(j-cur_NN+1)*shift_num+k] = train_future_x[test_num][i*total_day_num*shift_num+j*shift_num+k]
        
        pre_h1 = np.zeros(LP_relax_file.day_shift_num)
        real_h1 = np.zeros(LP_relax_file.day_shift_num)
        cnt = 0
        for i in range(LP_relax_file.day_shift_num):
            if i % shift_num != 3:
                pre_h1[i] = -pre_patient[cnt]
                real_h1[i] = -real_patient[cnt]
                cnt = cnt + 1
            else:
                pre_h1[i] = 0
                real_h1[i] = 0
        
        c = LP_relax_file.gen_obj_latter_days(cur_NN, cost, penalty)
        A, b, G, h = LP_relax_file.gen_constraints_latter_days(cur_NN, remaining_schedule, train_has_rested[test_num], real_h1, pre_h1, serve_patient_num)

        x_sol, sigma_sol, incur_penalty = LP_relax_file.get_updated_plan_for_each_day(cur_NN, c, A, b, G, h, penalty)
#                print(prev_cost[num])
        train_curr_penalty[test_num][cur_NN] = train_curr_penalty[test_num][cur_NN-1] + incur_penalty
        train_curr_cost[test_num][cur_NN] = train_curr_cost[test_num][cur_NN-1]
        for i in range(nurse_num):
            for j in range(shift_num-1):
                train_curr_cost[test_num][cur_NN] += cost[i] * x_sol[i*LP_relax_file.day_shift_num+j]
        for i in range(shift_num-1):
            train_curr_cost[test_num][cur_NN] += LP_relax_file.extra_payment * sigma_sol[i]
#        print(sigma_sol)
        
        for i in range(nurse_num):
            for j in range(cur_NN-1, total_day_num):
                for k in range(shift_num):
                    full_index = i*total_day_num*shift_num+j*shift_num+k
                    curr_index = i*LP_relax_file.day_shift_num+(j-cur_NN+1)*shift_num+k
                    train_future_x[test_num][full_index] = x_sol[curr_index]
        
        for i in range(nurse_num):
            if x_sol[i*LP_relax_file.day_shift_num+shift_num-1] == 1:
                train_has_rested[test_num][i] += 1
        
#        future_plan = x_sol[:item_num]
#        num = num + 1
#    return future_plan



print("nurse_num: ", nurse_num, "day_num: ", day_num, "penalty_for_reschedule: ", penaltyTerm, "extra_serve_patient_num: ", extra_serve_patient_num, "extra_payment: ", LP_relax_file.extra_payment)

for testi in range(startmark, endmark):
    print(testi, end=" ")
    cost = np.loadtxt(os.path.join(dataset_path, 'day_num=' + str(total_day_num) + '/payment/payment(' + str(testi) + ').txt'))
    serve_patient_num = np.loadtxt(os.path.join(dataset_path, 'day_num=' + str(total_day_num) + '/serve_patient_num/serve_patient_num(' + str(testi) + ').txt'))
    trainData = np.loadtxt(os.path.join(dataset_path, 'day_num=' + str(total_day_num) + '/train/train(' + str(testi) + ').txt'))
    true_patient_full = trainData[:, featureNum+1]
    
    for NN_cnt in range(total_day_num+1):
        cur_day_num = total_day_num
        if NN_cnt == 0:
            LP_relax_file.reset_day_num()
        else:
            cur_day_num = total_day_num - NN_cnt + 1
            LP_relax_file.change_day_num(cur_day_num)
            
        # Compute the NN_cnt plans
        for test_num in range(train_set_num):
#            print(test_num)
            cnt = test_num * LP_relax_file.day_work_shift_num
            real_patient = np.zeros(LP_relax_file.day_work_shift_num)
            penalty = np.full(LP_relax_file.x_num, penaltyTerm)
            for i in range(LP_relax_file.day_work_shift_num):
                real_patient[i] = true_patient_full[cnt]
                cnt = cnt + 1
#            print(real_patient)
            make_next_plan(test_num, NN_cnt, real_patient, real_patient, cost, serve_patient_num, penalty)
    
    
    train_obj = train_curr_cost[:, total_day_num] + train_curr_penalty[:, total_day_num]
    
    np.savetxt(os.path.join(default_path, 'true_schedule/true_schedule(' + str(testi) + ').txt'), train_future_x, fmt="%.0f")
    np.savetxt(os.path.join(default_path, 'true_cost/true_cost(' + str(testi) + ').txt'), train_curr_cost, fmt="%.2f")
    np.savetxt(os.path.join(default_path, 'true_penalty/true_penalty(' + str(testi) + ').txt'), train_curr_penalty, fmt="%.2f")
#    print(train_TOV_prev_cost, train_TOV_prev_prof)
    LP_relax_file.reset_day_num()
    print("TOV: ", np.sum(train_obj)/train_set_num)
    #true_obj = LP_relax_file.actual_obj(y_test1, y_test2, n_instance=test_set_size)
    #PReg = true_obj - test_obj
    #print("Test: TOV: ", np.sum(true_obj)/test_set_size, "EOV: ", np.sum(test_obj)/test_set_size, "PReg: ", np.sum(PReg)/test_set_size)
    #print("Training time: ", end_time - start_time)

    # reset
    train_future_x = train_future_x * 0
    train_has_rested = train_has_rested * 0
    train_curr_cost = train_curr_cost * 0
    train_curr_penalty = train_curr_penalty * 0
