import os
import sys
from collections import defaultdict
import numpy as np
import gurobipy as gp
from gurobipy import GRB
from sklearn.metrics import mean_squared_error
from numpy import inf
import torch
import math
import time
import sys
import linear_relax as LP_relax_file
from sklearn.metrics import mean_squared_error

nurse_num = LP_relax_file.nurse_num
day_num = LP_relax_file.day_num
shift_num = LP_relax_file.shift_num
day_shift_num = LP_relax_file.day_shift_num
day_work_shift_num = LP_relax_file.day_work_shift_num
decision_num = LP_relax_file.decision_num
t_decision_num = LP_relax_file.t_decision_num
penaltyTerm = LP_relax_file.penaltyTerm
extra_serve_patient_num = LP_relax_file.extra_serve_patient_num
minimum_relax_day = LP_relax_file.minimum_relax_day
maximum_relax_day = LP_relax_file.maximum_relax_day

extra_payment = int(sys.argv[1])
startmark = int(sys.argv[2])
endmark = int(sys.argv[3])
LP_relax_file.set_extra_payment(extra_payment)

train_case_num = 70
test_case_num = 30
methodList = ['Ridge', 'knn', 'CART', 'RF']
dataset_path = os.path.abspath(os.path.dirname(os.getcwd()))

print("nurse_num: ", nurse_num, "day_num: ", day_num, "penalty_for_reschedule: ", penaltyTerm, "extra_serve_patient_num: ", extra_serve_patient_num, "extra_payment: ", extra_payment)
for methodName in methodList:
    print(methodName)
    for testmark in range(startmark, endmark):
        start_time = time.time()
        cost = np.loadtxt(os.path.join(dataset_path, 'day_num=7/payment/payment(' + str(testmark) + ').txt'))
        serve_patient_num = np.loadtxt(os.path.join(dataset_path, 'day_num=7/serve_patient_num/serve_patient_num(' + str(testmark) + ').txt'))

        patient_num_temp = np.loadtxt(os.path.join(dataset_path, 'day_num=7/' + methodName + '/' + methodName + '(' + str(testmark) + ').txt'))

        real_patient_num = patient_num_temp[:,1]
        pre_patient_num = patient_num_temp[:,2]
        
        c = LP_relax_file.gen_obj(0, cost)
        A,b,G,h2,h3,h4,h5 = LP_relax_file.gen_matrix(nurse_num,day_num,shift_num,serve_patient_num,decision_num,day_shift_num)
        #print(G,h)

        # test IP
        test_real_obj_IP = LP_relax_file.actual_obj(c, A, b, G, real_patient_num, h2, h3, h4, h5, n_instance=test_case_num)
        

        test_pred_obj_list_IP = []
        for testNum in range(test_case_num):
    #            print(testNum, end=": ")
            real_patient = {}
            pre_patient = {}
            for i in range(day_work_shift_num):
                real_patient[i] = real_patient_num[i+testNum*day_work_shift_num]
                pre_patient[i] = pre_patient_num[i+testNum*day_work_shift_num]

            penalty = np.zeros(decision_num)
            for j in range(decision_num):
                penalty[j] = penaltyTerm
            total_cost = LP_relax_file.correction_single_obj(c, A, b, G, real_patient, pre_patient, h2, h3, h4, h5, penalty)

            test_pred_obj_list_IP.append(total_cost)
        
        end_time = time.time()
        print(testmark, "MSE: ", mean_squared_error(real_patient_num, pre_patient_num), "TOV: ", sum(np.array(test_real_obj_IP))/test_case_num, "EOV: ", sum(np.array(test_pred_obj_list_IP))/test_case_num, "avgPReg: ", sum(abs(np.array(test_pred_obj_list_IP) - test_real_obj_IP))/test_case_num, "runtime: ", str(end_time-start_time))
