# %% import packages
import numpy as np
import gurobipy as gp
from gurobipy import GRB
from gurobipy import quicksum as qsum
import pandas as pd
import torch
from collections import defaultdict
from torch_geometric.data import Data, Batch
import pickle
import random
import generate_history as gh

beta = 1

#%%
def generate_data(pairs, num):

    # pairs = {0:[1,2], 1:[0,2], 2:[0,1,3], 3:[2]}
   
    M = sum(len(item) for item in pairs.values())
    N = len(pairs)
    A = np.zeros((N,M))
    G = np.zeros((N,M))
    n = 0
    for node in pairs.keys():
        for neighbor in pairs[node]:
            # print('node:', node, 'neighbor:', neighbor)
            A[node,n] = 1
            A[neighbor,n] = -1
            G[node,n] = 1
            n = n + 1

    #generate demands and initial state
    # np.random.seed(42)
    q0 = np.random.randint(low = 5, high = 13, size = len(pairs))
    lbd_ls_t = []
    for i in range(num):
        lbd_ls = []
        for node in range(len(q0)):
            lbd = {}
            for neighbor in pairs[node]:       
                lbd[neighbor] = np.random.randint(low = 0, high = 10)
            lbd_ls.append(lbd)
        lbd_ls_t.append(lbd_ls)

    cost_ls = []
    for node in range(len(q0)):
        cost = {}
        for neighbor in pairs[node]:       
            cost[neighbor] = np.random.randint(low = 1, high = 10)
        cost_ls.append(cost)
    profit_ls = []
    for node in range(len(q0)):
        profit = {}
        for neighbor in pairs[node]:       
            profit[neighbor] = np.random.randint(low = 1, high = 10)
        profit_ls.append(profit)

    return q0, lbd_ls_t, cost_ls, profit_ls, A, G


#假设每一个时期开始前知道这期的demand,依据demand调度车,假设下个时间点demand不变
def naive_policy(q0, lbd_ls_t, cost_ls, profit_ls, pairs, num):
    obj = 0
    f_ls = []
    g_ls = []
    q_list = []
    q_dict = {}
    q_dict[0] = q0

    for i in range(num): #time step
        q_list.append(np.array(q0))

        lbd_ls = lbd_ls_t[i]
        sum_lbd = sum(lbd_ls[0].values())
        #rebalancing
        ind = sum_lbd < q0
        surplus = np.argwhere(ind>0)
        shortage = np.argwhere(ind==0)
        a = q0 - sum_lbd #surplus
        b = sum_lbd - q0 #shortage
        # print('surplus:', surplus, 'shortage:', shortage)
        # print('a:', a, 'b:', b)
        f = []
        for node in range(len(q0)):
            item = {}
            for neighbor in pairs[node]:
                item[neighbor] = 0
            f.append(item)
        g = []
        for node in range(len(q0)):
            item = {}
            for neighbor in pairs[node]:
                item[neighbor] = 0
            g.append(item)
        for node in shortage:
            node = node[0]
            cost  = cost_ls[node]
            cost = dict(sorted(cost.items(), key = lambda x:x[1], reverse = False))
            for neighbor in cost.keys():
                available = a[neighbor]
                s = b[node]
                if s == 0:
                    break
                if available <= 0:
                    continue
                transfer = min(available, s)
                f[neighbor][node] = transfer
                obj = obj + transfer*cost[neighbor] #rebalancing cost

                ### vehicle routing and update the state
                q0[neighbor] = q0[neighbor] - transfer
                # q0[node] = q0[node] + transfer       
                a[neighbor] = a[neighbor] - transfer
                b[node] = b[node] + transfer

        f_ls.append(f)

        #passenger flow
        for node in range(len(q0)):
            profit = profit_ls[node]     
            profit = dict(sorted(profit.items(), key = lambda x:x[1], reverse = True))
            lbd = lbd_ls[node]
            for neighbor in profit.keys():
                available = q0[node]
                if available <= 0:
                    break
                transfer = min(available, lbd[neighbor])
                g[node][neighbor] = transfer
                # q0[neighbor] = q0[neighbor] + transfer
                q0[node] = q0[node] - transfer
                obj = obj - transfer*profit[neighbor] #profit
        for node in range(len(q0)):
            for neighbor in f[node].keys():
                q0[neighbor] = q0[neighbor] + f[node][neighbor] 
        for node in range(len(q0)):
            for neighbor in f[node].keys():
                q0[neighbor] = q0[neighbor] + g[node][neighbor] 

        g_ls.append(g)

    f_list = []
    for f in f_ls:
        values_list = [list(ff.values()) for ff in f]
        concatenated_array = np.concatenate(values_list)
        f_list.append(concatenated_array)

    g_list = []
    for g in g_ls:
        values_list = [list(gg.values()) for gg in g]
        concatenated_array = np.concatenate(values_list)
        g_list.append(concatenated_array)

    lambda_ls = []
    for lbd_ls in lbd_ls_t:
        values_list = [list(lbd.values()) for lbd in lbd_ls]
        concatenated_array = np.concatenate(values_list)
        lambda_ls.append(concatenated_array)

    c = np.concatenate([list(cost.values()) for cost in cost_ls])
    p = np.concatenate([list(profit.values()) for profit in profit_ls])

    return f_ls, g_ls, f_list, g_list, q_list, lambda_ls, c, p



