import os
import sys
from collections import defaultdict
import numpy as np
import gurobipy as gp
from gurobipy import GRB
from sklearn.metrics import mean_squared_error
from numpy import inf
import torch
import math
import time

global month_num
global x_num
global y_num
global var_num
global relax_val

total_month_num = 12
month_num = total_month_num
x_num = month_num
y_num = month_num
var_num = x_num + y_num
relax_val = 1e-5

def mkdir(default_path, folder_name):
    path = os.path.join(default_path, folder_name)
    folder = os.path.exists(path)
    if not folder:
        os.makedirs(path)

def change_month_num(cur_month_num):
    global month_num
    global x_num
    global y_num
    global var_num
    month_num = cur_month_num
    x_num = cur_month_num
    y_num = cur_month_num
    var_num = x_num + y_num
    

def reset_month_num():
    global month_num
    global x_num
    global y_num
    global var_num
    month_num = total_month_num
    x_num = total_month_num
    y_num = total_month_num
    var_num = x_num + y_num


def gen_obj(price, cost):
    c = np.concatenate([-cost, price], axis=0)
    return c


def gen_constraints(cur_stage, demand, prior_x=None, prior_y=None):
    # y_i <= d_i
    G1 = np.zeros((month_num, var_num))
    h1 = np.zeros(month_num)
    for i in range(month_num):
        G1[i][x_num+i] = 1
        h1[i] = demand[i]
    
    # y_i <= sum^{i-1}_{j=0} x_j - sum^{i-1}_{j=0} y_j
    G2 = np.zeros((month_num, var_num))
    h2 = np.zeros(month_num)
    for i in range(month_num):
        G2[i][x_num+i] = 1
        for j in range(i):
            G2[i][j] = -1
            G2[i][x_num+j] = 1
#    np.savetxt('G2.txt', G2, fmt="%.0f")
    
    G = np.concatenate([G1, G2], axis=0)
    h = np.concatenate([h1, h2], axis=0)
    
    A = None
    b = None
    if cur_stage > 0:
        # x_i = x^{t-1}_i
        A1 = np.eye(cur_stage, var_num)
        b1 = prior_x[:cur_stage]
#        np.savetxt('A1.txt', A1, fmt="%.0f")
#        np.savetxt('b1.txt', b1, fmt="%.2f")
        
        # y_i = y^{t-1}_i
        A2 = np.zeros((cur_stage, var_num))
        b2 = prior_y[:cur_stage]
        for i in range(cur_stage):
            A2[i][x_num+i] = 1
#        np.savetxt('A2.txt', A2, fmt="%.0f")
#        np.savetxt('b2.txt', b2, fmt="%.2f")
#        print(cur_stage, b1, b2)
        
        A = np.concatenate([A1, A2], axis=0)
        b = np.concatenate([b1, b2], axis=0)
    
    return G, h, A, b


def actual_obj_check(price, cost, true_demand_total, n_instance):
    obj_list = []
    for num in range(n_instance):
        demand = np.zeros(month_num)
        cnt = num * month_num
        for i in range(month_num):
            demand[i] = true_demand_total[cnt]
            cnt = cnt + 1
        
        c = gen_obj(price, cost)
        G, h, A, b = gen_constraints(0, demand)
        G_row_size = G.shape[0]
        c = c.tolist()
        G = G.tolist()
        h = h.tolist()
        
        m = gp.Model()
        m.setParam('OutputFlag', 0)
        x = m.addVars(var_num, vtype=GRB.CONTINUOUS, lb=0, name='x')
        m.setObjective(x.prod(c), GRB.MAXIMIZE)
        for i in range(G_row_size):
            m.addConstr(x.prod(G[i]) <= h[i])
