import sys
sys.path.append('/data/home/ifb5104/K_server_RL')
from KServerEnv import KServerEnv 
from Policies.DQN import DQNAgent 
from Policies.DQN_10 import DQNAgent_10
# from Qtable import Qtable
from Policies.Qtable_MP import Qtable
from Policies.Random_Greedy import RandomPolicy, GreedyPolicy
from Policies.GCN_SL import GCN_SL
from Policies.GCN_RL import GCN_RL
import csv
import os 
from Policies.GCN_RL_GEN import GCN_RL_GEN
from Policies.Balance import BalancePolicy
from Policies.Harmonic import HarmonicPolicy
from Policies.WFA import WorkFunction
import torch
from multiprocessing import Pool, cpu_count
import psutil
import argparse
import time
import os



parser = argparse.ArgumentParser()
parser.add_argument("--test_all", type=int, default=0, help="Whether to do experiments for all arguments")
parser.add_argument("--num_steps", type=int, default=1, help="Number steps multiplied by 1000")
parser.add_argument("--use_all_cpus", type=int, default=0, help="Whether to use all free cpus")
parser.add_argument("--thread_pool", type=int, default=0, help="Whether to use ThreadPool instead of Pool")
parser.add_argument("--device", type = str, default = "cpu", help="Which device to use")
parser.add_argument("--save", type = bool, default = False, help="Whether to save the model")
parser.add_argument("--display_results", type = bool, default = False, help="Display results in Neptune")
parser.add_argument("--gamma", type = float, default = 0.99, help="Temporal Discount")
parser.add_argument("--print_results", type = bool, default = False, help="Whether to print the experiments results during the training")
parser.add_argument("--uniform_random", type = bool, default = True, help="If probabilities of arrival of requests are uniform")
parser.add_argument("--parallel", type = bool, default = False, help="If to run the processes in parallel")
# parser.add_argument("--grid", type = bool, default = False, help="If graph type is grid")
args = parser.parse_args()
test_all = args.test_all
use_all_cpus = args.use_all_cpus
num_steps = args.num_steps
thread_pool = args.thread_pool
device = args.device
save = args.save
display_results = args.display_results
gamma = args.gamma
print_results  = args.print_results
uniform_random = False



# number_nodes = [10, 20, 30, 10, 50]
# graph_types = []
# hidden_channels = 128
# if args.grid:
#     for num_nodes in number_nodes:
#         for rows in range(2, num_nodes+1):
#             columns = num_nodes // rows
#             if rows * columns == num_nodes and rows > columns > 1:
#                     graph_types.append([num_nodes, 128, f'grid_col{columns},'])



# print(graph_types)



if test_all == 0:
    methods  = [WorkFunction]
    methods  = [GreedyPolicy, HarmonicPolicy, BalancePolicy]
    # number_nodes = [9, 16, 25, 36, 49, 64, 81, 100]
    hidden_channels_list = [128]
    # graph_types = ['grid_gre_50', 'grid_gre_51', 'grid_gre_52', 'grid_gre_53', 'grid_gre_54']
    graph_types = ['EM']
    # graph_types = [ 'grid_gre_51', 'grid_gre_50','grid_gre_52', 'grid_gre_53', 'grid_gre_54', \
    #      'tree_50', 'tree_51', 'tree_52', 'tree_53', 'tree_54'\
    #          ]
#  'plane_50', 'plane_51', 'plane_52', 'plane_53', 'plane_54'       
    seeds = [42]
     
    


if test_all == 1:
    methods  = [RandomPolicy, GreedyPolicy, GCN_RL, Qtable, DQNAgent_10]
    number_nodes = [9, 25, 49]
    hidden_channels_list = [128]
    graph_types = ['tree_1', 'tree_2', 'tree_3', 'tree_4', 'tree_5', 'cycle', 'line']
    seeds = [42]  # List of different seeds


args_list = []
for method in methods:
    if method in [DQNAgent_10]:
        for num_nodes in number_nodes:
            for graph_type in graph_types:
                for seed in seeds:
                    args_list.append((method, num_nodes, graph_type, seed))
    # elif method in [GCN_RL, GCN_SL]:
    elif method in [GCN_RL]:
        for num_nodes in number_nodes:
            for graph_type in graph_types:
                for hidden_channels in hidden_channels_list:
                    args_list.append((method, num_nodes, graph_type, hidden_channels))
    
    

    elif method == Qtable:
        for num_nodes in number_nodes[:2]:
            for graph_type in graph_types:
                    args_list.append((method, num_nodes, graph_type))

    else:
        for num_nodes in number_nodes:
            for graph_type in graph_types:
                args_list.append((method, num_nodes, graph_type))