def invOptimization(f_list, g_list, q_list, lambda_ls, cost_ls, price_ls, tt_ls, N, pairs):
    num = len(f_list)
    M = len(f_list[0])

    A = np.zeros((N,M))
    G = np.zeros((N,M))
    B = np.zeros((N,M))
    m = 0
    for i,j in pairs:
        A[i, m] = 1
        A[j, m] = -1
        G[i, m] = 1
        B[j, m] = 1
        m += 1
    c = np.mean(np.stack(cost_ls), axis=0)
    p = np.mean(np.stack(price_ls), axis=0)
    tt = np.mean(np.stack(tt_ls), axis=0)

    m = gp.Model('formulation1')
    mu = m.addMVar(shape = num, name = 'mu')
    theta_f = m.addMVar(shape = M, vtype = GRB.CONTINUOUS, name = 'theta_f')
    theta_g = m.addMVar(shape = M, vtype = GRB.CONTINUOUS, name = 'theta_g')

    x = m.addMVar(shape = (num,N), lb = 0, vtype = GRB.CONTINUOUS, name = 'x')
    y = m.addMVar(shape = (num,N), lb = 0, vtype = GRB.CONTINUOUS, name = 'y')
    w = m.addMVar(shape = (num,M), lb = 0, vtype = GRB.CONTINUOUS, name = 'w')
    v1 = m.addMVar(shape = (num,M), lb = 0, vtype = GRB.CONTINUOUS, name = 'v1')
    v2 = m.addMVar(shape = (num,M), lb = 0, vtype = GRB.CONTINUOUS, name = 'v2')
    z1 = m.addMVar(shape = num, name = 'z1', lb = 0, ub = 1)
    z2 = m.addMVar(shape = num, name = 'z2', lb = 0, ub = 1)
    t = m.addMVar(shape = num, name = 't', lb = 0)
    # q_hat = m.addMVar(shape = (num,N), vtype = GRB.INTEGER, name = 'q_hat')
    # epsilon_1 = m.addMVar(shape = (num,M), name = 'epsilon_1')
    # epsilon_2 = m.addMVar(shape = (num,M), name = 'epsilon_2')
    
    q_hat = np.zeros(shape = (num, N))
    beta = 0.3
    m.addConstr(np.ones(num)-z1-z2 == 0)
    l1 = 0.0001
    l2 = 0.0001
    thetaf_ls = []
    thetag_ls = []
    mu_ls = []

    for i in range(num):
        f = f_list[i]
        g = g_list[i]
        q = q_list[i]
        lbd = np.array(lambda_ls[i])
        q_hat[i, :] = q - np.dot(A,f+g) 

        # c = np.array(cost_ls[i])
        # p = np.array(price_ls[i])

        Gfq = np.dot(G,f+g) - q
        m.addConstr(Gfq@y[i,:] == 0)
        # Afq = np.dot(A,f+g) - q + q_hat[i,:]
        # m.addConstr(Afq@x[i,:] == 0)
        m.addConstr(w[i,:]@f == 0)
        m.addConstr(v1[i,:]@g == 0)
        m.addConstr(v2[i,:]@(g-lbd) == 0)
        m.addConstr(z1[i]*(-theta_f@f+theta_g@g-mu[i]-t[i]) == 0)
        m.addConstr(z2[i]*t[i] == 0)
        # m.addConstr(c + x[i,:]@A + y[i,:]@G  + z1[i]*theta_f - w[i,:] +epsilon_1[i,:] == 0)
        # m.addConstr(-p + x[i,:]@A + y[i,:]@G  + z1[i]*theta_g - v1[i,:] + v2[i,:] +epsilon_2[i,:] == 0)
        m.addConstr(c*beta + x[i,:]@A + y[i,:]@G  - z1[i]*theta_f - w[i,:] + 2*l1*f == 0)
        m.addConstr(-p +tt*beta + x[i,:]@A + y[i,:]@G  + z1[i]*theta_g - v1[i,:] + v2[i,:] + 2*l2*g == 0)
        # m.addConstr(np.dot(A,f+g) <= q - q_hat[i,:])
        m.addConstr(t[i] >= -theta_f@f + theta_g@g - mu[i])

        # m.addConstr(theta_f == np.zeros(M))
        # m.addConstr(theta_g == np.zeros(M))

        # m.addConstr(qsum(q_hat[i, j] for j in range(N)) == 1500)

    m.setObjective(qsum(theta_f[i]*theta_f[i] for i in range(M))+qsum(theta_g[j]*theta_g[j] for j in range(M))+qsum(mu[a]*mu[a] for a in range(num)), GRB.MINIMIZE)
    # m.setObjective(qsum(theta_f[i]*theta_f[i] for i in range(M))+qsum(theta_g[j]*theta_g[j] for j in range(M)), GRB.MINIMIZE)
    # m.setObjective(0, GRB.MINIMIZE)
    m.params.NonConvex = 2
    m.setParam("TimeLimit", 180)  # 时间限制60秒
    m.setParam("NodefileStart", 0.5)  # 保存根节点解

    # 定义回调函数捕获根节点解
    # m._root_sol = None

    # def callback(model, where, theta_f1, theta_g1, mu1, r1, theta_f2, theta_g2, mu2, r2, thetaf_ls, thetag_ls, mu_ls, r_ls, thetaf2_ls, thetag2_ls, mu2_ls, r2_ls):
    def callback(model, where):
        nonlocal theta_f, theta_g, mu
        nonlocal thetaf_ls, thetag_ls, mu_ls
        if where == gp.GRB.Callback.MIPNODE:
            status = model.cbGet(gp.GRB.Callback.MIPNODE_STATUS)
            if status == gp.GRB.Status.OPTIMAL:
                relax_theta_f1 = model.cbGetNodeRel(theta_f)
                relax_theta_g1 = model.cbGetNodeRel(theta_g)
                relax_mu = model.cbGetNodeRel(mu)

                thetaf_ls.append(relax_theta_f1)
                thetag_ls.append(relax_theta_g1)
                mu_ls.append(relax_mu)

    
    # m.optimize(callback)
    # m.optimize(lambda model, where: callback(model, where, theta_f1, theta_g1, mu1, r1, theta_f2, theta_g2, mu2, r2, thetaf_ls, thetag_ls, mu_ls, r_ls, thetaf2_ls, thetag2_ls, mu2_ls, r2_ls))
    m.optimize(lambda model, where: callback(model, where))


    return thetaf_ls[-1], thetag_ls[-1], mu_ls[-1]


#%%
def invOptimization2(f_list, g_list, q_list, lambda_ls, cost_ls, price_ls, N, pairs):
    num = len(f_list)
    M = len(cost_ls[0])
    A = np.zeros((N,M))
    G = np.zeros((N,M))
    B = np.zeros((N,M))
    m = 0
    for i,j in pairs:
        A[i, m] = 1
        A[j, m] = -1
        G[i, m] = 1
        B[j, m] = 1
        m += 1

    # c = np.mean(np.stack(cost_ls), axis=0)
    # p = np.mean(np.stack(price_ls), axis=0)

    m = gp.Model('formulation2')
    mu = m.addMVar(shape = num, name = 'mu')
    theta_f = m.addMVar(shape = M, vtype = GRB.CONTINUOUS, name = 'theta_f')
    theta_g = m.addMVar(shape = M, vtype = GRB.CONTINUOUS, name = 'theta_g')

    y = m.addMVar(shape = (num,N), lb = 0, vtype = GRB.CONTINUOUS, name = 'y')
    w = m.addMVar(shape = (num,M), lb = 0, vtype = GRB.CONTINUOUS, name = 'w')
    v1 = m.addMVar(shape = (num,M), lb = 0, vtype = GRB.CONTINUOUS, name = 'v1')
    v2 = m.addMVar(shape = (num,M), lb = 0, vtype = GRB.CONTINUOUS, name = 'v2')
    beta = m.addMVar(shape = (num,N), lb = 0, vtype = GRB.CONTINUOUS, name = 'beta')
    z1 = m.addMVar(shape = num, name = 'z1', lb = 0, ub = 1)
    z2 = m.addMVar(shape = num, name = 'z2', lb = 0, ub = 1)
    t = m.addMVar(shape = num, name = 't', lb = 0)

    m.addConstr(np.ones(num)-z1-z2 == 0)
    for i in range(num):
        f = f_list[i]
        g = g_list[i]
        q = q_list[i]
        lbd = np.array(lambda_ls[i])

        c = np.array(cost_ls[i])
        p = np.array(price_ls[i])

        Gfq = np.dot(G,f+g) - q
        m.addConstr(Gfq@y[i,:] == 0)

        m.addConstr(w[i,:]@f == 0)
        m.addConstr(v1[i,:]@g == 0)
        m.addConstr(v2[i,:]@(g-lbd) == 0)
        m.addConstr(z1[i]*(theta_f@f+theta_g@g-mu[i]-t[i]) == 0)
        m.addConstr(z2[i]*t[i] == 0)
        m.addConstr(t[i] >= theta_f@f + theta_g@g + mu[i])

        m.addConstr(c + beta[i,:]@A + y[i,:]@G  - w[i,:] + z1[i]*theta_f  == 0)
        m.addConstr(-p + beta[i,:]@A + y[i,:]@G  - v1[i,:] + v2[i,:] + z1[i]*theta_g == 0)

        # m.addConstr(qsum(beta[i, j] for j in range(N)) == 20)


    m.setObjective(qsum(theta_f[i]*theta_f[i] for i in range(M))+qsum(theta_g[j]*theta_g[j] for j in range(M))+qsum(mu[a]*mu[a] for a in range(num)), GRB.MINIMIZE)
    # m.setObjective(qsum(theta_f[i] for i in range(M))+qsum(theta_g[j] for j in range(M))+qsum(mu_f[a] for a in range(num))+qsum(mu_g[a] for a in range(num)), GRB.MINIMIZE)
    # m.setObjective(qsum(theta_f[i] for i in range(M))+qsum(theta_g[j] for j in range(M))+qsum(mu[a] for a in range(num)), GRB.MINIMIZE)

    m.params.NonConvex = 2
    # m.params.MIPGap = 0.1
    m.optimize()

    beta_value = beta.X

    print('beta:', beta_value)

    row_sums = beta_value.sum(axis=1, keepdims=True)
    print('row_sums:', row_sums)
    # Normalize each row
    for i in range(N):
        if row_sums[i][0] == 0:
            continue
        else:
            beta_value[i, :] = beta_value[i, :] / row_sums[i][0]


    return beta_value, theta_f.X, theta_g.X, mu.X, row_sums