#        m.addConstr(y[0] == 0)
#        if num == 10:
#            fixed_x = [0,23.02,0,0,179.54,0,0,0,0,0,0,0]
#            fixed_y = [0,0,23.02,0,0,20,29,16,26,31,33,24.54]
#            for i in range(month_num):
#                m.addConstr(x[i] == fixed_x[i])
#                m.addConstr(y[i] == fixed_y[i])
        
        m.optimize()
        
        try:
            objective = m.objVal
            x_sol = np.zeros(month_num)
            y_sol = np.zeros(month_num)
            for i in range(month_num):
                x_sol[i] = x[i].x
                y_sol[i] = x[x_num+i].x
        except:
            print("cannot solve")
            m.computeIIS()
            m.write('model.ilp')
            objective = 0
            x_sol = np.zeros(month_num)
            y_sol = np.zeros(month_num)
            time.sleep(100)
        
#        if num == 10:
#            print(demand)
#            print(x_sol, sum(x_sol))
#            print(y_sol, sum(y_sol))
#            print(objective)
    
        obj_list.append(objective)
#        print(objective)
        
    return np.array(obj_list)


def t_obj_check(t, price, cost, pred_demand, true_demand, prior_x=None, prior_y=None):
    c = gen_obj(price, cost)
    true_demand_revealed = true_demand[:t+1]
    pred_demand_used = pred_demand[t+1:]
    demand = np.concatenate([true_demand_revealed, pred_demand_used], axis=0)
    if t == 0:
        G, h, A, b = gen_constraints(0, demand)
        G_row_size = G.shape[0]
        c = c.tolist()
        G = G.tolist()
        h = h.tolist()
        
        m = gp.Model()
        m.setParam('OutputFlag', 0)
        x = m.addVars(var_num, vtype=GRB.CONTINUOUS, lb=0, name='x')
        m.setObjective(x.prod(c), GRB.MAXIMIZE)
        for i in range(G_row_size):
            m.addConstr(x.prod(G[i]) <= h[i])
    else:
        G, h, A, b = gen_constraints(t, demand, prior_x=prior_x, prior_y=prior_y)
        G_row_size = G.shape[0]
        A_row_size = A.shape[0]
        c = c.tolist()
        G = G.tolist()
        h = h.tolist()
        A = A.tolist()
        b = b.tolist()
        
        m = gp.Model()
        m.setParam('OutputFlag', 0)
        x = m.addVars(var_num, vtype=GRB.CONTINUOUS, lb=0, name='x')
        m.setObjective(x.prod(c), GRB.MAXIMIZE)
        for i in range(G_row_size):
            m.addConstr(x.prod(G[i]) <= h[i])
        for i in range(A_row_size):
            m.addConstr(x.prod(A[i]) == b[i])
            
    m.optimize()
    try:
        objective = m.objVal
        x_sol = np.zeros(month_num)
        y_sol = np.zeros(month_num)
        for i in range(month_num):
            x_sol[i] = x[i].x
            y_sol[i] = x[x_num+i].x
    except:
        print("cannot solve")
        m.computeIIS()
        m.write('model.ilp')
        objective = 0
        x_sol = np.zeros(month_num)
        y_sol = np.zeros(month_num)
        time.sleep(100)
    
#    print(t, " x: ", x_sol, sum(x_sol))
#    print(t, " y: ", y_sol, sum(y_sol))
#    print(objective)
    
    return x_sol, y_sol


def correction_check(price, cost, pred_demand, true_demand):
    init_x, init_y = t_obj_check(0, price, cost, pred_demand, true_demand)
    
    t_updated_x = init_x
    t_updated_y = init_y
    for t in range(1, month_num):
        t_updated_x, t_updated_y = t_obj_check(t, price, cost, pred_demand, true_demand, t_updated_x, t_updated_y)
    
    pure_prof = np.sum(price*t_updated_y) - np.sum(cost*t_updated_x)
#    print("pure_prof: ", pure_prof)
    
    return pure_prof



