import sys
sys.path.append('/data/home/ifb5104/K_server_RL')
from concurrent.futures import thread
from json.tool import main
from KServerEnv import KServerEnv 
from Policies.DQN import DQNAgent 
from Policies.DQN_10 import DQNAgent_10
# from Qtable import Qtable
from Policies.Qtable_MP import Qtable
from Policies.Random_Greedy import RandomPolicy, GreedyPolicy
from Policies.GCN_SL import GCN_SL
from Policies.GCN_RL import GCN_RL
from Policies.GCN import GCN, Net
import csv

import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from multiprocessing import Pool, cpu_count
from multiprocessing.pool import ThreadPool
import numpy as np
import psutil

import pandas as pd
import matplotlib.pyplot as plt

import glob

import argparse

import time

import os

parser = argparse.ArgumentParser()
parser.add_argument("--test_one", type=bool, default=True, help="Whether to do experiments for all arguments")
parser.add_argument("--explr", type=int, default=1, help="Exploration steps multiplied by 1000")
parser.add_argument("--explt", type=int, default=1, help="Exploitation steps multiplied by 1000")
parser.add_argument("--use_all_cpus", type=int, default=0, help="Whether to use all free cpus")
parser.add_argument("--thread_pool", type=int, default=0, help="Whether to use ThreadPool instead of Pool")
parser.add_argument("--device", type = str, default = "cpu", help="Which device to use")
parser.add_argument("--save", type = bool, default = False, help="Whether to save the model")
args = parser.parse_args()

test_one = args.test_one
use_all_cpus = args.use_all_cpus
exploration = args.explr
exploitation = args.explt
thread_pool = args.thread_pool
device = args.device
save = args.save










def run_experiment(args):

    hyperparams = args[0]
    start_index = hyperparams.find('model_') + len('model_')
    end_index = hyperparams.find('_DQNAgent_10')
    graph_type = hyperparams[start_index:end_index]

    # Extract num_nodes
    start_index = hyperparams.find('_DQNAgent_10_') + len('_DQNAgent_10_')
    end_index = hyperparams.find('_seed')
    num_nodes = hyperparams[start_index:end_index]
    num_nodes = int(num_nodes)

    if num_nodes > 30:
        pass
    
    else:

        hidden_channels = args[1]

        start_time = time.time()
        
        Agent = GCN_SL
        


    

        num_servers = round(num_nodes / 6)
        env = KServerEnv(num_nodes, num_servers, batch_size=256, graph_type=graph_type, device=device)
        agent = Agent(env, hyperparams, hidden_channels = hidden_channels)

        agent_name = 'GCN_SL2'
        output_file_name = f'gcn_sl_results3/results_{graph_type}_{agent_name}_{num_nodes}_hidden_channels{hidden_channels}.csv'
        if os.path.exists(output_file_name):
            print(f'Experiment_{graph_type}_{agent_name}_{num_nodes}_hidden_channels{hidden_channels} has already been done')
        else:
            print(f'Started experiment_{graph_type}_{agent_name}_{num_nodes}_hidden_channels{hidden_channels}')
            # print(f'Experiment_{graph_type}_{agent_name}_{num_nodes} started')
            try:
                agent.optimize(exploration)
                agent.optimize(exploitation, epsilon = 0)
            except:
                pass 

            estimate, q1, q3 = agent.estimate(40)
            
            
        
            if save == True: 
                torch.save(agent.q_network.state_dict(), f'gcn_sl_results3/models/model_{agent_name}_{graph_type}__{num_nodes}_hidden_channels{hidden_channels}.pth')
            

            with open(output_file_name, 'w', newline='') as f:
                    writer = csv.writer(f)
                    writer.writerow(['graph_type', 'agent', 'num_nodes', 'hidden_channels', 'estimate', 'q1', 'q3'])
                    writer.writerow([graph_type, agent_name, num_nodes, hidden_channels, round(estimate.item(), 3), round(q1.item(), 3), round(q3.item(), 3)])  
            
            end_time = time.time()
            elapsed_time = end_time - start_time

            print(f'Experiment_{graph_type}_{agent_name}_{num_nodes}_hidden_channels{hidden_channels} took {round(elapsed_time, 3)} seconds to finish')

    
model_list = glob.glob('exp_results_check2/models/model*.pth') 
hidden_channels_list = [2, 4, 8, 16, 32]

args_list = []
for model in model_list:
    for hidden_channels in hidden_channels_list:
        args_list.append((model, hidden_channels))

print("Starting the experiments")
print("Exploration steps multiplied by 1000:", args.explr)
print("Exploitation steps multiplied by 1000:", args.explt)
print("Number of experiments:", len(args_list))
print("Device:", device)

if __name__ == "__main__":

    start_time_total = time.time()

    # if test_one == True: 
    #     run_experiment(args_list[0])

    # else:       
    #     for args in args_list:
    #         run_experiment(args)
    
    for args in args_list:
            run_experiment(args)

    end_time_total = time.time()
    elapsed_time_total = end_time_total - start_time_total
    print(f'Total time elapsed in seconds:', round(elapsed_time_total, 3))

    # while True:
    #     cpu_usage_percent = psutil.cpu_percent(interval=1)
    #     free_cpus = cpu_count() - int(cpu_count() * (cpu_usage_percent / 100))
    #     if free_cpus >= num_processes:
    #         break
    #     print(f'CPU usage is {cpu_usage_percent}%, waiting for available CPUs...')

    # with Pool(processes=20) as pool:
    #     pool.map(run_experiment, args_list2)

    # with Pool(processes=(16)) as pool:
    #    pool.map(run_experiment, args_list)

    
 






    