import numpy as np
import os, sys
import subprocess
from collections import defaultdict
import codecs
import gurobipy as gp
from gurobipy import GRB
from gurobipy import quicksum as qsum

sys.path.append(os.getcwd())
from src.misc.utils import mat2str

def solveLCP(env, params, theta, mu, G, B, K, desiredDistrib=None, desiredProd=None):
    l1 = 10e-5
    t = env.time
    availableProd = [(i,max(env.acc[t-1][i] + env.arrival_prod[t][i],0)) for i in env.scenario.factory]
    # print('available:', availableProd)
    # print('desiredDist:', desiredDistrib)
    
    qhat_s = np.array([int(desiredDistrib[i]*sum([v for i,v in availableProd])) for i in range(len(env.scenario.warehouse))])
    qhat_w = np.array([max(int(desiredProd[i].item()),0) for i in env.scenario.factory])
    # print('qhat_s:', qhat_s)
    # print(np.sum(qhat_s), np.sum(qhat_w))
    mT = np.array(params['mT'])
    mO = np.array(params['mO'])

    qw = np.array([max(env.acc[t-1][i] + env.arrival_prod[t][i],0) for i in env.scenario.factory])
    # qw = np.array([max(env.acc[t-1][i],0) for i in env.scenario.factory])
    qs = np.array([max(env.acc[t-1][i] + env.arrival_flow[t][i],0) for i in env.scenario.warehouse])
    # print(env.time, 'qw:', qw)

    cs = np.array([env.scenario.storage_capacities[i] for i in env.scenario.warehouse])
    cw = np.array([env.scenario.storage_capacities[i] for i in env.scenario.factory])

    # print(env.time, 'qs:', qs)
    # print(env.time, 'cs:', cs)

    lbd = np.array([env.demand[t][i] for i in env.scenario.warehouse])

    m = gp.Model('single_step')
    m.setParam('OutputFlag', 0)

    f = m.addMVar(shape = len(env.random_graph.edges), lb = 0, vtype = GRB.CONTINUOUS, name = 'f')
    w = m.addMVar(shape = len(qw), lb = 0, vtype = GRB.CONTINUOUS, name = 'w')
    eps_s = m.addMVar(shape = len(qs), lb = -10000, vtype = GRB.CONTINUOUS, name = 'eps_s')
    eps_w = m.addMVar(shape = len(qw), lb = -10000, vtype = GRB.CONTINUOUS, name = 'eps_w')
    t = [m.addVar(name=f't{i}', lb=0) for i in range(K)]
    eps_st = m.addMVar(shape = len(qs), lb = 0, vtype = GRB.CONTINUOUS, name = 'eps_st')
    eps_wt = m.addMVar(shape = len(qw), lb = 0, vtype = GRB.CONTINUOUS, name = 'eps_wt')
    error = m.addMVar(shape = len(qs), lb = 0, vtype = GRB.CONTINUOUS, name = 'error')
    error_w = m.addMVar(shape = len(qw), lb = 0, vtype = GRB.CONTINUOUS, name = 'error')
    z = m.addMVar(shape = len(qs), lb = 0, vtype = GRB.CONTINUOUS, name = 'z') 

    m.addConstr(B@f == qhat_s + eps_s)
    # m.addConstr(B@f >= qhat_s)
    m.addConstr(qs+B@f-z<= cs+error)
    m.addConstr(z<=lbd)
    m.addConstr(G@f <= qw)
    m.addConstr(qw+w-G@f <= cw+error_w)
    m.addConstr(w == qhat_w + eps_w)
    # m.addConstr(w >= qhat_w)

    for j in range(K):
        m.addConstr(-theta[j]@f + mu[j] <= t[j])

    m.addConstr(eps_st >= eps_s)
    m.addConstr(eps_st >= -eps_s)
    m.addConstr(eps_wt >= eps_w)
    m.addConstr(eps_wt >= -eps_w)


    # m.setObjective(mO@w + mT@f + qsum(eps_st) + qsum(eps_wt)  + sum(t[j] for j in range(K)), GRB.MINIMIZE)        # m.setObjective(mO@w + mT@f + sum(t[j] for j in range(K)), GRB.MINIMIZE)
    m.setObjective(mO@w + mT@f + 10*qsum(eps_st) + 10*qsum(eps_wt) + l1*f@f + 10000*qsum(error) + 10000*qsum(error_w) + sum(t[j] for j in range(K)), GRB.MINIMIZE) 
    # m.setObjective(qsum(eps_st) + qsum(eps_wt), GRB.MINIMIZE)
    m.optimize()

    # try:
    #     m.computeIIS()  # 计算 IIS
    #     m.write("single inventory.ilp")  # 保存 IIS 到文件（可选）
    # except: pass
    # print('objective value:', m.ObjVal)

    # m.computeIIS()  # 计算 IIS
    # m.write("single inventory.ilp")  # 保存 IIS 到文件（可选）

    f_value = f.X
    w_value = w.X
    # print(env.time, 'f:', np.sum(f_value), 'w:', np.sum(w_value))

    ship = {}
    idx = 0
    for e in env.random_graph.edges:
        ship[e] = f_value[idx]
        idx += 1

    prod = {(i): w_value[i] for i in env.scenario.factory}
    # ship = {env.random_graph.edges[i0]: f_value[i0] for i0 in range(len(env.G.edges))}
    action = (prod, ship)
    # print('prod:', prod)
    # print('ship:', ship)
    # print('demand:',lbd)
 
    return action, z.X


