import os
import sys
from collections import defaultdict
import numpy as np
import gurobipy as gp
from gurobipy import GRB
from sklearn.metrics import mean_squared_error
from numpy import inf
import torch
import math
import time

global month_num
global x_num
global y_num
global var_num
global relax_val

month_num = 12
x_num = month_num
y_num = month_num
var_num = x_num + y_num
relax_val = 1e-5

def mkdir(default_path, folder_name):
    path = os.path.join(default_path, folder_name)
    folder = os.path.exists(path)
    if not folder:
        os.makedirs(path)

def gen_obj(price, cost):
    c = np.concatenate([-cost, price], axis=0)
    return c

def gen_constraints(cur_stage, demand, prior_x=None, prior_y=None):
    # y_i <= d_i
    G1 = np.zeros((month_num, var_num))
    h1 = np.zeros(month_num)
    for i in range(month_num):
        G1[i][x_num+i] = 1
        h1[i] = demand[i]
    
    # y_i <= sum^{i-1}_{j=0} x_j - sum^{i-1}_{j=0} y_j
    G2 = np.zeros((month_num, var_num))
    h2 = np.zeros(month_num)
    for i in range(month_num):
        G2[i][x_num+i] = 1
        for j in range(i):
            G2[i][j] = -1
            G2[i][x_num+j] = 1
#    np.savetxt('G2.txt', G2, fmt="%.0f")
    
    G = np.concatenate([G1, G2], axis=0)
    h = np.concatenate([h1, h2], axis=0)
    
    A = None
    b = None
    if cur_stage > 0:
        # x_i = x^{t-1}_i
        A1 = np.eye(cur_stage, var_num)
        b1 = prior_x[:cur_stage]
        
        # y_i = y^{t-1}_i
        A2 = np.zeros((cur_stage, var_num))
        b2 = prior_y[:cur_stage]
        for i in range(cur_stage):
            A2[i][x_num+i] = 1
#        print(cur_stage, b1, b2)
        
        A = np.concatenate([A1, A2], axis=0)
        b = np.concatenate([b1, b2], axis=0)
    
    return G, h, A, b


def actual_obj_check(price, cost, true_demand_total, n_instance):
    obj_list = []
    for num in range(n_instance):
        demand = np.zeros(month_num)
        cnt = num * month_num
        for i in range(month_num):
            demand[i] = true_demand_total[cnt]
            cnt = cnt + 1
        
        c = gen_obj(price, cost)
        G, h, A, b = gen_constraints(0, demand)
        G_row_size = G.shape[0]
        c = c.tolist()
        G = G.tolist()
        h = h.tolist()
        
        m = gp.Model()
        m.setParam('OutputFlag', 0)
        x = m.addVars(var_num, vtype=GRB.CONTINUOUS, lb=0, name='x')
        m.setObjective(x.prod(c), GRB.MAXIMIZE)
        for i in range(G_row_size):
            m.addConstr(x.prod(G[i]) <= h[i])
#        m.addConstr(y[0] == 0)
#        if num == 10:
#            fixed_x = [0,23.02,0,0,179.54,0,0,0,0,0,0,0]
#            fixed_y = [0,0,23.02,0,0,20,29,16,26,31,33,24.54]
#            for i in range(month_num):
#                m.addConstr(x[i] == fixed_x[i])
#                m.addConstr(y[i] == fixed_y[i])
        
        m.optimize()
        
        try:
            objective = m.objVal
            x_sol = np.zeros(month_num)
            y_sol = np.zeros(month_num)
            for i in range(month_num):
                x_sol[i] = x[i].x
                y_sol[i] = x[x_num+i].x
        except:
            print("cannot solve")
            m.computeIIS()
            m.write('model.ilp')
            objective = 0
            x_sol = np.zeros(month_num)
            y_sol = np.zeros(month_num)
            time.sleep(100)
        
#        if num == 10:
#            print(demand)
#            print(x_sol, sum(x_sol))
#            print(y_sol, sum(y_sol))
#            print(objective)
    
        obj_list.append(objective)
#        print(objective)
        
    return np.array(obj_list)


