import sys
sys.path.append('/data/home/ifb5104/K_server_RL')
from KServerEnv import KServerEnv 
from Policies.DQN_10 import DQNAgent_10
from Policies.DQN_Lins import DQNAgent_Lins
# from Qtable import Qtable
from Policies.Qtable_MP import Qtable
from Policies.Qtable_WQL import Qtable_WQL
from Policies.Random_Greedy import RandomPolicy, GreedyPolicy
from Policies.Harmonic import HarmonicPolicy
from Policies.Balance import BalancePolicy
from Policies.GCN_SL import GCN_SL
from Policies.GCN_RL import GCN_RL
from Policies.GCN_RL_SL import GCN_RL_SL
from Policies.GCN_RL_GEN_ALL_SL import GCN_RL_GEN_ALL_SL
from Policies.GCN_RL_GEN_ALL import GCN_RL_GEN_ALL
from experiments.train_agent import train
from generate_requests import generate_requests
import os 
import numpy as np 
import csv


from multiprocessing import cpu_count
import psutil
import torch

from Policies.Random_Greedy import RandomPolicy, GreedyPolicy

# from WFA import HarmonicPolicy

from Policies.Random_Greedy import RandomPolicy, GreedyPolicy
from KServerEnv import KServerEnv 
from multiprocessing import cpu_count

import argparse

parser = argparse.ArgumentParser()

parser.add_argument("--save", action='store_true', help="Whether to save the model and results")
parser.add_argument("--num_steps", type=int, default=1, help="Number steps multiplied by 1000")
parser.add_argument("--num_nodes", type=int, default=9, help="Number of nodes in the graph")
parser.add_argument("--num_servers", type=int, default=None, help="Number of servers in the graph")
parser.add_argument("--bs", type=int, default=128, help="Batch Size")
parser.add_argument("--hch", type=int, default=128, help="Hidden Channels")
parser.add_argument("--nl", type=int, default=12, help="Number of Layers")
parser.add_argument("--display", action='store_true', help="Display results in Neptune")
parser.add_argument("--lr", type=float, default=0.001, help="Learning Rate")
parser.add_argument("--gamma", type=float, default=0.99, help="Temporal Discount")
parser.add_argument("--print", action='store_true', help="Whether to print the experiments results during the training")
parser.add_argument("--uniform_random", action='store_true', help="If probabilities of arrival of requests are uniform")
parser.add_argument("--parallel", action='store_true', help="If to run the processes in parallel")
parser.add_argument("--var_distance", action='store_true', help="If to set a varying distance between nodes")
parser.add_argument("--tr_dist", action='store_true', help="If to set a trainable distance")
parser.add_argument("--tr_att", action='store_true', help="If to use an attention mechanism for trainable distance")
parser.add_argument("--agent", type=str, default='greedy', help="Which agent")
parser.add_argument("--graph_type", type=str, default='grid_gre_50', help="Which graph type")
parser.add_argument("--gmgt", type=str, default='grid_gre', help="Which graph type for general model")
parser.add_argument("--device", type = str, default = "cpu", help="Which device to use")
parser.add_argument("--dir_graph", action='store_true', help="If it is a directed graph")
parser.add_argument("--test_seq", action='store_true', help="If test a sequence")
parser.add_argument("--optimize", action='store_true', help="Optimize the algorithm")
parser.add_argument("--print_seq", action='store_true', help="Print where it is going to be tested")
parser.add_argument("--bn", action='store_false', help="Don't use batch normalization")
parser.add_argument("--cp", action='store_true', help="Constant Probability")




args = parser.parse_args()
device = args.device
num_steps = args.num_steps
save = args.save
display_results = args.display
gamma = args.gamma
print_results  = args.print
uniform_random = args.uniform_random
var_distance = args.var_distance 
tr_dist = args.tr_dist
tr_att = args.tr_att
num_nodes = args.num_nodes 
graph_type = args.graph_type
dir_graph = args.dir_graph
use_batch_norm = args.bn
hidden_channels = args.hch
lr= args.lr


batch_size = args.bs
dir_graph = args.dir_graph
var_distance = args.var_distance 
gamma = args.gamma



