# %% import packages
import numpy as np
import gurobipy as gp
from gurobipy import GRB
from gurobipy import quicksum as qsum
import pandas as pd
import torch
from collections import defaultdict
from torch_geometric.data import Data, Batch
import pickle
import random
import generate_history as gh

beta = 0.3

def MPC(env, num, N, cost_ls, price_ls, demandTime_ls, demand_input, beta):

    pairs = env.edges
    M = len(pairs)
    A = np.zeros((N,M))
    G = np.zeros((N,M))
    B = np.zeros((N,M))
    m = 0
    for i,j in pairs:
        A[i, m] = 1
        A[j, m] = -1
        G[i, m] = 1
        B[j, m] = 1
        m += 1

    h = 10
    lambda_ls = [[] for _ in range(num+h)]
    
    for t in range(num+h):
        demand = []
        for k in range(len(pairs)):
            i, j = pairs[k]
            if (i, j) not in env.demand or t not in env.demand[i, j]:
                continue
            demand.append(env.demand[i, j][t])
        lambda_ls[t] = demand

    for i,j in pairs:
        for t in range(num+h):
            if (i,j) not in demand_input or t not in demand_input[i,j]:
                demand_input[i,j][t] = 0

    #MPC
    M = len(cost_ls[0])
    q0 = np.array([inner_dict[0] for _, inner_dict in sorted(env.acc.items())])
    obj = 0
    E = pairs
    # rebTime = env.rebTime
    # demandTime = env.demandTime
    demandTime = {}
    rebTime = {}
    ind = 0
    for i, j in pairs:
        demandTime[i,j] = demandTime_ls[0][ind]
        rebTime[i,j] = cost_ls[0][ind]
        ind += 1

    cost_dict = defaultdict(dict)
    price_dict = defaultdict(dict)
    for i in range(num+h):
        for (o,d) in pairs:
            cost_dict[i][(o,d)] = cost_ls[i][pairs.index((o,d))]
            price_dict[i][(o,d)] = price_ls[i][pairs.index((o,d))]

    f_dict = defaultdict(dict)
    g_dict = defaultdict(dict)
    f_ls = [[] for _ in range(num)]
    g_ls = [[] for _ in range(num)]
    q_ls = []
    for T in range(0,num):
        q_ls.append(q0)
        mdl = gp.Model('original_vector'+str(T))
        mdl.setParam('OutputFlag', 0)

        f = mdl.addVars(E, h, vtype = GRB.CONTINUOUS, name = 'f')
        g = mdl.addVars(E, h, vtype = GRB.CONTINUOUS, name = 'g')
        n = mdl.addMVar(shape = (h+1, N), lb = 0, vtype = GRB.CONTINUOUS, name = 'n')

        for t in range(0,h):
            # Flow conservation constraint
            if t == 0:
                mdl.addConstr(n[t, :] == q0)
            for i in range(N):
                summ1 = qsum(f[j,i,t-rebTime[j,i]] for j in range(N) if (j,i) in E and t-rebTime[j,i]>=0) + qsum(g[j,i,t-demandTime[j,i]] for j in range(N) if (j,i) in E and t-demandTime[j,i]>=0)
                summ2 = sum(f_dict[T+t-rebTime[j,i]][(j,i)] for j in range(N) if (j,i) in E and T+t-rebTime[j,i] in f_dict) + sum(g_dict[T+t-demandTime[j,i]][(j,i)] for j in range(N) if (j,i) in E and T+t-demandTime[j,i] in g_dict)
                        
                mdl.addConstr(n[t+1, i] == n[t, i] - gp.quicksum(f[i,j,t] for j in range(N) if (i,j) in E) - gp.quicksum(g[i,j,t] for j in range(N) if (i,j) in E) + summ1 + summ2) 

                mdl.addConstr(n[t, i] - gp.quicksum(f[i,j,t] for j in range(N) if (i,j) in E) - gp.quicksum(g[i,j,t] for j in range(N) if (i,j) in E) >= 0)
        
            idx = 0
            for i, j in E:           
                mdl.addConstr(g[i,j,0] <= np.array(lambda_ls[T][idx]))
                idx += 1
                if t > 0:
                    mdl.addConstr(g[i,j,t] <= demand_input[i,j][T+t])

        mdl.setObjective(gp.quicksum(price_dict[T+t][o,d]*g[o,d,t] - cost_dict[T+t][o,d]*f[o,d,t]*beta -demandTime[o,d]*beta*g[o,d,t] for (o,d) in pairs for t in range(0,h)), GRB.MAXIMIZE)
        mdl.optimize()
     
        for (o,d) in pairs:
            f_dict[T][(o,d)] = f[o,d,0].X
            g_dict[T][(o,d)] = g[o,d,0].X
        
        q0 = n.X[1, :]
        obj_b = 0
        for o,d in pairs:
            obj_b += price_dict[T][o,d]*g[o,d,0].X - cost_dict[T][o,d]*f[o,d,0].X*beta -demandTime[o,d]*beta*g[o,d,0].X
        obj += obj_b

        for k in range(len(pairs)):
            i, j = pairs[k]
            f_ls[T].append(f[i, j, 0].X)
            g_ls[T].append(g[i, j, 0].X)


    return obj, f_ls, g_ls, q_ls


