import os
import sys
from collections import defaultdict
import numpy as np
import gurobipy as gp
from gurobipy import GRB
from sklearn.metrics import mean_squared_error
from numpy import inf
import torch
import math
import time
import linear_relax as LP_relax_file

item_num = LP_relax_file.item_num
train_set_size = 70
total_month_num = LP_relax_file.month_num
train_TOV_x = np.zeros((train_set_size, LP_relax_file.x_num))
train_TOV_y = np.zeros((train_set_size, LP_relax_file.y_num))
train_TOV_z = np.zeros((train_set_size, LP_relax_file.z_num))
train_TOV_prev_prof = np.zeros((train_set_size, total_month_num))
train_TOV_prev_cost = np.zeros((train_set_size, total_month_num))
simulation_time = 10

cap = int(sys.argv[1])
trans_fee_percent = float(sys.argv[2])
LP_relax_file.set_capacity(cap)
LP_relax_file.set_trans_fee_percent(trans_fee_percent)
startmark = int(sys.argv[3])
endmark = int(sys.argv[4])

dataset_path = os.path.abspath(os.path.dirname(os.getcwd()))
default_path = os.path.join(dataset_path, 'data/trans_fee=' + str(trans_fee_percent) +'/v' + str(LP_relax_file.version_num) + '(item_num=' + str(item_num) + ',month_num=' + str(total_month_num) + ',trans_fee=' + str(trans_fee_percent) + ',cap=' + str(LP_relax_file.capacity) + ')/')

LP_relax_file.mkdir(default_path, 'true_owned_prod')
LP_relax_file.mkdir(default_path, 'true_y')
LP_relax_file.mkdir(default_path, 'true_z')
LP_relax_file.mkdir(default_path, 'true_cost')
LP_relax_file.mkdir(default_path, 'true_prof')

def make_next_plan(test_num, NN_cnt, pred_price, pred_weight, true_price, true_weight):
    global train_TOV_x
    global train_TOV_y
    global train_TOV_z
        
    if NN_cnt == 0:
        A_0, b_0, G_0, h_0 = LP_relax_file.gen_constraints(0)
        x_sol, y_sol, z_sol = LP_relax_file.get_init_plan(pred_price, pred_weight, true_weight, A_0, b_0, G_0, h_0)
        for i in range(item_num):
            train_TOV_x[test_num][i] = x_sol[i]
            
    else:
        x_prev_sol = train_TOV_x[test_num][(NN_cnt-1)*item_num:NN_cnt*item_num]
        A_0, b_0, G_0, h_0 = LP_relax_file.gen_constraints_latter_months(LP_relax_file.month_num, x_prev_sol)
        x_sol, y_sol, z_sol = LP_relax_file.get_updated_plan_for_each_month(LP_relax_file.month_num, train_TOV_prev_prof[test_num][NN_cnt], train_TOV_prev_cost[test_num][NN_cnt], pred_price, pred_weight, true_price, true_weight, A_0, b_0, G_0, h_0)
        for i in range(item_num):
            train_TOV_x[test_num][NN_cnt*item_num+i] = x_sol[i]
            train_TOV_y[test_num][(NN_cnt-1)*item_num+i] = y_sol[i]
            train_TOV_z[test_num][(NN_cnt-1)*item_num+i] = z_sol[i]



print("item_num: ", item_num, " month_num: ", LP_relax_file.month_num, " trans_fee_percent: ", trans_fee_percent, " capacity: ", LP_relax_file.capacity)

