
import torch.nn.functional as F
import random
from collections import deque
import torch
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from torch_geometric.data import Data
import torch_geometric
from Policies.NET import Net
import neptune
from KServerEnv import KServerEnv 
import csv
from Policies.GraphDataset import MyData
from torch_geometric.utils import to_dense_adj
import os


class GCN_RL_GEN_ALL():

    def __init__(self, seed = 42, gamma=0.99, lr=0.001, num_layers = 12, memory_size=10000, device = 'cpu', hidden_channels = 128, shared_weights = False, general_model_gt = 'tree', uniform_random = False, constant_probability = True, var_pr_ep = False, var_pr_ep_steps = 40, request_same_node = False, arrival_rates = False, use_batch_norm=False, var_distance = False, tr_dist = False, tr_att = False, dir_graph = False, batch_size=512, num_nodes = None, num_servers = None):

        seed = seed
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

        self.seed = seed
        self.hidden_channels = hidden_channels
        self.general_model_gt = general_model_gt
        self.var_distance = var_distance 
        self.tr_dist = tr_dist
        self.tr_att = tr_att
        self.dir_graph = dir_graph
        if self.var_distance == False and self.tr_dist == True: 
            raise ValueError("Unfeasible configuration")
        if self.tr_dist == False and self.tr_att == True: 
            raise ValueError("Unfeasible configuration")
        self.env_list = []
        if self.general_model_gt == 'SF': 
            self.num_nodes_list = [24] 
        elif self.general_model_gt == 'EM': 
            self.num_nodes_list = [74]
        # elif self.general_model_gt == 'dir_check':
        else: 
            # self.num_nodes_list = [9, 16, 25, 36, 49, 64, 81, 100]
            # self.num_nodes_list = [36, 64, 100]
            self.num_nodes_list = [9, 25, 36, 64]
            # self.num_nodes_list = [49, 81]
        self.num_nodes = num_nodes 
        if self.num_nodes != None: 
            self.num_nodes_list = [num_nodes]
            self.num_servers = num_servers
            # self.num_nodes_list = [9, 16, 25, 36, 49, 64, 81, 100]
        # print(self.num_nodes_list)
        self.uniform_random = uniform_random
        self.constant_probability = constant_probability
        self.use_batch_norm = use_batch_norm
        self.var_pr_ep = var_pr_ep
        self.arrival_rates = arrival_rates
        self.device = device
        if self.var_pr_ep: 
            constant_probability_env = True            
        else:
            constant_probability_env = False
        
        self.var_pr_ep_steps = var_pr_ep_steps
        for num_nodes in self.num_nodes_list:
            if num_servers != None: 
                num_servers = num_servers
            else: 
                num_servers = round(num_nodes/6)
                # num_servers = 4
            self.env_list.append(KServerEnv(num_nodes, num_servers, batch_size=batch_size, device = self.device, general_model=True, general_model_gt = self.general_model_gt, uniform_random = self.uniform_random, constant_probability = constant_probability_env, request_same_node = request_same_node, var_distance=self.var_distance))
        self.env_list_test = self.env_list.copy()
        self.env = self.env_list[-1]
        
        if self.env.general_model == False:
            raise ValueError("The environment general_model argument must be True")
        if self.uniform_random == True and self.constant_probability== False: 
            raise ValueError("Unfeasible configuration")
        self.lr = lr
        self.batch_size = self.env.batch_size
        self.memory_list = []
        for num_nodes in self.num_nodes_list:
            self.memory_list.append(deque(maxlen=memory_size))
        # self.uniform_random = self.env.uniform_random 
        # self.num_layers = int(abs(torch.quantile(GreedyPolicy(KServerEnv(num_nodes=self.env.num_nodes, num_servers = self.env.num_servers, graph_type = 'line', device=self.device)).estimate(20)[3], 0.1, interpolation='lower')))
        self.num_layers = num_layers
        self.gamma = gamma 
        # Initialize action-value funciton Q
        if self.uniform_random:
            self.in_channells = 2
        else: 
            self.in_channells = 3
        self.q_network = Net(in_channels=self.in_channells, hidden_channels = hidden_channels, num_layers = self.num_layers, shared_weights = shared_weights, use_batch_norm = use_batch_norm, var_distance = self.var_distance, tr_dist = tr_dist, tr_att = tr_att, dir_graph = self.dir_graph).to(self.device)
                      #  Net(in_channels=self.in_channells, hidden_channels = hidden_channels, out_channels=self.env.num_nodes, num_layers = (self.num_layers - 2), shared_weights = shared_weights).to(self.device)
        # Initialize target action-value function Q'
        self.target_network = Net(in_channels=self.in_channells, hidden_channels = hidden_channels, num_layers = self.num_layers , shared_weights = shared_weights, use_batch_norm = use_batch_norm, var_distance = self.var_distance, tr_dist = tr_dist, tr_att = tr_att, dir_graph = self.dir_graph).to(self.device)
        # every
        self.target_network.load_state_dict(self.q_network.state_dict())
        # optimizer
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=self.lr) 
        # total rewards for separate sets 
        self.total_reward = torch.empty((self.env.num_servers)).to(self.device) 
        self.total_reward_estimate = torch.empty((self.env.num_servers)).to(self.device) 
   
    def print_network_weights(self):
        network_weights = self.q_network.state_dict()
        # Print the weights of the network
        for name, param in network_weights.items():
            print(f"Layer: {name}\nWeights: {param}")

    def total_params(self):
        return sum(p.numel() for p in self.q_network.parameters())

    # def cm_var_dist(self, env):
    #     # Assuming you have a 9x9 matrix `original_matrix`


    #     # Assuming you have a `cost_matrix` of size `env.num_nodes x env.num_nodes`
    #     original_matrix = env.cost_matrix

    #     # Create a zero matrix of appropriate size
    #     fin_matrix = torch.zeros(env.num_nodes * env.batch_size, env.num_nodes * env.batch_size)

    #     # Fill the diagonals with the repeated matrix
    #     for i in range(env.batch_size):
    #         fin_matrix[i * env.num_nodes:(i + 1) * env.num_nodes, i * env.num_nodes:(i + 1) * env.num_nodes] = original_matrix

    #     return fin_matrix
    

    def node_pbs_formation(self, state_batch, env): 
        
        _, graph_indices = state_batch
        node_pbs = torch.zeros(env.batch_size, env.num_nodes).to(self.device)
        for i in range(env.batch_size):
            node_pbs[i] = env.graph_pbs[graph_indices[i]].view(1, -1)
        return node_pbs
        
    
    def observation_formation(self, state_batch, env, node_pbs = 1,  constant_probability = True): 
        state, graph_indices = state_batch
        if self.uniform_random: 
            X = torch.zeros(env.batch_size, 2, env.num_nodes).to(self.device)
            for i in range(env.batch_size):
                # use row i of qt_index as an index to extract a row from q_values
                X[i][0][state[i][:-1].long()] = 1
                X[i][1][state[i][-1].long()] = 1
        
        else: 
            X = torch.zeros(env.batch_size, 3, env.num_nodes).to(self.device)
            edge_index_list = []
            data_list = []
            for i in range(env.batch_size):
                X[i][0][state[i][:-1].long()] = 1
                X[i][1][state[i][-1].long()] = 1
                # X[i][2] = env.graph_pbs[graph_indices[i]].view(1, -1)

                if constant_probability:
                    X[i][2] = env.graph_pbs[graph_indices[i]].view(1, -1)
                else: 
                    X[i][2] = node_pbs[i]
                
                if self.arrival_rates: 
                    X[i][2] = X[i][2]* env.num_nodes
                
                if self.dir_graph:
                    im = torch.eye(env.num_nodes).to(self.device)
                    if self.var_distance: 
                        # cm = env.cost_matrices[graph_indices[i].item()].unsqueeze(0)
                        cm = env.cost_matrices_nn[graph_indices[i].item()].unsqueeze(0)
                        edge_index = cm + im 
                        if self.tr_dist == False:
                            edge_index = 1/edge_index
                        # original_tensor = 1/env.cost_matrices[graph_indices[i].item()].unsqueeze(0)
                        # edge_index = torch.where(torch.isinf(original_tensor), torch.zeros_like(original_tensor), original_tensor)
                    else:
                        edge_index = env.graph_edge_indices[graph_indices[i].item()]
                        edge_index = to_dense_adj(edge_index).to(self.device)
                        edge_index = edge_index + im 

                    edge_index_list.append(edge_index.squeeze(0))
                            
            
                else: 
                    x = X[i]
                    if self.var_distance: 
                        # Flatten the cost matrix into a 1D tensor
                        edge_index = env.graph_edge_indices[graph_indices[i].item()]
                        cost_matrix = env.cost_matrices[graph_indices[i].item()]

                        cost_vector = cost_matrix.view(-1)

                        # Compute the linear indices corresponding to the edges in the edge_index
                        linear_indices = edge_index[0] * cost_matrix.size(1) + edge_index[1]

                        # Extract edge weights from the cost vector using linear indices
                        edge_weight = cost_vector[linear_indices]

                        res = Data(x = x.T, edge_index = edge_index, edge_weight = edge_weight)
                    
                    else:
                        res = Data(x = x.T, edge_index = env.graph_edge_indices[graph_indices[i].item()]) # here we have to feed a particular edge_index
                    
                    data_list.append(res)
                    

                
        if self.dir_graph:
            
            X = torch.transpose(X, -1, -2)
            
            edge_index = torch.stack(edge_index_list, dim=0)
            # print(edge_index.size())
    
            # edge_index = edge_index.repeat(self.batch_size, 1, 1)
            data = MyData(X, edge_index)
            
            return data 
        else: 

            train_loader = torch_geometric.loader.DataLoader(data_list, batch_size=self.batch_size, shuffle=False)
            data = next(iter(train_loader))

            return data

          
        # if self.dir_graph:
            
        #     X = torch.transpose(X, -1, -2)
        #     edge_index_list = []
        #     for i in range(self.batch_size):
        #         if self.var_distance: 
        #             original_tensor = 1/env.cost_matrices[graph_indices[i].item()].unsqueeze(0)
        #             edge_index = torch.where(torch.isinf(original_tensor), torch.zeros_like(original_tensor), original_tensor)
        #         else:
        #             edge_index = env.graph_edge_indices[graph_indices[i].item()]
        #             edge_index = to_dense_adj(edge_index)

        #         edge_index_list.append(edge_index.squeeze(0))
            
        #     edge_index = torch.stack(edge_index_list, dim=0)
        #     # print(edge_index.size())
    
        #     # edge_index = edge_index.repeat(self.batch_size, 1, 1)
        #     data = MyData(X, edge_index)
            
        #     return data 
        # else: 

        #     data_list = []
        #     for i in range(self.batch_size):
        #         x = X[i]
        #         if self.var_distance: 
        #             # Flatten the cost matrix into a 1D tensor
        #             edge_index = env.graph_edge_indices[graph_indices[i].item()]
        #             cost_matrix = env.cost_matrices[graph_indices[i].item()]

        #             cost_vector = cost_matrix.view(-1)

        #             # Compute the linear indices corresponding to the edges in the edge_index
        #             linear_indices = edge_index[0] * cost_matrix.size(1) + edge_index[1]

        #             # Extract edge weights from the cost vector using linear indices
        #             edge_weight = cost_vector[linear_indices]

        #             res = Data(x = x.T, edge_index = edge_index, edge_weight = edge_weight)
                
        #         else:

        #             res = Data(x = x.T, edge_index = env.graph_edge_indices[graph_indices[i].item()]) # here we have to feed a particular edge_index
                
        #         data_list.append(res)
        #         train_loader = torch_geometric.loader.DataLoader(data_list, batch_size=self.batch_size, shuffle=False)
        #         data = next(iter(train_loader))
            
        #     return data

    def get_action(self, state_batch, env, epsilon=0.1):
        # with probability epsilon 
        if random.random() < epsilon:
            state = state_batch[0].to(self.device)
            random_indices = torch.randint(low = 0, high=env.num_servers, size=(self.batch_size,)).to(self.device)
            action_batch = torch.gather(state[:, :env.num_servers], dim=1, index=random_indices.unsqueeze(1))
            if self.env.request_same_node: 
                for i in range(self.env.batch_size):
                    if state[i][-1] in state[i][:-1]:  
                        action_batch[i] = state[i][-1]
            return action_batch.long()
        else:
            with torch.no_grad():
                data = self.observation_formation(state_batch, env).to(self.device)
                qt_index = state_batch[0][:,:env.num_servers].to(self.device)
                dis_req = env.distance_request(state_batch).to(self.device)
                if self.var_distance == True and self.dir_graph == False:
                    q_values = self.q_network(data.x, data.edge_index, data.batch, dis_req,  edge_weight=data.edge_weight)
                else:
                    q_values = self.q_network(data.x, data.edge_index, data.batch, dis_req)
                q_values = q_values.reshape(env.batch_size, -1)

                # create empty tensor C of size NxM
                C = torch.zeros_like(qt_index).to(self.device)

                # loop through each row of qt_index
                for i in range(qt_index.shape[0]):
                    # use row i of qt_index as an index to extract a row from q_values
                    row_b = q_values[i, qt_index[i].long()]
                    # assign the extracted row to the corresponding row in C
                    C[i] = row_b
                    
                max_index = torch.argmax(C, dim =1)
                action_batch = torch.gather(qt_index, 1, max_index.view(-1, 1))
                if self.env.request_same_node: 
                    for i in range(self.env.batch_size):
                        if state_batch[0][i][-1] in state_batch[0][i][:-1]:  
                            action_batch[i] = state_batch[0][i][-1]

        
                return action_batch
        
    def update(self, memory, env):
        

        batch = random.sample(memory, self.batch_size)
        concatenated = [torch.cat(tensors, dim=0) for tensors in zip(*batch)]

        if self.constant_probability:
            state, action_batch, reward_batch, next_state, graph_indices = concatenated[0], concatenated[1], concatenated[2], concatenated[3], concatenated[4]
            state_batch = state, graph_indices
            next_state_batch = next_state, graph_indices

        else: 
            state, action_batch, reward_batch, next_state, graph_indices, node_pbs, node_pbs_next = concatenated[0], concatenated[1], concatenated[2], concatenated[3], concatenated[4], concatenated[5], concatenated[6]
            state_batch = state, graph_indices
            next_state_batch = next_state, graph_indices
            # print(state, action_batch, reward_batch, next_state, graph_indices, node_pbs, node_pbs_next) 

        if self.constant_probability: 
            data = self.observation_formation(state_batch, env).to(self.device)
        else: 
            data = self.observation_formation(state_batch, env, node_pbs, constant_probability = False).to(self.device)
            # print(data.x)

        # data = self.observation_formation(state_batch, env).to(self.device)
        dis_req = env.distance_request(state_batch).to(self.device)
        if self.var_distance == True and self.dir_graph == False:
            q_values = self.q_network(data.x, data.edge_index, data.batch, dis_req,  edge_weight=data.edge_weight)
        else:
            q_values = self.q_network(data.x, data.edge_index, data.batch, dis_req)
        q_values = q_values.reshape(env.batch_size, -1)
        q_values = q_values.gather(1, action_batch.long())
        # print(dis_req)

        next_qt_index = next_state_batch[0][:,:env.num_servers]
        # data_next = self.observation_formation(next_state_batch, env).to(self.device)
        if self.constant_probability:
            data_next = self.observation_formation(next_state_batch, env).to(self.device)
        else:
            data_next = self.observation_formation(next_state_batch, env, node_pbs_next, constant_probability= False).to(self.device) 
            # print(data_next.x)
        dis_req_next = env.distance_request(next_state_batch).to(self.device)

        if self.var_distance == True and self.dir_graph == False:
            next_q_values = self.target_network(data_next.x, data_next.edge_index, data_next.batch, dis_req_next, edge_weight=data_next.edge_weight)
        else:
            next_q_values = self.target_network(data_next.x, data_next.edge_index, data_next.batch, dis_req_next)
        next_q_values = next_q_values.reshape(env.batch_size, -1)
        # print(dis_req_next)
        # create empty tensor C of size NxM
        C = torch.zeros_like(next_qt_index).to(self.device)
        # loop through each row of qt_index
        for i in range(next_qt_index.shape[0]):
            # use row i of qt_index as an index to extract a row from q_values
            row_b = next_q_values[i, next_qt_index[i].long()]
            # assign the extracted row to the corresponding row in C
            C[i] = row_b

        next_q_values = C.max(1)[0].unsqueeze(1)

        expected_q_values = reward_batch.view(-1, 1) + self.gamma * next_q_values 
        loss = F.mse_loss(q_values, expected_q_values)
        # print(q_values, next_q_values, expected_q_values)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        # Update target network
        self.soft_update_target_network()

    def remember(self, state, action, reward, next_state, memory, node_pbs = 1, node_pbs_next = 1):
        if self.constant_probability:
            for i in range(self.batch_size):
                memory.append((state[0][i].unsqueeze(0).to(self.device), action[i].unsqueeze(0), reward[i].unsqueeze(0).unsqueeze(0), next_state[0][i].unsqueeze(0).to(self.device), state[1][i].unsqueeze(0).to(self.device)))
                
        else: 
            for i in range(self.batch_size):
                memory.append((state[0][i].unsqueeze(0).to(self.device), action[i].unsqueeze(0), reward[i].unsqueeze(0).unsqueeze(0), next_state[0][i].unsqueeze(0).to(self.device), state[1][i].unsqueeze(0).to(self.device), node_pbs[i].unsqueeze(0), node_pbs_next[i].unsqueeze(0)))
    
    def soft_update_target_network(self, tau=0.001):

      for target_param, q_param in zip(self.target_network.parameters(), self.q_network.parameters()):
            target_param.data.copy_(tau * q_param.data + (1 - tau) * target_param.data)


  
    
    
    def optimize(self, num_steps=200, estimate_steps = 40, epsilon_decay = False, explr = 0.6, display_results = False, print_results = False, save_model = False, save_results = False, decay_rate = 0.0005):
        if save_model == True and self.num_nodes != None:
            if not os.path.exists(f'results/gen_testing/VD{self.var_distance}/{self.class_name}'):  
                os.makedirs(f'results/gen_testing/VD{self.var_distance}/{self.class_name}')
            if not os.path.exists(f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results'):
                os.makedirs(f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results')  
            if not os.path.exists(f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/models'):
                os.makedirs(f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/models')  
            if not os.path.exists(f'results/gen_testing/VD{self.var_distance}/train_curves'):
                os.makedirs(f'results/gen_testing/VD{self.var_distance}/train_curves')
        # state has to contain different graph with different sizes 
        # or the batch size can be equal to 1 to overcome this bottleneck 
        # there has to be an info about graph type as well 
        state_list = []
        for env in self.env_list:
            state_list.append(env.reset())
        steps_for_display = int(10000/self.batch_size)
        # steps_for_display = 1
        num_steps = int(num_steps*1000/self.batch_size)

        
        initial_percentage = explr  
        initial_limit = int(num_steps * initial_percentage)  


        if display_results:
            self.run = neptune.init_run(
                project="iliyasbektas/kserver",
                api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJiOTRhNmFlNi0xMzU0LTRiNGUtODZmYy05ZWQyMDA4ZjJiZDQifQ==",
            )  # your credentials 
            self.run["agent"] = self.class_name
            self.run["graph_type"] = self.general_model_gt
            self.run["gamma"] = self.gamma
            self.run["UR"] = self.uniform_random
            self.run["HCH"] = self.hidden_channels
            self.run["NL"] = self.num_layers
            self.run["CP"] = self.constant_probability
            self.run["lr"] = self.lr
            self.run["BN"] = self.use_batch_norm  
            self.run["var_distance"] = self.var_distance  
            self.run["CP"] = self.constant_probability
        tot_params = self.total_params()
        print('Starting Training....')   
        best_estimate_tot = float('-inf')
        for step in range(num_steps):
            new_state_list = []
            for index, state in enumerate(state_list):
                if epsilon_decay:
                    epsilon = self.min_epsilon + (self.max_epsilon - self.min_epsilon)*np.exp(-decay_rate*step)
                else:
                    if step < initial_limit:
                        epsilon = 0.5
                    else: 
                        epsilon = 0.1

                # we need a graph type here due to edge index 
                if self.constant_probability:
                    action = self.get_action(state,  self.env_list[index], epsilon).to(self.device)
                    next_state, reward, _ = self.env_list[index].step(action, state)    
                    self.remember(state,
                    action.to(self.device),
                    reward.to(self.device),
                    next_state,
                    self.memory_list[index]
                    )
                else: 
                    node_pbs = self.node_pbs_formation(state, self.env_list[index])
                    action = self.get_action(state, self.env_list[index], epsilon).to(self.device)
                    next_state, reward, _ = self.env_list[index].step(action, state)
                    node_pbs_next = self.node_pbs_formation(state, self.env_list[index])

                    # print(state, next_state)
                    # for l in range(env.batch_size):
                    #     print(self.env_list[index].graph_pbs[state[1][l]].view(1, -1)*self.env_list[index].num_nodes)
                    # print(node_pbs, node_pbs_next)
                    
                    self.remember(state,
                    action.to(self.device),
                    reward.to(self.device),
                    next_state,
                    self.memory_list[index],
                    node_pbs,
                    node_pbs_next
                    )
                
                self.update(self.memory_list[index], self.env_list[index])
                
                state = next_state
                self.total_reward = torch.cat((self.total_reward, reward), 0)
                new_state_list.append(state)
                # print(self.env_list[index].graph_pbs)

                if ((step+1)  % self.var_pr_ep_steps == 0) or ((step+1) == num_steps): 
                    if self.constant_probability == False:
                        if self.var_pr_ep: 
                            if self.general_model_gt in ['SF', 'EM', 'dir_check', 'dir_check_1']: 
                                self.env_list[index].graph_pbs = [] 
                                for i in range(self.env_list[index].graph_number_gm): 
                                    self.env_list[index].graph_pbs.append(torch.FloatTensor(self.env_list[index].add_pb()).to(self.device))
                            else:
                                self.env_list[index].regenerate_graph_gm()
                                
                        # # print(self.env_list[index].graph_pbs)
                        # if self.var_pr_ep:
                        #     self.env_list[index].graph_pbs = [] 
                        #     for i in range(self.env_list[index].graph_number_gm): 
                        #         self.env_list[index].graph_pbs.append(torch.FloatTensor(self.env_list[index].add_pb()).to(self.device))

            if print_results:
                if ((step+1)  % steps_for_display == 0) :
                    estimate_tot = self.estimate_all(estimate_steps)[0]
                    average_reward = torch.mean(self.total_reward[self.env.num_servers:])
                    step1000 = self.round_to_nearest_10000((step+1)*self.batch_size)
                    print(f"Step {step1000}, Epsilon {epsilon:.2f}, Average Reward {average_reward:.2f}, Estimate all {estimate_tot:.2f}") 
                    # av_estimates_list_pr, raw_result_pr = self.estimate(estimate_steps, print_results = print_results)
                    

            if save_model:
                if ((step+1)  % steps_for_display == 0) : 
                    if self.constant_probability: 
                        # print(True)
                        if print_results == False: 
                            estimate_tot = self.estimate_all(estimate_steps)[0]
                        if estimate_tot > best_estimate_tot:
                            
                            if self.num_nodes != None: 
                                # if self.general_model_gt == 'dir_check':
                                #     torch.save(self.q_network.state_dict(), f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/models/model_grsubop_{self.class_name}_{self.env.num_nodes}_{self.num_servers}_hidden_channels{self.hidden_channels}__gamma{self.gamma}_vd{self.var_distance}_td{self.tr_dist}_seed{self.seed}_DG{self.dir_graph}_bs{self.batch_size}_lr{self.lr}_nl{self.num_layers}.pth')
                                #     torch.save(self.q_network.state_dict(), f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/models/model_grop_{self.class_name}_{self.env.num_nodes}_{self.num_servers}_hidden_channels{self.hidden_channels}__gamma{self.gamma}_vd{self.var_distance}_td{self.tr_dist}_seed{self.seed}_DG{self.dir_graph}_bs{self.batch_size}_lr{self.lr}_nl{self.num_layers}.pth')
                                # else:
                                    torch.save(self.q_network.state_dict(), f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/models/model_{self.general_model_gt}_CP{self.constant_probability}_{self.class_name}_{self.env.num_nodes}_{self.num_servers}_hidden_channels{self.hidden_channels}__gamma{self.gamma}_vd{self.var_distance}_td{self.tr_dist}_ta{self.tr_att}_seed{self.seed}_DG{self.dir_graph}_bs{self.batch_size}_lr{self.lr}_nl{self.num_layers}.pth')
                            else:
                                torch.save(self.q_network.state_dict(), f'results/single_model_results/uniform{self.uniform_random}/models/model_GenModelAll_hch{self.hidden_channels}_nl{self.num_layers}_{self.general_model_gt}_UR{self.uniform_random}_CP{self.constant_probability}_lr{self.lr}_nl{self.num_layers}_varprep{self.var_pr_ep}{self.var_pr_ep_steps}_rqsmnd{self.env.request_same_node}_AR{self.arrival_rates}_gamma{self.gamma}_BN{self.use_batch_norm}_VD{self.var_distance}_DG{self.dir_graph}_bs{self.batch_size}_lr{self.lr}_nl{self.num_layers}.pth')
                            best_estimate_tot = estimate_tot
                    else:
                        if print_results == False: 
                            estimate_tot = self.estimate_all(estimate_steps)[0]
                        if estimate_tot > best_estimate_tot: 
                            if self.num_nodes != None: 
                                torch.save(self.q_network.state_dict(), f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/models/model_{self.general_model_gt}_CP{self.constant_probability}_{self.class_name}_{self.env.num_nodes}_{self.num_servers}_hidden_channels{self.hidden_channels}__gamma{self.gamma}_vd{self.var_distance}_td{self.tr_dist}_ta{self.tr_att}_seed{self.seed}_DG{self.dir_graph}_bs{self.batch_size}_lr{self.lr}_nl{self.num_layers}.pth')
                            else:
                                torch.save(self.q_network.state_dict(), f'results/single_model_results/uniform{self.uniform_random}/models/model_GenModelAll_hch{self.hidden_channels}_nl{self.num_layers}_{self.general_model_gt}_UR{self.uniform_random}_CP{self.constant_probability}_lr{self.lr}_nl{self.num_layers}_varprep{self.var_pr_ep}{self.var_pr_ep_steps}_rqsmnd{self.env.request_same_node}_AR{self.arrival_rates}_gamma{self.gamma}_BN{self.use_batch_norm}_VD{self.var_distance}_DG{self.dir_graph}_bs{self.batch_size}_lr{self.lr}_nl{self.num_layers}.pth')
                            best_estimate_tot = estimate_tot
            if save_results:
                if self.num_nodes != None: 
                    if ((step+1)  % steps_for_display == 0) :
                        if print_results == False:
                            estimate_tot = self.estimate_all(estimate_steps)[0]
                            average_reward = torch.mean(self.total_reward[self.env.num_servers:])
                            step1000 = self.round_to_nearest_10000((step+1)*self.batch_size)
                        
                        # # save the curve of estimates
                        # output_file_name_wqs = f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/train_curve/results_{self.env.graph_type}_{self.class_name}_{self.env.num_nodes}_{self.env.num_servers}_hidden_channels{self.hidden_channels}__gamma{self.gamma}_vd{self.var_distance}_td{self.tr_dist}_seed{self.seed}_DG{self.dir_graph}_bs{self.batch_size}_wqs.csv'
                        
                        # if not os.path.exists(output_file_name_wqs):
                        #     with open(output_file_name_wqs, 'w', newline='') as f:
                        #         writer = csv.writer(f)
                        #         writer.writerow(['step','epsilon', 'graph_type', 'agent', 'gamma', 'num_nodes', 'num_servers', 'seed', 'hidden_channels', 'VD', 'DG', 'batch_size', 'tot_params','estimate'])
                        #         writer.writerow([self.env.graph_type, self.class_name, self.gamma, self.env.num_nodes, self.seed, self.hidden_channels, round(estimate_tot.item(), 3)]) 
                        # else: 
                        #     with open(output_file_name_wqs, 'a', newline='') as f:  # Open in append mode
                        #         writer = csv.writer(f)
                        #         # Append new row to CSV file
                        #         writer.writerow([self.env.graph_type, self.class_name, self.gamma, self.env.num_nodes, self.seed, self.hidden_channels, round(estimate_tot.item(), 3)])
                        
                        # save the average reward curve
                        train_curve = f'results/gen_testing/VD{self.var_distance}/train_curves/train_curve_{self.env.graph_type}_{self.class_name}_{self.env.num_nodes}_{self.env.num_servers}_hidden_channels{self.hidden_channels}__gamma{self.gamma}_vd{self.var_distance}_td{self.tr_dist}_ta{self.tr_att}_seed{self.seed}_DG{self.dir_graph}_bs{self.batch_size}_lr{self.lr}_nl{self.num_layers}_ar.csv'

                        if not os.path.exists(train_curve):
                            with open(train_curve, 'w', newline='') as f:
                                writer = csv.writer(f)
                                writer.writerow(['step', 'graph_type', 'agent', 'gamma', 'lr', 'num_nodes', 'num_servers', 'seed', 'hidden_channels', 'num_layers', 'VD', 'tr_dist', 'tr_att', 'DG', 'batch_size', 'tot_params','average_reward', 'estimate'])
                                writer.writerow([step1000, self.env.graph_type, self.class_name, self.gamma, self.lr, self.env.num_nodes, self.env.num_servers, self.seed, self.hidden_channels, self.num_layers, self.var_distance, self.tr_dist, self.tr_att, self.dir_graph, self.batch_size, tot_params, round(average_reward.item(), 3), round(estimate_tot.item(), 3)]) 
                        else: 
                            with open(train_curve, 'a', newline='') as f:  # Open in append mode
                                writer = csv.writer(f)
                                # Append new row to CSV file
                                writer.writerow([step1000, self.env.graph_type, self.class_name, self.gamma, self.lr, self.env.num_nodes, self.env.num_servers, self.seed, self.hidden_channels, self.num_layers, self.var_distance, self.tr_dist, self.tr_att, self.dir_graph, self.batch_size, tot_params, round(average_reward.item(), 3), round(estimate_tot.item(), 3)]) 

                            
                        # if not os.path.exists(train_curve):
                        #     with open(train_curve, 'w', newline='') as f:
                        #         writer = csv.writer(f)
                        #         writer.writerow(['step','epsilon', 'graph_type', 'agent', 'gamma', 'num_nodes', 'num_servers', 'seed', 'hidden_channels', 'num_layers', 'CP', 'VD', 'DG', 'batch_size', 'tot_params','average_reward', 'estimate'])
                        #         writer.writerow([step1000, epsilon,  self.general_model_gt, self.class_name, self.gamma, self.env.num_nodes, self.num_servers, self.seed, self.hidden_channels, self.num_layers, self.constant_probability, self.var_distance, self.dir_graph, self.batch_size, tot_params, round(average_reward.item(), 3), round(estimate_tot.item(), 3)]) 
                        # else: 
                        #     with open(train_curve, 'a', newline='') as f:  # Open in append mode
                        #         writer = csv.writer(f)
                        #         # Append new row to CSV file
                        #         writer.writerow([step1000, epsilon,  self.general_model_gt, self.class_name, self.gamma, self.env.num_nodes, self.num_servers, self.seed, self.hidden_channels, self.num_layers, self.constant_probability, self.var_distance, self.dir_graph, self.batch_size, tot_params, round(average_reward.item(), 3), round(estimate_tot.item(), 3)]) 


            #     if ((step+1)  % steps_for_display == 0) : 
            #         if print_results: 
            #             av_estimates_list, raw_result = av_estimates_list_pr, raw_result_pr
            #         else: 
            #             av_estimates_list, raw_result = self.estimate(estimate_steps)
            #         if ((step+1)  % steps_for_display == 0) : 
            #             output_file_name = f'results/single_model_results/uniform{self.uniform_random}/results_GenModelAll_hch{self.hidden_channels}_nl{self.num_layers}_{self.general_model_gt}_UR{self.uniform_random}_CP{self.constant_probability}_lr{self.lr}_nl{self.num_layers}_varprep{self.var_pr_ep}{self.var_pr_ep_steps}_rqsmnd{self.env.request_same_node}_AR{self.arrival_rates}_gamma{self.gamma}_BN{self.use_batch_norm}_VD{self.var_distance}_DG{self.dir_graph}_bs{self.batch_size}.csv'
            #             with open(output_file_name, 'w', newline='') as f:
            #                 writer = csv.writer(f)
            #                 writer.writerow(['graph_type', 'agent', 'gamma','num_nodes', 'estimate', 'q1', 'q3'])
            #                 for ix, output_data in enumerate(av_estimates_list): 
            #                     for tree, value in output_data.items():
            #                         writer.writerow([tree, 'gen_all', 0.99 , self.num_nodes_list[ix], round(value.item(), 3), '', ''])
            #                         # print(self.num_nodes_list[ix], tree, value) 
                  
            #             output_file_name_raw = f'results/single_model_results/uniform{self.uniform_random}/raw_results/results_GenModelAll_hch{self.hidden_channels}_nl{self.num_layers}_{self.general_model_gt}_UR{self.uniform_random}_raw_results_CP{self.constant_probability}_lr{self.lr}_varprep{self.var_pr_ep}{self.var_pr_ep_steps}_rqsmnd{self.env.request_same_node}_AR{self.arrival_rates}_gamma{self.gamma}_BN{self.use_batch_norm}_VD{self.var_distance}_DG{self.dir_graph}_bs{self.batch_size}.csv'
            #             with open(output_file_name_raw, 'w', newline='') as f:
            #                     writer = csv.writer(f)
            #                     writer.writerow(['graph_type', 'num_nodes', ])
            #                     for ix, output_data in enumerate(raw_result): 
            #                         for tree, value in output_data.items():
            #                             writer.writerow([tree, self.num_nodes_list[ix], value])

            if display_results :
                if ((step+1)  % steps_for_display == 0) :
                    self.run["Average_Reward"].append(torch.mean(self.total_reward[self.env.num_servers:]))
                    if print_results == True or save_results == True:  
                        self.run["Estimate_all"].append(estimate_tot) 
                    else: 
                        self.run["Estimate_all"].append(self.estimate_all(estimate_steps)[0]) 
        if display_results :
            self.run.stop()
                            

            

            

    
    

    def estimate(self, num_steps = 1, print_results = False):

            state_list = []
            for env in self.env_list:
                state_list.append(env.reset())

            # state = self.env.reset()

            num_steps = int(num_steps*1000/self.batch_size)
            result_dict = {}
            list_of_dicts = [{} for _ in range(len(self.num_nodes_list))]

            for step in range(num_steps):
                new_state_list = []
                for index, state in enumerate(state_list):

                    # action = self.get_action(state, 0)
                    # next_state, reward, _ = self.env.step(action, state)

                    action = self.get_action(state,  self.env_list[index], 0).to(self.device)
                    next_state, reward, _ = self.env_list[index].step(action, state)

                    A = reward.view(-1, 1)
                    B = state[1]

                    for indx in range(len(self.env_list[index].graph_types)):
                        elements = A[B == indx]
                        if self.env_list[index].graph_types[indx] not in list_of_dicts[index]:
                            list_of_dicts[index][self.env_list[index].graph_types[indx]] = [elements]
                        else:
                            list_of_dicts[index][self.env_list[index].graph_types[indx]].append(elements)

                    state = next_state
                    new_state_list.append(state)

                state_list = new_state_list   
                
            
            list_of_average_dicts = [{} for _ in range(len(self.num_nodes_list))]
            for ix, result_dict in enumerate(list_of_dicts):
                for index, elements_list in result_dict.items():
                    combined_tensor = torch.cat(elements_list, dim=0)
                    average_tensor = torch.mean(combined_tensor, dim=0)
                    list_of_average_dicts[ix][index] = average_tensor
                # if print_results:
                #     print("Number of nodes:", self.num_nodes_list[ix]) 
                #     for index, average_tensor in list_of_average_dicts[ix].items():   
                #         # print("Number of nodes:", self.num_nodes_list[ix]) 
                #         print(index, round(average_tensor.item(), 3))
            
            if print_results:
                for ix, result_dict in enumerate(list_of_dicts):
                    print("Number of nodes:", self.num_nodes_list[ix]) 
                    for index, average_tensor in list_of_average_dicts[ix].items():   
                        # print("Number of nodes:", self.num_nodes_list[ix]) 
                        print(index, round(average_tensor.item(), 3))
            
            return list_of_average_dicts, list_of_dicts
        

               
            




    @property
    def class_name(self):
        return self.__class__.__name__


    
            
                    
    def estimate_all(self, num_steps = 1):  


        state_list = []
        for env in self.env_list_test:
            state_list.append(env.reset())  


        # state = self.env.reset() 
        
        num_steps = int(num_steps*1000/self.batch_size)


        for step in range(num_steps):
            new_state_list = []
            for index, state in enumerate(state_list):
                action = self.get_action(state,  self.env_list[index], 0).to(self.device)
                next_state, reward, _ = self.env_list[index].step(action, state)
                state = next_state
                self.total_reward_estimate = torch.cat((self.total_reward_estimate, reward), 0)
                new_state_list.append(state)

                # if ((step+1)  % self.var_pr_ep_steps == 0): 
                #     if self.var_pr_ep:
                #         self.env_list[index].graph_pbs = [] 
                #         for i in range(self.env_list[index].graph_number_gm): 
                #             self.env_list[index].graph_pbs.append(torch.FloatTensor(self.env_list[index].add_pb()).to(self.device))
                state_list = new_state_list   
                                   
        estimates = self.total_reward_estimate[-(num_steps*self.batch_size):]  
        return torch.mean(estimates), torch.quantile(estimates, 0.25), torch.quantile(estimates, 0.75), estimates
    
    
    def round_to_nearest_10000(self, number):
        return round(number / 10000) * 10000