
import sys
sys.path.append('/data/home/ifb5104/K_server_RL')

from KServerEnv import KServerEnv 
import torch
from multiprocessing import Pool, cpu_count
import psutil
import argparse
import time
import os
from generate_requests import generate_requests
from Policies.opt_off import opt_off
import csv

parser = argparse.ArgumentParser()
parser.add_argument("--use_all_cpus", action='store_true', help="Whether to use all free cpus")
parser.add_argument("--thread_pool", action='store_true', help="Whether to use ThreadPool instead of Pool")
parser.add_argument("--device", type=str, default="cpu", help="Which device to use")
parser.add_argument("--save", action='store_true', help="Whether to save the model")
parser.add_argument("--display_results", action='store_true', help="Display results in Neptune")
parser.add_argument("--gamma", type=float, default=0.99, help="Temporal Discount")
parser.add_argument("--print_results", action='store_true', help="Whether to print the experiments results during the training")
parser.add_argument("--uniform_random", action='store_true', help="If probabilities of arrival of requests are uniform")
parser.add_argument("--parallel", action='store_true', help="If to run the processes in parallel")
parser.add_argument("--var_distance", action='store_true', help="If to set a varying distance between nodes")
args = parser.parse_args()
# parser.add_argument("--grid", type = bool, default = False, help="If graph type is grid")

# test_all = args.test_all
use_all_cpus = args.use_all_cpus
# num_steps = args.num_steps
thread_pool = args.thread_pool
device = args.device
save = args.save
display_results = args.display_results
gamma = args.gamma
print_results  = args.print_results
uniform_random = False
burn_in = True
var_distance = args.var_distance 


# number_nodes = [10, 20, 30, 10, 50]
# graph_types = []
# hidden_channels = 128
# if args.grid:
#     for num_nodes in number_nodes:
#         for rows in range(2, num_nodes+1):
#             columns = num_nodes // rows
#             if rows * columns == num_nodes and rows > columns > 1:
#                     graph_types.append([num_nodes, 128, f'grid_col{columns},'])



# print(graph_types)




methods = ['Opt_Off']
# number_nodes = [9, 16, 25, 36, 49, 64, 81, 100]
# number_nodes = [9, 25, 36, 64]
number_nodes = [25, 49, 64, 81]
# number_nodes = [36, 64, 100]
# number_nodes = [9, 16, 25]
# number_nodes = [24]
# graph_types = [ 'grid_gre_51', 'grid_gre_50','grid_gre_52', 'grid_gre_53', 'grid_gre_54']
# graph_types = ['tree_50', 'tree_51', 'tree_52', 'tree_53',  'tree_54',  'grid_gre_51', 'grid_gre_50','grid_gre_52', 'grid_gre_53', 'grid_gre_54']
graph_types = ['grid_dir_60', 'grid_dir_61', 'grid_dir_62', 'grid_dir_63', 'grid_dir_64']
# graph_types = ['grid_dir_50', 'grid_dir_51', 'grid_dir_52', 'grid_dir_53', 'grid_dir_54']
# graph_types = ['plane_50', 'plane_51', 'plane_52', 'plane_53', 'plane_54']
# graph_types = ['grid_gre_50']
# graph_types = ['SF']
# graph_types = [ 'grid_gre_51', 'grid_gre_50','grid_gre_52', 'grid_gre_53', 'grid_gre_54']
    

args_list = []
for num_nodes in number_nodes:
    for graph_type in graph_types:
            args_list.append((num_nodes, graph_type))
# print(args_list)