agent_mapping = {
    "dqn": DQNAgent_10,
    "dqn_lins": DQNAgent_Lins,
    "qtable": Qtable,
    "random": RandomPolicy, 
    "greedy": GreedyPolicy,
    "gcn_sl": GCN_SL,
    "gcn_rl": GCN_RL,
    "qtable_wql": Qtable_WQL,
    "gcn_rl_sl": GCN_RL_SL,
    "gcn_rl_gen": GCN_RL_GEN_ALL,
    "gcn_sl_gen": GCN_RL_GEN_ALL_SL,
    "balance": BalancePolicy,
    "harmonic": HarmonicPolicy
}


max_cpu_usage_percent = 50
cpu_usage_percent = psutil.cpu_percent(interval=1)
free_cpus = cpu_count() - int(cpu_count() * (cpu_usage_percent / 100))
print("Number of free CPUs:", free_cpus)
# print(f"Number of nodes: {num_nodes}")
print(f"Device Type: {device}")
print(f"Gamma: {gamma}")
print(f"Number of hidden channels: {hidden_channels}")

if args.agent not in ["dqn" , "dqn_lins", "gcn_sl", "gcn_rl", "gcn_rl_gen", "gcn_sl_gen" "gcn_rl_sl"]:
    batch_size = 1
 
Agent = agent_mapping[args.agent]

if args.num_servers != None:
     num_servers = args.num_servers
else: 
     num_servers = round(num_nodes/6)

# if  graph_type =='toy':
#      num_nodes = 15
#      num_servers =  2





env = KServerEnv(num_nodes, num_servers, batch_size=batch_size, graph_type=graph_type, device = device, uniform_random = uniform_random,\
     request_same_node=True, arrival_rates=True, var_distance = var_distance)
# print(env.cost_matrix)

# print(env.cost_matrix)
print(f"Graph Type: {env.graph_type}")
print(f"Number of nodes: {env.num_nodes}")
print(f"Num Servers: {env.num_servers}")
print(f"Batch Size: {batch_size}")

if Agent in [GCN_RL_GEN_ALL, GCN_RL_GEN_ALL_SL]: 
     if args.gmgt not in  ['dir_check','dir_check_1', 'SF', 'EM'] :
          general_model_gt = graph_type[:-3]
     else: 
          general_model_gt = args.gmgt


# if  graph_type =='toy
#      general_model_gt = 'toy'



if args.optimize: 
     if Agent in [GCN_RL, GCN_RL_SL]:
          print(f"Directed Graph Model Used: {dir_graph}")
          agent = Agent(env, hidden_channels = hidden_channels, num_layers = args.nl, lr = lr, shared_weights = False, gamma = gamma, tr_dist = tr_dist, tr_att = tr_att, dir_graph = dir_graph, use_batch_norm = use_batch_norm)
          print(f"Learning Rate: {agent.lr}")
          print(f"Total Parameters: {agent.total_params()}")
          # print(agent.print_network_weights())
          # agent.optimize(num_steps, print_results = print_results, save_results = save, display_results = display_results)
          agent.optimize(num_steps, print_results = print_results, save_results = save, display_results = display_results)
          # agent.q_network.load_state_dict

     elif Agent in [GCN_RL_GEN_ALL, GCN_RL_GEN_ALL_SL]: 
          agent = Agent(hidden_channels=hidden_channels, num_layers = args.nl, lr = lr, shared_weights = False, general_model_gt = general_model_gt, batch_size = batch_size, uniform_random = uniform_random,  
          constant_probability = args.cp, arrival_rates = True, request_same_node = True, gamma = gamma, \
          var_pr_ep=True, var_pr_ep_steps=30,  use_batch_norm = use_batch_norm, var_distance=var_distance, dir_graph = dir_graph, tr_dist = tr_dist, tr_att = tr_att, num_nodes = num_nodes, num_servers = num_servers)
          print(f"Total Parameters: {agent.total_params()}")
          agent.optimize(num_steps, print_results = print_results, save_model = save, save_results= save, display_results = display_results)
          

     elif Agent in [DQNAgent_10, DQNAgent_Lins]:
          agent = Agent(env, seed = 42, lr = 0.01, gamma = gamma)
          print(f"Total Parameters: {agent.total_params()}")
          agent.optimize(num_steps, print_results = print_results, save_results = save, display_results = display_results)

     elif Agent == GreedyPolicy or Agent == RandomPolicy:
          agent = Agent(env)
          # print(f'Greedy Policy Estimate for 40,000 steps is {round(agent.estimate(num_steps, print_results=print_results)[0].item(), 3)}') 
          print(f'{args.agent} estimate is {agent.estimate(100, print_results=print_results)[0].item()}') 


     elif Agent == Qtable:
          agent = Qtable(env, lr = 0.1, gamma = gamma)
          # agent_name = str(Qtable).split("'")[1].split(".")[-1]
          # agent.q_table = torch.load(f'results/gen_testing/VD{var_distance}/{agent_name}/train_results/models/model_{env.graph_type}_{agent_name}_{num_nodes}_gamma{gamma}.pth')
          agent.optimize(num_steps, print_results = print_results, save_results = save, display_results = display_results, explr= 0.8)

     elif args.agent == "qtable_wql":
          agent = Qtable_WQL(env, lr = 0.1, gamma = gamma)
          agent.compare_policies()
          

