import sys
sys.path.append('/data/home/ifb5104/K_server_RL')
from concurrent.futures import thread
from json.tool import main
import multiprocessing
from KServerEnv import KServerEnv 
from Policies.DQN import DQNAgent 
from Policies.DQN_10 import DQNAgent_10
# from Qtable import Qtable
from Policies.Qtable_MP import Qtable
from Policies.Random_Greedy import RandomPolicy, GreedyPolicy
from Policies.GCN_SL import GCN_SL
from Policies.GCN_RL import GCN_RL
from Policies.NET import Net
import csv
import os 
from Policies.GCN_RL_GEN import GCN_RL_GEN
from Policies.Balance import BalancePolicy
from Policies.Harmonic import HarmonicPolicy
from Policies.WFA import WorkFunction
from generate_requests import generate_requests

import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from multiprocessing import Pool, cpu_count
from multiprocessing.pool import ThreadPool
import numpy as np
import psutil

import pandas as pd
import matplotlib.pyplot as plt

import glob

import argparse

import time

import os
import re

import neptune




def train(agent, num_steps=200, estimate_steps =20, epsilon_decay = False, explr = 0.6, display_results = False, print_results = False, decay_rate = 0.0005, failsafe = False, save_results = False):
         
        if save_results:
            if not os.path.exists(f'results/gen_testing/VD{agent.var_distance}/{agent.class_name}'):  
                os.makedirs(f'results/gen_testing/VD{agent.var_distance}/{agent.class_name}')
            if not os.path.exists(f'results/gen_testing/VD{agent.var_distance}/{agent.class_name}/train_results'):
                os.makedirs(f'results/gen_testing/VD{agent.var_distance}/{agent.class_name}/train_results')  
            if not os.path.exists(f'results/gen_testing/VD{agent.var_distance}/{agent.class_name}/train_results/models'):
                os.makedirs(f'results/gen_testing/VD{agent.var_distance}/{agent.class_name}/train_results/models')  
            if not os.path.exists(f'results/gen_testing/VD{agent.var_distance}/{agent.class_name}/train_results/raw_results'):
                os.makedirs(f'results/gen_testing/VD{agent.var_distance}/{agent.class_name}/train_results/raw_results')
            if not os.path.exists(f'results/gen_testing/VD{agent.var_distance}/train_curves'):
                os.makedirs(f'results/gen_testing/VD{agent.var_distance}/train_curves')

        file_paths = ['results/gen_testing/VD{agent.var_distance}/{agent.class_name}/train_results/results_{agent.env.graph_type}_{agent.class_name}_{agent.env.num_nodes}_hidden_channels{agent.hidden_channels}__gamma{agent.gamma}.csv']

        if any(os.path.exists(file_path) for file_path in file_paths):
            print(f'Skipping training, as one or more result files exist.')
        
        else: 
            print(f'Experiment  started')   

            state = agent.env.reset()
            steps_for_display = int(10000/agent.batch_size)
            # steps_for_display = 1
            num_steps = int(num_steps*1000/agent.batch_size)
            # num_steps = 10

            estimate_steps = int(estimate_steps*1000/agent.batch_size)

            if display_results:
                agent.run = neptune.init_run(
                    project="iliyasbektas/kserver",
                    api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJiOTRhNmFlNi0xMzU0LTRiNGUtODZmYy05ZWQyMDA4ZjJiZDQifQ==",
                )  # your credentials 
                agent.run["agent"] = agent.class_name
                agent.run["num_nodes"] = agent.env.num_nodes
                agent.run["graph_type"] = agent.env.graph_type
                agent.run["gamma"] = agent.gamma
                agent.run["vd"] = agent.var_distance
                agent.run["td"] = agent.tr_dist
                agent.run["ta"] = agent.tr_att
                agent.run["seed"] = agent.seed
            
            initial_percentage = explr  
            initial_limit = int(num_steps * initial_percentage)  
            
            # num_sequences = 10
            # requests, state_init = generate_requests(agent.env, 42, 4000, num_sequences=num_sequences)

            # # Function to find the most frequent element in a 1D tensor
            # def most_frequent_element(tensor_row):
            #     return torch.mode(tensor_row).values.item()

            # # Apply the function to each row of the 2D tensor
            # most_frequent_elements = [most_frequent_element(row) for row in requests]

            # print("The most frequent elements for each row are:", most_frequent_elements)
            # index_of_largest_value = agent.env.probabilities.index(max(agent.env.probabilities))
            # print("The index of the largest value is:", index_of_largest_value)
            # # print(env.probabilities)
            # print(state_init)
            tot_params = agent.total_params()
            best_estimate = float('-inf')
            start_time = time.time()
            for step in range(num_steps):
                if epsilon_decay:
                    epsilon = agent.min_epsilon + (agent.max_epsilon - agent.min_epsilon)*np.exp(-decay_rate*step)     
                else:
                    if step < initial_limit:
                        epsilon = 0.5
                    else: 
                        epsilon = 0.1
                if agent.constant_probability:
                    action = agent.get_action(state.to(agent.device), 0, failsafe = failsafe).to(agent.device)
                    # print(state.size())
                    # print(action.size())
                    next_state, reward, _ = agent.env.step(action, state.to(agent.device))
                    agent.remember(state.to(agent.device),
                    action,
                    reward.to(agent.device),
                    next_state.to(agent.device)
                    )
                else: 
                    node_pbs = torch.FloatTensor(agent.env.probabilities).view(1, -1).to(agent.device)
                    action = agent.get_action(state.to(agent.device), epsilon, failsafe = failsafe).to(agent.device)
                    
                    next_state, reward, _ = agent.env.step(action, state.to(agent.device))
                    node_pbs_next = torch.FloatTensor(agent.env.probabilities).view(1, -1).to(agent.device)     
                    agent.remember(state.to(agent.device),
                    action,
                    reward.to(agent.device),
                    next_state.to(agent.device),
                    node_pbs,
                    node_pbs_next
                    )
                
                
                agent.update()
                # agent.print_network_weights()
                
                state = next_state
                agent.total_reward.append(reward)

        #               start_time = time.time()

        # for i in range(num_requests): 
        #     # Print progress every 10%
        #     if i % progress_interval == 0 and i != 0:
        #         elapsed_time = time.time() - start_time
        #         print(f"{(i / num_requests) * 100:.0f}% of requests finished and it took {elapsed_time:.2f} seconds")
                
                # print(f"Step {step+1}, Epsilon {epsilon:.2f}, Average Reward {torch.mean(agent.total_reward[agent.env.num_servers:]):.2f}, Estimate {agent.estimate(estimate_steps)[0]:.2f}")
                
                if ((step+1)  % steps_for_display == 0):
                    if print_results:
                        elapsed_time = time.time() - start_time
                        start_time = time.time()
                        step_estimate = agent.estimate(estimate_steps)
                        average_reward = torch.mean(torch.cat(agent.total_reward[(-estimate_steps*agent.batch_size):]).view(-1))
                        step1000 = agent.round_to_nearest_10000((step+1)*agent.batch_size)
                        print(f"Step {step1000}, Epsilon {epsilon:.2f}, Average Reward {average_reward:.2f}, Estimate {step_estimate[0]:.2f}, Time Taken: {elapsed_time:.2f} seconds")

                    if display_results:
                        if print_results:
                            agent.run["Estimate"].append(step_estimate[0]) 
                            agent.run["average_reward"].append(average_reward)
                        else:
                            step_estimate = agent.estimate(estimate_steps)
                            average_reward = torch.mean(torch.cat(agent.total_reward[(-estimate_steps*agent.batch_size):]).view(-1))
                            agent.run["average_reward"].append(average_reward)
                            agent.run["Estimate"].append(step_estimate[0]) 
                        # if agent.step > int(num_steps * 0.95): 
                        #     agent.run.stop()
                    if save_results:
                        if agent.gen:
                            pass
                        else: 
                            if print_results == True or display_results ==True:
                                estimate, _, _, raw_result = step_estimate
                            else: 
                                average_reward = torch.mean(torch.cat(agent.total_reward[(-estimate_steps*agent.batch_size):]).view(-1))
                                step_estimate = agent.estimate(estimate_steps)
                                estimate, _, _, raw_result = step_estimate
                                step1000 = agent.round_to_nearest_10000((step+1)*agent.batch_size)

                            # print(f"Current estimate: {estimate.item()}, Best estimate: {best_estimate}")
                                
                            if estimate.item() > best_estimate: 
                                torch.save(agent.q_network.state_dict(), f'results/gen_testing/VD{agent.var_distance}/{agent.class_name}/train_results/models/model_{agent.env.graph_type}_{agent.class_name}_{agent.env.num_nodes}_{agent.env.num_servers}_hidden_channels{agent.hidden_channels}__gamma{agent.gamma}_vd{agent.var_distance}_td{agent.tr_dist}_ta{agent.tr_att}_seed{agent.seed}_DG{agent.dir_graph}_bs{agent.batch_size}_lr{agent.lr}_nl{agent.num_layers}.pth')
                                best_estimate = estimate.item()
                            
                            # print(f"Best estimate: {best_estimate}")

                    
                            # save the train curve
                            train_curve = f'results/gen_testing/VD{agent.var_distance}/train_curves/train_curve_{agent.env.graph_type}_{agent.class_name}_{agent.env.num_nodes}_{agent.env.num_servers}_hidden_channels{agent.hidden_channels}__gamma{agent.gamma}_vd{agent.var_distance}_td{agent.tr_dist}_ta{agent.tr_att}_seed{agent.seed}_DG{agent.dir_graph}_bs{agent.batch_size}_lr{agent.lr}_nl{agent.num_layers}_ar.csv'

                            if not os.path.exists(train_curve):
                                with open(train_curve, 'w', newline='') as f:
                                    writer = csv.writer(f)
                                    writer.writerow(['step', 'graph_type', 'agent', 'gamma', 'lr', 'num_nodes', 'num_servers', 'seed', 'hidden_channels', 'num_layers', 'VD', 'tr_dist', 'tr_att', 'DG', 'batch_size', 'tot_params','average_reward', 'estimate'])
                                    writer.writerow([step1000, agent.env.graph_type, agent.class_name, agent.gamma, agent.lr, agent.env.num_nodes, agent.env.num_servers, agent.seed, agent.hidden_channels, agent.num_layers, agent.var_distance, agent.tr_dist, agent.tr_att, agent.dir_graph, agent.batch_size, tot_params, round(average_reward.item(), 3), round(estimate.item(), 3)]) 
                            else: 
                                with open(train_curve, 'a', newline='') as f:  # Open in append mode
                                    writer = csv.writer(f)
                                    # Append new row to CSV file
                                    writer.writerow([step1000, agent.env.graph_type, agent.class_name, agent.gamma, agent.lr, agent.env.num_nodes, agent.env.num_servers, agent.seed, agent.hidden_channels, agent.num_layers, agent.var_distance, agent.tr_dist, agent.tr_att, agent.dir_graph, agent.batch_size, tot_params, round(average_reward.item(), 3), round(estimate.item(), 3)]) 


                            # save the last estimate

                            output_file_name_raw = f'results/gen_testing/VD{agent.var_distance}/{agent.class_name}/train_results/raw_results/results_{agent.env.graph_type}_{agent.class_name}_{agent.env.num_nodes}_{agent.env.num_servers}_hidden_channels{agent.hidden_channels}__gamma{agent.gamma}_vd{agent.var_distance}_td{agent.tr_dist}_ta{agent.tr_att}_seed{agent.seed}_DG{agent.dir_graph}_bs{agent.batch_size}_lr{agent.lr}_nl{agent.num_layers}_raw.csv'
                            with open(output_file_name_raw, 'w', newline='') as f:
                                    writer = csv.writer(f)
                                    # writer.writerow([raw_result])
                                    writer.writerow(raw_result.tolist())  
                        
            if display_results:
                agent.run.stop()
                
            if print_results:
                try: 
                    print(f"Step {agent.round_to_nearest_10000((step+1)*agent.batch_size)}, Epsilon {epsilon:.2f}, Average Reward {torch.mean(agent.total_reward[agent.env.num_servers:]):.2f}, Estimate {agent.estimate(estimate_steps)[0]:.2f}")
                except: 
                    # print(f"Step {step+1}, Average Reward {torch.mean(agent.total_reward[agent.env.num_servers:]):.2f}, Estimate {agent.estimate(estimate_steps)[0]:.2f}")
                    print(f"Step {agent.round_to_nearest_10000((step+1)*agent.batch_size)}, Average Reward {average_reward:.2f}, Estimate {step_estimate[0]:.2f}")
            
        