def actual_obj(price, cost, true_demand_total, n_instance):
    obj_list = []
    for num in range(n_instance):
        demand = np.zeros(month_num)
        cnt = num * month_num
        for i in range(month_num):
            demand[i] = true_demand_total[cnt]
            cnt = cnt + 1
        
        p = price.tolist()
        c = cost.tolist()
        demand = demand.tolist()
        
        m = gp.Model()
        m.setParam('OutputFlag', 0)
        x = m.addVars(month_num, vtype=GRB.CONTINUOUS, lb=0, name='x')
        y = m.addVars(month_num, vtype=GRB.CONTINUOUS, lb=0, name='y')
        m.setObjective(y.prod(p)-x.prod(c), GRB.MAXIMIZE)
        for i in range(month_num):
            m.addConstr(y[i] <= demand[i])
        for i in range(month_num):
            m.addConstr(y[i] <= gp.quicksum(x[j] for j in range(i)) - gp.quicksum(y[j] for j in range(i)))
#        m.addConstr(y[0] == 0)
#        if num == 10:
#            fixed_x = [0,23.02,0,0,179.54,0,0,0,0,0,0,0]
#            fixed_y = [0,0,23.02,0,0,20,29,16,26,31,33,24.54]
#            for i in range(month_num):
#                m.addConstr(x[i] == fixed_x[i])
#                m.addConstr(y[i] == fixed_y[i])
        
        m.optimize()
        
        try:
            objective = m.objVal
            x_sol = np.zeros(month_num)
            y_sol = np.zeros(month_num)
            for i in range(month_num):
                x_sol[i] = x[i].x
                y_sol[i] = y[i].x
        except:
            print("cannot solve")
            m.computeIIS()
            m.write('model.ilp')
            objective = 0
            x_sol = np.zeros(month_num)
            y_sol = np.zeros(month_num)
            time.sleep(100)
        
#        if num == 10:
#            print(demand)
#            print(x_sol, sum(x_sol))
#            print(y_sol, sum(y_sol))
#            print(objective)
    
        obj_list.append(objective)
#        print(objective)
        
    return np.array(obj_list)


def get_init_plan(price, cost, pred_demand, true_demand):
    true_demand_revealed = true_demand[:1]
    pred_demand_used = pred_demand[1:]
    demand = np.concatenate([true_demand_revealed, pred_demand_used], axis=0)
    p = price.tolist()
    c = cost.tolist()
#    demand = demand.tolist()
    
    m = gp.Model()
    m.setParam('OutputFlag', 0)
    x = m.addVars(month_num, vtype=GRB.CONTINUOUS, lb=0, name='x')
    y = m.addVars(month_num, vtype=GRB.CONTINUOUS, lb=0, name='y')
    m.setObjective(y.prod(p)-x.prod(c), GRB.MAXIMIZE)
    for i in range(month_num):
        m.addConstr(y[i] <= pred_demand[i])
    for i in range(month_num):
        m.addConstr(y[i] <= gp.quicksum(x[j] for j in range(i)) - gp.quicksum(y[j] for j in range(i)))
#        m.addConstr(y[0] == 0)

    m.optimize()
    try:
        objective = m.objVal
        x_sol = np.zeros(month_num)
        y_sol = np.zeros(month_num)
        for i in range(month_num):
            x_sol[i] = x[i].x
            y_sol[i] = y[i].x
    except:
        print("cannot solve")
        m.computeIIS()
        m.write('model.ilp')
        objective = 0
        x_sol = np.zeros(month_num)
        y_sol = np.zeros(month_num)
    
#    print("0 x: ", x_sol, sum(x_sol))
#    print("0 y: ", y_sol, sum(y_sol))
#    print(objective)
    
    return x_sol, y_sol


def get_t_updated_plan(t, price, cost, pred_demand, true_demand, prior_x, prior_y):
    true_demand_revealed = true_demand[:t+1]
    pred_demand_used = pred_demand[t+1:]
    demand = np.concatenate([true_demand_revealed, pred_demand_used], axis=0)
