from docplex.mp.model import Model
from collections import defaultdict
import numpy as np
import gurobipy as gp
from gurobipy import GRB
from gurobipy import quicksum as qsum

beta = 1

def solveOpt(env, action_rl, cost_ls, price_ls, demandTime_ls):
    t = env.time

    mdl = gp.Model(name="bi_level_problem")
    mdl.setParam('OutputFlag', 0)
    pairs = env.edges
    N = len(action_rl)
    M = len(pairs)
    A = np.zeros((N,M))
    G = np.zeros((N,M))
    m = 0
    for i,j in pairs:
        A[i, m] = 1
        A[j, m] = -1
        G[i, m] = 1
        m += 1

    q = []
    demand = []
    for k in range(len(pairs)):
            i, j = pairs[k]
            if (i, j) not in env.demand or t not in env.demand[i, j]:
                continue
            demand.append(env.demand[i, j][t])
    demand = np.array(demand)

    c = np.array(cost_ls[t])
    p = np.array(price_ls[t])
    tt = np.array(demandTime_ls[t])

    # demand = np.array([inner_dict[t] for _, inner_dict in env.demand.items()])
    q = np.array([inner_dict[t] for _, inner_dict in sorted(env.acc.items())])
    # c = np.array([inner_dict[t] for _, inner_dict in env.rebTime.items()])
    # p = np.array([inner_dict[t] for _, inner_dict in env.price.items()])


    f = mdl.addMVar(shape = M, lb = 0, vtype = GRB.CONTINUOUS, name = 'f')
    g = mdl.addMVar(shape = M, lb = 0, vtype = GRB.CONTINUOUS, name = 'g')
    mdl.addConstr(g <= demand)
    # mdl.addConstr(A@f+G@g <= q - action_rl)
    mdl.addConstr(A@(f+g) <= q - action_rl)
    mdl.addConstr(G@(f+g) <= q)
    mdl.setObjective(beta*c@f - (p-beta*tt)@g, GRB.MINIMIZE)
    mdl.optimize()


    return g.X, f.X


def RegsolveOpt(env, action_rl, theta_f, theta_g, mu, cost_ls, price_ls, demandTime_ls):
    t0 = env.time
    K = len(theta_f)

    mdl = gp.Model(name="bi_level_problem")
    mdl.setParam('OutputFlag', 0)
    pairs = env.edges
    N = len(action_rl)
    M = len(pairs)
    A = np.zeros((N,M))
    G = np.zeros((N,M))
    m = 0
    for i,j in pairs:
        A[i, m] = 1
        A[j, m] = -1
        G[i, m] = 1
        m += 1

    q = []
    demand = []
    for k in range(len(pairs)):
            i, j = pairs[k]
            if (i, j) not in env.demand or t0 not in env.demand[i, j]:
                continue
            demand.append(env.demand[i, j][t0])
    demand = np.array(demand)

    c = np.array(cost_ls[t0])
    p = np.array(price_ls[t0])
    tt = np.array(demandTime_ls[t0])

    l1 = 0.0001
    l2 = 0.0001

    # demand = np.array([inner_dict[t] for _, inner_dict in env.demand.items()])
    q = np.array([inner_dict[t0] for _, inner_dict in sorted(env.acc.items())])
    # c = np.array([inner_dict[t] for _, inner_dict in env.rebTime.items()])
    # p = np.array([inner_dict[t] for _, inner_dict in env.price.items()])


    f = mdl.addMVar(shape = M, lb = 0, vtype = GRB.CONTINUOUS, name = 'f')
    g = mdl.addMVar(shape = M, lb = 0, vtype = GRB.CONTINUOUS, name = 'g')
    t = [mdl.addVar(name=f't{i}', lb=0) for i in range(K)]

    mdl.addConstr(g <= demand)
    # mdl.addConstr(A@f+G@g <= q - action_rl)
    mdl.addConstr(A@(f+g) <= q - action_rl)
    mdl.addConstr(G@(f+g) <= q)
    for j in range(K):
        mdl.addConstr(-theta_f[j]@f + theta_g[j]@g - mu[j] <= t[j])
    mdl.setObjective(beta*c@f - (p-beta*tt)@g + sum(t[j] for j in range(K)) + l1*f@f + l2*g@g, GRB.MINIMIZE)
    mdl.optimize()

    # print('t1:', t1.X, 't2:', t2.X)
    # print('g:',np.sum(g.X))
    # print('f:', np.sum(f.X))

    return g.X, f.X

def ValuesolveOpt(env, action_rl, theta_f, theta_g, mu):
    t0 = env.time

    mdl = gp.Model(name="bi_level_problem2")
    mdl.setParam('OutputFlag', 0)
    pairs = env.edges
    N = len(action_rl)
    M = len(pairs)
    A = np.zeros((N,M))
    G = np.zeros((N,M))
    m = 0
    for i,j in pairs:
        A[i, m] = 1
        A[j, m] = -1
        G[i, m] = 1
        m += 1

    p = []
    c = []
    q = []
    demand = []
    for k in range(len(pairs)):
            i, j = pairs[k]
            if (i, j) not in env.demand or t0 not in env.demand[i, j]:
                continue
            p.append(env.price[i, j][t0])
            c.append(env.rebTime[i, j][t0])
            # q.append(env.acc[i, j][t])
            demand.append(env.demand[i, j][t0])
    p = np.array(p)
    c = np.array(c)
    q = np.array(q)
    demand = np.array(demand)

    # demand = np.array([inner_dict[t] for _, inner_dict in env.demand.items()])
    q = np.array([inner_dict[t0] for _, inner_dict in sorted(env.acc.items())])
    # c = np.array([inner_dict[t] for _, inner_dict in env.rebTime.items()])
    # p = np.array([inner_dict[t] for _, inner_dict in env.price.items()])


    f = mdl.addMVar(shape = M, lb = 0, vtype = GRB.CONTINUOUS, name = 'f')
    g = mdl.addMVar(shape = M, lb = 0, vtype = GRB.CONTINUOUS, name = 'g')
    t = mdl.addVar(lb = 0, vtype = GRB.CONTINUOUS, name = 't')

    mdl.addConstr(g <= demand)
    mdl.addConstr(G@(f+g) <= q)
    mdl.addConstr(theta_f@f + theta_g@g - mu <= t)

    mdl.setObjective(c@f - p@g + np.dot(action_rl, A)@(f+g) + t, GRB.MINIMIZE)
    # mdl.setObjective(c@f - p@g + t, GRB.MINIMIZE)
    mdl.optimize()


    return g.X, f.X