def Benchmark(env, num, N, cost_ls, price_ls, demandTime_ls, beta):

    pairs = env.edges
    M = len(pairs)
    A = np.zeros((N,M))
    G = np.zeros((N,M))
    B = np.zeros((N,M))
    m = 0
    for i,j in pairs:
        A[i, m] = 1
        A[j, m] = -1
        G[i, m] = 1
        B[j, m] = 1
        m += 1

          
    lambda_ls = [[] for _ in range(num)]
    
    for t in range(num):
        demand = []
        for k in range(len(pairs)):
            i, j = pairs[k]
            if (i, j) not in env.demand or t not in env.demand[i, j]:
                continue
            demand.append(env.demand[i, j][t])
        lambda_ls[t] = demand

    #MPC
    M = len(cost_ls[0])
    q0 = np.array([inner_dict[0] for _, inner_dict in sorted(env.acc.items())])

    E = pairs
    demandTime = {}
    rebTime = {}
    ind = 0
    for i, j in pairs:
        demandTime[i,j] = demandTime_ls[0][ind]
        rebTime[i,j] = cost_ls[0][ind]
        ind += 1

    cost_dict = defaultdict(dict)
    price_dict = defaultdict(dict)
    for i in range(num):
        for (o,d) in pairs:
            cost_dict[i][(o,d)] = cost_ls[i][pairs.index((o,d))]
            price_dict[i][(o,d)] = price_ls[i][pairs.index((o,d))]

    mdl = gp.Model('original_vector')
    mdl.setParam('OutputFlag', 0)

    E = pairs
    f = mdl.addVars(E, num, vtype = GRB.CONTINUOUS, name = 'f')
    g = mdl.addVars(E, num, vtype = GRB.CONTINUOUS, name = 'g')
    n = mdl.addMVar(shape = (num+1, N), lb = 0, vtype = GRB.CONTINUOUS, name = 'n')

    for t in range(0, num):
            
        if t == 0:
            mdl.addConstr(n[t, :] == q0)

        for i in range(N):
            
            mdl.addConstr(n[t+1, i] == n[t, i] - gp.quicksum(f[i,j,t] for j in range(N) if (i,j) in E) - gp.quicksum(g[i,j,t] for j in range(N) if (i,j) in E) + gp.quicksum(f[j,i,t-rebTime[j,i]] for j in range(N) if (j,i) in E and t-rebTime[j,i]>=0) + gp.quicksum(g[j,i,t-demandTime[j,i]] for j in range(N) if (j,i) in E and t-demandTime[j,i]>=0))
            mdl.addConstr(n[t, i] - gp.quicksum(f[i,j,t] for j in range(N) if (i,j) in E) - gp.quicksum(g[i,j,t] for j in range(N) if (i,j) in E) >= 0)
        
        idx = 0
        for i,j in E:           
            mdl.addConstr(g[i,j,t] <= np.array(lambda_ls[t][idx]))
            idx += 1

    mdl.setObjective(gp.quicksum(price_dict[t][o,d]*g[o,d,t] - cost_dict[t][o,d]*f[o,d,t]*beta -demandTime[o,d]*beta*g[o,d,t] for (o,d) in pairs for t in range(0,num)), GRB.MAXIMIZE)
    mdl.optimize()


    return mdl.ObjVal