#    print(t, true_demand_revealed)
    
    p = price.tolist()
    c = cost.tolist()
    demand = demand.tolist()
    
    m = gp.Model()
    m.setParam('OutputFlag', 0)
    x = m.addVars(month_num, vtype=GRB.CONTINUOUS, lb=0, name='x')
    y = m.addVars(month_num, vtype=GRB.CONTINUOUS, lb=0, name='y')
    m.setObjective(y.prod(p)-x.prod(c), GRB.MAXIMIZE)
    for i in range(month_num):
        m.addConstr(y[i] <= demand[i])
    for i in range(month_num):
        m.addConstr(y[i] <= gp.quicksum(x[j] for j in range(i)) - gp.quicksum(y[j] for j in range(i)))
    # fixed committed variables
    for i in range(t):
#        print("fixed ", i, end=" ")
        m.addConstr(x[i] == prior_x[i])
        m.addConstr(y[i] == prior_y[i])
#        m.addConstr(y[0] == 0)

    m.optimize()
    try:
        objective = m.objVal
        x_sol = np.zeros(month_num)
        y_sol = np.zeros(month_num)
        for i in range(month_num):
            x_sol[i] = x[i].x
            y_sol[i] = y[i].x
    except:
        print("cannot solve")
        m.computeIIS()
        m.write('model.ilp')
        objective = 0
        x_sol = np.zeros(month_num)
        y_sol = np.zeros(month_num)
        time.sleep(100)
    
#    print(t, " x: ", x_sol, sum(x_sol))
#    print(t, " y: ", y_sol, sum(y_sol))
#    print(objective)
    
    return x_sol, y_sol


def correction_single_obj(price, cost, pred_demand, true_demand):
    init_x, init_y = get_init_plan(price, cost, pred_demand, true_demand)
    
    t_updated_x = init_x
    t_updated_y = init_y
    for t in range(1, month_num):
        t_updated_x, t_updated_y = get_t_updated_plan(t, price, cost, pred_demand, true_demand, t_updated_x, t_updated_y)
    
    pure_prof = np.sum(price*t_updated_y) - np.sum(cost*t_updated_x)
#    print("pure_prof: ", pure_prof)
    
    return pure_prof

# price, cost: full version
def gen_obj_latter_days(t, price, cost):
    c = np.concatenate([-cost[t:], price[t:]], axis=0)
    return c

# curr_stocking: real number
def gen_constraints_latter_days(t, demand, curr_stocking):
    # y_i <= d_i
    G1 = np.zeros((month_num, var_num))
    h1 = np.zeros(month_num)
    for i in range(month_num):
        G1[i][x_num+i] = 1
        h1[i] = demand[i]
    
    # y_i <= sum^{i-1}_{j=0} x_j - sum^{i-1}_{j=0} y_j
    G2 = np.zeros((month_num, var_num))
    h2 = np.ones(month_num)
    for i in range(month_num):
        G2[i][x_num+i] = 1
        for j in range(i):
            G2[i][j] = -1
            G2[i][x_num+j] = 1
    h2 = h2 * curr_stocking
#    np.savetxt('G2.txt', G2, fmt="%.0f")
    
    G = np.concatenate([G1, G2], axis=0)
    h = np.concatenate([h1, h2], axis=0)
    
    return G, h


def get_updated_plan_for_each_day(t, c, G, h):
    rowSizeG = G.shape[0]

    c = c.tolist()
    G = G.tolist()
    h = h.tolist()

#    np.savetxt('c.txt', c, fmt="%.2f")
#    np.savetxt('h.txt', h, fmt="%.2f")
    
    m = gp.Model()
    m.setParam('OutputFlag', 0)
    x = m.addVars(var_num, vtype=GRB.CONTINUOUS, name='x')
    m.setObjective(x.prod(c), GRB.MAXIMIZE)
    for i in range(rowSizeG):
        m.addConstr(x.prod(G[i]) <= h[i])
