import torch.nn.functional as F
import random
from collections import deque
import torch
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from torch_geometric.data import Data
import torch_geometric
from Policies.NET import Net
import neptune
from KServerEnv import KServerEnv 
import csv
from Policies.GraphDataset import MyData
from torch_geometric.utils import to_dense_adj
import os
import re
import json


class GCN_RL_GEN_ALL_SL:
    def __init__(self, 
                 seed=42, 
                 gamma=0.99, 
                 lr=0.001, 
                 num_layers=12, 
                 memory_size=10000, 
                 device='cpu', 
                 hidden_channels=128, 
                 shared_weights=False, 
                 general_model_gt='tree', 
                 uniform_random=False, 
                 constant_probability=True, 
                 var_pr_ep=False, 
                 var_pr_ep_steps=40, 
                 request_same_node=True, 
                 arrival_rates=False, 
                 use_batch_norm=False, 
                 var_distance=False, 
                 tr_dist=False, 
                 tr_att=False, 
                 dir_graph=False, 
                 batch_size=512, 
                 num_nodes=None, 
                 num_servers=None):

        # Seed setting for reproducibility
        self.seed = seed
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

        # Basic parameters
        self.gamma = gamma
        self.lr = lr
        self.num_layers = num_layers
        self.device = device
        self.hidden_channels = hidden_channels
        self.shared_weights = shared_weights

        # Environment and model parameters
        self.general_model_gt = general_model_gt
        self.uniform_random = uniform_random
        self.constant_probability = constant_probability
        self.var_pr_ep = var_pr_ep
        self.var_pr_ep_steps = var_pr_ep_steps
        self.request_same_node = request_same_node
        self.arrival_rates = arrival_rates
        self.use_batch_norm = use_batch_norm
        self.var_distance = var_distance
        self.tr_dist = tr_dist
        self.tr_att = tr_att
        self.dir_graph = dir_graph

        # Environment validation checks
        if not self.var_distance and self.tr_dist:
            raise ValueError("Unfeasible configuration: var_distance must be True if tr_dist is True.")
        if not self.tr_dist and self.tr_att:
            raise ValueError("Unfeasible configuration: tr_dist must be True if tr_att is True.")

        # Initialize environment list and node configurations
        self.env_list = []
        if self.general_model_gt == 'SF':
            self.num_nodes_list = [24]
        elif self.general_model_gt == 'EM':
            self.num_nodes_list = [74]
        else:
            # self.num_nodes_list = [9, 25, 36, 64]
            self.num_nodes_list = [25, 49, 64, 81] #, 64, 81 

        # Set number of nodes and servers
        self.num_nodes = num_nodes
        self.num_servers = num_servers
        if self.num_nodes is not None:
            self.num_nodes_list = [num_nodes]

        # Probability configurations for environments
        constant_probability_env = self.var_pr_ep

        # Create environments
        for num_nodes in self.num_nodes_list:
            # num_servers = self.num_servers or round(num_nodes / 6)
            num_servers = 4
            env = KServerEnv(num_nodes, num_servers, batch_size=batch_size, device=self.device, 
                           general_model=True, general_model_gt=self.general_model_gt, 
                           uniform_random=self.uniform_random, 
                           constant_probability=constant_probability_env, 
                           request_same_node=self.request_same_node, 
                           var_distance=self.var_distance)
            
                    # Import 
            env.locations_to_id_list = []
            env.Q_policy_list = []
            env.Q_table_list = []
            for i, graph_type in enumerate(env.graph_types): 
                graph_group = re.match(r'^\D+(?=_)', graph_type).group(0)
                path = f"/data/home/ifb5104/K_server_RL/data/qtable_wql/var_distance{env.var_distance}/{graph_group}/{graph_type}_{env.num_nodes}_{env.num_servers}"
                with open(os.path.join(path, 'locations_to_id.json'), 'r') as f:
                        locations_to_id_str_keys = json.load(f)                           
                locations_to_id = {(eval(key)): value for key, value in locations_to_id_str_keys.items()}
                Q_table = torch.load(os.path.join(path, 'Q_table.pth')).to(self.device) 
                env.locations_to_id_list.append(locations_to_id)
                env.Q_table_list.append(Q_table)
            self.env_list.append(env)


        self.env_list_test = self.env_list.copy()
        self.env = self.env_list[-1]
        self.batch_size = self.env.batch_size
        
        # Environment general model check
        if not self.env.general_model:
            raise ValueError("The environment general_model argument must be True.")
        if self.uniform_random and not self.constant_probability:
            raise ValueError("Unfeasible configuration: constant_probability must be True if uniform_random is True.")

        # Create memory buffers
        self.memory_list = [deque(maxlen=memory_size) for _ in self.num_nodes_list]

        # Initialize networks and optimizer
        in_channels = 2 if self.uniform_random else 3
        self.q_network = Net(
            in_channels=in_channels, hidden_channels=self.hidden_channels, 
            num_layers=self.num_layers, shared_weights=self.shared_weights, 
            use_batch_norm=self.use_batch_norm, var_distance=self.var_distance, 
            tr_dist=self.tr_dist, tr_att=self.tr_att, dir_graph=self.dir_graph
        ).to(self.device)
        self.target_network = Net(
            in_channels=in_channels, hidden_channels=self.hidden_channels, 
            num_layers=self.num_layers, shared_weights=self.shared_weights, 
            use_batch_norm=self.use_batch_norm, var_distance=self.var_distance, 
            tr_dist=self.tr_dist, tr_att=self.tr_att, dir_graph=self.dir_graph
        ).to(self.device)
        self.target_network.load_state_dict(self.q_network.state_dict())
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=self.lr) 

        # Initialize reward tensors
        self.total_reward = torch.empty((self.env.num_servers)).to(self.device)
        self.total_reward_estimate = torch.empty((self.env.num_servers)).to(self.device)



    def print_network_weights(self):
        network_weights = self.q_network.state_dict()
        # Print the weights of the network
        for name, param in network_weights.items():
            print(f"Layer: {name}\nWeights: {param}")

    def total_params(self):
        return sum(p.numel() for p in self.q_network.parameters())

    def node_pbs_formation(self, state_batch, env): 
        
        _, graph_indices = state_batch
        node_pbs = torch.zeros(env.batch_size, env.num_nodes).to(self.device)
        for i in range(env.batch_size):
            node_pbs[i] = env.graph_pbs[graph_indices[i]].view(1, -1)
        return node_pbs
        
    
    def observation_formation(self, state_batch, env, node_pbs = 1,  constant_probability = True): 
        state, graph_indices = state_batch
        if self.uniform_random: 
            X = torch.zeros(env.batch_size, 2, env.num_nodes).to(self.device)
            for i in range(env.batch_size):
                # use row i of qt_index as an index to extract a row from q_values
                X[i][0][state[i][:-1].long()] = 1
                X[i][1][state[i][-1].long()] = 1
        
        else: 
            X = torch.zeros(env.batch_size, 3, env.num_nodes).to(self.device)
            edge_index_list = []
            data_list = []
            for i in range(env.batch_size):
                X[i][0][state[i][:-1].long()] = 1
                X[i][1][state[i][-1].long()] = 1
                # X[i][2] = env.graph_pbs[graph_indices[i]].view(1, -1)

                if constant_probability:
                    X[i][2] = env.graph_pbs[graph_indices[i]].view(1, -1)
                else: 
                    X[i][2] = node_pbs[i]
                
                if self.arrival_rates: 
                    X[i][2] = X[i][2]* env.num_nodes
                
                if self.dir_graph:
                    im = torch.eye(env.num_nodes).to(self.device)
                    if self.var_distance: 
                        # cm = env.cost_matrices[graph_indices[i].item()].unsqueeze(0)
                        cm = env.cost_matrices_nn[graph_indices[i].item()].unsqueeze(0)
                        edge_index = cm + im 
                        if self.tr_dist == False:
                            edge_index = 1/edge_index
                        # original_tensor = 1/env.cost_matrices[graph_indices[i].item()].unsqueeze(0)
                        # edge_index = torch.where(torch.isinf(original_tensor), torch.zeros_like(original_tensor), original_tensor)
                    else:
                        edge_index = env.graph_edge_indices[graph_indices[i].item()]
                        edge_index = to_dense_adj(edge_index).to(self.device)
                        edge_index = edge_index + im 

                    edge_index_list.append(edge_index.squeeze(0))
                            
            
                else: 
                    x = X[i]
                    if self.var_distance: 
                        # Flatten the cost matrix into a 1D tensor
                        edge_index = env.graph_edge_indices[graph_indices[i].item()]
                        cost_matrix = env.cost_matrices[graph_indices[i].item()]

                        cost_vector = cost_matrix.view(-1)

                        # Compute the linear indices corresponding to the edges in the edge_index
                        linear_indices = edge_index[0] * cost_matrix.size(1) + edge_index[1]

                        # Extract edge weights from the cost vector using linear indices
                        edge_weight = cost_vector[linear_indices]

                        res = Data(x = x.T, edge_index = edge_index, edge_weight = edge_weight)
                    
                    else:
                        res = Data(x = x.T, edge_index = env.graph_edge_indices[graph_indices[i].item()]) # here we have to feed a particular edge_index
                    
                    data_list.append(res)
                    

                
        if self.dir_graph:
            
            X = torch.transpose(X, -1, -2)
            
            edge_index = torch.stack(edge_index_list, dim=0)
            # print(edge_index.size())
    
            # edge_index = edge_index.repeat(self.batch_size, 1, 1)
            data = MyData(X, edge_index)
            
            return data 
        else: 

            train_loader = torch_geometric.loader.DataLoader(data_list, batch_size=self.batch_size, shuffle=False)
            data = next(iter(train_loader))

            return data

    def get_action(self, state_batch, env):

        data = self.observation_formation(state_batch, env).to(self.device)
        qt_index = state_batch[0][:,:env.num_servers].to(self.device)
        dis_req = env.distance_request(state_batch).to(self.device)
        if self.var_distance == True and self.dir_graph == False:
            q_values = self.q_network(data.x, data.edge_index, data.batch, dis_req,  edge_weight=data.edge_weight)
        else:
            q_values = self.q_network(data.x, data.edge_index, data.batch, dis_req)
        q_values = q_values.reshape(env.batch_size, -1)

        # create empty tensor C of size NxM
        C = torch.zeros_like(qt_index).to(self.device)

        # loop through each row of qt_index
        for i in range(qt_index.shape[0]):
            # use row i of qt_index as an index to extract a row from q_values
            row_b = q_values[i, qt_index[i].long()]
            # assign the extracted row to the corresponding row in C
            C[i] = row_b
            
        max_index = torch.argmax(C, dim =1)
        action_batch = torch.gather(qt_index, 1, max_index.view(-1, 1))
        if self.env.request_same_node: 
            for i in range(self.env.batch_size):
                if state_batch[0][i][-1] in state_batch[0][i][:-1]:  
                    action_batch[i] = state_batch[0][i][-1]
        
        result = torch.zeros_like(q_values)  
        for i, row in enumerate(qt_index):
            result[i, row.long()] = q_values[i, row.long()]       


        return action_batch.to(self.device), result.to(self.device)



    

    
    def get_target_values(self, state_batch, env):
        state, graph_indices = state_batch
        with torch.no_grad():
            qt_index = state[:, :env.num_servers].to(self.device)
            server_location_batch = state[:, env.num_servers].to(self.device)
            qt_index_tuple = tuple(map(tuple, qt_index.int().tolist()))
            q_id_tensor = torch.tensor(
                [env.locations_to_id_list[graph_indices[i]][index] for i, index in enumerate(qt_index_tuple)],
                dtype=torch.long
            ).to(self.device)
            server_location_batch = server_location_batch.long().to(self.device)
            target_q_values_full = torch.zeros(self.batch_size, env.num_nodes, dtype=torch.float).to(self.device)

            for i in range(self.batch_size):
                # Ensure source tensor has the same dtype as destination tensor
                target_q_values_full[i, qt_index[i].long()] = env.Q_table_list[graph_indices[i]][q_id_tensor[i], server_location_batch[i]].float().to(self.device)

            return target_q_values_full


  
    
    
    def optimize(self, num_steps=200, estimate_steps = 40, display_results = False, print_results = False, save_model = False, save_results = False):
        if save_model == True and self.num_nodes != None:
            if not os.path.exists(f'results/gen_testing/VD{self.var_distance}/{self.class_name}'):  
                os.makedirs(f'results/gen_testing/VD{self.var_distance}/{self.class_name}')
            if not os.path.exists(f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results'):
                os.makedirs(f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results')  
            if not os.path.exists(f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/models'):
                os.makedirs(f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/models')  
            if not os.path.exists(f'results/gen_testing/VD{self.var_distance}/train_curves'):
                os.makedirs(f'results/gen_testing/VD{self.var_distance}/train_curves')
        # state has to contain different graph with different sizes 
        # or the batch size can be equal to 1 to overcome this bottleneck 
        # there has to be an info about graph type as well 
        state_list = []
        for env in self.env_list:
            state_list.append(env.reset())
        steps_for_display = int(10000/self.batch_size)
        # steps_for_display = 1
        num_steps = int(num_steps*1000/self.batch_size)


        if display_results:
            self.run = neptune.init_run(
                project="iliyasbektas/kserver",
                api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJiOTRhNmFlNi0xMzU0LTRiNGUtODZmYy05ZWQyMDA4ZjJiZDQifQ==",
            )  # your credentials 
            self.run["agent"] = self.class_name
            self.run["graph_type"] = self.general_model_gt
            self.run["gamma"] = self.gamma
            self.run["UR"] = self.uniform_random
            self.run["HCH"] = self.hidden_channels
            self.run["NL"] = self.num_layers
            self.run["CP"] = self.constant_probability
            self.run["lr"] = self.lr
            self.run["BN"] = self.use_batch_norm  
            self.run["var_distance"] = self.var_distance  
            self.run["CP"] = self.constant_probability
            self.run["bs"] = self.batch_size
        tot_params = self.total_params()
        print('Starting Training....')   
        best_estimate_tot = float('-inf')
        for step in range(num_steps):
            new_state_list = []
            for index, state in enumerate(state_list):
                
                action, preds = self.get_action(state,  self.env_list[index])
                next_state, reward, _ = self.env_list[index].step(action, state)
                target_q_values = self.get_target_values(state, self.env_list[index])    
                loss = F.mse_loss(preds, target_q_values.float())
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()         
                state = next_state

                
                self.total_reward = torch.cat((self.total_reward, reward), 0)
                new_state_list.append(state)
                # print(self.env_list[index].graph_pbs)

                if ((step+1)  % self.var_pr_ep_steps == 0) or ((step+1) == num_steps): 
                    if self.constant_probability == False:
                        if self.var_pr_ep: 
                            if self.general_model_gt in ['SF', 'EM', 'dir_check', 'dir_check_1']: 
                                self.env_list[index].graph_pbs = [] 
                                for i in range(self.env_list[index].graph_number_gm): 
                                    self.env_list[index].graph_pbs.append(torch.FloatTensor(self.env_list[index].add_pb()).to(self.device))
                            else:
                                self.env_list[index].regenerate_graph_gm()
                                
                        # # print(self.env_list[index].graph_pbs)
                        # if self.var_pr_ep:
                        #     self.env_list[index].graph_pbs = [] 
                        #     for i in range(self.env_list[index].graph_number_gm): 
                        #         self.env_list[index].graph_pbs.append(torch.FloatTensor(self.env_list[index].add_pb()).to(self.device))

            if print_results:
                if ((step+1)  % steps_for_display == 0) :
                    estimate_tot = self.estimate_all(estimate_steps)[0]
                    # flattened_rewards = torch.cat(self.total_reward).view(-1)
                    # selected_rewards = flattened_rewards[(-estimate_steps * self.batch_size):]
                    average_reward = torch.mean(self.total_reward[self.env.num_servers:])
                    step1000 = self.round_to_nearest_10000((step+1)*self.batch_size)
                    print(f"Step {step1000}, Average Reward {average_reward:.2f}, Estimate all {estimate_tot:.2f}") 
                    # av_estimates_list_pr, raw_result_pr = self.estimate(estimate_steps, print_results = print_results)
            

            if display_results :
                if ((step+1)  % steps_for_display == 0) :
                    self.run["Average_Reward"].append(torch.mean(self.total_reward[self.env.num_servers:]))
                    if print_results == True or save_results == True:  
                        self.run["Estimate_all"].append(estimate_tot) 
                    else: 
                        self.run["Estimate_all"].append(self.estimate_all(estimate_steps)[0]) 
                    

            if save_model:
                if ((step+1)  % steps_for_display == 0) : 
         
                    if print_results == False: 
                        estimate_tot = self.estimate_all(estimate_steps)[0]
                    if estimate_tot > best_estimate_tot:
                        
                        if self.num_nodes != None: 
                            # if self.general_model_gt == 'dir_check':
                            #     torch.save(self.q_network.state_dict(), f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/models/model_grsubop_{self.class_name}_{self.env.num_nodes}_{self.num_servers}_hidden_channels{self.hidden_channels}__gamma{self.gamma}_vd{self.var_distance}_td{self.tr_dist}_seed{self.seed}_DG{self.dir_graph}_bs{self.batch_size}_lr{self.lr}_nl{self.num_layers}.pth')
                            #     torch.save(self.q_network.state_dict(), f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/models/model_grop_{self.class_name}_{self.env.num_nodes}_{self.num_servers}_hidden_channels{self.hidden_channels}__gamma{self.gamma}_vd{self.var_distance}_td{self.tr_dist}_seed{self.seed}_DG{self.dir_graph}_bs{self.batch_size}_lr{self.lr}_nl{self.num_layers}.pth')
                            # else:
                                torch.save(self.q_network.state_dict(), f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/models/model_{self.general_model_gt}_CP{self.constant_probability}_{self.class_name}_{self.env.num_nodes}_{self.num_servers}_hidden_channels{self.hidden_channels}__gamma{self.gamma}_vd{self.var_distance}_td{self.tr_dist}_ta{self.tr_att}_seed{self.seed}_DG{self.dir_graph}_bs{self.batch_size}_lr{self.lr}_nl{self.num_layers}.pth')
                        else:
                            torch.save(self.q_network.state_dict(), f'results/single_model_results/uniform{self.uniform_random}/models/model_GenModelAllSL_hch{self.hidden_channels}_nl{self.num_layers}_{self.general_model_gt}_UR{self.uniform_random}_CP{self.constant_probability}_lr{self.lr}_nl{self.num_layers}_varprep{self.var_pr_ep}{self.var_pr_ep_steps}_rqsmnd{self.env.request_same_node}_AR{self.arrival_rates}_gamma{self.gamma}_BN{self.use_batch_norm}_VD{self.var_distance}_DG{self.dir_graph}_bs{self.batch_size}_lr{self.lr}_nl{self.num_layers}.pth')
                        best_estimate_tot = estimate_tot
            
            if save_results:
                if self.num_nodes != None: 
                    if ((step+1)  % steps_for_display == 0) :
                        if print_results == False:
                            estimate_tot = self.estimate_all(estimate_steps)[0]
                            average_reward = torch.mean(self.total_reward[self.env.num_servers:])
                            step1000 = self.round_to_nearest_10000((step+1)*self.batch_size)
                        
                        train_curve = f'results/gen_testing/VD{self.var_distance}/train_curves/train_curve_{self.env.graph_type}_{self.class_name}_{self.env.num_nodes}_{self.env.num_servers}_hidden_channels{self.hidden_channels}__gamma{self.gamma}_vd{self.var_distance}_td{self.tr_dist}_ta{self.tr_att}_seed{self.seed}_DG{self.dir_graph}_bs{self.batch_size}_lr{self.lr}_nl{self.num_layers}_ar.csv'

                        if not os.path.exists(train_curve):
                            with open(train_curve, 'w', newline='') as f:
                                writer = csv.writer(f)
                                writer.writerow(['step', 'graph_type', 'agent', 'gamma', 'lr', 'num_nodes', 'num_servers', 'seed', 'hidden_channels', 'num_layers', 'VD', 'tr_dist', 'tr_att', 'DG', 'batch_size', 'tot_params','average_reward', 'estimate'])
                                writer.writerow([step1000, self.env.graph_type, self.class_name, self.gamma, self.lr, self.env.num_nodes, self.env.num_servers, self.seed, self.hidden_channels, self.num_layers, self.var_distance, self.tr_dist, self.tr_att, self.dir_graph, self.batch_size, tot_params, round(average_reward.item(), 3), round(estimate_tot.item(), 3)]) 
                        else: 
                            with open(train_curve, 'a', newline='') as f:  # Open in append mode
                                writer = csv.writer(f)
                                # Append new row to CSV file
                                writer.writerow([step1000, self.env.graph_type, self.class_name, self.gamma, self.lr, self.env.num_nodes, self.env.num_servers, self.seed, self.hidden_channels, self.num_layers, self.var_distance, self.tr_dist, self.tr_att, self.dir_graph, self.batch_size, tot_params, round(average_reward.item(), 3), round(estimate_tot.item(), 3)]) 

                            
                

            #     if ((step+1)  % steps_for_display == 0) : 
            #         if print_results: 
            #             av_estimates_list, raw_result = av_estimates_list_pr, raw_result_pr
            #         else: 
            #             av_estimates_list, raw_result = self.estimate(estimate_steps)
            #         if ((step+1)  % steps_for_display == 0) : 
            #             output_file_name = f'results/single_model_results/uniform{self.uniform_random}/results_GenModelAllSL_hch{self.hidden_channels}_nl{self.num_layers}_{self.general_model_gt}_UR{self.uniform_random}_CP{self.constant_probability}_lr{self.lr}_nl{self.num_layers}_varprep{self.var_pr_ep}{self.var_pr_ep_steps}_rqsmnd{self.env.request_same_node}_AR{self.arrival_rates}_gamma{self.gamma}_BN{self.use_batch_norm}_VD{self.var_distance}_DG{self.dir_graph}_bs{self.batch_size}.csv'
            #             with open(output_file_name, 'w', newline='') as f:
            #                 writer = csv.writer(f)
            #                 writer.writerow(['graph_type', 'agent', 'gamma','num_nodes', 'estimate', 'q1', 'q3'])
            #                 for ix, output_data in enumerate(av_estimates_list): 
            #                     for tree, value in output_data.items():
            #                         writer.writerow([tree, 'gen_all', 0.99 , self.num_nodes_list[ix], round(value.item(), 3), '', ''])
            #                         # print(self.num_nodes_list[ix], tree, value) 
                  
            #             output_file_name_raw = f'results/single_model_results/uniform{self.uniform_random}/raw_results/results_GenModelAllSL_hch{self.hidden_channels}_nl{self.num_layers}_{self.general_model_gt}_UR{self.uniform_random}_raw_results_CP{self.constant_probability}_lr{self.lr}_varprep{self.var_pr_ep}{self.var_pr_ep_steps}_rqsmnd{self.env.request_same_node}_AR{self.arrival_rates}_gamma{self.gamma}_BN{self.use_batch_norm}_VD{self.var_distance}_DG{self.dir_graph}_bs{self.batch_size}.csv'
            #             with open(output_file_name_raw, 'w', newline='') as f:
            #                     writer = csv.writer(f)
            #                     writer.writerow(['graph_type', 'num_nodes', ])
            #                     for ix, output_data in enumerate(raw_result): 
            #                         for tree, value in output_data.items():
            #                             writer.writerow([tree, self.num_nodes_list[ix], value])


        if display_results :
            self.run.stop()
                            

            

            

    
    

    def estimate(self, num_steps = 1, print_results = False):

            state_list = []
            for env in self.env_list:
                state_list.append(env.reset())

            # state = self.env.reset()

            num_steps = int(num_steps*1000/self.batch_size)
            result_dict = {}
            list_of_dicts = [{} for _ in range(len(self.num_nodes_list))]

            for step in range(num_steps):
                new_state_list = []
                for index, state in enumerate(state_list):

                    # action = self.get_action(state, 0)
                    # next_state, reward, _ = self.env.step(action, state)

                    action = self.get_action(state,  self.env_list[index]).to(self.device)
                    next_state, reward, _ = self.env_list[index].step(action, state)

                    A = reward.view(-1, 1)
                    B = state[1]

                    for indx in range(len(self.env_list[index].graph_types)):
                        elements = A[B == indx]
                        if self.env_list[index].graph_types[indx] not in list_of_dicts[index]:
                            list_of_dicts[index][self.env_list[index].graph_types[indx]] = [elements]
                        else:
                            list_of_dicts[index][self.env_list[index].graph_types[indx]].append(elements)

                    state = next_state
                    new_state_list.append(state)

                state_list = new_state_list   
                
            
            list_of_average_dicts = [{} for _ in range(len(self.num_nodes_list))]
            for ix, result_dict in enumerate(list_of_dicts):
                for index, elements_list in result_dict.items():
                    combined_tensor = torch.cat(elements_list, dim=0)
                    average_tensor = torch.mean(combined_tensor, dim=0)
                    list_of_average_dicts[ix][index] = average_tensor
                # if print_results:
                #     print("Number of nodes:", self.num_nodes_list[ix]) 
                #     for index, average_tensor in list_of_average_dicts[ix].items():   
                #         # print("Number of nodes:", self.num_nodes_list[ix]) 
                #         print(index, round(average_tensor.item(), 3))
            
            if print_results:
                for ix, result_dict in enumerate(list_of_dicts):
                    print("Number of nodes:", self.num_nodes_list[ix]) 
                    for index, average_tensor in list_of_average_dicts[ix].items():   
                        # print("Number of nodes:", self.num_nodes_list[ix]) 
                        print(index, round(average_tensor.item(), 3))
            
            return list_of_average_dicts, list_of_dicts
        

               
            




    @property
    def class_name(self):
        return self.__class__.__name__


    
            
                    
    def estimate_all(self, num_steps = 1):  


        state_list = []
        for env in self.env_list_test:
            state_list.append(env.reset())  


        # state = self.env.reset() 
        
        num_steps = int(num_steps*1000/self.batch_size)
        # num_steps = 1


        for step in range(num_steps):
            new_state_list = []
            for index, state in enumerate(state_list):
                action, _ = self.get_action(state,  self.env_list[index])
                next_state, reward, _ = self.env_list[index].step(action, state)
                state = next_state
                self.total_reward_estimate = torch.cat((self.total_reward_estimate, reward), 0)
                new_state_list.append(state)

                # if ((step+1)  % self.var_pr_ep_steps == 0): 
                #     if self.var_pr_ep:
                #         self.env_list[index].graph_pbs = [] 
                #         for i in range(self.env_list[index].graph_number_gm): 
                #             self.env_list[index].graph_pbs.append(torch.FloatTensor(self.env_list[index].add_pb()).to(self.device))
                state_list = new_state_list   
                                   
        estimates = self.total_reward_estimate[-(num_steps*self.batch_size):]  
        return torch.mean(estimates), torch.quantile(estimates, 0.25), torch.quantile(estimates, 0.75), estimates
    
    
    def round_to_nearest_10000(self, number):
        return round(number / 10000) * 10000