def get_action(q, u, lbd, cost, price, A, G):
    # This function should compute the action based on state s and parameter u
    # For simplicity, we'll return a random action here
    M = len(cost)
    m = gp.Model('single_step')
    m.setParam('OutputFlag', 0)
    f = m.addMVar(shape = M, lb = 0, vtype = GRB.CONTINUOUS, name = 'f')
    g = m.addMVar(shape = M, lb = 0, vtype = GRB.CONTINUOUS, name = 'g')
    m.addConstr(G@(f+g) <= q)
    m.addConstr(g <= np.array(lbd))

    m.addConstr(A@(f+g) <= q - u)

    m.setObjective(np.array(cost)@f-np.array(price)@g, GRB.MINIMIZE)
    # m.setObjective(c@f - p@g, GRB.MINIMIZE)

    m.optimize()
    try:
        m.computeIIS()  # 计算 IIS
        m.write("newmodel.ilp")  # 保存 IIS 到文件（可选）
    except:
        pass

    action = np.concatenate((f.X, g.X))

    return action

def compute_loss(a_hat, a):
    # Compute and return the loss between predicted action a_hat and true action a
    return np.sum((a_hat - a) ** 2)

def cem_update(states, actions, population_size, elite_frac, num_iterations, lambda_ls, cost_ls, price_ls, N, pairs):

    M = len(cost_ls[0])
    m = 0
    A = np.zeros((N,M))
    G = np.zeros((N,M))
    B = np.zeros((N,M))
    for i,j in pairs:
        A[i, m] = 1
        A[j, m] = -1
        G[i, m] = 1
        B[j, m] = 1
        m += 1

    # Initialize parameters
    # u_population = np.random.randn(population_size, N)

    # Parameters for the multivariate normal distribution
    mean = 5  # Mean of each component
    std_dev = 3  # Standard deviation of each component
    dimension = N  # Dimension of the vector

    # Covariance matrix: assuming independent components
    covariance_matrix = np.diag([std_dev**2] * dimension)

    # Generate one multivariate normal vector
    u_population = np.random.multivariate_normal([mean] * dimension, covariance_matrix, population_size)


    best_u = None
    best_loss = float('inf')
    
    num = len(cost_ls)
    for iteration in range(num_iterations):
        print('iteration:', iteration)
        losses = []
        
        for u in u_population:
            L_total = 0
            # for s, a in zip(states, actions):
            for i in range(num):
                a_hat = get_action(states[i], u, lambda_ls[i], cost_ls[i], price_ls[i], A, G)
                L_total += compute_loss(a_hat, actions[i])
            losses.append(L_total)
        
        # Get elite parameters
        elite_idxs = np.argsort(losses)[:int(elite_frac * population_size)]
        elite_us = u_population[elite_idxs]
        
        # Update distribution: here we use mean and std of the elite samples
        best_idx = elite_idxs[0]
        if losses[best_idx] < best_loss:
            best_loss = losses[best_idx]
            best_u = elite_us[0]
        
        # Sample new population
        mean = np.mean(elite_us, axis=0)
        std = np.std(elite_us, axis=0)
        # u_population = np.random.randn(population_size, N) * std + mean
        u_population = np.random.multivariate_normal(mean, np.diag(std**2), population_size)
    
    return best_u


#%%
def test(f_list, g_list, q_list, lambda_ls, cost_ls, price_ls, N, pairs):
    num = len(f_list)
    M = len(f_list[0])

    A = np.zeros((N,M))
    G = np.zeros((N,M))
    B = np.zeros((N,M))
    m = 0
    for i,j in pairs:
        A[i, m] = 1
        A[j, m] = -1
        G[i, m] = 1
        B[j, m] = 1
        m += 1
    c = np.mean(np.stack(cost_ls), axis=0)
    p = np.mean(np.stack(price_ls), axis=0)

    m = gp.Model('formulation1')

    x = m.addMVar(shape = (num,N), lb = 0, vtype = GRB.CONTINUOUS, name = 'x')
    y = m.addMVar(shape = (num,N), lb = 0, vtype = GRB.CONTINUOUS, name = 'y')
    w = m.addMVar(shape = (num,M), lb = 0, vtype = GRB.CONTINUOUS, name = 'w')
    v1 = m.addMVar(shape = (num,M), lb = 0, vtype = GRB.CONTINUOUS, name = 'v1')
    v2 = m.addMVar(shape = (num,M), lb = 0, vtype = GRB.CONTINUOUS, name = 'v2')
    q_hat = m.addMVar(shape = (num,N), lb = 0, vtype = GRB.INTEGER, name = 'q_hat')
    # epsilon_1 = m.addMVar(shape = (num,M), name = 'epsilon_1')
    # epsilon_2 = m.addMVar(shape = (num,M), name = 'epsilon_2')

    for i in range(num):
        f = f_list[i]
        g = g_list[i]
        q = q_list[i]
        lbd = np.array(lambda_ls[i])

        Gfq = np.dot(G,f+g) - q
        m.addConstr(Gfq@y[i,:] == 0)
        Afq = np.dot(A,f+g) - q + q_hat[i,:]
        m.addConstr(Afq@x[i,:] == 0)
        m.addConstr(w[i,:]@f == 0)
        m.addConstr(v1[i,:]@g == 0)
        m.addConstr(v2[i,:]@(g-lbd) == 0)
        m.addConstr(c + x[i,:]@A + y[i,:]@G  - w[i,:] == 0)
        m.addConstr(-p + x[i,:]@A + y[i,:]@G  - v1[i,:] + v2[i,:] == 0)
        m.addConstr(np.dot(A,f+g) <= q - q_hat[i,:])

    m.setObjective(0, GRB.MINIMIZE)

    m.params.NonConvex = 2
    # m.params.MIPGap = 0.1
    m.optimize()
    # m.computeIIS()  # 计算 IIS
    # m.write("model.ilp")  # 保存 IIS 到文件（可选）
  
    qhat_value = q_hat.X
    return qhat_value


#%%
def App_invOptimization(f_list, g_list, q_list, lambda_ls, cost_ls, price_ls, N, pairs):
    num = len(f_list)
    M = len(f_list[0])

    A = np.zeros((N,M))
    G = np.zeros((N,M))
    B = np.zeros((N,M))
    m = 0
    for i,j in pairs:
        A[i, m] = 1
        A[j, m] = -1
        G[i, m] = 1
        B[j, m] = 1
        m += 1
    c = np.mean(np.stack(cost_ls), axis=0)
    p = np.mean(np.stack(price_ls), axis=0)

    qhat_ls = []
    for i in range(num):
        print('i:', i)
        m = gp.Model('formulation1')

        x = m.addMVar(shape = N, lb = 0, vtype = GRB.CONTINUOUS, name = 'x')
        y = m.addMVar(shape = N, lb = 0, vtype = GRB.CONTINUOUS, name = 'y')
        w = m.addMVar(shape = M, lb = 0, vtype = GRB.CONTINUOUS, name = 'w')
        v1 = m.addMVar(shape = M, lb = 0, vtype = GRB.CONTINUOUS, name = 'v1')
        v2 = m.addMVar(shape = M, lb = 0, vtype = GRB.CONTINUOUS, name = 'v2')
        q_hat = m.addMVar(shape = N, lb = 0, vtype = GRB.INTEGER, name = 'q_hat')
        f = m.addMVar(shape = M, lb = 0, vtype = GRB.CONTINUOUS, name = 'f_opt')
        g = m.addMVar(shape = M, lb = 0, vtype = GRB.CONTINUOUS, name = 'g_opt')
        # epsilon_1 = m.addMVar(shape = (num,M), name = 'epsilon_1')
        # epsilon_2 = m.addMVar(shape = (num,M), name = 'epsilon_2')

        f_given = f_list[i]
        g_given = g_list[i]
        q = q_list[i]
        lbd = np.array(lambda_ls[i])

        m.addConstr(G@(f+g) <= q)
        m.addConstr((G@(f+g) - q)@y == 0)
        m.addConstr((A@(f+g) - q + q_hat)@x == 0)
        m.addConstr(w@f == 0)
        m.addConstr(v1@g == 0)
        m.addConstr(v2@(g-lbd) == 0)
        m.addConstr(c + x@A + y@G  - w == 0)
        m.addConstr(-p + x@A + y@G  - v1 + v2 == 0)
        m.addConstr(A@(f+g) <= q - q_hat)

        m.setObjective(qsum((f[i]-f_given[i])*(f[i]-f_given[i]) for i in range(M)) + qsum((g[i]-g_given[i])*(g[i]-g_given[i]) for i in range(M)), GRB.MINIMIZE)

        m.params.NonConvex = 2
        # m.params.MIPGap = 0.1
        m.optimize()
        print('objective value is :', m.ObjVal)
        # m.computeIIS()  # 计算 IIS
        # m.write("model.ilp")  # 保存 IIS 到文件（可选）

        qhat_ls.append(q_hat.X)

    return np.array(qhat_ls)

