import sys
sys.path.append('/data/home/ifb5104/K_server_RL')
from random import uniform
from KServerEnv import KServerEnv 
from Policies.DQN_10 import DQNAgent_10
# from Qtable import Qtable
from Policies.Qtable_MP import Qtable
from Policies.Qtable_WQL import Qtable_WQL
from Policies.Random_Greedy import RandomPolicy, GreedyPolicy
from Policies.GCN_SL import GCN_SL
from Policies.GCN_RL import GCN_RL
from Policies.GCN_RL_SL import GCN_RL_SL


from multiprocessing import cpu_count
import psutil
import torch
import time
from multiprocessing import Pool, cpu_count




from Policies.Random_Greedy import RandomPolicy, GreedyPolicy

# from WFA import HarmonicPolicy


from Policies.Random_Greedy import RandomPolicy, GreedyPolicy
from KServerEnv import KServerEnv 
from multiprocessing import cpu_count

from generate_requests import generate_requests

import argparse
parser = argparse.ArgumentParser()




# parser.add_argument("--save", type = bool, default = False, help="Whether to save the model and results")
# parser.add_argument("--num_steps", type=int, default=1, help="Number steps multiplied by 1000")
# parser.add_argument("--num_nodes", type=int, default=9, help="Number of nodes in the graph")
# parser.add_argument("--display", type = bool, default = False, help="Display results in Neptune")
# parser.add_argument("--gamma", type = float, default = 0.99, help="Temporal Discount")
# # parser.add_argument("--gamma", type = float, default = 0.99, help="Learning Rate")
# parser.add_argument("--print", type = bool, default = False, help="Whether to print the experiments results during the training")
# parser.add_argument("--uniform_random", type = bool, default = False, help="If probabilities of arrival of requests are uniform")
# parser.add_argument("--parallel", type = bool, default = False, help="If to run the processes in parallel")
# parser.add_argument("--var_distance", type = bool, default = False, help="If to set a varying distance between nodes")
# # parser.add_argument("--var_distance", action='store_true', help="If to set a varying distance between nodes")
# parser.add_argument("--agent", type = str, default = 'greedy', help="Which agent")
# args = parser.parse_args()

parser.add_argument("--save", action='store_true', help="Whether to save the model and results")
parser.add_argument("--num_steps", type=int, default=1, help="Number steps multiplied by 1000")
parser.add_argument("--num_nodes", type=int, default=9, help="Number of nodes in the graph")
parser.add_argument("--display", action='store_true', help="Display results in Neptune")
parser.add_argument("--gamma", type=float, default=0.99, help="Temporal Discount")
parser.add_argument("--print", action='store_true', help="Whether to print the experiments results during the training")
parser.add_argument("--uniform_random", action='store_true', help="If probabilities of arrival of requests are uniform")
parser.add_argument("--parallel", action='store_true', help="If to run the processes in parallel")
parser.add_argument("--var_distance", action='store_true', help="If to set a varying distance between nodes")
parser.add_argument("--tr_dist", action='store_true', help="If to set a trainable distance")
parser.add_argument("--agent", type=str, default='greedy', help="Which agent")
parser.add_argument("--device", type = str, default = "cpu", help="Which device to use")
parser.add_argument("--graph_type", type=str, default='grid_gre_50', help="Which graph type")


args = parser.parse_args()
device = args.device
num_steps = args.num_steps
save = args.save
display_results = args.display
gamma = args.gamma
print_results  = args.print
uniform_random = args.uniform_random
var_distance = args.var_distance 
tr_dist = args.tr_dist
num_nodes = args.num_nodes 
graph_type = args.graph_type
request_same_node = True
arrival_rates = True
print('Varying Distance:', var_distance)


agent_mapping = {
    "dqn": DQNAgent_10,
    "qtable": Qtable,
    "random": RandomPolicy, 
    "greedy": GreedyPolicy,
    "gcn_sl": GCN_SL,
    "gcn_rl": GCN_RL,
    "qtable_wql": Qtable_WQL,
    "gcn_rl_sl": GCN_RL_SL
}


max_cpu_usage_percent = 50
cpu_usage_percent = psutil.cpu_percent(interval=1)
free_cpus = cpu_count() - int(cpu_count() * (cpu_usage_percent / 100))
print("Number of free CPUs:", free_cpus)
# print(f"Number of nodes: {num_nodes}")
print(f"Device Type: {device}")

if args.agent in ["dqn" , "dqn_10", "gcn_sl", "gcn_rl", "gcn_rl_gen", "gcn_rl_sl"]:
    batch_size = 5
else:
    batch_size = 1 