def t_obj_check(t, price, cost, pred_demand, true_demand, prior_x=None, prior_y=None):
    c = gen_obj(price, cost)
    true_demand_revealed = true_demand[:t+1]
    pred_demand_used = pred_demand[t+1:]
    demand = np.concatenate([true_demand_revealed, pred_demand_used], axis=0)
    if t == 0:
        G, h, A, b = gen_constraints(0, demand)
        G_row_size = G.shape[0]
        c = c.tolist()
        G = G.tolist()
        h = h.tolist()
        
        m = gp.Model()
        m.setParam('OutputFlag', 0)
        x = m.addVars(var_num, vtype=GRB.CONTINUOUS, lb=0, name='x')
        m.setObjective(x.prod(c), GRB.MAXIMIZE)
        for i in range(G_row_size):
            m.addConstr(x.prod(G[i]) <= h[i])
    else:
        G, h, A, b = gen_constraints(t, demand, prior_x=prior_x, prior_y=prior_y)
        G_row_size = G.shape[0]
        A_row_size = A.shape[0]
        c = c.tolist()
        G = G.tolist()
        h = h.tolist()
        A = A.tolist()
        b = b.tolist()
        
        m = gp.Model()
        m.setParam('OutputFlag', 0)
        x = m.addVars(var_num, vtype=GRB.CONTINUOUS, lb=0, name='x')
        m.setObjective(x.prod(c), GRB.MAXIMIZE)
        for i in range(G_row_size):
            m.addConstr(x.prod(G[i]) <= h[i])
        for i in range(A_row_size):
            m.addConstr(x.prod(A[i]) == b[i])
            
    m.optimize()
    try:
        objective = m.objVal
        x_sol = np.zeros(month_num)
        y_sol = np.zeros(month_num)
        for i in range(month_num):
            x_sol[i] = x[i].x
            y_sol[i] = x[x_num+i].x
    except:
        print("cannot solve")
        m.computeIIS()
        m.write('model.ilp')
        objective = 0
        x_sol = np.zeros(month_num)
        y_sol = np.zeros(month_num)
        time.sleep(100)
    
#    print(t, " x: ", x_sol, sum(x_sol))
#    print(t, " y: ", y_sol, sum(y_sol))
#    print(objective)
    
    return x_sol, y_sol


def correction_check(price, cost, pred_demand, true_demand):
    init_x, init_y = t_obj_check(0, price, cost, pred_demand, true_demand)
    
    t_updated_x = init_x
    t_updated_y = init_y
    for t in range(1, month_num):
        t_updated_x, t_updated_y = t_obj_check(t, price, cost, pred_demand, true_demand, t_updated_x, t_updated_y)
    
    pure_prof = np.sum(price*t_updated_y) - np.sum(cost*t_updated_x)
#    print("pure_prof: ", pure_prof)
    
    return pure_prof



def actual_obj(price, cost, true_demand_total, n_instance):
    obj_list = []
    for num in range(n_instance):
        demand = np.zeros(month_num)
        cnt = num * month_num
        for i in range(month_num):
            demand[i] = true_demand_total[cnt]
            cnt = cnt + 1
        
        p = price.tolist()
        c = cost.tolist()
        demand = demand.tolist()
        
        m = gp.Model()
        m.setParam('OutputFlag', 0)
        x = m.addVars(month_num, vtype=GRB.CONTINUOUS, lb=0, name='x')
        y = m.addVars(month_num, vtype=GRB.CONTINUOUS, lb=0, name='y')
        m.setObjective(y.prod(p)-x.prod(c), GRB.MAXIMIZE)
        for i in range(month_num):
            m.addConstr(y[i] <= demand[i])
        for i in range(month_num):
            m.addConstr(y[i] <= gp.quicksum(x[j] for j in range(i)) - gp.quicksum(y[j] for j in range(i)))
#        m.addConstr(y[0] == 0)
#        if num == 10:
#            fixed_x = [0,23.02,0,0,179.54,0,0,0,0,0,0,0]
#            fixed_y = [0,0,23.02,0,0,20,29,16,26,31,33,24.54]
#            for i in range(month_num):
#                m.addConstr(x[i] == fixed_x[i])
#                m.addConstr(y[i] == fixed_y[i])
        
        m.optimize()
        
        try:
            objective = m.objVal
            x_sol = np.zeros(month_num)
            y_sol = np.zeros(month_num)
            for i in range(month_num):
                x_sol[i] = x[i].x
                y_sol[i] = y[i].x
        except:
            print("cannot solve")
            m.computeIIS()
            m.write('model.ilp')
            objective = 0
            x_sol = np.zeros(month_num)
            y_sol = np.zeros(month_num)
            time.sleep(100)
        