#    m.addConstr(x[x_num+sigma_num-1] == 0)
#    m.addConstr(x[x_num+5] == 2)

    m.optimize()
    t_updated_x = np.zeros(x_num)
    t_updated_y = np.zeros(y_num)

    try:
        for i in range(x_num):
            t_updated_x[i] = x[i].x
            t_updated_y[i] = x[x_num+i].x
        objective = m.objVal
    except:
        print("Stage ", t, ", cannot solve")
        m.computeIIS()
        m.write('model.ilp')
        np.savetxt('c.txt', c, fmt="%.2f")
        np.savetxt('G.txt', G, fmt="%.2f")
        np.savetxt('h.txt', h, fmt="%.2f")
        time.sleep(100)
#        for i in range(decision_num):
#            t_updated_sol[i] = 0
        objective = 0
    
    return t_updated_x, t_updated_y


# price, cost: full version
# pred_demand, true_demand: cur_month_num
def correction_single_for_latter_days(cur_NN, price, cost, pred_demand, true_demand, curr_profit, curr_stocking):
#    print("before: ", curr_profit)
    total_profit = curr_profit
    total_stocking = curr_stocking
    for t in range(cur_NN, total_month_num):
        cur_month_num = total_month_num - t
        change_month_num(cur_month_num)
        demand = np.zeros(cur_month_num)
        demand[0] = true_demand[t-cur_NN]
        for i in range(1, cur_month_num):
            demand[i] = pred_demand[t-cur_NN+i]
#        demand = np.concatenate([true_demand[t-cur_NN], pred_demand[t-cur_NN+1:]], axis=0)
        G_t, h_t = gen_constraints_latter_days(t, demand, total_stocking)
        c_t = gen_obj_latter_days(t, price, cost)
        t_updated_x, t_updated_y = get_updated_plan_for_each_day(t, c_t, G_t, h_t)
        
        # compute current states
#        print(t, t_updated_x[0], t_updated_y[0])
        new_profit = price[t] * t_updated_y[0] - cost[t] * t_updated_x[0]
        total_profit += new_profit
        new_stocking = t_updated_x[0] - t_updated_y[0]
        total_stocking += new_stocking
#        print(curr_cost, t_incur_penalty)
    
    cur_month_num = total_month_num - cur_NN
    change_month_num(cur_month_num)
#    print(has_rested)
#    print("after: ", curr_profit)
    return total_profit


def gen_obj_latter_days_full_intOpt(cur_NN, price, cost):
    c = np.concatenate([-cost[cur_NN:], price[cur_NN:]], axis=0)
    return c

# t = cur_stage
# cur_month_num = total_month_num - cur_NN
# demand: cur_month_num
# x_prev_stage, y_prev_stage: cur_month_num
def gen_constraints_latter_days_full_intOpt(cur_NN, cur_stage, demand, curr_stocking, x_prev_stage=None, y_prev_stage=None):

    # y_i <= d_i
    G1 = np.zeros((month_num, var_num))
    h1 = np.zeros(month_num)
    for i in range(month_num):
        G1[i][x_num+i] = 1
        h1[i] = demand[i]
    
    # y_i <= sum^{i-1}_{j=0} x_j - sum^{i-1}_{j=0} y_j
    G2 = np.zeros((month_num, var_num))
    h2 = np.ones(month_num)
    for i in range(month_num):
        G2[i][x_num+i] = 1
        for j in range(i):
            G2[i][j] = -1
            G2[i][x_num+j] = 1
    h2 = h2 * curr_stocking
#    np.savetxt('G2.txt', G2, fmt="%.0f")
    
    G = np.concatenate([G1, G2], axis=0)
    h = np.concatenate([h1, h2], axis=0)
    
    A = None
    b = None
    if cur_stage > cur_NN:
        # x_i = x^{t-1}_i
        A1 = np.eye(cur_stage-cur_NN, var_num)
        b1 = x_prev_stage[:cur_stage-cur_NN]
#        np.savetxt('A1.txt', A1, fmt="%.0f")
#        np.savetxt('b1.txt', b1, fmt="%.2f")
        
        # y_i = y^{t-1}_i
        A2 = np.zeros((cur_stage-cur_NN, var_num))
        b2 = y_prev_stage[:cur_stage-cur_NN]
        for i in range(cur_stage-cur_NN):
            A2[i][x_num+i] = 1