def run_experiment(hyperparams):

    start_time = time.time()
    num_nodes = hyperparams[0]
    if num_nodes in [36, 49, 64, 81]:
        num_servers = 4
    else: 
        num_servers = round(num_nodes/6)
    # num_servers = round(num_nodes / 6)
    num_servers = 4
    graph_type = hyperparams[1]
    seed = 42
    agent_name = 'Opt_Off'
    burn_in_period = 100
    if not os.path.exists(f'results/gen_testing/VD{var_distance}/{agent_name}'):  
            os.makedirs(f'results/gen_testing/VD{var_distance}/{agent_name}')  
    if burn_in: 
        if not os.path.exists(f'results/gen_testing/VD{var_distance}/{agent_name}/burn_in'):  
            os.makedirs(f'results/gen_testing/VD{var_distance}/{agent_name}/burn_in')  
     
    if burn_in:
        file_paths = [f'results/gen_testing/VD{var_distance}/{agent_name}/burn_in/results_{graph_type}_{agent_name}_{num_nodes}_{num_servers}.csv']
        
    else:
        file_paths = [f'results/gen_testing/VD{var_distance}/{agent_name}/results_{graph_type}_{agent_name}_{num_nodes}_{num_servers}.csv']
    
    
    if any(os.path.exists(file_path) for file_path in file_paths):
        print(f'Skipping experiment for {hyperparams}, as one or more result files exist.')   
    else:
    
        start_time_total = time.time()
        print('Starting to create a graph')
        

        env = KServerEnv(num_nodes, num_servers, batch_size=10, graph_type = graph_type, device=device, uniform_random = False, request_same_node= True, var_distance= var_distance, arrival_rates=True)
        # env = KServerEnv(num_nodes, num_servers, batch_size=batch_size, graph_type=graph_type, device = device, uniform_random = uniform_random,\
        #  request_same_node=True, arrival_rates=True, var_distance = var_distance)

        end_time_total = time.time()

        elapsed_time_total = end_time_total - start_time_total
        print(f'Total time for generating {env.graph_type} of size {env.num_nodes} elapsed in seconds:', round(elapsed_time_total, 3))

        requests, state_batch = generate_requests(env, seed, 4000, num_sequences=env.batch_size)
        if burn_in: 
            requests = requests[:, burn_in_period:]
        requests_off = torch.cat((state_batch[:, -1].reshape(state_batch.shape[0],1), requests), dim =1)
        # sp = dict(nx.all_pairs_shortest_path_length(env.graph))

        sp = {}
        for i, row in enumerate(env.cost_matrix):
            sp[i] = {j: int(row[j].item()) for j in range(len(row))}
                    
        

        s_init = state_batch[:, :env.num_servers]
            
        opt_off_costs = []
        for i in range(requests_off.shape[0]): 
            opt_off_costs.append(opt_off(s_init[i].tolist(), requests_off[i].tolist(), sp=sp))
    
        
        

            
        else: 
            # Include the seed in the output CSV file name if it's a DQNAgent or DQNAgent_10
            if save:
                # if not os.path.exists(f'results/gen_testing/VD{var_distance}/{agent_name}'):  
                #     os.makedirs(f'results/gen_testing/VD{var_distance}/{agent_name}')  

                if burn_in:
                    output_file_name = f'results/gen_testing/VD{var_distance}/{agent_name}/burn_in/results_{graph_type}_{agent_name}_{num_nodes}_{num_servers}.csv'
                else: 
                    output_file_name = f'results/gen_testing/VD{var_distance}/{agent_name}/results_{graph_type}_{agent_name}_{num_nodes}_{num_servers}.csv'

                with open(output_file_name, 'w', newline='') as f:
                    writer = csv.writer(f)
                    writer.writerow(opt_off_costs)
                
                end_time = time.time()
                elapsed_time = end_time - start_time

                print(f'Experiment_{graph_type}_{agent_name}_{gamma}_{num_nodes} took {round(elapsed_time, 3)} seconds to finish')

            

print("Starting the experiments")
print("Device:", device)
print("Save:", save)
print("Number of experiments:", len(args_list))


if __name__ == "__main__":  

    max_cpu_usage_percent = 50
    num_processes = min(cpu_count(), len(methods)*len(number_nodes)*len(graph_types))
    
    cpu_usage_percent = psutil.cpu_percent(interval=1)
    free_cpus = cpu_count() - int(cpu_count() * (cpu_usage_percent / 100))
    
    print("Number of free CPUs:", free_cpus)
    print("Using all free CPUs:", bool(use_all_cpus))

    start_time_total = time.time()

    if args.parallel:
        with Pool(processes=40) as pool:
            pool.map(run_experiment, args_list)

    else:
        for hyperparams in args_list:
            run_experiment(hyperparams)
    end_time_total = time.time()
    elapsed_time_total = end_time_total - start_time_total
    print(f'Total time elapsed in seconds:', round(elapsed_time_total, 3))

    

    # while True:
    #     cpu_usage_percent = psutil.cpu_percent(interval=1)
    #     free_cpus = cpu_count() - int(cpu_count() * (cpu_usage_percent / 100))
    #     if free_cpus >= num_processes:
    #         break
    #     print(f'CPU usage is {cpu_usage_percent}%, waiting for available CPUs...')

    # with Pool(processes=20) as pool:
    #     pool.map(run_experiment, args_list2)

    # with Pool(processes=(16)) as pool:
    #    pool.map(run_experiment, args_list)

    
 






    