#        if num == 10:
#            print(demand)
#            print(x_sol, sum(x_sol))
#            print(y_sol, sum(y_sol))
#            print(objective)
    
        obj_list.append(objective)
#        print(objective)
        
    return np.array(obj_list)


def get_init_plan(price, cost, pred_demand, true_demand):
    true_demand_revealed = true_demand[:1]
    pred_demand_used = pred_demand[1:]
    demand = np.concatenate([true_demand_revealed, pred_demand_used], axis=0)
    p = price.tolist()
    c = cost.tolist()
#    demand = demand.tolist()
    
    m = gp.Model()
    m.setParam('OutputFlag', 0)
    x = m.addVars(month_num, vtype=GRB.CONTINUOUS, lb=0, name='x')
    y = m.addVars(month_num, vtype=GRB.CONTINUOUS, lb=0, name='y')
    m.setObjective(y.prod(p)-x.prod(c), GRB.MAXIMIZE)
    for i in range(month_num):
        m.addConstr(y[i] <= pred_demand[i])
    for i in range(month_num):
        m.addConstr(y[i] <= gp.quicksum(x[j] for j in range(i)) - gp.quicksum(y[j] for j in range(i)))
#        m.addConstr(y[0] == 0)

    m.optimize()
    try:
        objective = m.objVal
        x_sol = np.zeros(month_num)
        y_sol = np.zeros(month_num)
        for i in range(month_num):
            x_sol[i] = x[i].x
            y_sol[i] = y[i].x
    except:
        print("cannot solve")
        m.computeIIS()
        m.write('model.ilp')
        objective = 0
        x_sol = np.zeros(month_num)
        y_sol = np.zeros(month_num)
    
#    print("0 x: ", x_sol, sum(x_sol))
#    print("0 y: ", y_sol, sum(y_sol))
#    print(objective)
    
    return x_sol, y_sol


def get_t_updated_plan(t, price, cost, pred_demand, true_demand, prior_x, prior_y):
    true_demand_revealed = true_demand[:t+1]
    pred_demand_used = pred_demand[t+1:]
    demand = np.concatenate([true_demand_revealed, pred_demand_used], axis=0)
#    print(t, true_demand_revealed)
    
    p = price.tolist()
    c = cost.tolist()
    demand = demand.tolist()
    
    m = gp.Model()
    m.setParam('OutputFlag', 0)
    x = m.addVars(month_num, vtype=GRB.CONTINUOUS, lb=0, name='x')
    y = m.addVars(month_num, vtype=GRB.CONTINUOUS, lb=0, name='y')
    m.setObjective(y.prod(p)-x.prod(c), GRB.MAXIMIZE)
    for i in range(month_num):
        m.addConstr(y[i] <= demand[i])
    for i in range(month_num):
        m.addConstr(y[i] <= gp.quicksum(x[j] for j in range(i)) - gp.quicksum(y[j] for j in range(i)))
    # fixed committed variables
    for i in range(t):
#        print("fixed ", i, end=" ")
        m.addConstr(x[i] == prior_x[i])
        m.addConstr(y[i] == prior_y[i])
#        m.addConstr(y[0] == 0)

    m.optimize()
    try:
        objective = m.objVal
        x_sol = np.zeros(month_num)
        y_sol = np.zeros(month_num)
        for i in range(month_num):
            x_sol[i] = x[i].x
            y_sol[i] = y[i].x
    except:
        print("cannot solve")
        m.computeIIS()
        m.write('model.ilp')
        objective = 0
        x_sol = np.zeros(month_num)
        y_sol = np.zeros(month_num)
        time.sleep(100)
    
#    print(t, " x: ", x_sol, sum(x_sol))
#    print(t, " y: ", y_sol, sum(y_sol))
#    print(objective)
    
    return x_sol, y_sol


def correction_single_obj(price, cost, pred_demand, true_demand):
    init_x, init_y = get_init_plan(price, cost, pred_demand, true_demand)
    
    t_updated_x = init_x
    t_updated_y = init_y
    for t in range(1, month_num):
        t_updated_x, t_updated_y = get_t_updated_plan(t, price, cost, pred_demand, true_demand, t_updated_x, t_updated_y)
    
    pure_prof = np.sum(price*t_updated_y) - np.sum(cost*t_updated_x)
#    print("pure_prof: ", pure_prof)
    
    return pure_prof
