import sys
sys.path.append('/data/home/ifb5104/K_server_RL')
from KServerEnv import KServerEnv 
from Policies.DQN import DQNAgent 
from Policies.DQN_10 import DQNAgent_10
from Policies.DQN_Lins import DQNAgent_Lins
# from Qtable import Qtable
from Policies.Qtable_MP import Qtable

from Policies.GCN_RL import GCN_RL
from Policies.GCN_SL import GCN_SL
import csv
import os 

import torch
from multiprocessing import Pool, cpu_count
import argparse
import time
import os
import psutil



parser = argparse.ArgumentParser()

parser.add_argument("--num_steps", type=int, default=1, help="Number steps multiplied by 1000")
parser.add_argument("--use_all_cpus", type=int, default=0, help="Whether to use all free cpus")
parser.add_argument("--thread_pool", type=int, default=0, help="Whether to use ThreadPool instead of Pool")
parser.add_argument("--device", type=str, default="cpu", help="Which device to use")
parser.add_argument("--save", action='store_true', help="Whether to save the model")
parser.add_argument("--display", action='store_true', help="Display results in Neptune")
parser.add_argument("--gamma", type=float, default=0.99, help="Temporal Discount")
parser.add_argument("--print", action='store_true', help="Whether to print the experiments results during the training")
parser.add_argument("--uniform_random", action='store_true', help="If probabilities of arrival of requests are uniform")
parser.add_argument("--parallel", action='store_true', help="If to run the processes in parallel")
parser.add_argument("--var_distance", action='store_true', help="If to set a varying distance between nodes")
parser.add_argument("--tr_dist", action='store_true', help="If we train the distance between nodes")
parser.add_argument("--agent", type=str, default='greedy', help="Which agent")
parser.add_argument("--dir_graph", action='store_true', help="If it is a directed graph architecture")
parser.add_argument("--nbn", action='store_false', help="Don't Use batch normalization")
parser.add_argument("--bs", type=int, default=128, help="Batch Size")
parser.add_argument("--lr", type=float, default=0.001, help="Learning Rate")



args = parser.parse_args()
use_all_cpus = args.use_all_cpus
num_steps = args.num_steps
thread_pool = args.thread_pool
device = args.device
save = args.save
display_results = args.display
gamma = args.gamma
print_results  = args.print
uniform_random = args.uniform_random
var_distance = args.var_distance 
tr_dist = args.tr_dist 
dir_graph = args.dir_graph

use_batch_norm = args.nbn
batch_size = args.bs
lr = args.lr 






# methods  = [DQNAgent_Lins]
# methods  = [Qtable_WQL]
# methods  = [Qtable]
methods  = [GCN_RL]
# methods = [GCN_RL, GreedyPolicy]
# methods = [GreedyPolicy]
# number_nodes = [9]
# number_nodes = [24]
# number_nodes = [9, 25, 36, 64]
# number_nodes = [49, 81]
number_nodes = [25, 49, 64, 81]
# number_nodes = [81, 100]
# number_nodes = [9, 16, 25, 36, 49, 64, 81, 100]
# number_nodes = [49]
hidden_channels_list = [128]
# graph_types = ['grid_gre_50', 'grid_tree_51']
# graph_types = ['tree_50', 'tree_51', 'tree_52', 'tree_53', 'tree_54']
# graph_types = ['grid_dir_50', 'grid_dir_51', 'grid_dir_52', 'grid_dir_53', 'grid_dir_54']
graph_types = ['grid_dir_60', 'grid_dir_61', 'grid_dir_62', 'grid_dir_63', 'grid_dir_64']
# graph_types = [ 'bn_grid_gre_51', 'bn_grid_gre_50','bn_grid_gre_52', \
#     'bn_grid_gre_53', 'bn_grid_gre_54', 'psn_grid_gre_51', 'psn_grid_gre_50',\
#         'psn_grid_gre_52', 'psn_grid_gre_53', 'psn_grid_gre_54', \
#             'lgnm_grid_gre_51', 'lgnm_grid_gre_50','lgnm_grid_gre_52', \
#                 'lgnm_grid_gre_53', 'lgnm_grid_gre_54'] 
# graph_types = [ 'grid_gre_51', 'grid_gre_50','grid_gre_52', 'grid_gre_53', 'grid_gre_54', \
#     'tree_50', 'tree_51', 'tree_52', 'tree_53', 'tree_54']
# hidden_channels_list = [2, 4, 8, 16, 32]
# graph_types = ['grid_gre_50']
# graph_types = ['EM']
# seeds = [42, 43, 44, 45, 46]
seeds = [42]
     
    




args_list = []
for method in methods:
    if method in [DQNAgent_10, DQNAgent_Lins]:
        for num_nodes in number_nodes:
            for graph_type in graph_types:
                for seed in seeds:
                    args_list.append((method, num_nodes, graph_type, seed))
    # elif method in [GCN_RL, GCN_SL]:
    elif method in [GCN_RL]:
        for num_nodes in number_nodes:
            for graph_type in graph_types:
                for hidden_channels in hidden_channels_list:
                    for seed in seeds:
                        args_list.append((method, num_nodes, graph_type, hidden_channels, seed))
                    
                    
    elif method == Qtable:
        for num_nodes in number_nodes[:2]:
            for graph_type in graph_types:
                args_list.append((method, num_nodes, graph_type))

    else:
        for num_nodes in number_nodes:
            for graph_type in graph_types:
                args_list.append((method, num_nodes, graph_type))







