import sys
import os
import numpy as np
import csv
import torch
import psutil
from multiprocessing import cpu_count
import itertools
import argparse
import pandas as pd

# Adding the path to your environment
sys.path.append('/data/home/ifb5104/K_server_RL')

# Import the necessary classes and functions
from KServerEnv import KServerEnv 
from Policies.DQN_10 import DQNAgent_10
from Policies.DQN_Lins import DQNAgent_Lins
from Policies.Qtable_MP import Qtable
from Policies.Qtable_WQL import Qtable_WQL
from Policies.Random_Greedy import RandomPolicy, GreedyPolicy
from Policies.Harmonic import HarmonicPolicy
from Policies.Balance import BalancePolicy
from Policies.GCN_SL import GCN_SL
from Policies.GCN_RL import GCN_RL
from Policies.GCN_RL_SL import GCN_RL_SL
from Policies.GCN_RL_GEN_ALL import GCN_RL_GEN_ALL
from experiments.train_agent import train
from generate_requests import generate_requests

# Set up argparse to handle command-line arguments
parser = argparse.ArgumentParser()

parser.add_argument("--save", action='store_true', help="Whether to save the model and results")
parser.add_argument("--num_steps", type=int, default=1, help="Number steps multiplied by 1000")
parser.add_argument("--num_nodes", type=int, default=9, help="Number of nodes in the graph")
parser.add_argument("--num_servers", type=int, default=None, help="Number of servers in the graph")
parser.add_argument("--bs", type=int, default=128, help="Batch Size")
parser.add_argument("--hch", type=int, default=128, help="Hidden Channels")
parser.add_argument("--nl", type=int, default=12, help="Number of Layers")
parser.add_argument("--display", action='store_true', help="Display results in Neptune")
parser.add_argument("--lr", type=float, default=0.001, help="Learning Rate")
parser.add_argument("--gamma", type=float, default=0.99, help="Temporal Discount")
parser.add_argument("--print", action='store_true', help="Whether to print the experiments results during the training")
parser.add_argument("--uniform_random", action='store_true', help="If probabilities of arrival of requests are uniform")
parser.add_argument("--parallel", action='store_true', help="If to run the processes in parallel")
parser.add_argument("--var_distance", action='store_true', help="If to set a varying distance between nodes")
parser.add_argument("--tr_dist", action='store_true', help="If to set a trainable distance")
parser.add_argument("--tr_att", action='store_true', help="If to use an attention mechanism for trainable distance")
parser.add_argument("--agent", type=str, default='greedy', help="Which agent")
parser.add_argument("--graph_type", type=str, default='grid_gre_50', help="Which graph type")
parser.add_argument("--gmgt", type=str, default='grid_gre', help="Which graph type for general model")
parser.add_argument("--device", type = str, default = "cpu", help="Which device to use")
parser.add_argument("--dir_graph", action='store_true', help="If it is a directed graph")
parser.add_argument("--test_seq", action='store_true', help="If test a sequence")
parser.add_argument("--optimize", action='store_true', help="Optimize the algorithm")
parser.add_argument("--print_seq", action='store_true', help="Print where it is going to be tested")
parser.add_argument("--bn", action='store_false', help="Don't use batch normalization")
parser.add_argument("--cp", action='store_true', help="Constant Probability")

args = parser.parse_args()

# Agent mapping
agent_mapping = {
    "dqn": DQNAgent_10,
    "dqn_lins": DQNAgent_Lins,
    "qtable": Qtable,
    "random": RandomPolicy, 
    "greedy": GreedyPolicy,
    "gcn_sl": GCN_SL,
    "GCN_RL": GCN_RL,
    "qtable_wql": Qtable_WQL,
    "gcn_rl_sl": GCN_RL_SL,
    "GCN_RL_GEN_ALL": GCN_RL_GEN_ALL,
    "balance": BalancePolicy,
    "harmonic": HarmonicPolicy
}

# Load the DataFrame with all combinations
# Assuming df is already defined or loaded from a CSV file

# Exclude columns that are not part of the combinations
columns_to_exclude = ['mean_estimate_div_optimal', 'mean_greedy_div_optimal', 'tr_dist', 
                      'sd_estimate_div_optimal', 'num_params', 'sd_greedy_div_optimal']

# Get unique combinations

unique_combinations = pd.read_csv('/home/ifb5104/K_server_RL/experiments/parameter_combinations.csv')

# Convert DataFrame to list of dictionaries (each dict represents a combination of parameters)
combinations = unique_combinations.to_dict('records')