#%%
class PairData(Data):
        """
        Store 2 graphs in one Data object (s_t and s_t+1)
        """

        def __init__(self, edge_index_s=None, x_s=None, reward=None, action=None, edge_index_t=None, x_t=None):
            super().__init__()
            self.edge_index_s = edge_index_s
            self.x_s = x_s
            self.reward = reward
            self.action = action
            self.edge_index_t = edge_index_t
            self.x_t = x_t

        def __inc__(self, key, value, *args, **kwargs):
            if key == 'edge_index_s':
                return self.x_s.size(0)
            if key == 'edge_index_t':
                return self.x_t.size(0)
            else:
                return super().__inc__(key, value, *args, **kwargs)


class ReplayData:
        """
        Replay buffer for SAC agents
        """

        def __init__(self):
            self.data_list = []

        def sample_batch(self, batch_size=32):
            print(len(self.data_list), batch_size)
            data = random.sample(self.data_list, batch_size)
            data = Batch.from_data_list(data, follow_batch=['x_s', 'x_t'])
            return data
    

def process_data(f_ls, g_ls, f_list, g_list, q_list, c, p, lambda_ls, q_hat, edge_index, pairs, path):

    num = len(f_ls)
    N = len(f_ls[0])

    #action
    action = list(q_hat)

    #state (14,13)
    #dacc
    dacc_ls = []
    T = 6
    for data_f, data_g in zip(f_ls, g_ls):
        f_ls_t = []
        unique_keys = set(key for d in data_f for key in d.keys())
        for key in unique_keys:
            transformed_dict = {index: d[key] for index, d in enumerate(data_f) if key in d}
            f_ls_t.append(transformed_dict)
        f_result = np.array([sum(d.values()) for d in f_ls_t])

        g_ls_t = []
        unique_keys = set(key for d in data_g for key in d.keys())
        for key in unique_keys:
            transformed_dict = {index: d[key] for index, d in enumerate(data_g) if key in d}
            g_ls_t.append(transformed_dict)
        g_result = np.array([sum(d.values()) for d in g_ls_t])
        result = f_result+g_result
        dacc_ls.append(result)
    dacc = defaultdict(dict)
    for t in range(num):
        for i in range(N):
            dacc[i][t] = dacc_ls[t][i]

    #demand
    demand = defaultdict(dict)
    for t in range(num):
        lbd = lambda_ls[t]
        m = 0
        for i,j in pairs:
            demand[i,j][t] = lbd[m]
            m += 1

    # for t in range(num):
    #     for i in range(N):
    #         lbd = lbd_ls_t[t][i]
    #         for key, value in lbd.items():
    #             demand[i, key][t] = value
    
    profit_ls = defaultdict(dict)
    for t in range(num):
        profit = p[t]
        m = 0
        for i,j in pairs:
            profit_ls[i,j][t] = profit[m]
            m += 1

    state = defaultdict(dict)
    for t in range(num):
        for i in range(N):
            state[i][t] = q_list[t][i]

    #reward
    reward = []
    for t in range(num):
        reward.append(np.dot(p[t], g_list[t]) - np.dot(c[t], f_list[t]))
    reward = np.array(reward)


    #transform dict to torch tensor
    s = 0.01
    # pair_set = [tuple(edge_index[:, i]) for i in range(edge_index.shape[1])]

    data_list = []
    for t in range(num-T):
        # data = parse_obs(state, dacc, demand, edge_index, t)
        # data2 = parse_obs(state, dacc, demand, edge_index, t+1)
        x = torch.cat((
            torch.tensor([state[n][t]*s for n in range(N)]
                        ).view(1, 1, N).float(),
            torch.tensor([[(state[n][t] + dacc[n][t])*s for n in range(N)]
                        for t in range(t, t+T)]).view(1, T, N).float(),
            torch.tensor([[sum([(demand[i, j][t])*(profit_ls[i,j][t])*s
                        for j in range(N) if (i,j) in pairs]) for i in range(N)] for t in range(t, t+T)]).view(1, T, N).float()),
                        dim=1).squeeze(0).view(1+T + T, N).T
        data = Data(x, edge_index)

        x2 = torch.cat((
            torch.tensor([state[n][t+1]*s for n in range(N)]
                        ).view(1, 1, N).float(),
            torch.tensor([[(state[n][t+1] + dacc[n][t+1])*s for n in range(N)]
                        for t in range(t, t+T)]).view(1, T, N).float(),
            torch.tensor([[sum([(demand[i, j][t+1])*(profit_ls[i,j][t])*s
                        for j in range(N) if (i,j) in pairs]) for i in range(N)] for t in range(t, t+T)]).view(1, T, N).float()),
                        dim=1).squeeze(0).view(1+T + T, N).T
            
        data2 = Data(x2, edge_index)

        data_list.append(PairData(edge_index, data.x, torch.as_tensor(
                reward, dtype=torch.float32), torch.as_tensor(action[t], dtype=torch.float32), edge_index, data2.x))

    replay_buffer = ReplayData()
    replay_buffer.data_list = data_list

    with open(path, 'wb') as f:
        pickle.dump(replay_buffer, f) 

def opt_policy(lambda_ls, cost_ls, profit_ls, demandTime_ls, demandTime, rebTime, pairs, N, q0):
    f_list = []
    g_list = []
    q_list = []
    M = len(pairs)
    A = np.zeros((N,M))
    G = np.zeros((N,M))
    B = np.zeros((N,M))
    m = 0
    for i,j in pairs:
        A[i, m] = 1
        A[j, m] = -1
        G[i, m] = 1
        B[j, m] = 1
        m += 1

    T = len(lambda_ls)
    q = q0
    obj = 0
    # qhat = np.random.randint(20, 80, size = (T, N))
    qhat = np.zeros((T, N))
    c = np.mean(np.stack(cost_ls), axis=0)
    p = np.mean(np.stack(profit_ls), axis=0)

    f_dict = defaultdict(dict)
    g_dict = defaultdict(dict)
    beta = 0.3
    for i in range(N):
        for (o,d) in pairs:
            f_dict[i][(o,d)] = 0
            g_dict[i][(o,d)] = 0
    for i in range(T):
        m = gp.Model('single_step')
        m.setParam('OutputFlag', 0)
        f = m.addMVar(shape = M, lb = 0, vtype = GRB.CONTINUOUS, name = 'f')
        g = m.addMVar(shape = M, lb = 0, vtype = GRB.CONTINUOUS, name = 'g')
        m.addConstr(G@(f+g) <= q)
        m.addConstr(g <= np.array(lambda_ls[i]))

        m.addConstr(A@(f+g) <= q - qhat[i,:])

        m.setObjective(np.array(cost_ls[i])@f*beta-(np.array(profit_ls[i])-beta*np.array(demandTime_ls[i]))@g, GRB.MINIMIZE)
        # m.setObjective(c@f - p@g, GRB.MINIMIZE)

        m.optimize()

        # q = q - np.dot(A, mdl.getAttr("f", f) + mdl.getAttr("g", g))
        f_value = f.X
        g_value = g.X
        f_list.append(f_value)
        g_list.append(g_value)
        q_list.append(q)

        print('time:', i)
        print('f:', np.sum(f_value))
        print('g:', np.sum(g_value))
        print('q:', np.sum(q))
        print('lbd:', np.sum(np.array(lambda_ls[i])))


        obj += m.ObjVal

        for (o,d) in pairs:
            f_dict[i+N][(o,d)] = f_value[pairs.index((o,d))]
            g_dict[i+N][(o,d)] = g_value[pairs.index((o,d))]
        

        for n in range(N):
            for d in range(N):
                if (n,d) in pairs:
                    q[n] = q[n] + f_dict[i+N-rebTime[n,d][i]][(n,d)] + g_dict[i+N-demandTime[n,d][i]][(n,d)]

        # if i == 0:
        #     q = q - np.dot(G, f_list[i]+g_list[i])+ np.dot(B, f_list[i]+g_list[i])
        #     # q = q - np.dot(G, f_list[i]+g_list[i])
        # else:
        #     q = q - np.dot(G, f_list[i]+g_list[i])+ np.dot(B, f_list[i-1]+g_list[i-1])
        q = q - np.dot(G, f_list[i]+g_list[i])


    return f_list, g_list, q_list