def ep_estimates_gr_opton(Agent, env, num_sequences, save = True):

     agent = Agent(env)
     estimates = []
     agent_name = str(Agent).split("'")[1].split(".")[-1]
     print(f'Estimates for {agent_name} are starting to be computed')

     for i in range(num_sequences):
          print(f'sequence{i} started')
          estimate = agent.estimate_seq(state[i].unsqueeze(0), requests[i].unsqueeze(0))[0]
          estimates.append((estimate/sequence_len).item())
          # print(f'sequence{i} finished')
     
     if save:
          if not os.path.exists(f'results/gen_testing/VD{var_distance}/{agent_name}'):  
               # print(True)
               os.makedirs(f'results/gen_testing/VD{var_distance}/{agent_name}')
          output_file_name = f'results/gen_testing/VD{var_distance}/{agent_name}/results_{graph_type}_{agent_name}_{env.num_nodes}_{env.num_servers}.csv'
          np_array = np.array(estimates)
          np.savetxt(output_file_name, np_array, delimiter=',')

if args.test_seq: 
     
     num_sequences = 10
     sequence_len = 4000
     requests, state = generate_requests(env, 42, sequence_len, num_sequences=num_sequences)

     if args.print_seq: 
          # Function to find the most frequent element in a 1D tensor
          def most_frequent_element(tensor_row):
               return torch.mode(tensor_row).values.item()

          # Apply the function to each row of the 2D tensor
          most_frequent_elements = [most_frequent_element(row) for row in requests]

          print("The most frequent elements for each row are:", most_frequent_elements)
          index_of_largest_value = env.probabilities.index(max(env.probabilities))
          print("The index of the largest value is:", index_of_largest_value)
          # print(env.probabilities)
          print(state)

          

     agent_name = str(Agent).split("'")[1].split(".")[-1]
     if Agent in [GCN_RL, GCN_RL_SL, DQNAgent_10,DQNAgent_Lins, GCN_RL_GEN_ALL,GCN_RL_GEN_ALL_SL]:
          env = KServerEnv(num_nodes, num_servers, batch_size=1, graph_type=graph_type, device = device, uniform_random = uniform_random,\
          request_same_node=True, arrival_rates=True, var_distance = var_distance)
          if Agent in [DQNAgent_10, DQNAgent_Lins]:
               agent = Agent(env, seed = 42, lr = 0.01, gamma = gamma)
          elif Agent in [GCN_RL_GEN_ALL, GCN_RL_GEN_ALL_SL]: 
               tot_params = GCN_RL_GEN_ALL(hidden_channels = hidden_channels,num_layers = args.nl, lr = 0.001, shared_weights = False, general_model_gt = general_model_gt, gamma = gamma, var_distance=var_distance, tr_dist = tr_dist, tr_att = tr_att, dir_graph = dir_graph, use_batch_norm = use_batch_norm, constant_probability = args.cp).total_params()
               agent = GCN_RL(env, hidden_channels = hidden_channels, num_layers = args.nl, lr = 0.001, shared_weights = True, gamma = gamma, tr_dist = tr_dist, tr_att = tr_att, dir_graph = dir_graph, use_batch_norm = use_batch_norm)
          else: 
               agent = Agent(env, hidden_channels = hidden_channels, num_layers = args.nl,lr = 0.001, shared_weights = False, gamma = gamma, tr_dist = tr_dist, tr_att = tr_att, dir_graph = dir_graph, use_batch_norm = use_batch_norm)

          # if args.agent ==  "gcn_rl_gen":
          #      agent.q_network.load_state_dict(torch.load("/home/ifb5104/K_server_RL/results/single_model_results/uniformFalse/models/model_GenModelAll_hch128_nl12_grid_dir_URFalse_CPFalse_lr0.001_varprepTrue30_rqsmndTrue_ARTrue_gamma0.9_BNTrue_VDTrue_DGTrue.pth"))
          # else:
          if Agent in [GCN_RL, GCN_RL_SL]:
               agent.q_network.load_state_dict(torch.load(f'results/gen_testing/VD{var_distance}/{agent_name}/train_results/models/model_{graph_type}_{agent_name}_{env.num_nodes}_{env.num_servers}_hidden_channels{hidden_channels}__gamma{gamma}_vd{var_distance}_td{tr_dist}_ta{tr_att}_seed42_DG{dir_graph}_bs{batch_size}_lr{lr}_nl{args.nl}.pth'))
          elif Agent == GCN_RL_GEN_ALL:
               # agent.q_network.load_state_dict(torch.load(f'results/gen_testing/VD{var_distance}/{agent_name}/train_results/models/model_{general_model_gt}_CP{args.cp}_{agent_name}_{env.num_nodes}_{env.num_servers}_hidden_channels{hidden_channels}__gamma{gamma}_vd{var_distance}_td{tr_dist}_ta{tr_att}_seed42_DG{dir_graph}_bs{batch_size}_lr{lr}_nl{args.nl}.pth'))
               agent.q_network.load_state_dict(torch.load(f'results/single_model_results/uniformFalse/models/model_GenModelAll_hch128_nl4_grid_dir_URFalse_CPFalse_lr0.001_nl4_varprepTrue30_rqsmndTrue_ARTrue_gamma0.99_BNTrue_VDFalse_DGTrue_bs128_lr0.001_nl4.pth'))
          elif Agent == GCN_RL_GEN_ALL_SL:
               # agent.q_network.load_state_dict(torch.load(f'results/gen_testing/VD{var_distance}/{agent_name}/train_results/models/model_{general_model_gt}_CP{args.cp}_{agent_name}_{env.num_nodes}_{env.num_servers}_hidden_channels{hidden_channels}__gamma{gamma}_vd{var_distance}_td{tr_dist}_ta{tr_att}_seed42_DG{dir_graph}_bs{batch_size}_lr{lr}_nl{args.nl}.pth'))
               # print('Num_layers', agent.num_layers)
               # print('agent_dir_graph', agent.dir_graph)
               # print(agent.uniform_random, agent.var_distance, agent.hidden_channels, use_batch_norm)
               
               agent.q_network.load_state_dict(torch.load(f'results/single_model_results/uniformFalse/models/model_GenModelAllSL_hch128_nl4_grid_dir_sl_URFalse_CPTrue_lr0.001_nl4_varprepTrue30_rqsmndTrue_ARTrue_gamma0.99_BNTrue_VDFalse_DGTrue_bs2_lr0.001_nl4.pth'))
               # agent.q_network.load_state_dict(torch.load(f'results/single_model_results/uniformFalse/models/model_GenModelAllSL_hch128_nl4_grid_dir_sl_URFalse_CPTrue_lr0.001_nl4_varprepTrue30_rqsmndTrue_ARTrue_gamma0.99_BNTrue_VDFalse_DGTrue_bs128_lr0.001_nl4.pth'))
          else: 
               agent.q_network.load_state_dict(torch.load(f'results/gen_testing/VD{var_distance}/{agent_name}/train_results/models/model_{graph_type}_{agent_name}_{env.num_nodes}_{env.num_servers}_gamma{gamma}_vd{var_distance}.pth'))
          
          # load or compute the results for Greedy and Optimal Online for these episodes
          try: 
               opt_estimate = np.loadtxt(f'results/gen_testing/VD{var_distance}/Qtable_WQL/results_{graph_type}_Qtable_WQL_{env.num_nodes}_{env.num_servers}.csv', delimiter=',')
               greedy_estimate = np.loadtxt(f'results/gen_testing/VD{var_distance}/GreedyPolicy/results_{graph_type}_GreedyPolicy_{env.num_nodes}_{env.num_servers}.csv', delimiter=',')
          except: 
               print('Estimates for Greedy and Optimal Online have to be computed...')
               env = KServerEnv(num_nodes, num_servers, batch_size=1, graph_type=graph_type, device = device, uniform_random = uniform_random,\
                    request_same_node=True, arrival_rates=True, var_distance = var_distance)
               for Agent in [GreedyPolicy, Qtable_WQL]: 
                    ep_estimates_gr_opton(Agent, env, num_sequences)
               opt_estimate = np.loadtxt(f'results/gen_testing/VD{var_distance}/Qtable_WQL/results_{graph_type}_Qtable_WQL_{env.num_nodes}_{env.num_servers}.csv', delimiter=',')
               greedy_estimate = np.loadtxt(f'results/gen_testing/VD{var_distance}/GreedyPolicy/results_{graph_type}_GreedyPolicy_{env.num_nodes}_{env.num_servers}.csv', delimiter=',')


          # estimate = agent.estimate_seq(state, requests)[3]
          print(f'Estimates for {args.agent} are starting to be computed')

          if args.agent in ['dqn', 'dqn_lins']: 
               estimates = agent.estimate_seq(state, requests)[3]
               mean_estimate = torch.mean(estimates, dim = 1)
               mean_estimate = mean_estimate.unsqueeze(1)
               transposed_mean_estimate = mean_estimate.transpose(0, 1)

          else:
               estimates = []
               for i in range(num_sequences):
                    print(f'sequence{i} started')
                    estimate = agent.estimate_seq(state[i].unsqueeze(0), requests[i].unsqueeze(0))[0]
                    estimates.append(estimate)
               estimates = torch.stack(estimates).to(device)
               mean_estimate = estimates/sequence_len
               mean_estimate = mean_estimate.unsqueeze(1)
               # mean_estimate = estimate.mean(dim=1, keepdim=True)
               transposed_mean_estimate = mean_estimate.transpose(0, 1)
               
               


          # Convert the NumPy array back to a list of scalar tensors
          opt_estimate = torch.tensor(opt_estimate).to(device)
          opt_estimate = opt_estimate.view(1, -1)
          greedy_estimate = torch.tensor(greedy_estimate).to(device)
          greedy_estimate = greedy_estimate.view(1, -1)
          diff =  transposed_mean_estimate/opt_estimate 
          diff_greedy= greedy_estimate/opt_estimate 

          # Extract the required statistics
          mean_estimate_div_optimal = round(diff.mean().item(), 3)
          mean_greedy_div_optimal = round(diff_greedy.mean().item(), 3)
          sd_estimate_div_optimal = round(diff.std().item(), 3)
          sd_greedy_div_optimal = round(diff_greedy.std().item(), 3)

          # Print the statistics
          print('Mean of Estimate Divided by Optimal Online:', mean_estimate_div_optimal)
          print('Mean of Greedy Divided by Optimal Online:', mean_greedy_div_optimal)
          print('SD of Estimate Divided by Optimal Online:', sd_estimate_div_optimal)
          print('SD of Greedy Divided by Optimal Online:', sd_greedy_div_optimal)


          diff = torch.round(diff, decimals = 3)
          print('Estimate:', transposed_mean_estimate)
          print('Estimate Divided by Optimal Online:', diff.tolist())


          if save: 
               if args.agent not in ['gcn_rl_gen', 'gcn_sl_gen']:
                    tot_params = agent.total_params()
               print(tot_params)
               if not os.path.exists(f'results/gen_testing/VD{var_distance}/estimate_seq'):  
                    os.makedirs(f'results/gen_testing/VD{var_distance}/estimate_seq')
               output_file_name = f'results/gen_testing/VD{var_distance}/estimate_seq/{agent.env.graph_type}_{agent_name}_{agent.env.num_nodes}_{agent.env.num_servers}_hidden_channels{hidden_channels}__gamma{agent.gamma}_vd{agent.var_distance}_td{args.tr_dist}_seed42_DG{args.dir_graph}_bs{batch_size}_lr{lr}_nl{args.nl}_CP{args.cp}.csv'
               # Save to CSV
               if not os.path.exists(output_file_name):
                    with open(output_file_name, 'w', newline='') as f:
                         writer = csv.writer(f)
                         writer.writerow(['graph_type', 'agent', 'gamma', 'lr', 'num_nodes', 'num_servers', 'hidden_channels', 'num_layers', 'batch_size','CP', 'dir_graph', 'var_distance', 'tr_dist', 'tr_att', 'num_params', 'mean_estimate_div_optimal', 'mean_greedy_div_optimal', 'sd_estimate_div_optimal', 'sd_greedy_div_optimal'])
                         writer.writerow([agent.env.graph_type, agent_name, agent.gamma, agent.lr, agent.env.num_nodes, agent.env.num_servers, hidden_channels, agent.num_layers, batch_size, args.cp, dir_graph, var_distance, tr_dist, tr_att, tot_params, mean_estimate_div_optimal, mean_greedy_div_optimal, sd_estimate_div_optimal, sd_greedy_div_optimal])
     
     else: 
          # ep_estimates_gr_opton(Agent, env, num_sequences)
          try: 
               opt_estimate = np.loadtxt(f'results/gen_testing/VD{var_distance}/Qtable_WQL/results_{graph_type}_Qtable_WQL_{env.num_nodes}_{env.num_servers}.csv', delimiter=',')
               greedy_estimate = np.loadtxt(f'results/gen_testing/VD{var_distance}/GreedyPolicy/results_{graph_type}_GreedyPolicy_{env.num_nodes}_{env.num_servers}.csv', delimiter=',')
          except: 
               print('Estimates for Greedy and Optimal Online have to be computed...')
               env = KServerEnv(num_nodes, num_servers, batch_size=1, graph_type=graph_type, device = device, uniform_random = uniform_random,\
                    request_same_node=True, arrival_rates=True, var_distance = var_distance)
               for Agent in [GreedyPolicy, Qtable_WQL]: 
                    ep_estimates_gr_opton(Agent, env, num_sequences)
               opt_estimate = np.loadtxt(f'results/gen_testing/VD{var_distance}/Qtable_WQL/results_{graph_type}_Qtable_WQL_{env.num_nodes}_{env.num_servers}.csv', delimiter=',')
               greedy_estimate = np.loadtxt(f'results/gen_testing/VD{var_distance}/GreedyPolicy/results_{graph_type}_GreedyPolicy_{env.num_nodes}_{env.num_servers}.csv', delimiter=',')

          opt_estimate = torch.tensor(opt_estimate).to(device)
          opt_estimate = opt_estimate.view(1, -1)
          greedy_estimate = torch.tensor(greedy_estimate).to(device)
          greedy_estimate = greedy_estimate.view(1, -1)

          diff_greedy= greedy_estimate/opt_estimate 

          
          mean_greedy_div_optimal = round(diff_greedy.mean().item(), 3)
          sd_greedy_div_optimal = round(diff_greedy.std().item(), 3)

           # Print the statistics
          
          print('Mean of Greedy Divided by Optimal Online:', mean_greedy_div_optimal)
          
          print('SD of Greedy Divided by Optimal Online:', sd_greedy_div_optimal)


          diff_greedy = torch.round(diff_greedy, decimals = 3)
          print('Greedy Estimate:', greedy_estimate)
          print('Greedy Estimate Divided by Optimal Online:', diff_greedy.tolist())



     
     
    
        
     # for i in range(num_sequences):
     #      print('started')
     #      estimate, q1, q3, raw_result = agent.estimate_seq(state[i].unsqueeze(0), requests[i].unsqueeze(0))
     



# env = KServerEnv(num_nodes, num_servers, batch_size=512, graph_type='tree_50', device='cuda', uniform_random = False)


# parser.add_argument("--save", type = bool, default = False, help="Whether to save the model and results")
# parser.add_argument("--num_steps", type=int, default=1, help="Number steps multiplied by 1000")
# parser.add_argument("--num_nodes", type=int, default=9, help="Number of nodes in the graph")
# parser.add_argument("--display", type = bool, default = False, help="Display results in Neptune")
# parser.add_argument("--gamma", type = float, default = 0.99, help="Temporal Discount")
# # parser.add_argument("--gamma", type = float, default = 0.99, help="Learning Rate")
# parser.add_argument("--print", type = bool, default = False, help="Whether to print the experiments results during the training")
# parser.add_argument("--uniform_random", type = bool, default = False, help="If probabilities of arrival of requests are uniform")
# parser.add_argument("--parallel", type = bool, default = False, help="If to run the processes in parallel")
# parser.add_argument("--var_distance", type = bool, default = False, help="If to set a varying distance between nodes")
# # parser.add_argument("--var_distance", action='store_true', help="If to set a varying distance between nodes")
# parser.add_argument("--agent", type = str, default = 'greedy', help="Which agent")
# args = parser.parse_args()