
from collections import defaultdict
import numpy as np
import subprocess
import os
import networkx as nx
from utils import mat2str
from copy import deepcopy
import json
import gurobipy as gp
from gurobipy import GRB
from gurobipy import quicksum as qsum
import pandas as pd
import torch
from collections import defaultdict
from torch_geometric.data import Data, Batch
import pickle
import random

class Scenario_History:
    def __init__(self, N1=2, N2=4, tf=60, sd=None, ninit=5, tripAttr=None, demand_input=None, demand_ratio=None,
                 trip_length_preference=0.25, grid_travel_time=1, fix_price=True, alpha=0.2, json_file=None, json_hr=9, json_tstep=2, varying_time=False, json_regions=None, idx=0):
        # trip_length_preference: positive - more shorter trips, negative - more longer trips
        # grid_travel_time: travel time between grids
        # demand_input： list - total demand out of each region,
        #          float/int - total demand out of each region satisfies uniform distribution on [0, demand_input]
        #          dict/defaultdict - total demand between pairs of regions
        # demand_input will be converted to a variable static_demand to represent the demand between each pair of nodes
        # static_demand will then be sampled according to a Poisson distribution
        # alpha: parameter for uniform distribution of demand levels - [1-alpha, 1+alpha] * demand_input
        self.sd = sd
        if sd != None:
            np.random.seed(self.sd)
        if json_file == None:
            self.varying_time = varying_time
            self.is_json = False
            self.alpha = alpha
            self.trip_length_preference = trip_length_preference
            self.grid_travel_time = grid_travel_time
            self.demand_input = demand_input
            self.fix_price = fix_price
            self.N1 = N1
            self.N2 = N2
            self.G = nx.complete_graph(N1*N2)
            self.G = self.G.to_directed()
            self.demandTime = dict()
            self.rebTime = dict()
            # self.edges = list(self.G.edges) + [(i, i) for i in self.G.nodes]
            self.edges = list(self.G.edges) 
            for i, j in self.edges:
                self.demandTime[i, j] = defaultdict(
                    lambda: (abs(i//N1-j//N1) + abs(i % N1-j % N1))*grid_travel_time)
                self.rebTime[i, j] = defaultdict(
                    lambda: (abs(i//N1-j//N1) + abs(i % N1-j % N1))*grid_travel_time)
            for n in self.G.nodes:
                self.G.nodes[n]['accInit'] = int(ninit)
            self.tf = tf
            self.demand_ratio = defaultdict(list)

            if demand_ratio == None or type(demand_ratio) == list:
                for i, j in self.edges:
                    if type(demand_ratio) == list:
                        self.demand_ratio[i, j] = list(np.interp(range(0, tf), np.arange(
                            0, tf+1, tf/(len(demand_ratio)-1)), demand_ratio))+[demand_ratio[-1]]*tf
                    else:
                        self.demand_ratio[i, j] = [1]*(tf+tf)
            else:
                for i, j in self.edges:
                    if (i, j) in demand_ratio:
                        self.demand_ratio[i, j] = list(np.interp(range(0, tf), np.arange(
                            0, tf+1, tf/(len(demand_ratio[i, j])-1)), demand_ratio[i, j]))+[1]*tf
                    else:
                        self.demand_ratio[i, j] = list(np.interp(range(0, tf), np.arange(
                            0, tf+1, tf/(len(demand_ratio['default'])-1)), demand_ratio['default']))+[1]*tf
            if self.fix_price:  # fix price
                self.p = defaultdict(dict)
                for i, j in self.edges:
                    self.p[i, j] = (np.random.rand()*2+1) * \
                        (self.demandTime[i, j][0]+1)
            if tripAttr != None:  # given demand as a defaultdict(dict)
                self.tripAttr = deepcopy(tripAttr)
            else:
                self.tripAttr = self.get_random_demand()  # randomly generated demand
        else:
            self.varying_time = varying_time
            self.is_json = True
            with open(json_file, "r") as file:
                data = json.load(file)
            self.tstep = json_tstep
            self.N1 = data["nlat"]
            self.N2 = data["nlon"]
            self.demand_input = defaultdict(dict)
            self.json_regions = json_regions

            if json_regions != None:
                self.G = nx.complete_graph(json_regions)
            elif 'region' in data:
                self.G = nx.complete_graph(data['region'])
            else:
                # print('case3')
                self.G = nx.complete_graph(self.N1*self.N2)
            self.G = self.G.to_directed()
            self.p = defaultdict(dict)
            self.alpha = 0
            self.demandTime = defaultdict(dict)
            self.rebTime = defaultdict(dict)
            self.json_start = json_hr * 60
            self.tf = tf
            # self.edges = list(self.G.edges) + [(i, i) for i in self.G.nodes]
            self.edges = list(self.G.edges)

            for i, j in self.demand_input:
                self.demandTime[i, j] = defaultdict(int)
                self.rebTime[i, j] = 1

            for item in data["demand"]:
                t, o, d, v, tt, p = item["time_stamp"], item["origin"], item[
                    "destination"], item["demand"], item["travel_time"], item["price"]
                if json_regions != None and (o not in json_regions or d not in json_regions):
                    continue
                if (o, d) not in self.demand_input:
                    self.demand_input[o, d], self.p[o, d], self.demandTime[o, d] = defaultdict(
                        float), defaultdict(float), defaultdict(float)

                self.demand_input[o, d][(
                    t-self.json_start)//json_tstep] += v*demand_ratio
                self.p[o, d][(t-self.json_start) //
                             json_tstep] += p*v*demand_ratio
                self.demandTime[o, d][(t-self.json_start) //
                                      json_tstep] += tt*v*demand_ratio/json_tstep
            inner_keys = set()
            for outer_key in self.demand_input.keys():    
                inner_keys.update(self.demand_input[outer_key].keys())
            self.num = len(inner_keys)

            for o, d in self.edges:
                for t in self.demand_input[o, d]:
                    self.p[o, d][t] /= self.demand_input[o, d][t]
                    self.demandTime[o, d][t] /= self.demand_input[o, d][t]
                    self.demandTime[o, d][t] = max(
                        int(round(self.demandTime[o, d][t])), 1)
                
            for item in data["rebTime"]:
                hr, o, d, rt = item["time_stamp"], item["origin"], item["destination"], item["reb_time"]
                if json_regions != None and (o not in json_regions or d not in json_regions):
                    continue
                if varying_time:
                    t0 = int((hr*60 - self.json_start)//json_tstep)
                    t1 = int((hr*60 + 60 - self.json_start)//json_tstep)
                    for t in range(t0, t1):
                        self.rebTime[o, d][t] = max(
                            int(round(rt/json_tstep)), 1)
                else:
                    if hr == json_hr:
                        for t in range(0, self.num):
                            self.rebTime[o, d][t] = max(
                                int(round(rt/json_tstep)), 1)

            for item in data["totalAcc"]:
                hr, acc = item["hour"], item["acc"] #hr = 19,20,21
                if hr == json_hr+int(round(json_tstep/2*tf/60)):
                    for n in self.G.nodes:
                        self.G.nodes[n]['accInit'] = int(acc/len(self.G))
            # for n in self.G.nodes:
            #     self.G.nodes[n]['accInit'] = int(data["totalAcc"][0]['acc']/len(self.G))
         
            self.tripAttr = self.get_random_demand(idx)


        # with open('tripAttr.json', 'w') as f:
        #     json.dump(self.tripAttr, f)
        # with open('rebTime.json', 'w') as f:
        #     json.dump(self.rebTime, f)

    def get_random_demand(self, idx, reset=False):
        print('idx:', idx)
        # generate demand and price
        # reset = True means that the function is called in the reset() method of AMoD enviroment,
        #   assuming static demand is already generated
        # reset = False means that the function is called when initializing the demand

        demand = defaultdict(dict)
        price = defaultdict(dict)
        tripAttr = []
    
        # converting demand_input to static_demand
        # skip this when resetting the demand
        # if not reset:
        if self.is_json:     
            for i, j in self.edges:
                for t in range(self.num):
                    if (i, j) in self.demand_input and t in self.demand_input[i, j]:   
                        demand[i, j][t] = np.random.poisson(
                            self.demand_input[i, j][t])
                        price[i, j][t] = self.p[i, j][t]
                    else:
                        demand[i, j][t] = 0
                        price[i, j][t] = 0
                    tripAttr.append((i, j, t+60*idx, demand[i, j][t], price[i, j][t]))
        else:
            self.static_demand = dict()
            region_rand = (np.random.rand(len(self.G))
                           * self.alpha*2+1-self.alpha)
            if type(self.demand_input) in [float, int, list, np.array]:

                if type(self.demand_input) in [float, int]:
                    self.region_demand = region_rand * self.demand_input
                else:
                    self.region_demand = region_rand * \
                        np.array(self.demand_input)
                for i in self.G.nodes:
                    J = [j for _, j in self.G.out_edges(i)]
                    prob = np.array(
                        [np.math.exp(-self.rebTime[i, j][0]*self.trip_length_preference) for j in J])
                    prob = prob/sum(prob)
                    for idx in range(len(J)):
                        self.static_demand[i, J[idx]
                                           ] = self.region_demand[i] * prob[idx]
            elif type(self.demand_input) in [dict, defaultdict]:
                for i, j in self.edges:
                    self.static_demand[i, j] = self.demand_input[i, j] if (
                        i, j) in self.demand_input else self.demand_input['default']

                    self.static_demand[i, j] *= region_rand[i]
            else:
                raise Exception(
                    "demand_input should be number, array-like, or dictionary-like values")
        
            # generating demand and prices
            if self.fix_price:
                p = self.p
            for i, j in self.edges:
                for t in self.demand_input[i, j]:        
                    demand[i, j][t] = np.random.poisson(
                        self.static_demand[i, j]*self.demand_ratio[i, j][t])
                    if self.fix_price:
                        price[i, j][t] = p[i, j]
                    else:
                        price[i, j][t] = min(3, np.random.exponential(
                            2)+1)*self.demandTime[i, j][t]
                    tripAttr.append((i, j, t+60*idx, demand[i, j][t], price[i, j][t]))

        return tripAttr

#%%
city = 'nyc_brooklyn'
demand_ratio = {'san_francisco': 2, 'washington_dc': 4.2, 'nyc_brooklyn': 9, 'rome': 1.8,
                'shenzhen_downtown_west': 2.5}
json_hr = {'san_francisco': 19, 'washington_dc': 19, 'nyc_brooklyn': 19, 'rome': 8,
           'shenzhen_downtown_west': 8}
beta = {'san_francisco': 0.2, 'washington_dc': 0.5, 'nyc_brooklyn': 0.5, 'porto': 0.1, 'rome': 0.1,
        'shenzhen_downtown_west': 0.5}

test_tstep = {'san_francisco': 3,
              'nyc_brooklyn': 3, 'shenzhen_downtown_west': 3}
    
def MPC(lambda_ls, cost_ls, price_ls, demandTime, rebTime, demand_input, pairs, N, q0, num, pos):

    beta = 0.3
    M = len(pairs)
    A = np.zeros((N,M))
    G = np.zeros((N,M))
    B = np.zeros((N,M))
    m = 0
    for i,j in pairs:
        A[i, m] = 1
        A[j, m] = -1
        G[i, m] = 1
        B[j, m] = 1
        m += 1

    h = 10
    # lambda_ls = [[] for _ in range(num+h)]
    # price_ls = [[] for _ in range(num+h)]
    # cost_ls = [[] for _ in range(num+h)]

    obj = 0
    E = pairs

    cost_dict = defaultdict(dict)
    price_dict = defaultdict(dict)
    for i in range(num):
        for (o,d) in pairs:
            # print('i:', i ,cost_ls[i])
            cost_dict[i][(o,d)] = cost_ls[i][pairs.index((o,d))]
            price_dict[i][(o,d)] = price_ls[i][pairs.index((o,d))]

    f_dict = defaultdict(dict)
    g_dict = defaultdict(dict)
    q_list = []
    for T in range(0,num-h):
        q_list.append(q0)
        print('T:', T)
        mdl = gp.Model('original_vector'+str(T))
        mdl.setParam('OutputFlag', 0)

        # f = mdl.addMVar(shape = (h, M), lb = 0, vtype = GRB.CONTINUOUS, name = 'f')
        # g = mdl.addMVar(shape = (h, M), lb = 0, vtype = GRB.CONTINUOUS, name = 'g')
        # n = mdl.addMVar(shape = (h, N), lb = 0, vtype = GRB.CONTINUOUS, name = 'n')

        f = mdl.addVars(E, h, vtype = GRB.CONTINUOUS, name = 'f')
        g = mdl.addVars(E, h, vtype = GRB.CONTINUOUS, name = 'g')
        n = mdl.addMVar(shape = (h+1, N), lb = 0, vtype = GRB.CONTINUOUS, name = 'n')

        for t in range(0,h):
            #current time t0
            t0 = pos+T+t
            # Flow conservation constraint
            if t == 0:
                mdl.addConstr(n[t, :] == q0)
            for i in range(N):
                
                summ1 = qsum(f[j,i,t-rebTime[j,i][t0]] for j in range(N) if (j,i) in E and t-rebTime[j,i][t0]>=0) + qsum(g[j,i,t-demandTime[j,i][t0]] for j in range(N) if (j,i) in E and t-demandTime[j,i][t0]>=0)
                summ2 = sum(f_dict[T+t-rebTime[j,i][t0]][(j,i)] for j in range(N) if (j,i) in E and T+t-rebTime[j,i][t0] in f_dict) + sum(g_dict[T+t-demandTime[j,i][t0]][(j,i)] for j in range(N) if (j,i) in E and T+t-demandTime[j,i][t0] in g_dict)
                        
                mdl.addConstr(n[t+1, i] == n[t, i] - gp.quicksum(f[i,j,t] for j in range(N) if (i,j) in E) - gp.quicksum(g[i,j,t] for j in range(N) if (i,j) in E) + summ1 + summ2) 
                
                mdl.addConstr(n[t, i] - gp.quicksum(f[i,j,t] for j in range(N) if (i,j) in E) - gp.quicksum(g[i,j,t] for j in range(N) if (i,j) in E) >= 0)
        
            idx = 0
            for i, j in E:           
                mdl.addConstr(g[i,j,0] <= np.array(lambda_ls[T][idx]))
                idx += 1
                if t > 0:
                    mdl.addConstr(g[i,j,t] <= demand_input[i,j][T+t])
           

        mdl.setObjective(gp.quicksum(price_dict[T+t][o,d]*g[o,d,t] - cost_dict[T+t][o,d]*f[o,d,t]*beta -demandTime[o,d][t0]*beta*g[o,d,t] for (o,d) in pairs for t in range(0,h)), GRB.MAXIMIZE)
        # mdl.setObjective(gp.quicksum(price_dict[T+t][o,d]*g[o,d,t] - cost_dict[T+t][o,d]*f[o,d,t] for (o,d) in pairs for t in range(0,h)), GRB.MAXIMIZE)
        mdl.optimize()
        # print('T:', T)
        # print('obj:', mdl.getObjective().getValue())
        for (o,d) in pairs:
            f_dict[T][(o,d)] = f[o,d,0].X
            g_dict[T][(o,d)] = g[o,d,0].X
        
        q0 = n.X[1, :]
        # for i in range(N):
        #     q0[i] = n.X[1, i] 
        #     for j in range(N):
        #         if (j,i) in E:
        #             if T+1-rebTime[j,i][t] in f_dict:
        #                 q0[i] += f_dict[T+1-rebTime[j,i][t]][(j,i)]
        #             if T+1-demandTime[j,i][t] in g_dict:
        #                 q0[i] += g_dict[T+1-demandTime[j,i][t]][(j,i)]

        # obj_b = np.dot(price_ls[T],g.X[0,:]) - np.dot(cost_ls[T],f.X[0,:])
        obj_b = 0
        for o,d in pairs:
            obj_b += price_dict[T][o,d]*g[o,d,0].X - cost_dict[T][o,d]*f[o,d,0].X*beta -demandTime[o,d][T]*beta*g[o,d,0].X
        obj += obj_b


    f_list = [[] for _ in range(num-h)]
    g_list = [[] for _ in range(num-h)]
    for t0 in range(num-h):
        for i,j in pairs:
            f_list[t0].append(f_dict[t0][(i, j)])
            g_list[t0].append(g_dict[t0][(i, j)])
        f_list[t0] = np.array(f_list[t0])
        g_list[t0] = np.array(g_list[t0])
        # print(t0)
        # print('f:', np.sum(f_list[t0]))
        # print('g:', np.sum(g_list[t0]))
        # print('lbd:', np.sum(np.array(lambda_ls[t0])))

    return f_list, g_list, q_list, obj


def benchmark_policy(lambda_ls, cost_ls, price_ls, demandTime, rebTime, pairs, N, q0, num, idx):

    M = len(pairs)
    A = np.zeros((N,M))
    G = np.zeros((N,M))
    B = np.zeros((N,M))
    m = 0
    for i,j in pairs:
        A[i, m] = 1
        A[j, m] = -1
        G[i, m] = 1
        B[j, m] = 1
        m += 1

    cost_dict = defaultdict(dict)
    price_dict = defaultdict(dict)
    for i in range(num):
        for (o,d) in pairs:
            cost_dict[i][(o,d)] = cost_ls[i][pairs.index((o,d))]
            price_dict[i][(o,d)] = price_ls[i][pairs.index((o,d))]

    mdl = gp.Model('original_vector')

    # f = mdl.addMVar(shape = (num, M), lb = 0, vtype = GRB.CONTINUOUS, name = 'f')
    # g = mdl.addMVar(shape = (num, M), lb = 0, vtype = GRB.CONTINUOUS, name = 'g')

    E = pairs
    f = mdl.addVars(E, num, vtype = GRB.CONTINUOUS, name = 'f')
    g = mdl.addVars(E, num, vtype = GRB.CONTINUOUS, name = 'g')
    n = mdl.addMVar(shape = (num+1, N), lb = 0, vtype = GRB.CONTINUOUS, name = 'n')

    for t in range(0, num):
        t0 = t+idx
            
        if t == 0:
            mdl.addConstr(n[t, :] == q0)

        for i in range(N):
            
            mdl.addConstr(n[t+1, i] == n[t, i] - gp.quicksum(f[i,j,t] for j in range(N) if (i,j) in E) - gp.quicksum(g[i,j,t] for j in range(N) if (i,j) in E) + gp.quicksum(f[j,i,t-rebTime[j,i][t0]] for j in range(N) if (j,i) in E and t-rebTime[j,i][t0]>=0) + gp.quicksum(g[j,i,t-demandTime[j,i][t0]] for j in range(N) if (j,i) in E and t-demandTime[j,i][t0]))
            # mdl.addConstr(n[t+1, i] == n[t+1, i] + gp.quicksum(f[j,i,t+N-1] for j in range(N) if (j,i) in E) + gp.quicksum(g[j,i,t+N-1] for j in range(N) if (j,i) in E))
            mdl.addConstr(n[t, i] - gp.quicksum(f[i,j,t] for j in range(N) if (i,j) in E) - gp.quicksum(g[i,j,t] for j in range(N) if (i,j) in E) >= 0)
        
        a = 0
        for i,j in E:           
            mdl.addConstr(g[i,j,t] <= np.array(lambda_ls[t][a]))
            a += 1

    # mdl.setObjective(qsum(np.array(price_ls[t])@g[t,:] - np.array(cost_ls[t])@f[t,:] for t in range(0,num)), GRB.MAXIMIZE)
    beta = 0.3
    # obj = 0
    # for t in range(0,num):
    #     for o,d in E:
    #         obj += price_dict[t][o,d]*g[o,d,t+N] - cost_dict[t][o,d]*f[o,d,t+N] -demandTime[o,d][t]*beta*g[o,d,t+N]
    mdl.setObjective(gp.quicksum(price_dict[t][o,d]*g[o,d,t] - cost_dict[t][o,d]*f[o,d,t]*beta -demandTime[o,d][t]*beta*g[o,d,t] for (o,d) in pairs for t in range(0,num)), GRB.MAXIMIZE)
    # mdl.setObjective(qsum(price_dict[t][o,d]*g[o,d,t+N] for (o,d) in pairs for t in range(0,num)), GRB.MAXIMIZE)
    mdl.optimize()

    f_list = [[] for _ in range(num)]
    g_list = [[] for _ in range(num)]
    for t0 in range(num):
        for (i,j,t) in f.keys():
            if t == t0:
                # print(i,j,t)
                f_list[t0].append(f[i, j, t].X)
                g_list[t0].append(g[i, j, t].X)
        f_list[t0] = np.array(f_list[t0])
        g_list[t0] = np.array(g_list[t0])
        # print(t0)
        # print('f:', np.sum(f_list[t0]))
        # print('g:', np.sum(g_list[t0]))
        # print('n:', np.sum(n.X[t0]))
        # print('lbd:', np.sum(np.array(lambda_ls[t0])))

    # f_list = [row for row in f.X]
    # g_list = [row for row in g.X]
    q_list = [row for row in n.X[:num, :]]
    print('objective value:', mdl.ObjVal)

    return f_list, g_list, q_list, mdl.ObjVal


def Generate_History(tripAttr, rebTime, demandTime, demand_input, N, acc, num):
    
    lambda_ls = [[] for _ in range(num)]
    price_ls = [[] for _ in range(num)]
    cost_ls = [[] for _ in range(num)]
    demandTime_ls = [[] for _ in range(num)]

    # print('demandTime_avg:', demandTime_avg)
    pairs = []
    for i, j, t, d, p in tripAttr:
        if t < num:
            lambda_ls[t].append(d)
            price_ls[t].append(p)
            if (i,j) not in pairs:
                pairs.append((i,j))     
            cost_ls[t].append(rebTime[i,j][t])
            if (i, j) not in demandTime or t not in demandTime[i, j]:
                demandTime[i, j][t] = 0
            if (i, j) not in demand_input or t not in demand_input[i, j]:
                demand_input[i, j][t] = 0
            demandTime_ls[t].append(demandTime[i,j][t])


    price_avg = np.round(np.mean(np.stack(price_ls), axis=0))
    price_ls = [price_avg for _ in range(num)]

    tt_avg = np.round(np.mean(np.stack(demandTime_ls), axis=0))
    demandTime_avg = defaultdict(dict)
    ind = 0
    for i,j in pairs:
        for t in range(num):
            demandTime_avg[i,j][t] = tt_avg[ind]
        ind += 1

    demandTime_ls = [tt_avg for _ in range(num)]
    demandTime = demandTime_avg

    q0 = np.ones(N)*int(acc/N)
    h = 0
    num_steps = 60
    num_episode = int(num/num_steps)
    F_list, G_list, Q_list = [], [], []
    lbd, price, cost, dT = [], [], [], []
    for i in range(num_episode):
        f_list, g_list, q_list, obj = MPC(lambda_ls[num_steps*i:num_steps*(i+1)], cost_ls[num_steps*i:num_steps*(i+1)], price_ls[num_steps*i:num_steps*(i+1)], demandTime, rebTime, demand_input, pairs, N, q0, num_steps, i*num_steps)
        # f_list, g_list, q_list, obj = benchmark_policy(lambda_ls[num_steps*i:num_steps*(i+1)], cost_ls[num_steps*i:num_steps*(i+1)], price_ls[num_steps*i:num_steps*(i+1)], demandTime, rebTime, pairs, N, q0, num_steps, i*num_steps)
        print('obj of MPC:', obj)
        F_list = F_list + f_list
        G_list = G_list + g_list
        Q_list = Q_list + q_list

        lbd = lbd + lambda_ls[num_steps*i:num_steps*(i+1)-h]
        price = price + price_ls[num_steps*i:num_steps*(i+1)-h]
        cost = cost + cost_ls[num_steps*i:num_steps*(i+1)-h]
        dT = dT + demandTime_ls[num_steps*i:num_steps*(i+1)-h]

        # lbd = lbd + lambda_ls[num_steps*i:num_steps*(i+1)]
        # price = price + price_ls[num_steps*i:num_steps*(i+1)]
        # cost = cost + cost_ls[num_steps*i:num_steps*(i+1)]
        # dT = dT + demandTime_ls[num_steps*i:num_steps*(i+1)]
    # f_list, g_list, q_list, obj = benchmark_policy(lambda_ls, cost_ls, price_ls, demandTime, rebTime, pairs, N, q0, num-h, demandTime_avg)
    # print('obj of benchmark:', obj)

    return F_list, G_list, Q_list, lbd, price, cost, dT, pairs


def Generate_Data(num_episode):
    tripAttr_ls = []
    rebTime_all = defaultdict(dict)
    demandTime_all = defaultdict(dict)
    for i in range(num_episode):
        scenario_history = Scenario_History(json_file=f"data/scenario_{city}.json", demand_ratio=demand_ratio[city],
                                        json_hr=json_hr[city], sd=i+10, json_tstep=3, tf=60, idx = i)
        # Initialize Dataset
        tripAttr = scenario_history.tripAttr
        rebTime = scenario_history.rebTime
        demandTime = scenario_history.demandTime
        demand_input = scenario_history.demand_input

        tripAttr_ls.extend(tripAttr)

        if i == 0:
            rebTime_all = deepcopy(rebTime)
            demandTime_all = deepcopy(demandTime)
            demand_input_all = deepcopy(demand_input)
        else:
            #change key of demandTime and rebTime
            new_rebTime = defaultdict(dict)
            for outer_key, subdict in rebTime.items():
                new_subdict = {key + 60*i: value for key, value in subdict.items()}
                new_rebTime[outer_key] = new_subdict
            new_demandTime = defaultdict(dict)
            for outer_key, subdict in demandTime.items():
                new_subdict = {key + 60*i: value for key, value in subdict.items()}
                new_demandTime[outer_key] = new_subdict
            new_demand_input = defaultdict(dict)
            for outer_key, subdict in demand_input.items():
                new_subdict = {key + 60*i: value for key, value in subdict.items()}
                new_demand_input[outer_key] = new_subdict
            
            for key in rebTime_all:
                rebTime_all[key].update(new_rebTime[key])
            for key in demandTime_all:
                demandTime_all[key].update(new_demandTime[key])
            for key in demand_input_all:
                demand_input_all[key].update(new_demand_input[key])

    return tripAttr_ls, rebTime_all, demandTime_all, demand_input_all




