import sys
sys.path.append('/data/home/ifb5104/K_server_RL')
from concurrent.futures import thread
from json.tool import main
import multiprocessing
from KServerEnv import KServerEnv 
from Policies.DQN import DQNAgent 
from Policies.DQN_10 import DQNAgent_10
# from Qtable import Qtable
from Policies.Qtable_MP import Qtable
from Policies.Random_Greedy import RandomPolicy, GreedyPolicy
from Policies.GCN_SL import GCN_SL
from Policies.GCN_RL import GCN_RL
from Policies.NET import Net
import csv
import os 
from Policies.GCN_RL_GEN import GCN_RL_GEN
from Policies.Balance import BalancePolicy
from Policies.Harmonic import HarmonicPolicy
from Policies.WFA import WorkFunction
from generate_requests import generate_requests

import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from multiprocessing import Pool, cpu_count
from multiprocessing.pool import ThreadPool
import numpy as np
import psutil

import pandas as pd
import matplotlib.pyplot as plt

import glob

import argparse

import time

import os
import re


parser = argparse.ArgumentParser()
parser.add_argument("--use_all_cpus", action='store_true', help="Whether to use all free cpus")
parser.add_argument("--thread_pool", action='store_true', help="Whether to use ThreadPool instead of Pool")
parser.add_argument("--device", type=str, default="cpu", help="Which device to use")
parser.add_argument("--save", action='store_true', help="Whether to save the model")
parser.add_argument("--display_results", action='store_true', help="Display results in Neptune")
parser.add_argument("--gamma", type=float, default=0.99, help="Temporal Discount")
parser.add_argument("--print_results", action='store_true', help="Whether to print the experiments results during the training")
parser.add_argument("--uniform_random", action='store_true', help="If probabilities of arrival of requests are uniform")
parser.add_argument("--parallel", action='store_true', help="If to run the processes in parallel")
parser.add_argument("--var_distance", action='store_true', help="If to set a varying distance between nodes")
parser.add_argument("--tr_dist", action='store_true', help="If we train the distance between nodes")
parser.add_argument("--nl", type=int, default=12, help="Number of Layers")
parser.add_argument("--dir_graph", action='store_true', help="If it is a directed graph architecture")
parser.add_argument("--test_by_one", action='store_true', help="If to test one by one")
args = parser.parse_args()


use_all_cpus = args.use_all_cpus
thread_pool = args.thread_pool
device = args.device
save = args.save
display_results = args.display_results
gamma = args.gamma
print_results  = args.print_results
uniform_random = False
request_same_node=True
arrival_rates=True
seq_req=True
burn_in = False
burn_in_period = 100
var_distance = args.var_distance 
tr_dist = args.tr_dist 



methods  = [GCN_RL]
# number_nodes = [74]
# number_nodes = [9, 25, 36, 64]
# number_nodes = [25, 49, 64, 81]
number_nodes = [25]
# number_nodes = [1024]
# number_nodes = [9, 16, 25, 36, 49, 64, 81, 100, 1024]
hidden_channels_list = [128]
# graph_types = [ 'grid_gre_51', 'grid_gre_50','grid_gre_52', 'grid_gre_53', 'grid_gre_54']
# graph_types = [ 'bn_grid_gre_51', 'bn_grid_gre_50','bn_grid_gre_52', \
#     'bn_grid_gre_53', 'bn_grid_gre_54']   
graph_types = ['grid_dir_50', 'grid_dir_51', 'grid_dir_52', 'grid_dir_53', 'grid_dir_54']
# graph_types = ['grid_dir_60', 'grid_dir_61', 'grid_dir_62', 'grid_dir_63', 'grid_dir_64']
# graph_types = [ 'bn_grid_gre_51', 'psn_grid_gre_51', 'lgnm_grid_gre_54'] 
# graph_types = ['tree_50', 'tree_51', 'tree_52', 'tree_53', 'tree_54']
# graph_types = [ 'grid_dir_51', 'grid_dir_50','grid_dir_52', 'grid_dir_53', 'grid_dir_54']
# graph_types = ['EM']
# graph_types = [ 'grid_gre_51', 'grid_gre_50','grid_gre_52', 'grid_gre_53', 'grid_gre_54', \
#          'tree_50', 'tree_51', 'tree_52', 'tree_53', 'tree_54' ]
#              'plane_50', 'plane_51', 'plane_52', 'plane_53', 'plane_54']             
seeds = [42]
     
    
args_list = []
for method in methods:
    if method in [DQNAgent_10]:
        for num_nodes in number_nodes:
            for graph_type in graph_types:
                for seed in seeds:
                    args_list.append((method, num_nodes, graph_type, seed))
    
    elif method in [GCN_RL]:
        for num_nodes in number_nodes:
            for graph_type in graph_types:
                for hidden_channels in hidden_channels_list:
                    args_list.append((method, num_nodes, graph_type, hidden_channels))

    elif method == Qtable:
        for num_nodes in number_nodes[:2]:
            for graph_type in graph_types:
                    args_list.append((method, num_nodes, graph_type))

    else:
        for num_nodes in number_nodes:
            for graph_type in graph_types:
                args_list.append((method, num_nodes, graph_type))