for testi in range(startmark, endmark):
    print(testi, end=" ")
    train_prices_full = np.loadtxt(os.path.join(dataset_path, 'data/item_num=5, month_num=' + str(LP_relax_file.month_num) + '/rescale_train_prices/rescale_train_prices(' + str(testi) + ').txt'))
    train_weights_full = np.loadtxt(os.path.join(dataset_path, 'data/item_num=5, month_num=' + str(LP_relax_file.month_num) + '/train_weights/train_weights(' + str(testi) + ').txt'))

    train_obj = 0
    
    for NN_cnt in range(total_month_num):
        cur_month_num = total_month_num - NN_cnt
        if NN_cnt == 0:
            LP_relax_file.reset_month_num()
        if NN_cnt > 0:
            LP_relax_file.change_month_num(cur_month_num)
        
        # compute current states
        if NN_cnt > 0:
            for train_num in range(train_set_size):
                true_weight = np.zeros(item_num * total_month_num)
                true_value = np.zeros(item_num * total_month_num)
                cnt = train_num * item_num * total_month_num
                for i in range(item_num * total_month_num):
                    true_weight[i] = train_weights_full[cnt]
                    true_value[i] = train_prices_full[cnt]
                    cnt = cnt + 1
                x_prev_sol = train_TOV_x[train_num]
                y_prev_sol = train_TOV_y[train_num]
                z_prev_sol = train_TOV_z[train_num]
                train_TOV_prev_prof[train_num][NN_cnt] = np.dot(true_value[:NN_cnt*item_num], x_prev_sol[:NN_cnt*item_num])
                train_TOV_prev_cost[train_num][NN_cnt] = np.dot(true_weight[:item_num], x_prev_sol[:item_num]) + trans_fee_percent*np.dot(true_weight[item_num:NN_cnt*item_num], z_prev_sol[:(NN_cnt-1)*item_num]) + np.dot(true_weight[item_num:NN_cnt*item_num], y_prev_sol[:(NN_cnt-1)*item_num])
        
        # Compute the NN_cnt plans
        for train_num in range(train_set_size):
            cnt = train_num * total_month_num * item_num + NN_cnt * item_num
            true_price = np.zeros(LP_relax_file.x_num)
            true_weight = np.zeros(LP_relax_file.x_num)
            pred_price = np.zeros(LP_relax_file.x_num)
            pred_weight = np.zeros(LP_relax_file.x_num)
            for i in range(LP_relax_file.x_num):
                true_price[i] = train_prices_full[cnt]
    #            pred_price[i] = pred_price_full[cnt]
                true_weight[i] = train_weights_full[cnt]
    #            pred_weight[i] = pred_weight_full[cnt]
                cnt = cnt + 1
            make_next_plan(train_num, NN_cnt, true_price, true_weight, true_price, true_weight)
        
            if NN_cnt == total_month_num - 1:
                curr_prof = train_TOV_prev_prof[train_num][NN_cnt] + np.dot(true_price, train_TOV_x[train_num][(total_month_num-1)*item_num:]) + np.dot(true_weight, train_TOV_x[train_num][(total_month_num-1)*item_num:])
                curr_cost = train_TOV_prev_cost[train_num][NN_cnt] + trans_fee_percent*np.dot(true_weight, train_TOV_z[train_num][(total_month_num-2)*item_num:]) + np.dot(true_weight, train_TOV_y[train_num][(total_month_num-2)*item_num:])
                train_obj = train_obj + curr_prof - curr_cost
                
    
    np.savetxt(os.path.join(default_path, 'true_owned_prod/true_owned_prod(' + str(testi) + ').txt'), train_TOV_x, fmt="%.0f")
    np.savetxt(os.path.join(default_path, 'true_y/true_y(' + str(testi) + ').txt'), train_TOV_y, fmt="%.0f")
    np.savetxt(os.path.join(default_path, 'true_z/true_z(' + str(testi) + ').txt'), train_TOV_z, fmt="%.0f")
    np.savetxt(os.path.join(default_path, 'true_cost/true_cost(' + str(testi) + ').txt'), train_TOV_prev_cost, fmt="%.2f")
    np.savetxt(os.path.join(default_path, 'true_prof/true_prof(' + str(testi) + ').txt'), train_TOV_prev_prof, fmt="%.2f")
#    print(train_TOV_prev_cost, train_TOV_prev_prof)
    LP_relax_file.reset_month_num()
    print("TOV: ", np.sum(train_obj)/train_set_size)


    # reset
    train_TOV_x = train_TOV_x * 0
    train_TOV_y = train_TOV_y * 0
    train_TOV_z = train_TOV_z * 0
    train_TOV_prev_prof = train_TOV_prev_prof * 0
    train_TOV_prev_cost = train_TOV_prev_cost * 0