def MPC(env, current_num, q_current, f_dict, w_dict):
    # def MPC_inventory(p, m_S, m_O, m_T, tau_ji, tau_0, d, c, T, V_W, V_S, E, q_current, current_num, f_dict, w_dict, validation, num_train):
    V_S = env.scenario.warehouse
    V_W = env.scenario.factory
    E = env.scenario.G.edges
    V = env.scenario.G.nodes
    T = env.scenario.tf

    # f_dict, w_dict = {}, {}

    # Create model
    model = gp.Model("multi_time_inventory")
    model.setParam('OutputFlag', 0)

    # Variables
    T = range(current_num, current_num+30)
    f = model.addVars(T, E, name="f", lb=0)        # flow from i to j at time t
    w = model.addVars(T, V_W, name="w", lb=0)      # order at time t
    q = model.addVars(T, V, name="q", lb=0)        # inventory at time t
    z = model.addVars(T, V_S, name="z", lb=0)      # min(d, q)

    # Inventory balance constraints
    for i in V:
        model.addConstr(q[current_num, i] == q_current[i])
    for t in T: #(30,40)
        # t = current_num + t
        for i in V_S:
            # incoming flow with delay
            incoming1= gp.quicksum(
                f[t - env.random_graph.edges[(j,i)]['time'], j, i] 
                for (j, i_) in E if i_ == i and (t - env.random_graph.edges[(j,i)]['time']) in T
            )
            incoming2 = sum(f_dict[t_past][j, i]
                                    for (j, i_) in E if i_ == i
                                    if (t_past := t - env.random_graph.edges[(j,i)]['time']) in f_dict)
            incoming = incoming1 + incoming2
            if t < max(T):  # balance q_{t+1}
                model.addConstr(
                    q[t, i] + incoming - z.get((t, i), 0)
                    == q[t + 1, i] if t + 1 in T else gp.quicksum([]),
                    name=f"balance_{t}_{i}"
                )
            model.addConstr(
                    q[t, i] + incoming - z.get((t, i), 0) <= env.scenario.storage_capacities[i])

        for i in V_W:
            # incoming flow with delay
            # print('capacity:', c[t0, i])
            incoming_order1, incoming_order2 = 0,0
            if t - env.scenario.production_time in T:
                incoming_order1= w[t - env.scenario.production_time, i]
            if t - env.scenario.production_time in w_dict:
                incoming_order2 = w_dict[t - env.scenario.production_time][i]
            # print('incoming orer1:', incoming_order1, 'incoming_order2:', incoming_order2)
            incoming_order = incoming_order1 + incoming_order2
            outgoing = gp.quicksum(f[t, i, j] for (i_, j) in E if i_ == i)
            if t < max(T):  # balance q_{t+1}
                model.addConstr(
                    q[t, i] + incoming_order - outgoing
                    == q[t + 1, i] if t + 1 in T else gp.quicksum([])
                )
            model.addConstr(
                    q[t, i] + incoming_order - outgoing <= env.scenario.storage_capacities[i]
                )
            model.addConstr(
                gp.quicksum(f[t, i, j] for (i_, j) in E if i_ == i)
                <= q[t, i]
            )


    # min(d, q) linearization
    for t in T:
        for i in V_S:
            model.addConstr(z[t, i] <= env.demand[t][i])
            model.addConstr(z[t, i] <= q[t, i])

    # Objective
    profit = gp.quicksum(env.scenario.product_prices[0] * z[t, i] for t in T for i in V_S)
    storage_cost = gp.quicksum(env.scenario.G.nodes[i]['storage_cost'] * q[t, i] for t in T for i in V)
    order_cost = gp.quicksum(env.scenario.G.nodes[i]['production_cost'] * w[t, i] for t in T for i in V_W)
    transport_cost = gp.quicksum(env.random_graph.edges[(i,j)]['cost'] * f[t, i, j] for t in T for (i, j) in E)

    model.setObjective(profit - (storage_cost + order_cost + transport_cost), GRB.MAXIMIZE)

    # Solve
    model.optimize()

    # model.computeIIS()  # 计算 IIS
    # model.write("inventory validation.ilp")  # 保存 IIS 到文件（可选）

    # try:
    #     model.computeIIS()  # 计算 IIS
    #     model.write("inventory validation.ilp")  # 保存 IIS 到文件（可选）
    # except: 
    #     print('IIS computation failed.')


    f_list = []
    qs_list = []
    qw_list = []
    w_list = []
    if model.Status == gp.GRB.OPTIMAL:
        for t in T:
            f_list.append(np.array([f[t,i,j].X for (i,j) in E]))
            qs_list.append(np.array([q[t,i].X for i in V_S]))
            qw_list.append(np.array([q[t,i].X for i in V_W]))
            w_list.append(np.array([w[t,i].X for i in V_W]))
    f0 = {(i, j): f[current_num, i, j].X for (i, j) in E}
    w0 = {i: w[current_num, i].X for i in V_W}
    q1 = {i: q[current_num + 1, i].X for i in V}
    z0 = {i: z[current_num, i].X for i in V_S}

    r0 = env.scenario.product_prices[0]*sum(z[current_num, i].X for i in V_S) - sum(env.scenario.G.nodes[i]['storage_cost']*q[current_num, i].X for i in V) - sum(env.scenario.G.nodes[i]['production_cost']*w[current_num, i].X for i in V_W) - sum(env.random_graph.edges[(i,j)]['cost']*f[current_num, i, j].X for (i,j) in E)


    return f_list, qs_list, qw_list, w_list, f0, w0, q1, r0, z0, model.ObjVal-r0

