from cmath import sqrt
from generate_planar_graph import generate_planar_graph
import torch
import random
import networkx as nx
import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt
from torch_geometric.utils import convert
import math
import re
import os
import pickle



class KServerEnv:
    def __init__(self, 
                 num_nodes=9, 
                 num_servers=2, 
                 uniform_random=False, 
                 general_model=False, 
                 constant_probability=True, 
                 balanced_algorithm=False, 
                 request_same_node=True, 
                 arrival_rates=True, 
                 seq_req=False, 
                 var_distance=False, 
                 general_model_gt='tree', 
                 graph_type='tree_1', 
                 device="cpu", 
                 seed=123, 
                 batch_size=1):
        
        # Basic configurations
        self.num_nodes = num_nodes
        self.num_servers = num_servers
        self.batch_size = batch_size
        self.graph_type = graph_type
        self.general_model = general_model
        self.seed = seed
        self.balanced_algorithm = balanced_algorithm
        self.uniform_random = uniform_random
        self.constant_probability = constant_probability
        self.general_model_gt = general_model_gt
        self.request_same_node = request_same_node
        self.seq_req = seq_req
        self.var_distance = var_distance
        self.arrival_rates = arrival_rates

        # Validate configuration
        if not self.request_same_node and self.seq_req:
            raise ValueError("Unfeasible configuration: seq_req requires request_same_node to be True.")

        # Device configuration
        if device == "cuda":
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = torch.device(device)

        # Graph generation settings
        if self.batch_size < 50:
            self.graph_number_gm = self.batch_size
        else:
            self.graph_number_gm = 50

        # Initialize lists for storing graph properties
        self.graph_edge_indices = []
        self.cost_matrices = []
        self.cost_matrices_nn = []
        self.graph_pbs = []

        # Determine graph types based on general_model_gt
        if general_model_gt == 'all':
            self.graph_types = ['tree_1', 'grid_gre_1', 'cycle', 'line', 'grid']
        elif general_model_gt == 'dir_check':
            self.graph_types = ['grsubop', 'grop']
        elif general_model_gt == 'dir_check_1':
            self.graph_types = ['grsubop_1', 'grop_1']
        elif general_model_gt == 'SF':
            self.graph_types = ['SF']
        elif general_model_gt == 'EM':
            self.graph_types = ['EM']
        elif general_model_gt == 'grid_dir_sl':
            self.graph_types = ['grid_dir_50', 'grid_dir_51', 'grid_dir_52','grid_dir_53','grid_dir_54']
        else:
            self.graph_types = []

        # Load or create graph and related properties
        if (self.graph_type not in ["EM", "SF"]) and (self.general_model_gt not in ["EM", "SF"]) and (self.num_nodes > 200):
            file_path = f'data/big_graphs/{self.graph_type}_{self.num_nodes}.pkl'
            if os.path.exists(file_path):
                with open(file_path, 'rb') as f:
                    self.graph = pickle.load(f)
                self.cost_matrix = torch.load(f'data/big_graphs/cost_matrix/{self.graph_type}_{self.num_nodes}.pth').to(self.device)
                self.probabilities = torch.load(f'data/big_graphs/probabilities/{self.graph_type}_{self.num_nodes}.pth')
            else:
                self.graph = self.create_graph(self.graph_type)
                with open(file_path, 'wb') as f:
                    pickle.dump(self.graph, f)
                seed = int(re.findall(r'\d+', self.graph_type)[0])
                self.add_graph_pb(self.graph, seed, self.graph_type)
                self.cost_matrix = self.build_cost_matrix(self.graph).to(self.device)
                torch.save(self.cost_matrix, f'data/big_graphs/cost_matrix/{self.graph_type}_{self.num_nodes}.pth')
                self.probabilities = [self.graph.nodes[node]['probability'] for node in self.graph.nodes]
                torch.save(self.probabilities, f'data/big_graphs/probabilities/{self.graph_type}_{self.num_nodes}.pth')
            if os.path.exists(f'data/big_graphs/cost_matrix_nn/{self.graph_type}_{self.num_nodes}.pth'):
                self.cost_matrix_nn = torch.load(f'data/big_graphs/cost_matrix_nn/{self.graph_type}_{self.num_nodes}.pth').to(self.device)
            else:
                self.cost_matrix_nn = self.build_cost_matrix_nn(self.graph).to(self.device)
                torch.save(self.cost_matrix_nn, f'data/big_graphs/cost_matrix_nn/{self.graph_type}_{self.num_nodes}.pth')
        else:
            self.graph = self.create_graph(self.graph_type)
            if self.graph_type in ["EM", "SF"] or self.general_model_gt in ["EM", "SF"]:
                self.num_nodes = self.graph.number_of_nodes()
            if any(keyword in self.graph_type for keyword in ['grid_gre', 'tree', 'grid_dir']):
                seed = int(re.findall(r'\d+', self.graph_type)[0])
                self.add_graph_pb(self.graph, seed, self.graph_type)
            else:
                self.add_graph_pb(self.graph, self.seed, self.graph_type)
            self.cost_matrix = self.build_cost_matrix(self.graph).to(self.device)
            self.cost_matrix_nn = self.build_cost_matrix_nn(self.graph).to(self.device)
            self.probabilities = [self.graph.nodes[node]['probability'] for node in self.graph.nodes]

        # General model configuration
        if self.general_model:
            if general_model_gt == 'grid':
                for rows in range(2, self.num_nodes + 1):
                    columns = self.num_nodes // rows
                    if rows * columns == self.num_nodes and rows > columns > 1:
                        graph = nx.grid_2d_graph(rows, columns)
                        mapping = {(i, j): i * columns + j for i, j in graph.nodes()}
                        graph = nx.relabel_nodes(graph, mapping)
                        graph_type = f'grid_col{columns}'
                        self.graph_types.append(graph_type)
                        self.graph_edge_indices.append(convert.from_networkx(graph).edge_index)
                        self.cost_matrices.append(self.build_cost_matrix(graph).to(self.device))
                        self.cost_matrices_nn.append(self.build_cost_matrix_nn(graph).to(self.device))
                        if not self.uniform_random:
                            self.graph_pbs.append(torch.FloatTensor(self.add_graph_pb(graph, rows, graph_type)).to(self.device))
            elif general_model_gt in ['plane', 'tree', 'grid_gre', 'grid_dir', 'bn_grid_gre', 'psn_grid_gre', 'lgnm_grid_gre']:
                for i in range(self.graph_number_gm):
                    graph = self.create_graph_gm(i)
                    self.graph_edge_indices.append(convert.from_networkx(graph).edge_index)
                    self.cost_matrices.append(self.build_cost_matrix(graph).to(self.device))
                    self.cost_matrices_nn.append(self.build_cost_matrix_nn(graph).to(self.device))
                    graph_type = f'{self.general_model_gt}_{i}'
                    self.graph_types.append(graph_type)
                    if not self.uniform_random:
                        self.graph_pbs.append(torch.FloatTensor(self.add_graph_pb(graph, i, graph_type)).to(self.device))
            else:
                for n, graph_type in enumerate(self.graph_types):
                    graph = self.create_graph(graph_type)
                    self.graph_edge_indices.append(convert.from_networkx(graph).edge_index)
                    self.cost_matrices.append(self.build_cost_matrix(graph).to(self.device))
                    self.cost_matrices_nn.append(self.build_cost_matrix_nn(graph).to(self.device))
                    if not self.uniform_random:
                        self.graph_pbs.append(torch.FloatTensor(self.add_graph_pb(graph, n, graph_type)).to(self.device))

        self.reset()


    def reset(self):

        state_batch = torch.empty((0, self.num_servers+1))

        
        for i in range(self.batch_size):
          state = torch.LongTensor(random.sample(range(0, self.num_nodes), self.num_servers+1))
          sorted_state = torch.sort(state[:self.num_servers])[0]
          state = torch.cat([sorted_state, state[-1].unsqueeze(0)]) 
          state_batch = torch.cat((state_batch, state.unsqueeze(0)), 0)
        if self.general_model: 
            graph_indices = torch.randint(0, len(self.graph_edge_indices), (self.batch_size, 1))
            return state_batch, graph_indices
        else: 
            return state_batch

    def step(self, action_batch, state_batch, server_indices_batch = 1, next_req = 1):

        if self.general_model: 
            state, graph_indices = state_batch
            state = torch.clone(state).to(self.device)
        else:
            if self.balanced_algorithm: 
                state = torch.clone(state_batch).to(self.device)
                server_indices = torch.clone(server_indices_batch).to(self.device)
            else:
                state = torch.clone(state_batch).to(self.device)
        
        destination_node = state[:, -1]
        action_index = torch.where(state == action_batch)[1]
        
        # calculating rewards for each element from the batch
        reward_batch = torch.empty((1)).to(self.device) 
        for i in range(self.batch_size):
            if self.request_same_node:
                if self.general_model:  
                    if state[i][-1] in state[i][:-1]: 
                        min_cost = torch.tensor(0).unsqueeze(0).to(self.device)
                    else:     
                        min_cost = self.cost_matrices[graph_indices[i]][action_batch[i].long(), destination_node[i].long()]
                else:
                    if state[i][-1] in state[i][:-1]: 
                        min_cost = torch.tensor(0).unsqueeze(0).to(self.device)
                    else: 
                        min_cost = self.cost_matrix[action_batch[i].long(), destination_node[i].long()]
            else: 
                if self.general_model:  
                    min_cost = self.cost_matrices[graph_indices[i]][action_batch[i].long(), destination_node[i].long()]
                else:
                    min_cost = self.cost_matrix[action_batch[i].long(), destination_node[i].long()]
            reward = -min_cost
            reward_batch = torch.cat((reward_batch, reward), 0)
        reward_batch = reward_batch[1:]   
        # print(state, action_batch, reward_batch)      
        # state after sending particular server to requests
        destination_node = torch.reshape(destination_node, (self.batch_size, 1)).long()
        if self.request_same_node:
            state_upd = []
            for i in range(self.batch_size):
                if state[i][-1] in state[i][:-1]:
                    state_upd.append(state[i])
                else: 
                    state_ch = state[i].clone()
                    state_ch[action_index[i]] = state[i][-1]
                    state_upd.append(state_ch)
            state_upd = torch.stack(state_upd, dim=0)
        else: 
            state_upd = torch.where(action_index.unsqueeze(1).expand(-1, self.num_servers+1) == torch.arange(self.num_servers+1).to(self.device), destination_node, state)
    
        

        if self.request_same_node: 
            if self.uniform_random:
                random_int = torch.randint(0, self.num_nodes, size=(self.batch_size, 1)).to(self.device)      
            else: 
                if self.general_model == False:
                    if self.seq_req: 
                        random_int = next_req.to(self.device)
                    else: 
                        random_int = np.random.choice(list(self.graph.nodes), size=(self.batch_size,), p=self.probabilities, replace=True)
                        random_int = torch.tensor(random_int).reshape(self.batch_size, 1).to(self.device)
                        if self.constant_probability == False: 
                            self.add_graph_pb(self.graph, 1, self.graph_type, replicable = False)
                            self.probabilities = [self.graph.nodes[node]['probability'] for node in self.graph.nodes]        
                else:
                    random_int = torch.zeros(self.batch_size, 1).to(self.device)
                    for i in range(self.batch_size):
                        random_int[i] = torch.Tensor(np.random.choice(list(self.graph.nodes), size=1, replace=True, p=self.graph_pbs[graph_indices[i]].cpu().detach().numpy()))
                    
                    if self.constant_probability == False: 
                        self.graph_pbs = [] 
                        for i in range(self.graph_number_gm): 
                            self.graph_pbs.append(torch.FloatTensor(self.add_pb(self.graph_types[i])).to(self.device))
                            
        else:
            if self.uniform_random:
                random_int = torch.randint(0, self.num_nodes, size=(self.batch_size, 1)).to(self.device)
                for i in range(self.batch_size):   
                    while random_int[i] in state_upd[:, : self.num_servers][i]:
                        random_int[i] = torch.randint(0, self.num_nodes, size=(1,))
            else: 
                if self.general_model == False:
                    if self.seq_req: 
                        random_int = next_req.to(self.device)
                    else:            
                        random_int = np.random.choice(list(self.graph.nodes), size=(self.batch_size,), p=self.probabilities, replace=True)
                        for i in range(self.batch_size):
                            while random_int[i] in state_upd[:, :self.num_servers][i]:
                                random_int[i] = np.random.choice(list(self.graph.nodes), size=1, p=self.probabilities)
                        random_int = torch.tensor(random_int).reshape(self.batch_size, 1).to(self.device)

                        if self.constant_probability == False: 
                            self.add_graph_pb(self.graph, 1, self.graph_type, replicable = False)
                            self.probabilities = [self.graph.nodes[node]['probability'] for node in self.graph.nodes]     
                else:  
                    random_int = torch.zeros(self.batch_size, 1).to(self.device)
                    for i in range(self.batch_size):
                        random_int[i] = torch.Tensor(np.random.choice(list(self.graph.nodes), size=1, replace=True, p=self.graph_pbs[graph_indices[i]].cpu().detach().numpy()))
                        while float(random_int[i]) in state_upd[:, :self.num_servers][i]:
                            random_int[i] = torch.Tensor(np.random.choice(list(self.graph.nodes), size=1, replace=True, p=self.graph_pbs[graph_indices[i]].cpu().detach().numpy()))
                    
                    if self.constant_probability == False: 
                        self.graph_pbs = [] 
                        for i in range(self.graph_number_gm): 
                            self.graph_pbs.append(torch.FloatTensor(self.add_pb(self.graph_types[i])).to(self.device))
                # env.graph_pbs[graph_indices[i]].view(1, -1)
        
        

        # we sort it so that we do not repeat ourselves e.g. [2, 1, 3] server locations is the same as [3, 1, 2] 
        if self.balanced_algorithm: 
            # Create two vectors of size [1, 4]
            vector1 = state_upd[:, : self.num_servers]
            # vector2 = server_indices.clone()

            # Flatten the vectors to 1D tensors
            vector1 = vector1.view(-1)
    
            # Sort both vectors based on the values in vector1
            sorted_indices = torch.argsort(vector1)
            sorted_vector1 = vector1[sorted_indices]
            sorted_server_indices = server_indices[sorted_indices]

            # Reshape them back to size [1, 4] if needed
            sorted_state = sorted_vector1.view(1, -1)
            # random_int = next_req.to(self.device)
            sorted_state = torch.cat((sorted_state, random_int), dim=1)
            # sorted_vector2 = sorted_vector2.view(1, -1)
            return sorted_state, reward_batch, {}, sorted_server_indices
        else:
            sorted_state = torch.sort(state_upd[:, : self.num_servers], dim=1)[0]
            sorted_state = torch.cat((sorted_state, random_int), dim=1)

            if self.general_model: 
                return (sorted_state, graph_indices), reward_batch, {}
            else: 
                return sorted_state, reward_batch, {}
        

    # def build_cost_matrix(self):
    #     return torch.from_numpy(graph.rand(self.num_nodes, 30, 0.7)[1])
    
    def build_cost_matrix(self, graph):
        cost_matrix = torch.zeros(self.num_nodes, self.num_nodes)
        for i in range(self.num_nodes):
            for j in range(self.num_nodes):
                if i != j:
                    cost_matrix[i, j] = nx.dijkstra_path_length(graph, i, j)
        return cost_matrix
    

    def build_cost_matrix_nn(self, graph):
        large_value = 1e6
        cost_matrix = torch.full((self.num_nodes, self.num_nodes), large_value)
        
        for i in range(self.num_nodes):
            for j in range(self.num_nodes):
                if i != j:
                    if graph.has_edge(i, j):  # Check if there is an immediate edge between node i and node j
                        cost_matrix[i, j] = nx.dijkstra_path_length(graph, i, j)
                else: 
                    cost_matrix[i, j] = 0

        return cost_matrix



    def create_graph(self, graph_type,  graph_max_weight = 30, graph_density = 0.7):


        if graph_type == 'paper':
             # create a random adjency matrix
            np.random.seed(seed=self.seed)
            m = np.zeros((self.num_nodes, self.num_nodes))
            for i in range(self.num_nodes):
                for j in range(self.num_nodes):
                    if i != j:
                        p = np.random.random()
                        if p < graph_density:  # kind of sparsity
                            m[i][j] = np.random.randint(1, graph_max_weight + 1)
                            # m[i][j] = 1
                            m[j][i] = m[i][j]
                        else:
                            m[i][j] = np.inf
                            m[j][i] = m[i][j]

            # connect isolated nodes
            for i, row in enumerate(m):
                connected = self.check(row)  # check if is connected
                if not connected:
                    while True:
                        for j in range(self.num_nodes):
                            if i != j:
                                p = np.random.random()
                                if p < 0.9:  # kind of sparsity
                                    m[i][j] = np.random.randint(1, self.max_weight + 1)
                                    # m[i][j] = 1
                                    m[j][i] = m[i][j]

                        row = m[i][:]
                        connected = self.check(row)
                        if connected:
                            break

            G = nx.from_numpy_array(m)  
            return G

        if graph_type =='cycle':
            G = nx.cycle_graph(self.num_nodes)
            return G
        
        if graph_type =='line':
            # Create an empty graph
            np.random.seed(seed=self.seed)
            G = nx.Graph()

            # Add nodes to the graph
            num_nodes = self.num_nodes
            G.add_nodes_from(range(0, num_nodes-1))

            # Add edges to create a line graph
            for i in range(0, num_nodes-1):
                G.add_edge(i, i+1)
                if self.var_distance: 
                    weight = np.random.exponential() 
                    G.add_edge(i, i+1, weight=weight)
                else: 
                    G.add_edge(i, i+1)
            return G
        
        if graph_type =='grsubop': 
            
            G = nx.DiGraph()

            self.num_nodes = 3
            self.num_servers = 2
            # Add nodes
            G.add_node(0)
            G.add_node(1)
            G.add_node(2)

            # Add directed edges with weights
            G.add_edge(0, 1, weight=1)
            G.add_edge(1, 2, weight=2)
            G.add_edge(2, 1, weight=2)
            G.add_edge(2, 0, weight=1000)
            
            return G

        if graph_type =='grop': 
            
            G = nx.DiGraph()

            self.num_nodes = 3
            self.num_servers = 2
            # Add nodes
            G.add_node(0)
            G.add_node(1)
            G.add_node(2)

            # Add directed edges with weights
            G.add_edge(0, 1, weight=1)
            G.add_edge(1, 0, weight=1)
            G.add_edge(1, 2, weight=2)
            G.add_edge(2, 1, weight=2)
            G.add_edge(2, 0, weight=1000)
            
            return G

        if graph_type == 'grsubop_1': 
            self.num_nodes = 4
            self.num_servers = 2
            G = nx.DiGraph()
            G.add_node(0)
            G.add_node(1)
            G.add_node(2)
            G.add_node(3)
            
            G.add_edge(0, 1, weight=2)
            G.add_edge(1, 2, weight=3)
            G.add_edge(2, 1, weight=3)
            G.add_edge(2, 0, weight=1000)
            G.add_edge(2, 3, weight=2)
            G.add_edge(3, 2, weight=2)

            return G

        if graph_type == 'grop_1': 
            self.num_nodes = 4
            self.num_servers = 2
            G = nx.DiGraph()
            G.add_node(0)
            G.add_node(1)
            G.add_node(2)
            G.add_node(3)
            
            G.add_edge(0, 1, weight=2)
            G.add_edge(1, 0, weight=2)
            G.add_edge(1, 2, weight=3)
            G.add_edge(2, 1, weight=3)
            G.add_edge(2, 0, weight=1000)
            G.add_edge(2, 3, weight=2)
            G.add_edge(3, 2, weight=2)

            return G
        if graph_type == 'grsubop_2': 
            self.num_nodes = 5
            self.num_servers = 2
            G = nx.DiGraph()
            
            G.add_node(0)
            G.add_node(1)
            G.add_node(2)
            G.add_node(3)
            G.add_node(4)
            
            G.add_edge(0, 1, weight=1)
            G.add_edge(1, 2, weight=1)
            G.add_edge(2, 3, weight=2)
            G.add_edge(3, 2, weight=2)
            G.add_edge(3, 4, weight=1)
            G.add_edge(4, 0, weight=1000)
            G.add_edge(4, 2, weight=2)

            return G

        
        if 'toy' in graph_type:
            
            # Use regular expression to find digits
            matches = re.findall(r'\d+', graph_type)
            
            random.seed(int(matches[0]))

            G = nx.DiGraph()
            
            
            num_nodes = self.num_nodes
            high_cost = 1000
            
            
            for i in range(num_nodes):
                G.add_node(i)
            
            
            for i in range(num_nodes - 1):
                G.add_edge(i, i + 1, weight=random.randint(1, 5))
            
            
            G.add_edge(num_nodes - 1, 0, weight=high_cost)
            G.add_edge(num_nodes // 2, 0, weight=high_cost)
            
            
            for i in range(2, num_nodes - 1):
                if i % 2 == 0:
                    G.add_edge(i, random.randint(i + 2, num_nodes - 1), weight=random.randint(10, 20))
            
            return G
            
            
            
        if 'tree' in graph_type:
            
            # Use regular expression to find digits
            matches = re.findall(r'\d+', graph_type)
            seed = int(matches[0])
            random.seed(seed)
            np.random.seed(seed)

            # Create an empty graph
            G = nx.Graph()


            # Add 10 nodes to the graph
            for i in range(self.num_nodes):
                G.add_node(i)

            # Add edges to create a random tree
            for i in range(1, self.num_nodes):
                parent = random.choice(list(G.nodes)[:i])
                
                if self.var_distance: 
                    weight = np.random.exponential() 
                    G.add_edge(parent, i, weight=weight)
                else: 
                    G.add_edge(parent, i)
            
            return G

        if graph_type =='SF':
            random.seed(self.seed)
            np.random.seed(self.seed)
            df = pd.read_csv('data/TNs/SiouxFalls_net.csv')
            df.A = df.A - 1 
            df.B = df.B - 1

            G = nx.Graph()

            # if self.var_distance:
            #     for index, row in df.iterrows():
            #         source = row['A']
            #         target = row['B']
            #         weight = np.random.exponential() # Change the range for different distances
            #         G.add_edge(source, target, weight=weight)
            # else:
            G = nx.from_pandas_edgelist(df, source='A', target='B', edge_attr='a0', create_using=G)
            
            if self.var_distance: 
                for edge in G.edges():
                    weight = np.random.exponential() # Change the range for different distances
                    G.edges[edge]['weight'] = weight

            return G

        if graph_type =='EM': 
            random.seed(self.seed)
            np.random.seed(self.seed)
            df = pd.read_excel('data/TNs/EM.xlsx')
            G = nx.Graph()
            for index, row in df.iterrows():
                init_node = int(row['init_node']) - 1  # Subtract 1 from node labels
                term_node = int(row['term_node']) - 1  # Subtract 1 from node labels
                if self.var_distance: 
                    weight = np.random.exponential() # Change the range for different distances
                    G.add_edge(init_node, term_node, weight=weight)
                else: 
                    G.add_edge(init_node, term_node)
            return G

        # Set a random seed for reproducibility
        if 'plane' in graph_type:
            matches = re.findall(r'\d+', graph_type)
            G = generate_planar_graph(self.num_nodes, seed = int(matches[0]))
            return G
        

        if graph_type =='grid':
            
            
            rows = columns = int(math.sqrt(self.num_nodes))
            G = nx.grid_2d_graph(rows, columns)

            if self.var_distance: 
                for edge in G.edges():
                    weight = np.random.exponential() # Change the range for different distances
                    G.edges[edge]['weight'] = weight
            return G

        if 'grid_dir' in graph_type:
            

            matches = re.findall(r'\d+', graph_type)
            seed = int(matches[0])    
            G = self.create_grid_graph_dir(seed)

            return G
            
        if 'grid_gre' in graph_type:
            

            matches = re.findall(r'\d+', graph_type)
            seed = int(matches[0])

            G = self.create_grid_graph_gre(seed)
            
            while not nx.is_connected(G): 
                seed = (seed+1)*100
                
                G = self.create_grid_graph_gre(seed)
        
            
            return G
                
        

    def create_graph_gm(self, seed, regenerate_graph = False):

        if regenerate_graph:
            seed = random.randint(1, 100000)
            
            

        if self.general_model_gt in ['bn_grid_gre', 'psn_grid_gre', 'lgnm_grid_gre', 'grid_gre']:
            
            random.seed(seed)
            np.random.seed(seed)
            G = self.create_grid_graph_gre(seed)
            
            while not nx.is_connected(G):
                seed = (seed+1)*100
                G = self.create_grid_graph_gre(seed)
            
            return G
        
        if self.general_model_gt == 'grid_dir':
            random.seed(seed)
            np.random.seed(seed)
            G = self.create_grid_graph_dir(seed)
    
            return G

        if self.general_model_gt == 'plane':

            random.seed(seed)
            G = generate_planar_graph(self.num_nodes, seed = seed)
    
            return G

        if self.general_model_gt == 'tree': 

            random.seed(seed)
            np.random.seed(seed)

            # Create an empty graph
            G = nx.Graph()


            # Add n nodes to the graph
            for i in range(self.num_nodes):
                G.add_node(i)

        
            

            # Add edges to create a random tree
            for i in range(1, self.num_nodes):
                parent = random.choice(list(G.nodes)[:i])
                
                if self.var_distance: 
                    weight = np.random.exponential() 
                    G.add_edge(parent, i, weight=weight)
                else: 
                    G.add_edge(parent, i)

            
            return G
    
    def regenerate_graph_gm(self): 

        self.graph_edge_indices = []
        self.cost_matrices = []
        self.graph_types = []
        self.graph_pbs = []

        for i in range(self.graph_number_gm): 
            graph = self.create_graph_gm(i, regenerate_graph = True)
            self.graph_edge_indices.append(convert.from_networkx(graph).edge_index)
            self.cost_matrices.append(self.build_cost_matrix(graph).to(self.device))
            self.cost_matrices_nn.append(self.build_cost_matrix_nn(graph).to(self.device))
            # self.graph_types.append(f'{self.general_model_gt}_{i}')
            graph_type = f'{self.general_model_gt}_{i}'
            if self.uniform_random==False:
                self.graph_pbs.append(torch.FloatTensor(self.add_graph_pb(graph, i, graph_type, replicable= False)).to(self.device))
    
    

    def create_grid_graph_dir(self, seed):

        random.seed(seed)
        np.random.seed(seed)
        

        width = height = int(math.sqrt(self.num_nodes))

        G_undirected = nx.grid_2d_graph(width, height)
        G = nx.DiGraph()
        G.add_nodes_from(G_undirected.nodes)
        for e in G_undirected.edges:
            if random.random() < 0.5:
                G.add_edge(e[0], e[1])
            else:
                G.add_edge(e[1], e[0])

        components = list(nx.algorithms.components.strongly_connected_components(G))
        while len(components) > 1:
            smallest = min(components, key=len)
            candidates = []
            for node in smallest:
                for out_edge in G.out_edges(node):
                    if out_edge[1] not in smallest:
                        candidates.append(out_edge)
                for in_edge in G.in_edges(node):
                    if in_edge[0] not in smallest:
                        candidates.append(in_edge)
            swap_edge = candidates[np.random.choice(len(candidates))]
            G.add_edge(swap_edge[1], swap_edge[0])
            components = list(nx.algorithms.components.strongly_connected_components(G))

        mapping = {(i, j): i * height + j for i, j in G.nodes()}
        G = nx.relabel_nodes(G, mapping)

        if self.var_distance:
            for node, neighbor in G.edges():
                weight = np.random.exponential() 
                G.add_edge(node, neighbor, weight=weight)

        return G 
        
    def create_grid_graph_gre_f(self, seed): 
        random.seed(seed)
        np.random.seed(seed)
        
        p = 0.6057  # Probability to remove horizontal and vertical edges
        q = 0.4162  # Probability to add diagonal edges

        
        rows = columns = int(math.sqrt(self.num_nodes))
        G = nx.grid_2d_graph(rows, columns)
        for edge in list(G.edges):
            if edge[0][0] == 0 or edge[0][0] == rows - 1 or edge[0][1] == 0 or edge[0][1] == columns - 1:
                # Always keep rim edges
                continue
            if edge[0][0] % 2 == 0:
                # Horizontal edge
                if random.random() > p:
                    G.remove_edge(edge[0], edge[1])
            else:
                # Vertical edge
                if not G.has_edge((edge[0][0] - 1, edge[0][1]), edge[0]):
                    # If no horizontal edge exists, keep the vertical edge
                    # if random.random() > p(1 - p):
                    #     G.remove_edge(edge[0], edge[1])
                    continue
                else:
                    if random.random() > (1 - p):
                        G.remove_edge(edge[0], edge[1])

        # Generate diagonal edges
        for node in G.nodes:
            if node[0] % 2 != 0 and node[1] % 2 != 0:
                neighbors = [(node[0] + 1, node[1] + 1), (node[0] + 1, node[1] - 1),
                            (node[0] - 1, node[1] + 1), (node[0] - 1, node[1] - 1)]
                for neighbor in neighbors:
                    if neighbor in G.nodes and random.random() < q:
                        G.add_edge(node, neighbor)
            else: 
                pass
        mapping = {(i, j): i * columns + j for i, j in G.nodes()}
        G = nx.relabel_nodes(G, mapping)
    

        return G



    def create_grid_graph_gre_orig(self, seed): 

            random.seed(seed)
            np.random.seed(seed)
            p = 0.6057  # Probability to remove horizontal and vertical edges
            q = 0.4162
            # q = 0  # Probability to add diagonal edges

            rows = columns = int(math.sqrt(self.num_nodes))
            edge_list  = []
            for i in range(rows - 1, -1, -1):
                # Iterate over columns from left to right
                for j in range(columns):
                    current_node = (i, j)
                    
                    # horizontal edge   
                    if (j+1) <= (columns-1): 
                        right_node = (i, j+1)
                        if (current_node[0] == 0 or current_node[0] == (rows-1)):
                            edge_list.append((current_node, right_node))
                        else: 
                            if random.random() < p:
                                edge_list.append((current_node, right_node))


            for i in range(rows - 1, -1, -1):
                # Iterate over columns from left to right
                for j in range(columns):
                    current_node = (i, j)

                    if (i-1) >= 0: 
                        upper_node = (i-1, j)
                        if current_node[1] == 0 or current_node[1] == (columns-1):
                            edge_list.append((current_node, upper_node))
                        else: 
                            if ((i, j-1), current_node) not in edge_list: 
                                edge_list.append((current_node, upper_node))
                            else:
                                
                                if (current_node[0] == 1 and current_node[1] == 1):
                                    if random.random() < (1 - p*(1-p)):
                                        edge_list.append((current_node, upper_node)) 
                                else:
                                    if random.random() < p:
                                        edge_list.append((current_node, upper_node))

            G = nx.Graph()

            

            # Add edges to the graph with random weights
            if self.var_distance:
                for edge in edge_list:
                    weight = np.random.exponential() # Change the range for different distances
                    G.add_edge(edge[0], edge[1], weight=weight)
            else: 
                G.add_edges_from(edge_list)

            for node in G.nodes:
                if node[0] % 2 != 0 and node[1] % 2 != 0:
                    neighbors = [(node[0] + 1, node[1] + 1), (node[0] + 1, node[1] - 1),
                                (node[0] - 1, node[1] + 1), (node[0] - 1, node[1] - 1)]
                    for neighbor in neighbors:
                        if neighbor in G.nodes and random.random() < q:
                            if self.var_distance: 
                                weight = np.random.exponential() # Change the range for different distances
                                G.add_edge(node, neighbor, weight=weight)
                            else: 
                                G.add_edge(node, neighbor)

            mapping = {(i, j): i * columns + j for i, j in G.nodes()}
            G = nx.relabel_nodes(G, mapping)

            return G



    # check connectivity
    def check(self, row):
        brow = np.isinf(row)
        edges = np.count_nonzero(~brow) - 1  # number of edges

        if edges > 0:
            connected = True
        else:
            connected = False

        return connected

    def average_hops(self):

        if self.graph_type =='paper': 
            graph = self.remove_invalid_edges()
        else: 
            graph = self.graph

        total_hops = 0
        num_pairs = 0

        for source in graph.nodes():
            for target in graph.nodes():
                if source != target:
                    shortest_path_length = nx.shortest_path_length(graph, source, target)
                    total_hops += shortest_path_length
                    num_pairs += 1

        average_hops = total_hops / num_pairs

        return round(average_hops, 2)

    def remove_invalid_edges(self):

        graph = self.graph
        valid_edges = [(u, v, data) for u, v, data in graph.edges(data=True) if data['weight'] != np.inf]
        valid_graph = nx.Graph()
        valid_graph.add_edges_from(valid_edges)

        unit_graph = valid_graph.copy()
        for u, v, data in unit_graph.edges(data=True):
            data['weight'] = 1
        return unit_graph

    def plot_graph(self):
        
        if self.graph_type =='paper': 
            graph = self.remove_invalid_edges()
        else: 
            graph = self.graph

        if 'plane' in self.graph_type:
            pos = nx.planar_layout(graph)
        else:
            pos = nx.spring_layout(graph, seed = self.seed)  # Choose a layout algorithm, adjust as needed

        nx.draw(graph, pos=pos, with_labels=True, node_color='lightblue', edge_color='gray')
        plt.title("Graph Visualization")
        plt.show()

    def add_graph_pb(self, graph, seed, graph_type, replicable = True):
        
        if replicable:
            random.seed(seed)
            np.random.seed(seed)
        
        if 'bn' in graph_type:
            weights = np.random.binomial(n=10, p=0.5, size=self.num_nodes)  
        elif 'psn' in graph_type:
            weights = np.random.poisson(lam=5, size=self.num_nodes)  
        elif 'lgnm' in graph_type:
            weights = np.random.lognormal(mean=0, sigma=1, size=self.num_nodes)  
        else:
            weights = np.random.exponential(scale=1, size=self.num_nodes)

        weights_sum = np.sum(weights)
        probabilities = weights / weights_sum
        
        
        for i, node in enumerate(graph.nodes):
            graph.nodes[node]['probability'] = probabilities[i]

        if self.general_model == True:
            return probabilities

    def add_pb(self, graph_type):

        if 'bn' in graph_type:
            weights = np.random.binomial(n=10, p=0.5, size=self.num_nodes)  
        elif 'psn' in graph_type:
            weights = np.random.poisson(lam=5, size=self.num_nodes)  
        elif 'lgnm' in graph_type:
            weights = np.random.lognormal(mean=0, sigma=1, size=self.num_nodes)  
        else:
            weights = np.random.exponential(scale=1, size=self.num_nodes)

        weights_sum = np.sum(weights)
        probabilities = weights / weights_sum

        return probabilities


    def plot_graph_pb(self):
        
        if self.graph_type =='paper': 
            graph = self.remove_invalid_edges()
        else: 
            graph = self.graph
    


        # Scale node sizes by probability
        max_prob = max(nx.get_node_attributes(graph, 'probability').values())
        min_prob = min(nx.get_node_attributes(graph, 'probability').values())
        node_sizes = [1000 * (graph.nodes[node]['probability'] - min_prob) / (max_prob - min_prob) + 100 
                        for node in graph.nodes]  # Scale sizes between 100 and 1100
        node_labels = {node: f'{node}\n{graph.nodes[node]["probability"]:.2f}' for node in graph.nodes}

        plt.figure(figsize=(12, 12))  

        if self.graph_type == 'grid' or 'grid_gre' in self.graph_type:
            n = len(graph.nodes())
            columns = int(n ** 0.5)  # Assuming a square grid
            coords = {node: (node % columns, -node // columns) for node in graph.nodes()}
            
            if self.var_distance:
                edge_labels = {(u, v): f'{graph[u][v]["weight"]:.2f}' for u, v in graph.edges()}
                
                inverted_weights = {edge: 1.0 / graph.edges[edge]['weight'] for edge in graph.edges()}
                nx.set_edge_attributes(graph, inverted_weights, 'inverted_weight')

                pos = nx.spring_layout(graph, pos=coords, fixed=coords.keys(), weight='inverted_weight')
                # pos = nx.spring_layout(graph, weight='inverted_weight')

                nx.draw(graph, pos, with_labels=True, node_size=node_sizes, node_color='lightblue', edge_color='gray', labels=node_labels)
                nx.draw_networkx_edge_labels(graph, pos, edge_labels=edge_labels)
            else:
                nx.draw(graph, coords, with_labels=True, node_size=node_sizes, node_color='lightblue', edge_color='gray', labels=node_labels)

            plt.show()
        
        elif 'grid_dir' in self.graph_type:
            n = len(graph.nodes())
            columns = int(n ** 0.5)  # Assuming a square grid
            coords = {node: (node % columns, -node // columns) for node in graph.nodes()}
            
            if self.var_distance:
                edge_labels = {}
                for u, v in graph.edges():
                    incoming_label = f'{graph[v][u]["weight"]:.2f} →' if (v, u) in graph.edges() else ''
                    outgoing_label = f'← {graph[u][v]["weight"]:.2f}' if (u, v) in graph.edges() else ''
                    edge_labels[(u, v)] = f'{outgoing_label}\n{incoming_label}' if incoming_label and outgoing_label else outgoing_label or incoming_label

                pos = nx.spring_layout(graph, pos=coords, fixed=coords.keys())
                nx.draw(graph, pos, with_labels=True, node_size=node_sizes, node_color='lightblue', edge_color='gray', labels=node_labels, arrows=True)
                nx.draw_networkx_edge_labels(graph, pos, edge_labels=edge_labels)
            else:
                nx.draw(graph, coords, with_labels=True, node_size=node_sizes, node_color='lightblue', edge_color='gray', labels=node_labels, arrows=True)

            plt.show()
        
        elif 'toy' in self.graph_type:
            if self.var_distance:
                edge_labels = {}
                for u, v in graph.edges():
                    incoming_label = f'{graph[v][u]["weight"]:.2f} →' if (v, u) in graph.edges() else ''
                    outgoing_label = f'← {graph[u][v]["weight"]:.2f}' if (u, v) in graph.edges() else ''
                    edge_labels[(u, v)] = f'{outgoing_label}\n{incoming_label}' if incoming_label and outgoing_label else outgoing_label or incoming_label
                # Invert the weights to ensure larger weights mean larger distances
                inverted_weights = {edge: 1.0 / graph.edges[edge]['weight'] for edge in graph.edges()}
                nx.set_edge_attributes(graph, inverted_weights, 'inverted_weight')
                
                pos = nx.spring_layout(graph, weight='inverted_weight', seed=self.seed)  # Choose a layout algorithm, adjust as needed
                edge_labels = {(u, v): f'{graph[u][v]["weight"]:.2f}' for u, v in graph.edges()}
                nx.draw(graph, pos, with_labels=True, node_size=node_sizes, node_color='lightblue', edge_color='gray', labels=node_labels)
                nx.draw_networkx_edge_labels(graph, pos, edge_labels=edge_labels)
                
        else:
            
            if 'plane' in self.graph_type:
                pos = nx.planar_layout(graph)
            else:
                # Invert the weights to ensure larger weights mean larger distances
                inverted_weights = {edge: 1.0 / graph.edges[edge]['weight'] for edge in graph.edges()}
                nx.set_edge_attributes(graph, inverted_weights, 'inverted_weight')
                
                pos = nx.spring_layout(graph, weight='inverted_weight', seed=self.seed)  # Choose a layout algorithm, adjust as needed
            
            
            if self.var_distance:
                edge_labels = {(u, v): f'{graph[u][v]["weight"]:.2f}' for u, v in graph.edges()}
                nx.draw(graph, pos, with_labels=True, node_size=node_sizes, node_color='lightblue', edge_color='gray', labels=node_labels)
                nx.draw_networkx_edge_labels(graph, pos, edge_labels=edge_labels)
            else:
                nx.draw(graph, pos, with_labels=True, node_size=node_sizes, node_color='lightblue', labels=node_labels)
            
            plt.show()

    def get_min_distance(self, state):

        actions_all = state[:, :-1]
        A = self.cost_matrix[actions_all.squeeze().long()]
        B = state[:, -1]

        # Reshape B to match the dimensions of A
        B_reshaped = B.view(self.batch_size, 1, 1).long()

        # Use torch.gather() to extract scalars from A
        # result = torch.gather(A, self.num_servers, B_reshaped.expand(-1, self.num_servers, -1))
        result = torch.gather(A, 2, B_reshaped.expand(-1, self.num_servers, -1))

        min_distance, min_index = torch.min(result, dim=1)
        min_distance = torch.reshape(min_distance, (self.batch_size, -1))

        min_distance_actions = torch.gather(actions_all, 1, min_index.view(-1, 1))

        return min_distance, min_distance_actions

    def distance_request(self, state_batch): 

        if self.general_model:
            tensor_list = []
            for i in range(self.batch_size): 
                state = state_batch[0][i]
                graph_index = state_batch[1][i]
                cost_matrix = self.cost_matrices[graph_index]
                dis_req_state = cost_matrix[:, state[-1].long()].reshape(-1, 1)
                tensor_list.append(dis_req_state)

            result = torch.stack(tensor_list).reshape(-1, 1)
            return result
        
        else:
            A = self.cost_matrix.unsqueeze(0).repeat(self.batch_size, 1, 1)
            B = state_batch[:, -1]

            # Reshape B to match the dimensions of A
            B_reshaped = B.view(self.batch_size, 1, 1).long()

            # Use torch.gather() to extract scalars from A
            # result = torch.gather(A, self.num_servers, B_reshaped.expand(-1, self.num_servers, -1))
            result = torch.gather(A, 2, B_reshaped.expand(-1, self.num_nodes, -1))

            return result.view(-1, 1)





    def create_grid_graph_gre(self, seed): 

            random.seed(seed)
            p = 0.6057  # Probability to remove horizontal and vertical edges
            q = 0.4162  # Probability to add diagonal edges

            
            rows = columns = int(math.sqrt(self.num_nodes))
            G = nx.grid_2d_graph(rows, columns)
            for edge in list(G.edges):
                if edge[0][0] == 0 or edge[0][0] == rows - 1 or edge[0][1] == 0 or edge[0][1] == columns - 1:
                    # Always keep rim edges
                    continue
                if abs(edge[0][0] - edge[1][0]) == 0 and abs(edge[1][1] - edge[1][1]) == 1:
                    if random.random() > p:
                        G.remove_edge(edge[0], edge[1])
                else:
                    # Vertical edge
                    if not G.has_edge((edge[0][0] - 1, edge[0][1]), edge[0]):
                        # If no horizontal edge exists, keep the vertical edge
                        if random.random() > (1 - p):
                            G.remove_edge(edge[0], edge[1])
                    else:
                        if random.random() > (1 - p):
                            G.remove_edge(edge[0], edge[1])

            # Generate diagonal edges
            for node in G.nodes:
                if node[0] % 2 != 0 and node[1] % 2 != 0:
                    neighbors = [(node[0] + 1, node[1] + 1), (node[0] + 1, node[1] - 1),
                                (node[0] - 1, node[1] + 1), (node[0] - 1, node[1] - 1)]
                    for neighbor in neighbors:
                        if neighbor in G.nodes and random.random() < q:
                            G.add_edge(node, neighbor)
                else: 
                    pass
            mapping = {(i, j): i * columns + j for i, j in G.nodes()}
            G = nx.relabel_nodes(G, mapping)
        

            return G
            # for rows in range(2, self.num_nodes+1):
            #     columns = self.num_nodes // rows
            #     if rows * columns == self.num_nodes and rows > columns > 1:

                


# def create_grid_graph_gre(self, seed): 

    #         random.seed(seed)
    #         p = 0.6057  # Probability to remove horizontal and vertical edges
    #         q = 0.4162
    #         # q = 0  # Probability to add diagonal edges

    #         width = height = int(math.sqrt(self.num_nodes))
    #         # random grid (GRE)
    #         G = nx.grid_2d_graph(width, height)
    #         # remove edges
    #         p = 0.6057
    #         for x in range(width - 1): # left to right
    #             for y in range(height - 1): # bottom to top
    #                 # horizontal edge removal
    #                 if (y > 0):
    #                     if np.random.random() < (1 - p):
    #                         G.remove_edge((x, y), (x + 1, y))
    #                     # vertical edge removal
    #                 if (x > 0):
    #                     if (y == 0) or G.has_edge((x - 1, y), (x, y)):
    #                         if np.random.random() < p * (1 - p):
    #                             G.remove_edge((x, y), (x, y + 1))
            
            
    #         q = 0.4162
    #         for x in range(1, width, 2):
    #             for y in range(1, height, 2):
    #                 for edge_to in [(x - 1, y - 1), (x + 1, y - 1), (x - 1, y + 1), (x + 1, y + 1)]:
    #                     if G.has_node(edge_to) and np.random.random() < q:
    #                         # if self.var_distance: 
    #                             weight = np.random.randint(1, 31)  # random edge length between 1 and 30
    #                             G.add_edge((x, y), edge_to, weight=weight)
    #                         # else: 
    #                         #     G.add_edge((x, y), edge_to)



    #         mapping = {(i, j): i * height + j for i, j in G.nodes()}
    #         G = nx.relabel_nodes(G, mapping)

        

    #         return G




# class KServerEnv():
#     def __init__(self, num_nodes = 9, num_servers = 2, uniform_random = True, general_model = False, constant_probability = True, balanced_algorithm = False, request_same_node = False,  arrival_rates = False, seq_req = False, var_distance = False, general_model_gt = 'tree', graph_type = 'tree_1', device ="cpu", seed = 123, batch_size = 1):
#         self.num_nodes = num_nodes
     
#         self.num_servers = num_servers
#         self.batch_size = batch_size
#         self.graph_type = graph_type
#         self.general_model = general_model
#         self.seed = seed 
#         self.balanced_algorithm = balanced_algorithm
#         self.uniform_random = uniform_random
#         self.constant_probability = constant_probability
#         self.general_model_gt = general_model_gt
#         # print('gen model type')
#         # print(self.general_model_gt)
#         self.request_same_node = request_same_node
#         self.seq_req = seq_req
#         self.var_distance = var_distance
#         if self.request_same_node == False and self.seq_req== True: 
#             raise ValueError("Unfeasible configuration")
#         self.arrival_rates = arrival_rates
#         if self.batch_size < 50: 
#             self.graph_number_gm = self.batch_size
#         else:
#             self.graph_number_gm = 50
#         self.graph_edge_indices = []
#         self.cost_matrices = []
#         self.cost_matrices_nn = []
#         self.graph_pbs = []
#         if general_model_gt =='all':
#             self.graph_types = ['tree_1', 'grid_gre_1','cycle', 'line', 'grid']
#         elif general_model_gt =='dir_check':
#             self.graph_types = ['grsubop', 'grop']
#         elif general_model_gt =='dir_check_1':
#             self.graph_types = ['grsubop_1', 'grop_1']
#         elif general_model_gt =='SF':
#             self.graph_types = ['SF']
#         elif general_model_gt =='EM':
#             self.graph_types = ['EM']
#         else: 
#             self.graph_types = []
#         # print(self.graph_types)
#         if device == "cuda":
#             self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#         else:
#             self.device = torch.device(device)
#         # Since we don't want to regenerate big graphs all the time, we save them and reuse them after the first generation 
#         if (self.graph_type not in ["EM", "SF"])  and (self.general_model_gt not in ["EM", "SF"]) and (self.num_nodes > 200):

#             file_path = f'data/big_graphs/{self.graph_type}_{self.num_nodes}.pkl'

#             if os.path.exists(file_path):
#                 with open(file_path, 'rb') as f:    
#                     self.graph = pickle.load(f)
#                 self.cost_matrix = torch.load(f'data/big_graphs/cost_matrix/{self.graph_type}_{self.num_nodes}.pth').to(self.device)                
#                 self.probabilities = torch.load(f'data/big_graphs/probabilities/{self.graph_type}_{self.num_nodes}.pth')
#             else: 
#                 # creating and saving graph
#                 self.graph = self.create_graph(self.graph_type)
#                 with open(file_path, 'wb') as f:
#                     pickle.dump(self.graph, f)
#                 # doing this, since it is computationally expensive
#                 seed = int(re.findall(r'\d+', self.graph_type)[0]) 
#                 self.add_graph_pb(self.graph, seed, self.graph_type)
#                 # self.add_graph_pb(self.graph, self.seed, self.graph_type)
#                 # creating and  saving cost matrix 
#                 self.cost_matrix = self.build_cost_matrix(self.graph).to(self.device)                
#                 torch.save(self.cost_matrix, f'data/big_graphs/cost_matrix/{self.graph_type}_{self.num_nodes}.pth')
#                 # creating and saving probabilities
#                 self.probabilities = [self.graph.nodes[node]['probability'] for node in self.graph.nodes]  
#                 torch.save(self.probabilities, f'data/big_graphs/probabilities/{self.graph_type}_{self.num_nodes}.pth')            
#             if os.path.exists(f'data/big_graphs/cost_matrix_nn/{self.graph_type}_{self.num_nodes}.pth'):    
#                 self.cost_matrix_nn = torch.load(f'data/big_graphs/cost_matrix_nn/{self.graph_type}_{self.num_nodes}.pth').to(self.device) 
                
#             else: 
#                 self.cost_matrix_nn = self.build_cost_matrix_nn(self.graph).to(self.device).to(self.device)                
#                 torch.save(self.cost_matrix_nn, f'data/big_graphs/cost_matrix_nn/{self.graph_type}_{self.num_nodes}.pth')
                
#         else:     
#             self.graph = self.create_graph(self.graph_type)
#             if self.graph_type in ["EM", "SF"] or self.general_model_gt in ["EM", "SF"]:  
#                 self.num_nodes = self.graph.number_of_nodes()
#                 # self.num_servers = round(self.num_nodes/6)
#             if any(keyword in self.graph_type for keyword in ['grid_gre', 'tree', 'grid_dir']):
#                 seed = int(re.findall(r'\d+', self.graph_type)[0])
#                 self.add_graph_pb(self.graph, seed, self.graph_type)
#                 # self.add_graph_pb(self.graph, self.seed, self.graph_type)
#                 # print(int(re.findall(r'\d+', self.graph_type)[0]))
#             else: 
#                 self.add_graph_pb(self.graph, self.seed, self.graph_type)
#             self.cost_matrix = self.build_cost_matrix(self.graph).to(self.device)
#             self.cost_matrix_nn = self.build_cost_matrix_nn(self.graph).to(self.device)
#             self.probabilities = [self.graph.nodes[node]['probability'] for node in self.graph.nodes]    

        
 
#         # print(self.graph_types)
#         if self.general_model: 
            
#             if general_model_gt == 'grid':
                
#                 for rows in range(2, self.num_nodes+1):
#                     columns = self.num_nodes // rows
#                     if rows * columns == self.num_nodes and rows > columns > 1:
#                         graph = nx.grid_2d_graph(rows, columns)
#                         mapping = {(i, j): i * columns + j for i, j in graph.nodes()}
#                         graph = nx.relabel_nodes(graph, mapping)
#                         graph_type = f'grid_col{columns}'
#                         self.graph_types.append(graph_type)
#                         self.graph_edge_indices.append(convert.from_networkx(graph).edge_index)
#                         self.cost_matrices.append(self.build_cost_matrix(graph).to(self.device))
#                         self.cost_matrices_nn.append(self.build_cost_matrix_nn(graph).to(self.device))
#                         if self.uniform_random==False:
#                             self.graph_pbs.append(torch.FloatTensor(self.add_graph_pb(graph, rows, graph_type)).to(self.device))

#             elif general_model_gt in ['plane', 'tree', 'grid_gre', 'grid_dir', 'bn_grid_gre', 'psn_grid_gre', 'lgnm_grid_gre']: 
                
#                 for i in range(self.graph_number_gm): 
#                         graph = self.create_graph_gm(i)
#                         self.graph_edge_indices.append(convert.from_networkx(graph).edge_index)
#                         self.cost_matrices.append(self.build_cost_matrix(graph).to(self.device))
#                         self.cost_matrices_nn.append(self.build_cost_matrix_nn(graph).to(self.device))
#                         graph_type = f'{self.general_model_gt}_{i}'
#                         self.graph_types.append(graph_type)
#                         if self.uniform_random==False:
#                             self.graph_pbs.append(torch.FloatTensor(self.add_graph_pb(graph, i, graph_type)).to(self.device))
#             else: 
#                 # print(3)
#                 for n, graph_type in enumerate(self.graph_types):
#                     # print(i)
#                     graph = self.create_graph(graph_type)
#                     self.graph_edge_indices.append(convert.from_networkx(graph).edge_index)
#                     self.cost_matrices.append(self.build_cost_matrix(graph).to(self.device))
#                     self.cost_matrices_nn.append(self.build_cost_matrix_nn(graph).to(self.device))
#                     if self.uniform_random==False:
#                         self.graph_pbs.append(torch.FloatTensor(self.add_graph_pb(graph, n, graph_type)).to(self.device))
#                 # print(f'Edge Indices: {self.graph_edge_indices}')
#         self.reset() 