def run_experiment(hyperparams):

    start_time = time.time()
    num_nodes = hyperparams[1]
    Agent = hyperparams[0]
    graph_type = hyperparams[2]
    seed = 42
    hidden_channels = hyperparams[3]
    num_sequences = 10
    num_servers = 4
    

    # if Agent in [DQNAgent, DQNAgent_10]:
    #     seed = hyperparams[3]
    #     hidden_channels = None
        
    # elif Agent in [GCN_RL, GCN_SL]:
    #     seed = None
    #     hidden_channels = hyperparams[3]
    # else: 
    #     seed = None
    #     hidden_channels = None
    agent_name = str(hyperparams[0]).split("'")[1].split(".")[-1]

    if 'bn_grid_gre' in graph_type:  
        agent_name_gen =  f'{agent_name}_gen_bn'
    elif'psn_grid_gre' in graph_type:  
        agent_name_gen =  f'{agent_name}_gen_psn'
    elif'lgnm_grid_gre' in graph_type:  
        agent_name_gen =  f'{agent_name}_gen_lgnm'
    else: 
        agent_name_gen =  f'{agent_name}_gen_sl'  

    # file_paths = [] 

    file_paths = [f'results/gen_testing/VD{var_distance}/{agent_name_gen}/results_{hyperparams[2]}_{agent_name}_{num_nodes}_{num_servers}_gamma{gamma}.csv']

    if any(os.path.exists(file_path) for file_path in file_paths):
        print(f'Skipping experiment for {hyperparams}, as one or more result files exist.')
    else: 
        # num_servers = round(num_nodes / 6)
        # num_servers = 4

        if Agent in [DQNAgent_10]:
            env = KServerEnv(num_nodes, num_servers, batch_size=512, graph_type=graph_type, device=device, uniform_random = uniform_random)
            agent = Agent(env, seed = seed)

        elif Agent in [GCN_RL, GCN_SL]:
            # env = KServerEnv(num_nodes, num_servers, batch_size=512, graph_type=graph_type, device=device, uniform_random = uniform_random)
            # agent = GCN_RL(env, hidden_channels = 128, shared_weights = True, gamma = gamma, gen=True) 
            print('start')
            # env = KServerEnv(num_nodes, num_servers, batch_size=1, graph_type=graph_type, device=device,\
            #      uniform_random = uniform_random, request_same_node=request_same_node, arrival_rates=arrival_rates, seq_req=seq_req)
            # agent = Agent(env, hidden_channels = hidden_channels, shared_weights = True, gamma = gamma, gen = True, est_pr_acc = False, use_batch_norm = True, dir_graph = True)
            if args.test_by_one: 
                test_bs = 1
            else: 
                test_bs = num_sequences
            env = KServerEnv(num_nodes, num_servers, batch_size=test_bs, graph_type=graph_type, device=device,\
                uniform_random = uniform_random, request_same_node=request_same_node, arrival_rates=arrival_rates, seq_req=seq_req)
            agent = Agent(env, hidden_channels = hidden_channels, shared_weights = True, gamma = gamma, gen = True, use_batch_norm = True, num_layers = args.nl, dir_graph = args.dir_graph) 
            print('start')

            # (hidden_channels=128, general_model_gt = 'grid_gre', batch_size= 512, uniform_random = False, \
            #     constant_probability = False, var_pr_ep = True, request_same_node= True, arrival_rates= True)
            
            # if 'grid_gre' in agent.env.graph_type: 
            #     # agent.q_network.load_state_dict(torch.load('/home/ifb5104/K_server_RL/results/single_model_results/uniformFalse/models/model_GenModelAll_hch128_nl12_tree_URFalse_CPFalse_lr0.001_varprepTrue30_rqsmndTrue_ARTrue_gamma0.95.pth')) 
            #     agent.q_network.load_state_dict(torch.load(f'/home/ifb5104/K_server_RL/results/single_model_results/uniformFalse/models/model_GenModelAll_hch128_nl12_grid_gre_URFalse_CPFalse_lr0.001_varprepTrue30_rqsmndTrue_ARTrue_gamma0.99.pth'))                                                                                 
            if 'bn_grid_gre' in agent.env.graph_type: 
                agent.q_network.load_state_dict(torch.load('/home/ifb5104/K_server_RL/results/single_model_results/uniformFalse/models/model_GenModelAll_hch128_nl12_bn_grid_gre_URFalse_CPFalse_lr0.001_nl12_varprepTrue30_rqsmndTrue_ARTrue_gamma0.99_BNTrue_VDFalse_DGFalse_bs512_lr0.001_nl12.pth'))                                                                                 
            elif 'psn_grid_gre' in agent.env.graph_type: 
                agent.q_network.load_state_dict(torch.load('/home/ifb5104/K_server_RL/results/single_model_results/uniformFalse/models/model_GenModelAll_hch128_nl12_psn_grid_gre_URFalse_CPFalse_lr0.001_nl12_varprepTrue30_rqsmndTrue_ARTrue_gamma0.99_BNTrue_VDFalse_DGFalse_bs512_lr0.001_nl12.pth'))                                                                                 
            elif 'lgnm_grid_gre' in agent.env.graph_type: 
                agent.q_network.load_state_dict(torch.load('/home/ifb5104/K_server_RL/results/single_model_results/uniformFalse/models/model_GenModelAll_hch128_nl12_lgnm_grid_gre_URFalse_CPFalse_lr0.001_nl12_varprepTrue30_rqsmndTrue_ARTrue_gamma0.99_BNTrue_VDFalse_DGFalse_bs512_lr0.001_nl12.pth'))                                                                                  
            elif 'tree' in agent.env.graph_type:
                agent.q_network.load_state_dict(torch.load(f'/home/ifb5104/K_server_RL/results/single_model_results/uniformFalse/models/model_GenModelAll_hch128_nl12_grid_gre_URFalse_CPFalse_lr0.001_varprepTrue30_rqsmndTrue_ARTrue_gamma0.99.pth'))
            elif 'plane' in agent.env.graph_type:
                agent.q_network.load_state_dict(torch.load('/home/ifb5104/K_server_RL/results/single_model_results/uniformFalse/models/model_GenModelAll_hch128_nl12_plane_URFalse_CPFalse_lr0.001_varprepTrue_rqsmndTrue_ARTrue.pth'))
            elif 'SF' in agent.env.graph_type:
                agent.q_network.load_state_dict(torch.load('/home/ifb5104/K_server_RL/results/single_model_results/uniformFalse/models/model_GenModelAll_hch128_nl12_tree_URFalse_CPFalse_lr0.001_varprepTrue30_rqsmndTrue_ARTrue_gamma0.95.pth'))                                                                                  
            elif 'EM' in agent.env.graph_type:
                agent.q_network.load_state_dict(torch.load('/home/ifb5104/K_server_RL/results/single_model_results/uniformFalse/models/model_GenModelAll_hch128_nl12_tree_URFalse_CPFalse_lr0.001_varprepTrue30_rqsmndTrue_ARTrue_gamma0.95.pth'))                                                                                  
            elif 'grid_dir' in agent.env.graph_type:
                # agent.q_network.load_state_dict(torch.load(f'/home/ifb5104/K_server_RL/results/single_model_results/uniformFalse/models/model_GenModelAll_hch128_nl12_grid_dir_URFalse_CPFalse_lr0.001_nl12_varprepTrue30_rqsmndTrue_ARTrue_gamma0.99_BNTrue_VDFalse_DGTrue_bs256_lr0.001_nl12.pth'))                                                                                  
                agent.q_network.load_state_dict(torch.load(f'results/single_model_results/uniformFalse/models/model_GenModelAllSL_hch128_nl4_grid_dir_sl_URFalse_CPTrue_lr0.001_nl4_varprepTrue30_rqsmndTrue_ARTrue_gamma0.99_BNTrue_VDFalse_DGTrue_bs128_lr0.001_nl4.pth'))                                                                                      
            # elif 'grid_dir' in agent.env.graph_type:
            #     agent.q_network.load_state_dict(torch.load(f'/home/ifb5104/K_server_RL/results/single_model_results/uniformFalse/models/model_GenModelAll_hch128_nl12_grid_dir_URFalse_CPFalse_lr0.001_nl12_varprepTrue30_rqsmndTrue_ARTrue_gamma0.99_BNTrue_VDFalse_DGTrue_bs256_lr0.001_nl12.pth'))                                                                                  
        
            

        elif Agent == GCN_RL_GEN: 
            env = KServerEnv(num_nodes, num_servers, batch_size=512, graph_type = graph_type, general_model=True, device=device, uniform_random = uniform_random)
            agent = Agent(env, hidden_channels = hidden_channels, shared_weights = False, gamma = gamma) 
            # agent = agent(env, shared_weights = False, gamma = gamma, hidden_channels = hidden_channels) 
        elif Agent == BalancePolicy:
            env = KServerEnv(num_nodes, num_servers, batch_size=1, graph_type=graph_type, device=device, uniform_random = uniform_random, balanced_algorithm=True)
            agent = Agent(env)
        else: 
            env = KServerEnv(num_nodes, num_servers, batch_size=1, graph_type=graph_type, device=device, uniform_random = uniform_random)
            agent = Agent(env)

        
        

        print(f'Experiment {graph_type} {num_nodes}_{num_servers} {agent_name} {hyperparams[1]} seed{seed} hidden_channels{hidden_channels} started')


        
        requests, state = generate_requests(env, seed, 4000, num_sequences=num_sequences)
        if args.test_by_one:         
            estimates = []
            q1s =[]
            q3s = []
            raw_results = [] 
            for i in range(num_sequences):
                print(f'sequence{i} started')
                estimate, q1, q3, raw_result = agent.estimate_seq(state[i].unsqueeze(0), requests[i].unsqueeze(0))
                estimates.append(estimate)
                q1s.append(q1)
                q3s.append(q3)
                raw_results.append(raw_result)
                print(f'sequence{i} finished')
        else:
            estimates = agent.estimate_seq(state, requests)[3]
            estimate_sums = torch.sum(estimates, dim = 1)
        
        
        # raw_results = [] 
        # for i in range(num_sequences):
        #     raw_results.append(estimates[i])
        

            
        
        
        
        # Include the seed in the output CSV file name if it's a DQNAgent or DQNAgent_10

        
        if save:
            if not os.path.exists(f'results/gen_testing/VD{var_distance}/{agent_name_gen}'):  
                os.makedirs(f'results/gen_testing/VD{var_distance}/{agent_name_gen}')
            if not os.path.exists(f'results/gen_testing/VD{var_distance}/{agent_name_gen}/models'):
                os.makedirs(f'results/gen_testing/VD{var_distance}/{agent_name_gen}/models')  
            if not os.path.exists(f'results/gen_testing/VD{var_distance}/{agent_name_gen}/raw_results'):
                os.makedirs(f'results/gen_testing/VD{var_distance}/{agent_name_gen}/raw_results')      
            
            output_file_name = f'results/gen_testing/VD{var_distance}/{agent_name_gen}/results_{hyperparams[2]}_{agent_name}_{num_nodes}_{num_servers}_gamma{gamma}.csv'
            
            output_file_name_raw = f'results/gen_testing/VD{var_distance}/{agent_name_gen}/raw_results/results_{hyperparams[2]}_{agent_name}_{num_nodes}_{num_servers}_gamma{gamma}_raw.csv'
            

            if args.test_by_one == False:    
                with open(output_file_name, 'w', newline='') as f:
                        writer = csv.writer(f)
                        writer.writerow(['graph_type', 'agent', 'gamma','num_nodes', 'seed', 'hidden_channels', 'estimate'])
                        for i in range(num_sequences):    
                            writer.writerow([graph_type, agent_name, gamma, num_nodes, seed, hidden_channels, round(estimate_sums[i].item(), 3)])  
                with open(output_file_name_raw, 'w', newline='') as f:
                        writer = csv.writer(f)
                        # writer.writerow([raw_result])
                        for i in range(num_sequences):    
                            row_as_list = estimates[i].tolist()
                            row_as_string = str(row_as_list)
                            
                            # Write the string representation of the row to the file
                            writer.writerow([row_as_string])

            else:
                with open(output_file_name, 'w', newline='') as f:
                        writer = csv.writer(f)
                        writer.writerow(['graph_type', 'agent', 'gamma','num_nodes', 'seed', 'hidden_channels', 'estimate', 'q1', 'q3'])
                        for i in range(num_sequences):    
                            writer.writerow([graph_type, agent_name, gamma, num_nodes, seed, hidden_channels, round(estimates[i].item(), 3), round(q1s[i].item(), 3), round(q3s[i].item(), 3)])  
                with open(output_file_name_raw, 'w', newline='') as f:
                        writer = csv.writer(f)
                        for i in range(num_sequences):    
                            writer.writerow(raw_results[i].tolist())                
        end_time = time.time()
        elapsed_time = end_time - start_time

        print(f'Experiment_{hyperparams[2]}_{agent_name}_{gamma}_{hyperparams[1]}_seed{seed}_hidden_channels{hidden_channels} took {round(elapsed_time, 3)} seconds to finish')

    