def benchmark_policy(lambda_ls, cost_ls, price_ls, demandTime, rebTime, pairs, N, q0, num):

    M = len(pairs)
    A = np.zeros((N,M))
    G = np.zeros((N,M))
    B = np.zeros((N,M))
    m = 0
    for i,j in pairs:
        A[i, m] = 1
        A[j, m] = -1
        G[i, m] = 1
        B[j, m] = 1
        m += 1

    cost_dict = defaultdict(dict)
    price_dict = defaultdict(dict)
    for i in range(num):
        for (o,d) in pairs:
            cost_dict[i][(o,d)] = cost_ls[i][pairs.index((o,d))]
            price_dict[i][(o,d)] = price_ls[i][pairs.index((o,d))]

    mdl = gp.Model('original_vector')
    # print('rebTime:', rebTime)
    # print('demandTime:', demandTime)
    # print('cost:', cost_dict)
    # print('price:', price_dict)
    # mdl.setParam('OutputFlag', 0)

    # f = mdl.addMVar(shape = (num, M), lb = 0, vtype = GRB.CONTINUOUS, name = 'f')
    # g = mdl.addMVar(shape = (num, M), lb = 0, vtype = GRB.CONTINUOUS, name = 'g')

    max_value = float('-inf')  # 初始化为负无穷，确保任何数都比它大
    # 遍历字典
    for key in demandTime:
        sub_dict = demandTime[key]  # 获取内部字典
        current_max = max(sub_dict.values())  # 计算当前内部字典的最大值
        if current_max > max_value:
            max_value = current_max  # 更新最大值
    L = max_value

    E = pairs
    f = mdl.addVars(E, num+L, vtype = GRB.CONTINUOUS, name = 'f')
    g = mdl.addVars(E, num+L, vtype = GRB.CONTINUOUS, name = 'g')
    n = mdl.addMVar(shape = (num+1, N), lb = 0, vtype = GRB.CONTINUOUS, name = 'n')

    mdl.addConstrs(f[i,j,t] == 0 for t in range(0,L) for (i,j) in E)
    mdl.addConstrs(g[i,j,t] == 0 for t in range(0,L) for (i,j) in E)

    for t in range(0, num):
            
        if t == 0:
            mdl.addConstr(n[t, :] == q0)

        for i in range(N):
            
            mdl.addConstr(n[t+1, i] == n[t, i] - gp.quicksum(f[i,j,t+L] for j in range(N) if (i,j) in E) - gp.quicksum(g[i,j,t+L] for j in range(N) if (i,j) in E) + gp.quicksum(f[j,i,t+L-rebTime[j,i][t]] for j in range(N) if (j,i) in E) + gp.quicksum(g[j,i,t+L-demandTime[j,i][t]] for j in range(N) if (j,i) in E))
            # mdl.addConstr(n[t+1, i] == n[t+1, i] + gp.quicksum(f[j,i,t+N-1] for j in range(N) if (j,i) in E) + gp.quicksum(g[j,i,t+N-1] for j in range(N) if (j,i) in E))
            mdl.addConstr(n[t, i] - gp.quicksum(f[i,j,t+L] for j in range(N) if (i,j) in E) - gp.quicksum(g[i,j,t+L] for j in range(N) if (i,j) in E) >= 0)
        
        idx = 0
        for i,j in E:           
            mdl.addConstr(g[i,j,t+L] <= np.array(lambda_ls[t][idx]))
            idx += 1

    # mdl.setObjective(qsum(np.array(price_ls[t])@g[t,:] - np.array(cost_ls[t])@f[t,:] for t in range(0,num)), GRB.MAXIMIZE)

    # obj = 0
    # for t in range(0,num):
    #     for o,d in E:
    #         obj += price_dict[t][o,d]*g[o,d,t+N] - cost_dict[t][o,d]*f[o,d,t+N] -demandTime[o,d][t]*beta*g[o,d,t+N]
    mdl.setObjective(gp.quicksum(price_dict[t][o,d]*g[o,d,t+L] - cost_dict[t][o,d]*f[o,d,t+L]*beta -demandTime[o,d][t]*beta*g[o,d,t+L] for (o,d) in pairs for t in range(0,num)), GRB.MAXIMIZE)
    # mdl.setObjective(qsum(price_dict[t][o,d]*g[o,d,t+N] for (o,d) in pairs for t in range(0,num)), GRB.MAXIMIZE)
    mdl.optimize()

    f_list = [[] for _ in range(num)]
    g_list = [[] for _ in range(num)]
    for t0 in range(num):
        for (i,j,t) in f.keys():
            if t == t0+L:
                # print(i,j,t)
                f_list[t0].append(f[i, j, t].X)
                g_list[t0].append(g[i, j, t].X)
        f_list[t0] = np.array(f_list[t0])
        g_list[t0] = np.array(g_list[t0])
        # print(t0)
        # print('f:', np.sum(f_list[t0]))
        # print('g:', np.sum(g_list[t0]))
        # print('n:', np.sum(n.X[t0]))
        # print('lbd:', np.sum(np.array(lambda_ls[t0])))

    # f_list = [row for row in f.X]
    # g_list = [row for row in g.X]
    q_list = [row for row in n.X[:num, :]]
    print('objective value:', mdl.ObjVal)

    return f_list, g_list, q_list


def naive_invOptimization(f_list, g_list, q_list, pairs, N):
    M = len(pairs)
    A = np.zeros((N,M))
    m = 0
    for i,j in pairs:
        A[i, m] = 1
        A[j, m] = -1

    qhat_list = []
    for i in range(len(f_list)):
        qhat = q_list[i] - np.dot(A, f_list[i] + g_list[i])
        qhat_list.append(qhat)
    return np.array(qhat_list)