def run_experiment(hyperparams):

    start_time = time.time()
    num_nodes = hyperparams[1]
    Agent = hyperparams[0]
    graph_type = hyperparams[2]

    if Agent in [DQNAgent, DQNAgent_10]:
        seed = hyperparams[3]
        hidden_channels = None
        
    elif Agent in [GCN_RL, GCN_SL]:
        seed = None
        hidden_channels = hyperparams[3]
    else: 
        seed = None
        hidden_channels = None

 

    num_servers = round(num_nodes / 6)

    if Agent in [DQNAgent_10]:
        env = KServerEnv(num_nodes, num_servers, batch_size=512, graph_type=graph_type, device=device, uniform_random = uniform_random)
        agent = Agent(env, seed = seed)

    elif Agent in [GCN_RL, GCN_SL]:
        env = KServerEnv(num_nodes, num_servers, batch_size=512, graph_type=graph_type, device=device, uniform_random = uniform_random)
        agent = Agent(env, hidden_channels = hidden_channels, shared_weights = False, gamma = gamma) 

    elif Agent == GCN_RL_GEN: 
        env = KServerEnv(num_nodes, num_servers, batch_size=512, graph_type = graph_type, general_model=True, device=device, uniform_random = uniform_random)
        agent = Agent(env, hidden_channels = hidden_channels, shared_weights = False, gamma = gamma) 
        # agent = agent(env, shared_weights = False, gamma = gamma, hidden_channels = hidden_channels) 
    elif Agent == BalancePolicy:
        env = KServerEnv(num_nodes, num_servers, batch_size=1, graph_type=graph_type, device=device, uniform_random = uniform_random, balanced_algorithm=True)
        agent = Agent(env)

    else: 
        env = KServerEnv(num_nodes, num_servers, batch_size=1, graph_type=graph_type, device=device, uniform_random = uniform_random)
        agent = Agent(env)

    agent_name = str(hyperparams[0]).split("'")[1].split(".")[-1]

    file_paths = [
    f'results/gen_testing/{agent_name}/results_{hyperparams[2]}_{agent_name}_{hyperparams[1]}_hidden_channels{hidden_channels}__gamma{gamma}.csv',
    f'results/gen_testing/{agent_name}/results_{hyperparams[2]}_{agent_name}_{hyperparams[1]}__gamma{gamma}.csv',
    f'results/gen_testing/{agent_name}/results_{hyperparams[2]}_{agent_name}_{hyperparams[1]}_gamma{gamma}.csv']

    if any(os.path.exists(file_path) for file_path in file_paths):
        print(f'Skipping experiment for {hyperparams}, as one or more result files exist.')
        
    else: 
        

        print(f'Experiment_{hyperparams[2]}_{agent_name}_{hyperparams[1]}_seed{seed}_hidden_channels{hidden_channels} started')
        try:
            agent.optimize(num_steps, print_results = print_results, save_results = save, display_results = display_results)
        except:
            pass 
        
        if Agent == GCN_RL_GEN: 
            estimate, q1, q3, _ = agent.estimate_all(40)
            output_data = agent.estimate(40)  

        elif Agent in [GreedyPolicy, RandomPolicy, HarmonicPolicy, BalancePolicy, WorkFunction]: 
            estimates = []
            q1s =[]
            q3s = []
            raw_results = []
            for i in range(10):
                if Agent== WorkFunction:
                    estimate, q1, q3, raw_result = agent.estimate(1)
                else:
                    estimate, q1, q3, raw_result = agent.estimate(4)
                estimates.append(estimate)
                q1s.append(q1)
                q3s.append(q3)
                raw_results.append(raw_result)
        elif Agent in [WorkFunction]: 
            estimate, q1, q3, raw_result = agent.estimate(10)
            
        
        
        
        # Include the seed in the output CSV file name if it's a DQNAgent or DQNAgent_10

        

        if save:
            if not os.path.exists(f'results/gen_testing/{agent_name}'):  
                os.makedirs(f'results/gen_testing/{agent_name}')
            if not os.path.exists(f'results/gen_testing/{agent_name}/models'):
                os.makedirs(f'results/gen_testing/{agent_name}/models')  
            if not os.path.exists(f'results/gen_testing/{agent_name}/raw_results'):
                os.makedirs(f'results/gen_testing/{agent_name}/raw_results')      

            if Agent == GCN_RL_GEN: 
                torch.save(agent.q_network.state_dict(), f'results/gen_testing/{agent_name}/models/model_{hyperparams[2]}_{agent_name}_{hyperparams[1]}_seed{seed}_gamma{gamma}.pth') 
                # output_file_name = f'results/single_model_results/results_{args.agent}_{num_nodes}_{graph_type}__gamma{gamma}_hidch{hidden_channels}.csv'
                output_file_name = f'results/gen_testing/{agent_name}/results_{hyperparams[2]}_{agent_name}_{hyperparams[1]}_seed{seed}_gamma{gamma}.csv'
                with open(output_file_name, 'w', newline='') as f:
                        writer = csv.writer(f)
                        writer.writerow(['graph_type', 'agent', 'gamma','num_nodes', 'estimate', 'q1', 'q3'])
                        writer.writerow(['general', agent_name, gamma, num_nodes, round(estimate.item(), 3), round(q1.item(), 3), round(q3.item(), 3)])  
                    
                        # Writing the output_data to CSV
                        for tree, value in output_data.items():
                            writer.writerow([tree, agent_name, gamma, num_nodes, round(value.item(), 3), '', ''])
            
            

            # if seed is not None and Agent in [DQNAgent_10]:
            #     output_file_name = f'results/gen_testing/{agent_name}/results_{hyperparams[2]}_{agent_name}_{hyperparams[1]}_seed{seed}_gamma{gamma}.csv'
            #     torch.save(agent.q_network.state_dict(), f'results/gen_testing/{agent_name}/models/model_{hyperparams[2]}_{agent_name}_{hyperparams[1]}_seed{seed}_gamma{gamma}.pth') 
            # elif seed is None and Agent in [GCN_RL, GCN_SL]:
            #     output_file_name = f'results/gen_testing/{agent_name}/results_{hyperparams[2]}_{agent_name}_{hyperparams[1]}_hidden_channels{hidden_channels}__gamma{gamma}.csv'
            #     torch.save(agent.q_network.state_dict(), f'results/gen_testing/{agent_name}/models/model_{hyperparams[2]}_{agent_name}_{hyperparams[1]}_hidden_channels{hidden_channels}__gamma{gamma}.pth')
            # else:
            elif Agent in [GreedyPolicy, RandomPolicy, HarmonicPolicy, BalancePolicy, WorkFunction]: 
                output_file_name = f'results/gen_testing/{agent_name}/results_{hyperparams[2]}_{agent_name}_{hyperparams[1]}__gamma{gamma}.csv'
                with open(output_file_name, 'w', newline='') as f:
                        writer = csv.writer(f)
                        writer.writerow(['graph_type', 'agent', 'gamma','num_nodes', 'seed', 'hidden_channels', 'estimate', 'q1', 'q3'])
                        for i in range(10):    
                            writer.writerow([hyperparams[2], agent_name, gamma, hyperparams[1], seed, hidden_channels, round(estimates[i].item(), 3), round(q1s[i].item(), 3), round(q3s[i].item(), 3)])  

                output_file_name_raw = f'results/gen_testing/{agent_name}/raw_results/results_{hyperparams[2]}_{agent_name}_{hyperparams[1]}__gamma{gamma}_raw.csv'
                with open(output_file_name_raw, 'w', newline='') as f:
                        writer = csv.writer(f)
                        # writer.writerow([raw_result])
                        for i in range(10):    
                            writer.writerow(raw_results[i].tolist()) 
        end_time = time.time()
        elapsed_time = end_time - start_time

        print(f'Experiment_{hyperparams[2]}_{agent_name}_{gamma}_{hyperparams[1]}_seed{seed}_hidden_channels{hidden_channels} took {round(elapsed_time, 3)} seconds to finish')

        