#        np.savetxt('A2.txt', A2, fmt="%.0f")
#        np.savetxt('b2.txt', b2, fmt="%.2f")
#        print(cur_stage, b1, b2)
        
        A = np.concatenate([A1, A2], axis=0)
        b = np.concatenate([b1, b2], axis=0)
    
    return G, h, A, b


def get_updated_plan_for_each_day_full_intOpt(t, c, G, h, A=None, b=None):
    rowSizeG = G.shape[0]
    c = c.tolist()
    G = G.tolist()
    h = h.tolist()
    
    if A is not None:
        rowSizeA = A.shape[0]
        A = A.tolist()
        b = b.tolist()

#    np.savetxt('c.txt', c, fmt="%.2f")
#    np.savetxt('h.txt', h, fmt="%.2f")
    
    m = gp.Model()
    m.setParam('OutputFlag', 0)
    x = m.addVars(var_num, vtype=GRB.CONTINUOUS, name='x')
    m.setObjective(x.prod(c), GRB.MAXIMIZE)
    for i in range(rowSizeG):
        m.addConstr(x.prod(G[i]) <= h[i])
    if A is not None:
        for i in range(rowSizeA):
            m.addConstr(x.prod(A[i]) == b[i])
#    m.addConstr(x[x_num+sigma_num-1] == 0)
#    m.addConstr(x[x_num+5] == 2)

    m.optimize()
    t_updated_x = np.zeros(x_num)
    t_updated_y = np.zeros(y_num)

    try:
        for i in range(x_num):
            t_updated_x[i] = x[i].x
            t_updated_y[i] = x[x_num+i].x
        objective = m.objVal
    except:
        print("Stage ", t, ", cannot solve")
        m.computeIIS()
        m.write('model.ilp')
        np.savetxt('c.txt', c, fmt="%.2f")
        np.savetxt('G.txt', G, fmt="%.2f")
        np.savetxt('h.txt', h, fmt="%.2f")
        np.savetxt('A.txt', A, fmt="%.2f")
        np.savetxt('b.txt', b, fmt="%.2f")
        time.sleep(100)
#        for i in range(decision_num):
#            t_updated_sol[i] = 0
        objective = 0
    
    return t_updated_x, t_updated_y


# price, cost: full version
# pred_demand, true_demand: cur_month_num
def check_grad_compute_used(cur_NN, price, cost, pred_demand, true_demand, curr_profit, curr_stocking):

    total_profit = curr_profit
    total_stocking = curr_stocking
    for cur_stage in range(cur_NN, total_month_num):
        demand = np.concatenate([true_demand[:cur_stage-cur_NN+1], pred_demand[cur_stage-cur_NN+1:]], axis=0)
#        print(true_demand[:cur_stage-cur_NN], pred_demand[cur_stage-cur_NN+1:])
        c_t = gen_obj_latter_days(cur_NN, price, cost)
        
        if cur_stage == cur_NN:
            G_t, h_t, A_t, b_t = gen_constraints_latter_days_full_intOpt(cur_NN, cur_stage, demand, total_stocking)
            t_updated_x, t_updated_y = get_updated_plan_for_each_day_full_intOpt(cur_stage, c_t, G_t, h_t)
        else:
            G_t, h_t, A_t, b_t = gen_constraints_latter_days_full_intOpt(cur_NN, cur_stage, demand, total_stocking, x_prev_stage=t_updated_x, y_prev_stage=t_updated_y)
            t_updated_x, t_updated_y = get_updated_plan_for_each_day_full_intOpt(cur_stage, c_t, G_t, h_t, A_t, b_t)
        
        # compute current states
        new_profit = price[cur_stage] * t_updated_y[cur_stage-cur_NN] - cost[cur_stage] * t_updated_x[cur_stage-cur_NN]
        total_profit += new_profit
#        new_stocking = t_updated_x[cur_stage-cur_NN] - t_updated_y[cur_stage-cur_NN]
#        total_stocking += new_stocking
    
    return total_profit