def MPC(env, num, N, cost_ls, price_ls, demandTime_ls, demand_input):

    # q0 = np.ones(N)*int(1500/N)


    pairs = env.edges
    M = len(pairs)
    A = np.zeros((N,M))
    G = np.zeros((N,M))
    B = np.zeros((N,M))
    m = 0
    for i,j in pairs:
        A[i, m] = 1
        A[j, m] = -1
        G[i, m] = 1
        B[j, m] = 1
        m += 1

    h = 10
    lambda_ls = [[] for _ in range(num+h)]
    
    for t in range(num+h):
        demand = []
        for k in range(len(pairs)):
            i, j = pairs[k]
            if (i, j) not in env.demand or t not in env.demand[i, j]:
                continue
            demand.append(env.demand[i, j][t])
        lambda_ls[t] = demand


    #MPC
    M = len(cost_ls[0])
    q0 = np.array([inner_dict[0] for _, inner_dict in sorted(env.acc.items())])
    obj = 0
    E = pairs
    # rebTime = env.rebTime
    # demandTime = env.demandTime
    demandTime = {}
    rebTime = {}
    ind = 0
    for i, j in pairs:
        demandTime[i,j] = demandTime_ls[0][ind]
        rebTime[i,j] = cost_ls[0][ind]
        ind += 1

    cost_dict = defaultdict(dict)
    price_dict = defaultdict(dict)
    for i in range(num+h):
        for (o,d) in pairs:
            cost_dict[i][(o,d)] = cost_ls[i][pairs.index((o,d))]
            price_dict[i][(o,d)] = price_ls[i][pairs.index((o,d))]

    f_dict = defaultdict(dict)
    g_dict = defaultdict(dict)
    f_ls = [[] for _ in range(num)]
    g_ls = [[] for _ in range(num)]
    q_ls = []
    for T in range(0,num):
        q_ls.append(q0)
        mdl = gp.Model('original_vector'+str(T))
        mdl.setParam('OutputFlag', 0)

        # f = mdl.addMVar(shape = (h, M), lb = 0, vtype = GRB.CONTINUOUS, name = 'f')
        # g = mdl.addMVar(shape = (h, M), lb = 0, vtype = GRB.CONTINUOUS, name = 'g')
        # n = mdl.addMVar(shape = (h+1, N), lb = 0, vtype = GRB.CONTINUOUS, name = 'n')

        f = mdl.addVars(E, h, vtype = GRB.CONTINUOUS, name = 'f')
        g = mdl.addVars(E, h, vtype = GRB.CONTINUOUS, name = 'g')
        n = mdl.addMVar(shape = (h+1, N), lb = 0, vtype = GRB.CONTINUOUS, name = 'n')

        for t in range(0,h):
            # Flow conservation constraint
            if t == 0:
                mdl.addConstr(n[t, :] == q0)
            for i in range(N):
                summ1 = qsum(f[j,i,t-rebTime[j,i]] for j in range(N) if (j,i) in E and t-rebTime[j,i]>=0) + qsum(g[j,i,t-demandTime[j,i]] for j in range(N) if (j,i) in E and t-demandTime[j,i]>=0)
                summ2 = sum(f_dict[T+t-rebTime[j,i]][(j,i)] for j in range(N) if (j,i) in E and T+t-rebTime[j,i] in f_dict) + sum(g_dict[T+t-demandTime[j,i]][(j,i)] for j in range(N) if (j,i) in E and T+t-demandTime[j,i] in g_dict)
                # summ2 = 0
                # for j in range(N):
                #     if (j,i) in E:
                #         if T+t-rebTime[j,i][t] in f_dict:
                #             summ2 += f_dict[T+t-rebTime[j,i][t]][(j,i)]
                #         if T+t-demandTime[j,i][t] in g_dict:
                #             summ2 += g_dict[T+t-demandTime[j,i][t]][(j,i)]
                        
                mdl.addConstr(n[t+1, i] == n[t, i] - gp.quicksum(f[i,j,t] for j in range(N) if (i,j) in E) - gp.quicksum(g[i,j,t] for j in range(N) if (i,j) in E) + summ1 + summ2) 

                # for j in range(N):
                #     if (j,i) in E:
                #         try:
                #             mdl.addConstr(n[t+1, i] == n[t+1, i] + f[j,i,t-rebTime[j,i][t]])
                #         except:
                #             try: 
                #                 mdl.addConstr(n[t+1, i] == n[t+1, i] + f_dict[T+t-rebTime[j,i][t]][(j,i)])
                #             except: 
                #                 pass
                #         try:
                #             mdl.addConstr(n[t+1, i] == n[t+1, i] + g[j,i,t-demandTime[j,i][t]])
                #         except:
                #             try: 
                #                 mdl.addConstr(n[t+1, i] == n[t+1, i] + g_dict[T+t-demandTime[j,i][t]][(j,i)])
                #             except: pass
                
                # mdl.addConstr(n[t+1, i] == n[t+1, i] + gp.quicksum(f[j,i,t+N-1] for j in range(N) if (j,i) in E) + gp.quicksum(g[j,i,t+N-1] for j in range(N) if (j,i) in E))
                mdl.addConstr(n[t, i] - gp.quicksum(f[i,j,t] for j in range(N) if (i,j) in E) - gp.quicksum(g[i,j,t] for j in range(N) if (i,j) in E) >= 0)
        
            idx = 0
            for i, j in E:           
                mdl.addConstr(g[i,j,0] <= np.array(lambda_ls[T][idx]))
                idx += 1
                if t > 0:
                    mdl.addConstr(g[i,j,t] <= demand_input[i,j])

        mdl.setObjective(gp.quicksum(price_dict[T+t][o,d]*g[o,d,t] - cost_dict[T+t][o,d]*f[o,d,t]*beta -demandTime[o,d]*beta*g[o,d,t] for (o,d) in pairs for t in range(0,h)), GRB.MAXIMIZE)
        # mdl.setObjective(gp.quicksum(price_dict[T+t][o,d]*g[o,d,t] - cost_dict[T+t][o,d]*f[o,d,t] for (o,d) in pairs for t in range(0,h)), GRB.MAXIMIZE)
        mdl.optimize()
        # print('T:', T)
        # print('obj:', mdl.getObjective().getValue())
        for (o,d) in pairs:
            f_dict[T][(o,d)] = f[o,d,0].X
            g_dict[T][(o,d)] = g[o,d,0].X
        
        q0 = n.X[1, :]
        # for i in range(N):
        #     q0[i] = n.X[1, i] 
        #     for j in range(N):
        #         if (j,i) in E:
        #             if T+1-rebTime[j,i][t] in f_dict:
        #                 q0[i] += f_dict[T+1-rebTime[j,i][t]][(j,i)]
        #             if T+1-demandTime[j,i][t] in g_dict:
        #                 q0[i] += g_dict[T+1-demandTime[j,i][t]][(j,i)]

        # obj_b = np.dot(price_ls[T],g.X[0,:]) - np.dot(cost_ls[T],f.X[0,:])
        obj_b = 0
        for o,d in pairs:
            obj_b += price_dict[T][o,d]*g[o,d,0].X - cost_dict[T][o,d]*f[o,d,0].X*beta -demandTime[o,d]*beta*g[o,d,0].X
        obj += obj_b

        for k in range(len(pairs)):
            i, j = pairs[k]
            f_ls[T].append(f[i, j, 0].X)
            g_ls[T].append(g[i, j, 0].X)

    # f_list = [[] for _ in range(num)]
    # g_list = [[] for _ in range(num)]
    # for t0 in range(num):
    #     for i,j in pairs:
    #         f_list[t0].append(f_dict[t0][(i, j)])
    #         g_list[t0].append(g_dict[t0][(i, j)])
    #     f_list[t0] = np.array(f_list[t0])
    #     g_list[t0] = np.array(g_list[t0])
        # print(t0)
        # print('f:', np.sum(f_list[t0]))
        # print('g:', np.sum(g_list[t0]))
        # print('lbd:', np.sum(np.array(lambda_ls[t0])))


    return obj, f_ls, g_ls, q_ls


# def MPC(lambda_ls, cost_ls, price_ls, demandTime, rebTime, pairs, N, q0, num):

