import sys
sys.path.append('/data/home/ifb5104/K_server_RL')
from asyncio.proactor_events import constants
from random import uniform
from KServerEnv import KServerEnv 
from Policies.DQN import DQNAgent 
from Policies.DQN_10 import DQNAgent_10
# from Qtable import Qtable
from Policies.Qtable_MP import Qtable
from Policies.Random_Greedy import RandomPolicy, GreedyPolicy
from Policies.GCN_SL import GCN_SL
from Policies.NET import Net
from Policies.GCN_RL import GCN_RL
from Policies.GCN_RL_GEN import GCN_RL_GEN
import csv

import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from multiprocessing import Pool, cpu_count
import numpy as np
import psutil

import pandas as pd
import matplotlib.pyplot as plt

import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--num_nodes", type=int, default=10, help="Number of nodes")
parser.add_argument("--num_steps", type=int, default=1, help="Number steps multiplied by 1000")
parser.add_argument("--agent", type = str, default= "random", help="Agent")
parser.add_argument("--graph_type", type = str, default = "tree_1", help="Type of graph used for solving the problem")
parser.add_argument("--device", type = str, default = "cuda", help="Which device to use")
parser.add_argument("--save", type = bool, default = False, help="Whether to save the model")
parser.add_argument("--uniform_random", type = bool, default = False, help="If probabilities of arrival of requests are uniform")
parser.add_argument("--gamma", type = float, default = 0.99, help="Temporal Discount")
parser.add_argument("--hidden_channels", type = int, default = 128, help="Hidden channels for GCN")


args = parser.parse_args()

agent_mapping = {
    "dqn": DQNAgent,
    "dqn_10": DQNAgent_10,
    "qtable": Qtable,
    "random": RandomPolicy, 
    "greedy": GreedyPolicy,
    "gcn_sl": GCN_SL,
    "gcn_rl": GCN_RL,
    "gcn_rl_gen": GCN_RL_GEN
}

# for num_nodes in [10, 20, 30, 10, 50]:

if args.agent not in agent_mapping:
    raise ValueError("Invalid agent name provided")


if args.agent in ["dqn" , "dqn_10", "gcn_sl", "gcn_rl", "gcn_rl_gen"]:
    batch_size = 512
else:
    batch_size = 1 


num_nodes= args.num_nodes
if num_nodes < 10:
    num_servers = 2
else: 
    num_servers = round(num_nodes/6)

uniform_random = False
constant_probability = False
# print("Number of nodes:", args.num_nodes)
print("Number of nodes:", num_nodes)
print("Number of servers:", num_servers)
print("Number of steps multiplied by 1000:", args.num_steps)
print("Agent:", args.agent)
print("Graph type:", args.graph_type)
print("Device:", args.device)
print("Save:", args.save)
print("Batch Size:", batch_size)
print("Uniform random:", uniform_random)

max_cpu_usage_percent = 50
cpu_usage_percent = psutil.cpu_percent(interval=1)
free_cpus = cpu_count() - int(cpu_count() * (cpu_usage_percent / 100))

print("Number of free CPUs:", free_cpus)




device = args.device
graph_type = args.graph_type 
save = args.save
gamma = args.gamma
hidden_channels = args.hidden_channels
# uniform_random = args.uniform_random

if args.agent == "gcn_rl_gen":
    env = KServerEnv(num_nodes, num_servers, batch_size=batch_size, graph_type = graph_type, device = device, general_model=True, uniform_random = uniform_random, constant_probability = constant_probability)
else: 
    env = KServerEnv(num_nodes, num_servers, batch_size=batch_size, graph_type = graph_type, device = device, uniform_random = uniform_random, constant_probability = constant_probability)
    
agent = agent_mapping[args.agent]
num_steps = args.num_steps



# agent hyperparameters

if args.agent in ["gcn_rl", "gcn_rl_gen"]:
    agent = agent(env, shared_weights = False, gamma = gamma, hidden_channels = hidden_channels)
elif args.agent == "random" or args.agent == "greedy":
    agent = agent(env)
else: 
    agent = agent(env, gamma = gamma, hidden_channels = hidden_channels)

if args.graph_type == "SF":
    print("Since it is SioxFalls graph type:")
    print("Number of nodes:", agent.env.num_nodes)
    print("Number of servers:", agent.env.num_servers)
    num_nodes= agent.env.num_nodes
    num_servers = agent.env.num_servers

# optimizing 
if args.agent == "random" or args.agent == "greedy":
    pass
else: 
    agent.optimize(num_steps, print_results=True)

# estimating
if args.agent == "gcn_rl_gen":
    estimate, q1, q3, _ = agent.estimate_all(40) 
    output_data = agent.estimate(40)  
else:
    estimate, q1, q3, _ = agent.estimate(40)  


if save == True:
    torch.save(agent.q_network.state_dict(), f'results/single_model_results/nonuniform/models/model_{args.agent}_{num_nodes}_{graph_type}__gamma{gamma}_hidch{hidden_channels}_ur{uniform_random}_CP{constant_probability}.pth') 
    output_file_name = f'results/single_model_results/nonuniform/results_{args.agent}_{num_nodes}_{graph_type}__gamma{gamma}_hidch{hidden_channels}_ur{uniform_random}.csv'
    if args.agent == "gcn_rl_gen":
        with open(output_file_name, 'w', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(['graph_type', 'agent', 'gamma','num_nodes', 'estimate', 'q1', 'q3'])
            writer.writerow(['general', args.agent, gamma, num_nodes, round(estimate.item(), 3), round(q1.item(), 3), round(q3.item(), 3)])  
        
            # Writing the output_data to CSV
            for tree, value in output_data.items():
                writer.writerow([tree, args.agent, gamma, num_nodes, round(value.item(), 3), '', ''])
        
    else:  
        with open(output_file_name, 'w', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(['graph_type', 'agent', 'gamma','num_nodes', 'estimate', 'q1', 'q3'])
            writer.writerow([graph_type, args.agent, gamma, num_nodes, round(estimate.item(), 3), round(q1.item(), 3), round(q3.item(), 3)])  
