import os
import sys
import linear_relax as LP_relax_file

from collections import defaultdict
import numpy as np
import gurobipy as gp
from gurobipy import GRB
from sklearn.metrics import mean_squared_error
from numpy import inf
import torch
import math
import time
from numpy import linalg as LA

month_num = LP_relax_file.month_num

small_or_large = int(sys.argv[1])
startmark = int(sys.argv[2])
endmark = int(sys.argv[3])

instance_num = 30
methodList = ['Ridge', 'knn', 'CART', 'RF']

dataset_path = os.path.abspath(os.path.dirname(os.getcwd()))
default_path = os.path.join(dataset_path, 'data/month_num='+str(month_num)+'/')

if small_or_large == 0:
  print("small price,", end=' ')
elif small_or_large == 1:
  print("large price,", end=' ')

print("month_num: ", month_num)

for methodName in methodList:
    print(methodName)
    for testmark in range(startmark, endmark):
        start_time = time.time()
        cost = np.loadtxt(os.path.join(dataset_path, 'data/month_num='+str(month_num)+'/cost/cost(' + str(testmark) + ').txt'))
        if small_or_large == 0:
          price = np.loadtxt(os.path.join(dataset_path, 'data/month_num='+str(month_num)+'/small_price/price(' + str(testmark) + ').txt'))
        elif small_or_large == 1:
          price = np.loadtxt(os.path.join(dataset_path, 'data/month_num='+str(month_num)+'/large_price/price(' + str(testmark) + ').txt'))
        demand_temp = np.loadtxt(os.path.join(default_path, methodName + '/' + methodName + '_demands(' + str(testmark) + ').txt'))

        
        true_demand_total = demand_temp[:, 0]
        pred_demand_total = demand_temp[:, 1]
        
        true_obj = LP_relax_file.actual_obj(price, cost, true_demand_total, n_instance=instance_num)

        corr_obj_list = []
        for testNum in range(instance_num):
            # print(testNum)
            true_demand = np.zeros(month_num)
            pred_demand = np.zeros(month_num)
            for i in range(month_num):
                true_demand[i] = true_demand_total[i+testNum*month_num]
                pred_demand[i] = pred_demand_total[i+testNum*month_num]

            corrrlst = LP_relax_file.correction_single_obj(price, cost, pred_demand, true_demand)
            corr_obj_list.append(corrrlst)

        end_time = time.time()
        runtime = end_time - start_time

        print("MSE: ", mean_squared_error(true_demand_total, pred_demand_total), end=" ")
        print("TOV: ", sum(true_obj)/instance_num, "EOV: ", sum(corr_obj_list)/instance_num, "PReg: ", sum(abs(true_obj - np.array(corr_obj_list)))/instance_num, "runtime: ", runtime)