#     M = len(pairs)
#     A = np.zeros((N,M))
#     G = np.zeros((N,M))
#     B = np.zeros((N,M))
#     m = 0
#     for i,j in pairs:
#         A[i, m] = 1
#         A[j, m] = -1
#         G[i, m] = 1
#         B[j, m] = 1
#         m += 1

    # cost_dict = defaultdict(dict)
    # price_dict = defaultdict(dict)
    # h = 10
    # for i in range(num+h):
    #     for (o,d) in pairs:
    #         cost_dict[i][(o,d)] = cost_ls[i][pairs.index((o,d))]
    #         price_dict[i][(o,d)] = price_ls[i][pairs.index((o,d))]

    # f_dict = defaultdict(dict)
    # g_dict = defaultdict(dict)
    # f_ls = [[] for _ in range(num)]
    # g_ls = [[] for _ in range(num)]
    # q_ls = []
    # E = pairs
    # for T in range(0,num):
    #     q_ls.append(q0)
    #     mdl = gp.Model('original_vector'+str(T))
    #     mdl.setParam('OutputFlag', 0)

    #     # f = mdl.addMVar(shape = (h, M), lb = 0, vtype = GRB.CONTINUOUS, name = 'f')
    #     # g = mdl.addMVar(shape = (h, M), lb = 0, vtype = GRB.CONTINUOUS, name = 'g')
    #     # n = mdl.addMVar(shape = (h+1, N), lb = 0, vtype = GRB.CONTINUOUS, name = 'n')

    #     f = mdl.addVars(E, h, vtype = GRB.CONTINUOUS, name = 'f')
    #     g = mdl.addVars(E, h, vtype = GRB.CONTINUOUS, name = 'g')
    #     n = mdl.addMVar(shape = (h+1, N), lb = 0, vtype = GRB.CONTINUOUS, name = 'n')

        # for t in range(0,h):
        #     # Flow conservation constraint
        #     if t == 0:
        #         mdl.addConstr(n[t, :] == q0)
        #     for i in range(N):
        #         summ1 = qsum(f[j,i,t-rebTime[j,i]] for j in range(N) if (j,i) in E and t-rebTime[j,i]>=0) + qsum(g[j,i,t-demandTime[j,i]] for j in range(N) if (j,i) in E and t-demandTime[j,i]>=0)
        #         summ2 = sum(f_dict[T+t-rebTime[j,i]][(j,i)] for j in range(N) if (j,i) in E and T+t-rebTime[j,i] in f_dict) + sum(g_dict[T+t-demandTime[j,i]][(j,i)] for j in range(N) if (j,i) in E and T+t-demandTime[j,i] in g_dict)
                        
        #         mdl.addConstr(n[t+1, i] == n[t, i] - gp.quicksum(f[i,j,t] for j in range(N) if (i,j) in E) - gp.quicksum(g[i,j,t] for j in range(N) if (i,j) in E) + summ1 + summ2) 
                
        #         # mdl.addConstr(n[t+1, i] == n[t+1, i] + gp.quicksum(f[j,i,t+N-1] for j in range(N) if (j,i) in E) + gp.quicksum(g[j,i,t+N-1] for j in range(N) if (j,i) in E))
        #         mdl.addConstr(n[t, i] - gp.quicksum(f[i,j,t] for j in range(N) if (i,j) in E) - gp.quicksum(g[i,j,t] for j in range(N) if (i,j) in E) >= 0)
        
        #     idx = 0
        #     for i,j in E:           
        #         mdl.addConstr(g[i,j,t] <= np.array(lambda_ls[T+t][idx]))
        #         idx += 1

    #     beta = 1.5
    #     mdl.setObjective(gp.quicksum(price_dict[T+t][o,d]*g[o,d,t] - cost_dict[T+t][o,d]*f[o,d,t]*beta -demandTime[o,d]*beta*g[o,d,t] for (o,d) in pairs for t in range(0,h)), GRB.MAXIMIZE)
    #     # mdl.setObjective(gp.quicksum(price_dict[T+t][o,d]*g[o,d,t] - cost_dict[T+t][o,d]*f[o,d,t] for (o,d) in pairs for t in range(0,h)), GRB.MAXIMIZE)
    #     mdl.optimize()
    #     # print('T:', T)
    #     # print('obj:', mdl.getObjective().getValue())
    #     for (o,d) in pairs:
    #         f_dict[T][(o,d)] = f[o,d,0].X
    #         g_dict[T][(o,d)] = g[o,d,0].X
        
    #     q0 = n.X[1, :]

    #     obj_b = 0
    #     for o,d in pairs:
    #         obj_b += price_dict[T][o,d]*g[o,d,0].X - cost_dict[T][o,d]*f[o,d,0].X*beta -demandTime[o,d]*beta*g[o,d,0].X
    #     obj += obj_b

    #     for k in range(len(pairs)):
    #         i, j = pairs[k]
    #         f_ls[T].append(f[i, j, 0].X)
    #         g_ls[T].append(g[i, j, 0].X)


    # return obj, f_ls, g_ls, q_ls


# %%
def Generate_History(tripAttr, rebTime, demandTime, N, acc, edge_index, num):
    
    tripAttr_ls, rebTime_all, demandTime_all, demand_input = gh.Generate_Data(1)
    f_list, g_list, q_list, lambda_ls, price_ls, cost_ls, demandTime_ls, pairs = gh.Generate_History(tripAttr_ls, rebTime_all, demandTime_all, demand_input, N=N, acc=acc, num=num)
    M = len(f_list[0])
    A = np.zeros((N,M))
    G = np.zeros((N,M))
    B = np.zeros((N,M))
    m = 0
    for i,j in pairs:
        A[i, m] = 1
        A[j, m] = -1
        G[i, m] = 1
        B[j, m] = 1
        m += 1

    qhat2 = naive_invOptimization(f_list, g_list, q_list, pairs, N)
    # print('qhat2:', qhat2)
    # qhat2 = np.random.randint(70, 110, size = (num, N))
    # qhat2 = App_invOptimization(f_list, g_list, q_list, lambda_ls, cost_ls, price_ls, N, pairs)
    # qhat2 = test(f_list, g_list, q_list, lambda_ls, cost_ls, price_ls, N, pairs)

    print('inverse optimization 1:')
    theta_f, theta_g, mu = invOptimization(f_list, g_list, q_list, lambda_ls, cost_ls, price_ls, demandTime_ls, N, pairs)

    qhat = np.random.randint(70, 110, size = (num, N))
    # theta_f = np.random.rand(len(pairs))
    # theta_g = np.random.rand(len(pairs))
    # mu = np.random.rand(num)

    # print('qhat:', qhat)
    print('theta_f:', theta_f)
    print('theta_g:', theta_g)
    
    print('inverse optimziation 2:')
    # qhat3, theta_f2, theta_g2, mu2, sums = invOptimization2(f_list, g_list, q_list, lambda_ls, cost_ls, price_ls, N, pairs)
    # print('qhat3:', qhat3)


    # qhat3 = np.random.randint(20, 80, size = (num, N))
    qhat3 = np.zeros((num, N))
    
    theta_f2 = np.random.rand(len(pairs))
    theta_g2 = np.random.rand(len(pairs))
    mu2 = np.random.rand(num)
    sums = np.random.rand(num)
   
  
    #generate f_ls, g_ls
    f_ls = []
    g_ls = []
    for t in range(num-10):
        f = [{} for _ in range(N)]
        d = 0
        for i,j in pairs:
            f[i][j] = f_list[t][d]
            d += 1
        f_ls.append(f)

        g = [{} for _ in range(N)]
        d = 0
        for i,j in pairs:
            g[i][j] = g_list[t][d]
            d += 1
        g_ls.append(g)


    process_data(f_ls, g_ls, f_list, g_list, q_list, cost_ls, price_ls, lambda_ls, qhat, edge_index, pairs, 'wyx_replay_buffer.pkl')
    process_data(f_ls, g_ls, f_list, g_list, q_list, cost_ls, price_ls, lambda_ls, qhat2, edge_index, pairs, 'others_replay_buffer.pkl')
    process_data(f_ls, g_ls, f_list, g_list, q_list, cost_ls, price_ls, lambda_ls, qhat3, edge_index, pairs, 'others2_replay_buffer.pkl')
   
    return theta_f, theta_g, mu