agent = agent_mapping[args.agent]
# graph_types = ['tree_54', 'tree_55','grid_gre_54','tree_55'] #'grid_gre_50','tree_52',
# graph_types = ['tree_54', 'tree_55'] #'grid_gre_50','tree_52',
# graph_types = [ 'grid_dir_50', 'grid_dir_51', 'grid_dir_52', 'grid_dir_53', 'grid_dir_54']
# graph_types = ['tree_50', 'tree_51', 'tree_52', 'tree_53', 'tree_54']
# graph_types = ['grid_gre_50', 'grid_gre_51', 'grid_gre_52', 'grid_gre_53', 'grid_gre_54']
graph_types = ['grid_dir_60', 'grid_dir_61', 'grid_dir_62', 'grid_dir_63', 'grid_dir_64']
number_nodes = [49, 64, 81] # 

args_list = []
for graph_type in graph_types: 
          for num_nodes in number_nodes:  
              args_list.append((graph_type, num_nodes)) 

def run_experiment(hyperparams):

          agent = agent_mapping[args.agent]
          # if num_nodes == 36: 
          #      num_servers_max = 4
          # if num_nodes in [36, 49, 64, 81]:
          #      num_servers_max = 4
          # else: 
          #      num_servers_max = round(num_nodes/6)
          # for num_servers in range(2, num_servers_max+1, 1):
               # num_servers= round(num_nodes/6)
               # num_servers = 2
          num_servers = 4 
          graph_type = hyperparams[0]
          num_nodes = hyperparams[1]
               
          env = KServerEnv(num_nodes, num_servers, batch_size=batch_size, graph_type=graph_type, device = device, uniform_random = uniform_random,\
               request_same_node=True, arrival_rates=True, var_distance = var_distance)
          print(f"Graph Type: {env.graph_type}")
          print('Number of Nodes:', num_nodes, 'Number of Servers:', num_servers)

          # num_sequences = 3
          # requests, state = generate_requests(env, 42, 400, num_sequences=num_sequences)
          # print(requests[0, :10]) 

          # # Function to find the most frequent element in a 1D tensor
          # def most_frequent_element(tensor_row):
          #      return torch.mode(tensor_row).values.item()

          # # Apply the function to each row of the 2D tensor
          # most_frequent_elements = [most_frequent_element(row) for row in requests]

          # print("The most frequent elements for each row are:", most_frequent_elements)
          # index_of_largest_value = env.probabilities.index(max(env.probabilities))
          # print("The index of the largest value is:", index_of_largest_value)
          # # print(env.probabilities)
          # print(state)
          # # print(env.cost_matrix)
     

          if agent in [GCN_RL, GCN_RL_SL]:
               agent = agent(env, hidden_channels = 128, lr = 0.001, shared_weights = False, gamma = gamma, tr_dist = tr_dist)
               # print(agent.print_network_weights())
               # agent.optimize(num_steps, print_results = print_results, save_results = save, display_results = display_results)
               
          elif agent == DQNAgent_10:
               agent = DQNAgent_10(env, seed = 42, lr = 0.01, gamma = gamma)
               agent.optimize(num_steps, print_results = print_results, save_results = save, display_results = display_results)

          elif args.agent == 'greedy':
               agent = GreedyPolicy(env)
               # print(f'Greedy Policy Estimate for 40,000 steps is {round(agent.estimate(num_steps, print_results=print_results)[0].item(), 3)}') 
               print(f'Greedy Policy Estimate for 40,000 steps is {agent.estimate(40, print_results=print_results)[0].item()}') 

          elif agent == Qtable:
               agent = Qtable(env, lr = 0.1, gamma = gamma)
               # agent_name = str(Qtable).split("'")[1].split(".")[-1]
               # agent.q_table = torch.load(f'results/gen_testing/VD{var_distance}/{agent_name}/train_results/models/model_{env.graph_type}_{agent_name}_{num_nodes}_gamma{gamma}.pth')
               agent.optimize(num_steps, print_results = print_results, save_results = save, display_results = display_results, explr= 0.8)

          elif args.agent == "qtable_wql":
               agent = Qtable_WQL(env, lr = 0.1, gamma = gamma)
               agent.compare_policies()
               

print("Starting the experiments")
print("Device:", device)
print("Save:", save)
print("Number of experiments:", len(args_list))


if __name__ == "__main__":  

    max_cpu_usage_percent = 50
    num_processes = min(cpu_count(), len(number_nodes)*len(graph_types))
    
    cpu_usage_percent = psutil.cpu_percent(interval=1)
    free_cpus = cpu_count() - int(cpu_count() * (cpu_usage_percent / 100))
    
    print("Number of free CPUs:", free_cpus)
    

    start_time_total = time.time()

    if args.parallel:
        with Pool(processes=40) as pool:
            pool.map(run_experiment, args_list)

    else:
        for hyperparams in args_list:
            run_experiment(hyperparams)
    end_time_total = time.time()
    elapsed_time_total = end_time_total - start_time_total
    print(f'Total time elapsed in seconds:', round(elapsed_time_total, 3))


          # env = KServerEnv(num_nodes, num_servers, batch_size=512, graph_type='tree_50', device='cuda', uniform_random = False)