def ep_estimates_gr_opton(Agent, env, num_sequences, save = True):

     agent = Agent(env)
     estimates = []
     agent_name = str(Agent).split("'")[1].split(".")[-1]
     for i in range(num_sequences):
          print(f'sequence{i} started')
          estimate = agent.estimate_seq(state[i].unsqueeze(0), requests[i].unsqueeze(0))[0]
          estimates.append((estimate/sequence_len).item())
          # print(f'sequence{i} finished')
     
     if save:
          if not os.path.exists(f'results/gen_testing/VD{args.var_distance}/{agent_name}'):  
               # print(True)
               os.makedirs(f'results/gen_testing/VD{args.var_distance}/{agent_name}')
          output_file_name = f'results/gen_testing/VD{args.var_distance}/{agent_name}/results_{args.graph_type}_{agent_name}_{env.num_nodes}_{env.num_servers}.csv'
          np_array = np.array(estimates)
          np.savetxt(output_file_name, np_array, delimiter=',')

# Loop over each combination and run experiments
for params in combinations:
    # Set each parameter from the combination
    args.agent = params['agent']
    args.batch_size = params['batch_size']
    args.num_layers = params['num_layers']
    args.graph_type = params['graph_type']
    args.num_nodes = params['num_nodes']
    args.num_servers = params['num_servers']
    # args.var_distance = params['var_distance']
    args.dir_graph = params['dir_graph']
    args.gamma = params['gamma']
    args.lr = params['lr']
    args.hidden_channels = params['hidden_channels']
    args.cp = params['CP']
    args.tr_att = params['tr_att']

    # Print current combination for tracking
    print(f"Running experiment with combination: {params}")

    # Initialize and run the agent as per your script
    Agent = agent_mapping[args.agent]

    # Handle general_model_gt for GCN_RL_GEN_ALL
    if Agent == GCN_RL_GEN_ALL:
        if args.gmgt not in ['dir_check', 'dir_check_1']:
            general_model_gt = args.graph_type[:-3]
        else:
            general_model_gt = args.gmgt

    # Reinitialize the environment for each experiment
    env = KServerEnv(num_nodes=args.num_nodes, num_servers=args.num_servers, batch_size=args.batch_size, 
                     graph_type=args.graph_type, device=args.device, uniform_random=args.uniform_random,
                     var_distance=args.var_distance, request_same_node=True, arrival_rates=True)
    
    # Adjust batch size for non-DQN agents
    if args.agent not in ["dqn", "dqn_lins", "gcn_sl", "gcn_rl", "gcn_rl_gen", "gcn_rl_sl", "GCN_RL", "GCN_RL_GEN_ALL"]:
        args.batch_size = 1  
    
    if args.optimize:
        if Agent in [GCN_RL, GCN_RL_SL]:
            agent = Agent(env, hidden_channels=args.hidden_channels, num_layers=args.nl, lr=args.lr, 
                          shared_weights=False, gamma=args.gamma, tr_dist=args.tr_dist, tr_att=args.tr_att, 
                          dir_graph=args.dir_graph, use_batch_norm=args.bn)
            train(agent, args.num_steps, print_results=args.print, save_results=args.save, display_results=args.display)
        elif Agent == GCN_RL_GEN_ALL:
            agent = GCN_RL_GEN_ALL(hidden_channels=args.hidden_channels, num_layers=args.nl, lr=args.lr, 
                                   shared_weights=False, general_model_gt=general_model_gt, batch_size=args.batch_size, 
                                   uniform_random=args.uniform_random, constant_probability=args.cp, arrival_rates=True, 
                                   request_same_node=True, gamma=args.gamma, var_pr_ep=True, var_pr_ep_steps=30, 
                                   use_batch_norm=args.bn, var_distance=args.var_distance, dir_graph=args.dir_graph, 
                                   tr_dist=args.tr_dist, tr_att=args.tr_att, num_nodes=args.num_nodes, num_servers=args.num_servers)
            agent.optimize(args.num_steps, print_results=args.print, save_model=args.save, save_results=args.save, 
                           display_results=args.display)
        elif Agent in [DQNAgent_10, DQNAgent_Lins]:
            agent = Agent(env, seed=42, lr=0.01, gamma=args.gamma)
            agent.optimize(args.num_steps, print_results=args.print, save_results=args.save, display_results=args.display)
        elif Agent == GreedyPolicy or Agent == RandomPolicy:
            agent = Agent(env)
            print(f'{args.agent} estimate is {agent.estimate(100, print_results=args.print)[0].item()}')
        elif Agent == Qtable:
            agent = Qtable(env, lr=0.1, gamma=args.gamma)
            agent.optimize(args.num_steps, print_results=args.print, save_results=args.save, display_results=args.display)
        elif args.agent == "qtable_wql":
            agent = Qtable_WQL(env, lr=0.1, gamma=args.gamma)
            agent.compare_policies()

    if args.test_seq:
        num_sequences = 10
        sequence_len = 4000
        requests, state = generate_requests(env, 42, sequence_len, num_sequences=num_sequences)

        print(args.batch_size)

        if args.print_seq:
            # Function to find the most frequent element in a 1D tensor
            def most_frequent_element(tensor_row):
                return torch.mode(tensor_row).values.item()

            # Apply the function to each row of the 2D tensor
            most_frequent_elements = [most_frequent_element(row) for row in requests]

            print("The most frequent elements for each row are:", most_frequent_elements)
            index_of_largest_value = env.probabilities.index(max(env.probabilities))
            print("The index of the largest value is:", index_of_largest_value)
            print(state)

        agent_name = str(Agent).split("'")[1].split(".")[-1]
        if Agent in [GCN_RL, GCN_RL_SL, DQNAgent_10, DQNAgent_Lins, GCN_RL_GEN_ALL]:
            env = KServerEnv(num_nodes=args.num_nodes, num_servers=args.num_servers, batch_size=num_sequences, 
                             graph_type=args.graph_type, device=args.device, uniform_random=args.uniform_random, 
                             request_same_node=True, arrival_rates=True, var_distance=args.var_distance)
            
            if Agent in [DQNAgent_10, DQNAgent_Lins]:
                agent = Agent(env, seed=42, lr=0.01, gamma=args.gamma)
            elif Agent == GCN_RL_GEN_ALL:
                tot_params = GCN_RL_GEN_ALL(hidden_channels=args.hidden_channels, num_layers=args.nl, lr=0.001, 
                                            shared_weights=False, general_model_gt=general_model_gt, gamma=args.gamma, 
                                            var_distance=args.var_distance, tr_dist=args.tr_dist, tr_att=args.tr_att, 
                                            dir_graph=args.dir_graph, use_batch_norm=args.bn, constant_probability=args.cp).total_params()
                agent = GCN_RL(env, hidden_channels=args.hidden_channels, num_layers=args.nl, lr=0.001, 
                               shared_weights=False, gamma=args.gamma, tr_dist=args.tr_dist, tr_att=args.tr_att, 
                               dir_graph=args.dir_graph, use_batch_norm=args.bn)
            else:
                agent = Agent(env, hidden_channels=args.hidden_channels, num_layers=args.nl, lr=0.001, 
                              shared_weights=False, gamma=args.gamma, tr_dist=args.tr_dist, tr_att=args.tr_att, 
                              dir_graph=args.dir_graph, use_batch_norm=args.bn)

            if Agent in [GCN_RL, GCN_RL_SL]:
                agent.q_network.load_state_dict(torch.load(f'results/gen_testing/VD{args.var_distance}/{agent_name}/train_results/models/model_{args.graph_type}_{agent_name}_{args.num_nodes}_{args.num_servers}_hidden_channels{args.hidden_channels}__gamma{args.gamma}_vd{args.var_distance}_td{args.tr_dist}_ta{args.tr_att}_seed42_DG{args.dir_graph}_bs{args.batch_size}_lr{args.lr}_nl{args.nl}.pth'))
            elif Agent == GCN_RL_GEN_ALL:
                agent.q_network.load_state_dict(torch.load(f'results/gen_testing/VD{args.var_distance}/{agent_name}/train_results/models/model_{general_model_gt}_CP{args.cp}_{agent_name}_{args.num_nodes}_{args.num_servers}_hidden_channels{args.hidden_channels}__gamma{args.gamma}_vd{args.var_distance}_td{args.tr_dist}_ta{args.tr_att}_seed42_DG{args.dir_graph}_bs{args.batch_size}_lr{args.lr}_nl{args.nl}.pth'))
            else:
                agent.q_network.load_state_dict(torch.load(f'results/gen_testing/VD{args.var_distance}/{agent_name}/train_results/models/model_{args.graph_type}_{agent_name}_{args.num_nodes}_{args.num_servers}_gamma{args.gamma}_vd{args.var_distance}.pth'))
            
            # load or compute the results for Greedy and Optimal Online for these episodes
            try:
                opt_estimate = np.loadtxt(f'results/gen_testing/VD{args.var_distance}/Qtable_WQL/results_{args.graph_type}_Qtable_WQL_{args.num_nodes}_{args.num_servers}.csv', delimiter=',')
                greedy_estimate = np.loadtxt(f'results/gen_testing/VD{args.var_distance}/GreedyPolicy/results_{args.graph_type}_GreedyPolicy_{args.num_nodes}_{args.num_servers}.csv', delimiter=',')
            except:
                print('Estimates for Greedy and Optimal Online have to be computed...')
                env = KServerEnv(num_nodes=args.num_nodes, num_servers=args.num_servers, batch_size=1, graph_type=args.graph_type, 
                                 device=args.device, uniform_random=args.uniform_random, request_same_node=True, 
                                 arrival_rates=True, var_distance=args.var_distance)
                for Agent in [GreedyPolicy, Qtable_WQL]:
                    ep_estimates_gr_opton(Agent, env, num_sequences)
                opt_estimate = np.loadtxt(f'results/gen_testing/VD{args.var_distance}/Qtable_WQL/results_{args.graph_type}_Qtable_WQL_{args.num_nodes}_{args.num_servers}.csv', delimiter=',')
                greedy_estimate = np.loadtxt(f'results/gen_testing/VD{args.var_distance}/GreedyPolicy/results_{args.graph_type}_GreedyPolicy_{args.num_nodes}_{args.num_servers}.csv', delimiter=',')

            # estimate = agent.estimate_seq(state, requests)[3]
            print(f'Estimates for {args.agent} are starting to be computed')

            if args.agent in ['gcn_rl_gen', 'dqn', 'dqn_lins', 'gcn_rl', 'GCN_RL', 'GCN_RL_GEN_ALL']:
                estimates = agent.estimate_seq(state, requests)[3]
                mean_estimate = torch.mean(estimates, dim=1)
                mean_estimate = mean_estimate.unsqueeze(1)
                transposed_mean_estimate = mean_estimate.transpose(0, 1)
            else:
                estimates = []
                for i in range(num_sequences):
                    print(f'sequence{i} started')
                    estimate = agent.estimate_seq(state[i].unsqueeze(0), requests[i].unsqueeze(0))[0]
                    estimates.append(estimate)
                estimates = torch.stack(estimates).to(args.device)
                mean_estimate = estimates / sequence_len
                mean_estimate = mean_estimate.unsqueeze(1)
                transposed_mean_estimate = mean_estimate.transpose(0, 1)

            # Convert the NumPy array back to a list of scalar tensors
            opt_estimate = torch.tensor(opt_estimate).to(args.device)
            opt_estimate = opt_estimate.view(1, -1)
            greedy_estimate = torch.tensor(greedy_estimate).to(args.device)
            greedy_estimate = greedy_estimate.view(1, -1)
            diff = transposed_mean_estimate / opt_estimate
            diff_greedy = greedy_estimate / opt_estimate

            # Extract the required statistics
            mean_estimate_div_optimal = round(diff.mean().item(), 3)
            mean_greedy_div_optimal = round(diff_greedy.mean().item(), 3)
            sd_estimate_div_optimal = round(diff.std().item(), 3)
            sd_greedy_div_optimal = round(diff_greedy.std().item(), 3)

            # Print the statistics
            print('Mean of Estimate Divided by Optimal Online:', mean_estimate_div_optimal)
            print('Mean of Greedy Divided by Optimal Online:', mean_greedy_div_optimal)
            print('SD of Estimate Divided by Optimal Online:', sd_estimate_div_optimal)
            print('SD of Greedy Divided by Optimal Online:', sd_greedy_div_optimal)

            diff = torch.round(diff, decimals=3)
            print('Estimate:', transposed_mean_estimate)
            print('Estimate Divided by Optimal Online:', diff.tolist())

            if args.save:
                if args.agent not in ['gcn_rl_gen', 'GCN_RL_GEN_ALL']:
                    tot_params = agent.total_params()
                print(tot_params)
                if not os.path.exists(f'results/gen_testing/VD{args.var_distance}/estimate_seq'):
                    os.makedirs(f'results/gen_testing/VD{args.var_distance}/estimate_seq')
                output_file_name = f'results/gen_testing/VD{args.var_distance}/estimate_seq/{agent.env.graph_type}_{agent_name}_{agent.env.num_nodes}_{agent.env.num_servers}_hidden_channels{args.hidden_channels}__gamma{agent.gamma}_vd{agent.var_distance}_td{args.tr_dist}_seed42_DG{args.dir_graph}_bs{args.batch_size}_lr{args.lr}_nl{args.nl}_CP{args.cp}.csv'
                # Save to CSV
                if not os.path.exists(output_file_name):
                    with open(output_file_name, 'w', newline='') as f:
                        writer = csv.writer(f)
                        writer.writerow(['graph_type', 'agent', 'gamma', 'lr', 'num_nodes', 'num_servers', 'hidden_channels', 'num_layers', 'batch_size','CP', 'dir_graph', 'var_distance', 'tr_dist', 'tr_att', 'num_params', 'mean_estimate_div_optimal', 'mean_greedy_div_optimal', 'sd_estimate_div_optimal', 'sd_greedy_div_optimal'])
                        writer.writerow([agent.env.graph_type, agent_name, agent.gamma, agent.lr, agent.env.num_nodes, agent.env.num_servers, args.hidden_channels, agent.num_layers, args.batch_size, args.cp, args.dir_graph, args.var_distance, args.tr_dist, args.tr_att, tot_params, mean_estimate_div_optimal, mean_greedy_div_optimal, sd_estimate_div_optimal, sd_greedy_div_optimal])
        else:
            ep_estimates_gr_opton(Agent, env, num_sequences)