print("Starting the experiments")
# print("Number of steps multiplied by 1000:", args.num_steps)
print("Number of experiments:", len(args_list))
print("Device:", device)
print("Save:", save)
print("Uniform Random:", uniform_random)
print("Display Results:", display_results)

if __name__ == "__main__":

    max_cpu_usage_percent = 50
    num_processes = min(cpu_count(), len(methods)*len(number_nodes)*len(graph_types))
    
    cpu_usage_percent = psutil.cpu_percent(interval=1)
    free_cpus = cpu_count() - int(cpu_count() * (cpu_usage_percent / 100))
    
    print("Number of free CPUs:", free_cpus)
    # print("Using all free CPUs:", bool(use_all_cpus))

    start_time_total = time.time()

    if args.parallel:
        with Pool(processes=2) as pool:
            pool.map(run_experiment, args_list)

    else:
        for hyperparams in args_list:
            run_experiment(hyperparams)
    end_time_total = time.time()
    elapsed_time_total = end_time_total - start_time_total
    print(f'Total time elapsed in seconds:', round(elapsed_time_total, 3))

    

    # while True:
    #     cpu_usage_percent = psutil.cpu_percent(interval=1)
    #     free_cpus = cpu_count() - int(cpu_count() * (cpu_usage_percent / 100))
    #     if free_cpus >= num_processes:
    #         break
    #     print(f'CPU usage is {cpu_usage_percent}%, waiting for available CPUs...')

    # with Pool(processes=20) as pool:
    #     pool.map(run_experiment, args_list2)

    # with Pool(processes=(16)) as pool:
    #    pool.map(run_experiment, args_list)

    
 






    