print("Starting the experiments")
print("Number of steps multiplied by 1000:", args.num_steps)
print("Number of experiments:", len(args_list))
print("Device:", device)
print("Save:", save)
print("Uniform Random:", uniform_random)
print("Display Results:", display_results)

if __name__ == "__main__":

    max_cpu_usage_percent = 50
    num_processes = min(cpu_count(), len(methods)*len(number_nodes)*len(graph_types))
    
    cpu_usage_percent = psutil.cpu_percent(interval=1)
    free_cpus = cpu_count() - int(cpu_count() * (cpu_usage_percent / 100))
    
    print("Number of free CPUs:", free_cpus)
    print("Using all free CPUs:", bool(use_all_cpus))

    start_time_total = time.time()

    if args.parallel:
        with Pool(processes=20) as pool:
            pool.map(run_experiment, args_list)

    else:
        for hyperparams in args_list:
            run_experiment(hyperparams)
    end_time_total = time.time()
    elapsed_time_total = end_time_total - start_time_total
    print(f'Total time elapsed in seconds:', round(elapsed_time_total, 3))

    

    # while True:
    #     cpu_usage_percent = psutil.cpu_percent(interval=1)
    #     free_cpus = cpu_count() - int(cpu_count() * (cpu_usage_percent / 100))
    #     if free_cpus >= num_processes:
    #         break
    #     print(f'CPU usage is {cpu_usage_percent}%, waiting for available CPUs...')

    # with Pool(processes=20) as pool:
    #     pool.map(run_experiment, args_list2)

    # with Pool(processes=(16)) as pool:
    #    pool.map(run_experiment, args_list)

    
 






    