def run_experiment(hyperparams):

    start_time = time.time()
    num_nodes = hyperparams[1]
    Agent = hyperparams[0]
    graph_type = hyperparams[2]
    
    if Agent in [DQNAgent, DQNAgent_10, DQNAgent_Lins]:
        seed = hyperparams[3]
        hidden_channels = None
        
    elif Agent in [GCN_RL, GCN_SL]:
        seed = hyperparams[4]
        hidden_channels = hyperparams[3]
    else: 
        seed = None
        hidden_channels = None

    # num_servers = round(num_nodes/6)
    num_servers = 4
    

    if Agent in [DQNAgent_10, DQNAgent_Lins]:
        env = KServerEnv(num_nodes, num_servers, batch_size=batch_size, graph_type=graph_type, device=device, uniform_random = uniform_random, request_same_node=True, arrival_rates=True, var_distance= var_distance) 
        # env = KServerEnv(num_nodes, num_servers, batch_size=1, graph_type='tree_50', device='cpu', uniform_random = False, request_same_node=True, arrival_rates=True)
        agent = Agent(env)

    else:
        # env = KServerEnv(num_nodes, num_servers, batch_size=batch_size, graph_type=graph_type, device=device, uniform_random = uniform_random, request_same_node=True, arrival_rates=True, var_distance= var_distance)
        env = KServerEnv(num_nodes, num_servers, batch_size=batch_size, graph_type=graph_type, device = device, uniform_random = uniform_random,\
            request_same_node=True, arrival_rates=True, var_distance = var_distance)
        # agent = Agent(env, hidden_channels = hidden_channels, shared_weights = False, gamma = gamma, tr_dist = tr_dist, dir_graph = False, seed = seed) 
        agent = Agent(env, hidden_channels = hidden_channels, lr = lr, shared_weights = False, gamma = gamma, tr_dist = tr_dist, dir_graph = dir_graph, use_batch_norm = use_batch_norm)


    agent_name = str(hyperparams[0]).split("'")[1].split(".")[-1]

    
    if Agent in [DQNAgent_10, DQNAgent_Lins]:
        file_paths = [f'results/gen_testing/VD{agent.var_distance}/{agent.class_name}/train_results/results_{agent.env.graph_type}_{agent.class_name}_{agent.env.num_nodes}_{agent.env.num_servers}_gamma{agent.gamma}_vd{agent.var_distance}.csv']
    else:
        file_paths = [
        f'results/gen_testing/VD{agent.var_distance}/train_curves/train_curve_{env.graph_type}_{agent.class_name}_{env.num_nodes}_{env.num_servers}_hidden_channels{agent.hidden_channels}__gamma{agent.gamma}_vd{agent.var_distance}_td{agent.tr_dist}_ta{agent.tr_att}_seed{agent.seed}_DG{agent.dir_graph}_bs{agent.batch_size}_lr{agent.lr}_nl{agent.num_layers}_ar.csv',
        ]
    # f"data/qtable_wql/{agent.env.graph_type}_{agent.num_nodes}/Q_policy.pth"
    
    if any(os.path.exists(file_path) for file_path in file_paths):
        print(f'Skipping experiment for {hyperparams}, as one or more result files exist.')    
    else: 
        print(f'Experiment_{hyperparams[2]}_{agent_name}_{hyperparams[1]}_seed{seed}_hidden_channels{hidden_channels} started')
        
        agent.optimize(num_steps, print_results = print_results, save_results = save, display_results = display_results)
        
        end_time = time.time()
        elapsed_time = end_time - start_time

        print(f'Experiment_{hyperparams[2]}_{agent_name}_{gamma}_{hyperparams[1]}_seed{seed}_hidden_channels{hidden_channels} took {round(elapsed_time, 3)} seconds to finish')

        

print("Starting the experiments")
print("Number of steps multiplied by 1000:", args.num_steps)
print("Number of experiments:", len(args_list))
print("Device:", device)
print("Save:", save)
print("Uniform Random:", uniform_random)
print("Display Results:", display_results)

if __name__ == "__main__":

    max_cpu_usage_percent = 50
    num_processes = min(cpu_count(), len(methods)*len(number_nodes)*len(graph_types))
    
    cpu_usage_percent = psutil.cpu_percent(interval=1)
    free_cpus = cpu_count() - int(cpu_count() * (cpu_usage_percent / 100))
    
    print("Number of free CPUs:", free_cpus)
    print("Using all free CPUs:", bool(use_all_cpus))

    start_time_total = time.time()

    if args.parallel:
        with Pool(processes=20) as pool:
            pool.map(run_experiment, args_list)

    else:
        for hyperparams in args_list:
            if hyperparams[0] in [Qtable, DQNAgent_10, DQNAgent_Lins]:
                num_steps = int(args.num_steps * pow(1.4, (number_nodes.index(hyperparams[1]))))
            else: 
                num_steps = args.num_steps

            start_time = time.time()
            run_experiment(hyperparams)
            end_time = time.time()
            elapsed_time = end_time- start_time
            print(f'Time elapsed for this experiment in seconds:', round(elapsed_time, 3))

    end_time_total = time.time()
    elapsed_time_total = end_time_total - start_time_total
    print(f'Total time elapsed in seconds:', round(elapsed_time_total, 3))

    



    