def Benchmark(env, num, N, cost_ls, price_ls, demandTime_ls):

    pairs = env.edges
    M = len(pairs)
    A = np.zeros((N,M))
    G = np.zeros((N,M))
    B = np.zeros((N,M))
    m = 0
    for i,j in pairs:
        A[i, m] = 1
        A[j, m] = -1
        G[i, m] = 1
        B[j, m] = 1
        m += 1

          
    lambda_ls = [[] for _ in range(num)]
    
    for t in range(num):
        demand = []
        for k in range(len(pairs)):
            i, j = pairs[k]
            if (i, j) not in env.demand or t not in env.demand[i, j]:
                continue
            demand.append(env.demand[i, j][t])
        lambda_ls[t] = demand

    #MPC
    M = len(cost_ls[0])
    q0 = np.array([inner_dict[0] for _, inner_dict in sorted(env.acc.items())])

    E = pairs
    # rebTime = env.rebTime
    # demandTime = env.demandTime
    demandTime = {}
    rebTime = {}
    ind = 0
    for i, j in pairs:
        demandTime[i,j] = demandTime_ls[0][ind]
        rebTime[i,j] = cost_ls[0][ind]
        ind += 1

    #benchmark
    # M = len(cost_ls[0])
    # mdl = gp.Model('original_vector')
    # mdl.setParam('OutputFlag', 0)

    # f = mdl.addMVar(shape = (num, M), lb = 0, vtype = GRB.CONTINUOUS, name = 'f')
    # g = mdl.addMVar(shape = (num, M), lb = 0, vtype = GRB.CONTINUOUS, name = 'g')
    # n = mdl.addMVar(shape = (num+1, N), lb = 0, vtype = GRB.CONTINUOUS, name = 'n')

    # timeset = range(0, num) 

    # for t in timeset:
    #     # Flow conservation constraint
    #     if 0 < t < max(timeset):
    #         # mdl.addConstr(n[t+1, :] == n[t, :] - A@(f[t,:]+g[t,:]))
    #         mdl.addConstr(n[t+1, :] == n[t, :] - G@(f[t,:]+g[t,:])+ B@(f[t-1,:]+g[t-1,:]))
    #         # Non-negativity of node inventory
    #         mdl.addConstr(n[t, :] >= 0)
    #     if t == 0:
    #         # mdl.addConstr(n[t + 1,:] == n_0 - A@(f[t,:]+g[t,:]))
    #         # mdl.addConstr(n[t + 1,:] == n_0 - A@(f[t,:]+g[t,:]))
    #         mdl.addConstr(n[t + 1,:] == q0 - G@(f[t,:]+g[t,:]))
    #     mdl.addConstr(g[t, :] <= np.array(lambda_ls[t]))

    # mdl.setObjective(qsum(np.array(price_ls[t])@g[t,:] - np.array(cost_ls[t])@f[t,:] for t in timeset), GRB.MAXIMIZE)
    # mdl.optimize()

    cost_dict = defaultdict(dict)
    price_dict = defaultdict(dict)
    for i in range(num):
        for (o,d) in pairs:
            cost_dict[i][(o,d)] = cost_ls[i][pairs.index((o,d))]
            price_dict[i][(o,d)] = price_ls[i][pairs.index((o,d))]

    mdl = gp.Model('original_vector')
    mdl.setParam('OutputFlag', 0)
    # print('rebTime:', rebTime)
    # print('demandTime:', demandTime)
    # print('cost:', cost_dict)
    # print('price:', price_dict)
    # mdl.setParam('OutputFlag', 0)

    # f = mdl.addMVar(shape = (num, M), lb = 0, vtype = GRB.CONTINUOUS, name = 'f')
    # g = mdl.addMVar(shape = (num, M), lb = 0, vtype = GRB.CONTINUOUS, name = 'g')

    # max_value = float('-inf')  # 初始化为负无穷，确保任何数都比它大
    # # 遍历字典
    # for key in demandTime:
    #     sub_dict = demandTime[key]  # 获取内部字典
    #     current_max = max(sub_dict.values())  # 计算当前内部字典的最大值
    #     if current_max > max_value:
    #         max_value = current_max  # 更新最大值
    # L = int(max_value)

    E = pairs
    f = mdl.addVars(E, num, vtype = GRB.CONTINUOUS, name = 'f')
    g = mdl.addVars(E, num, vtype = GRB.CONTINUOUS, name = 'g')
    n = mdl.addMVar(shape = (num+1, N), lb = 0, vtype = GRB.CONTINUOUS, name = 'n')

    # mdl.addConstrs(f[i,j,t] == 0 for t in range(0,L) for (i,j) in E)
    # mdl.addConstrs(g[i,j,t] == 0 for t in range(0,L) for (i,j) in E)
    for t in range(0, num):
            
        if t == 0:
            mdl.addConstr(n[t, :] == q0)

        for i in range(N):
            
            mdl.addConstr(n[t+1, i] == n[t, i] - gp.quicksum(f[i,j,t] for j in range(N) if (i,j) in E) - gp.quicksum(g[i,j,t] for j in range(N) if (i,j) in E) + gp.quicksum(f[j,i,t-rebTime[j,i]] for j in range(N) if (j,i) in E and t-rebTime[j,i]>=0) + gp.quicksum(g[j,i,t-demandTime[j,i]] for j in range(N) if (j,i) in E and t-demandTime[j,i]>=0))
            # mdl.addConstr(n[t+1, i] == n[t+1, i] + gp.quicksum(f[j,i,t+N-1] for j in range(N) if (j,i) in E) + gp.quicksum(g[j,i,t+N-1] for j in range(N) if (j,i) in E))
            mdl.addConstr(n[t, i] - gp.quicksum(f[i,j,t] for j in range(N) if (i,j) in E) - gp.quicksum(g[i,j,t] for j in range(N) if (i,j) in E) >= 0)
        
        idx = 0
        for i,j in E:           
            mdl.addConstr(g[i,j,t] <= np.array(lambda_ls[t][idx]))
            idx += 1

    # mdl.setObjective(qsum(np.array(price_ls[t])@g[t,:] - np.array(cost_ls[t])@f[t,:] for t in range(0,num)), GRB.MAXIMIZE)
    
    # obj = 0
    # for t in range(0,num):
    #     for o,d in E:
    #         obj += price_dict[t][o,d]*g[o,d,t+N] - cost_dict[t][o,d]*f[o,d,t+N] -demandTime[o,d][t]*beta*g[o,d,t+N]
    mdl.setObjective(gp.quicksum(price_dict[t][o,d]*g[o,d,t] - cost_dict[t][o,d]*f[o,d,t]*beta -demandTime[o,d]*beta*g[o,d,t] for (o,d) in pairs for t in range(0,num)), GRB.MAXIMIZE)
    # mdl.setObjective(qsum(price_dict[t][o,d]*g[o,d,t+N] for (o,d) in pairs for t in range(0,num)), GRB.MAXIMIZE)
    mdl.optimize()
    
    # print('benchmark:', obj_b)

    return mdl.ObjVal


def Verify(env, result, num):

    df = pd.read_csv('topology.csv')
    N = df['i'].nunique()
    # q0 = np.ones(N)*int(1500/N)

    pairs = env.edges
    M = len(pairs)
    A = np.zeros((N,M))
    G = np.zeros((N,M))
    B = np.zeros((N,M))
    m = 0
    for i,j in pairs:
        A[i, m] = 1
        A[j, m] = -1
        G[i, m] = 1
        B[j, m] = 1
        m += 1

    lambda_ls = [[] for _ in range(num)]
    price_ls = [[] for _ in range(num)]
    cost_ls = [[] for _ in range(num)]
  
    for t in range(num):
        p = []
        c = []
        demand = []
        for k in range(len(pairs)):
            i, j = pairs[k]
            if (i, j) not in env.demand or t not in env.demand[i, j]:
                continue
            p.append(env.price[i, j][t])
            c.append(env.rebTime[i, j][t])
            demand.append(env.demand[i, j][t])
        cost_ls[t] = c
        price_ls[t] = p
        lambda_ls[t] = demand

    q0 = np.array([inner_dict[0] for _, inner_dict in sorted(env.acc.items())])

    # mdl.setObjective(qsum(np.array(cost_ls[t])@f[t,:] - np.array(price_ls[t])@g[t,:] for t in timeset), GRB.MINIMIZE)

    # print(type(result), result)
    q_ls = []
    q_ls.append(q0)
    g_ls, f_ls = [], []
    reward = 0
    for t in range(num):
        g = result[t][0]
        f = result[t][1]
        g_ls.append(g)
        f_ls.append(f)
        if t == 0:
            q_ls.append(q0 - np.dot(G, f_ls[t]+g_ls[t]))
        else:
            q_ls.append(q_ls[t] - np.dot(G, f_ls[t]+g_ls[t])+ np.dot(B, f_ls[t-1]+g_ls[t-1]))
            print('n>=0:', q_ls[t]>=0)
        print('demand limit:', g_ls[t]<=np.array(lambda_ls[t]))
        reward += np.dot(np.array(price_ls[t]), g_ls[t]) - np.dot(np.array(cost_ls[t]), f_ls[t]) 
    print('reward:', reward)



    

