import gym
import numpy as np
import random
import pickle
import time
import networkx as nx
import copy
import torch, torch_geometric
from torch.autograd import Variable
import torch.nn.parallel
import scipy
from collections import deque, defaultdict
from enum import Enum

from typing import List

from Netlist2Graph.netlist_dataset import NetlistDataset
import infrastructure.pytorch_util as ptu
from networks.discriminators import DiscriminatorNet
from .domain_rewards import bayes_opt, domain_reward, supply_connectivity, sim_graph
#from .connectivity import ConnectivityUtils as cu

from mlcirc_pipelines.interfaces_rl import GraphToPygRLInterface
from mlcirc_pipelines.bayes_opt.bo_utils import CircuitBayesOptConfig
from mlcirc_utils import StandardGraph, ComponentEnum
from mlcirc_tech_sky130 import BaseSky130Configuration, Sky130TechPlugin
from mlcirc_utils.utils import BiDict

from mlcirc_utils.branch_detection import ConnectivityUtils as cu

from mlcirc_pipelines.interfaces_rl import get_enum_name_from_value, RLEdgeTerminalEnum, RLNodeTypeEnum
class CircuitFeatures(Enum):
    TYPE = 0
    VTH = 1
    LENGTH = 2
    WIDTH = 3
    FINGERS = 4
    
class CircuitEnv(gym.Env):
    """ This class extends the gym environment for circuits specifically.
    """
    PERF_LOW = -1
    PERF_HIGH = 0
    
    def __init__(self, max_nodes, num_types, edge_types, dataset, set_masks=True, simple=False,
                 embed_dim=None, n_gcn_layers=None, n_layers=None, layer_size=None, learning_rate=None,
                 gpu_id=0, stack_index=None, save_dir=None, diff=False, se=True):
        """Constructor method for the CircuitEnv class.

        :max_nodes: The total number of nodes a graph can have. 
        :num_types: The number of types of nodes in the graph.
        :dataset: The dataset to use for the environment.
        """
        self.save_directory = save_dir
        # Initialize sim environment
        self.max_nodes = max_nodes  
        self.num_types = num_types          # 3 types: net, nmos, pmos
        self.edge_types = edge_types        # 4 types: gate, drain, source, bulk  
        self.set_masks = set_masks 

        # 4 actions: Pick node to connect to, pick node to connect from, specify edge type (bulk-to-net, gate-to-net, etc.), terminate agent                                                           
        self.action_space = gym.spaces.MultiDiscrete([self.max_nodes, self.max_nodes, self.edge_types, 2, 5]) 

        # This needs some work
        self.observation_space = gym.Space(shape=[1, self.max_nodes, self.max_nodes + self.num_types ])    

        # Initialize graph
        self.graph = StandardGraph()
        self.dataset = dataset
        num_graphs_raw = self.dataset.len()
        self.graphs: List[StandardGraph] = []
        self.labels = []
        self.gpu_id = gpu_id

        # Iterate over dataset
        for n in range(num_graphs_raw):
            graph,l = self.dataset.get(n)           
            self.graphs.append(graph)
            self.labels.append(l)
        
        # Initialize discriminator
        self.discriminator = None if simple else DiscriminatorNet(self.num_types, embed_dim, n_gcn_layers, n_layers, layer_size, learning_rate)
        self.counter = 0
        self.max_action = 20
        self.stack_index = stack_index
        self.mask_obj = Mask(self.max_nodes, self.num_types, self.edge_types, diff=diff, se=se)

        self.stdgraph_to_data_interface = GraphToPygRLInterface()
        configuration = BaseSky130Configuration()
        self.plugin = Sky130TechPlugin(configuration)

        self.bo_config = CircuitBayesOptConfig(
                                            specs={'Ao_db':{'target':17, 'minimize':False}, 'dc_power': {'target':10e-3, 'minimize':True},
                                                    'bw_3db': {'target':10e6, 'minimize':False},'gain_difference': {'target':0, 'minimize':True}},
                                            base_netlist="/path/opamp_extra_biases.scs",
                                            base_harness_yaml_file='/path/opamp/meas_harness_nooutbias.yaml',
                                            run_dir="/path",
                                            mlcirc_tech_pdk="mlcirc_tech_sky130",
                                            base_pdk_configuration="BaseSky130Configuration",
                                            pdk_tech_plugin="Sky130TechPlugin",
                                            gpu_id=gpu_id) 
        self.construction_config = {"mlcirc_tech_pdk": "mlcirc_tech_sky130",
                                    "base_pdk_configuration": "BaseSky130Configuration",
                                    "pdk_tech_plugin": "Sky130TechPlugin",
                                    "pickle_path": "/path/pickle_graph.pkl"
                                                }

        self.target_dict = {'Ao_db':  {'target': 17,
                                'minimize': False , 'constraint': False, 'exact': False},
                            'dc_power':   {'target': 10e-3,
                                'minimize': True, 'constraint': False, 'exact':False},
                            'bw_3db': {'target':10e6, 'minimize':False, 'constraint':False, 'exact':False},
                            'gain_difference': {'target':0, 'minimize':True, 'constraint':False, 'exact':True, 'bound': 1.25}}  

        
    def find_n_neighbors(self, adj_matrix, start_node, n):
        """
        Finds the nearest neighbors in a graph.
        
        :adj_matrix: An adjacency matrix of the graph.
        :start_node: The node to start from.
        :n: The number of neighbors to find.
        :return: List of neighbors.
        """
        num_nodes = adj_matrix.shape[0]
        visited = [False] * num_nodes
        distance = [None] * num_nodes
        queue = deque([start_node])
        
        visited[start_node] = True
        distance[start_node] = 0
        counter = 0
        
        while queue and counter < n:
            node = queue.popleft()
            for i in range(num_nodes):
                
                if adj_matrix[node][i] and not visited[i] and counter < n-1:
                    visited[i] = True
                    distance[i] = distance[node] + 1
                    counter = counter + 1
                    queue.append(i)

        neighbors_and_dist = [(node, dist) for node, dist in enumerate(distance) if dist != None]
        
        neighbors_and_dist.sort(key=lambda x: x[1])
        neighbors_and_dist = [i[0] for i in neighbors_and_dist]
        return neighbors_and_dist


    def reset(self, random_start=False,  random_circuit=False, index = None, num_sample=0, 
              verbose=False, iteration=None, traj_no=None, sample_nodes=None):
        """
        Resets the circuit environment.
        
        :random_start: Enables random sampling of the graph.
        :verbose: Turns on more print statements for the debug.
        """
        self.iteration = iteration
        self.traj_no = traj_no
        self.graph = StandardGraph()

        # FIXME: Currently uses last of the graphs as data
        if random_circuit:
            self.dataset_index = random.randint(0, len(self.graphs)-1)
        else:
            self.dataset_index = 2 if index is None else index #len(self.graphs)-1 # used for expert action sampling # 0 for strongarm # len(self.graphs)-1  for inverter chain


        self.ground_nodes = np.zeros([self.max_nodes, self.num_types])

        # # Create graph from the sample nodes and relabel to consecutive integers, keep original numbers
        graph = copy.deepcopy(self.graphs[self.dataset_index])

        # Identify original VSS and VDD
        original_nodes = graph.get_nodes()  
        vss_node = next((n[0] for n in original_nodes if n[0] == "VSS"), None) # Check if VSS is in the graph
        vdd_node = next((n[0] for n in original_nodes if n[0] == "VDD"), None)

        # Create a new mapping that ensures: VSS→0, VDD→1, others assigned consecutively
        mapping = {}
        used_labels = set()

        if vss_node is not None:
            mapping[vss_node] = 0
            used_labels.add(0)

        if vdd_node is not None:
            mapping[vdd_node] = 1
            used_labels.add(1)

        pins_full = graph.get_node_features("VSS", ["custom_features", "pin_order"])
        pins = []
        for p in pins_full: 
            if 'VIN' in p or 'VOUT' in p or 'VDD' in p or 'VSS' in p:
                pins.append(p)
        next_label = 2
        threshold = 2
        for node_id, node_type in original_nodes:
            if node_id not in mapping and (node_id in pins and node_type["component_type"] == ComponentEnum.NET):
                while next_label in used_labels:
                    next_label += 1
                    threshold += 1
                mapping[node_id] = next_label
                used_labels.add(next_label)
    
        
        for node_id, node_type in original_nodes:
            if node_id not in mapping:
                while next_label in used_labels:
                    next_label += 1
                mapping[node_id] = next_label
                used_labels.add(next_label)

        # Relabel the graph before doing anything else
        graph.relabel_nodes(mapping)

        reverse_orig = {}
        for idx, node in enumerate(original_nodes):
            reverse_orig[list(mapping.keys()).index(node[0])] = idx


        # Now create the adjacency matrix from the re-labeled graph
        adj_gnd = graph.get_adjacency_matrix()

        # Fill the maximum-sized adjacency
        self.ground_adj = np.zeros([self.max_nodes, self.max_nodes])
        self.ground_adj[:adj_gnd.shape[0], :adj_gnd.shape[1]] = adj_gnd

        # If still need to sample from the graph, do it now (on the re-labeled adjacency).
        ground_truth_nodes = graph.get_num_nodes()

        if sample_nodes is None:
            if random_start:
                node_start = random.randint(0, adj_gnd.shape[0] - 1)
                nodes_to_sample = random.randint(ground_truth_nodes - 3, ground_truth_nodes - 1)
                sample_nodes = self.find_n_neighbors(adj_gnd, node_start, nodes_to_sample)
            else:
                node_start = 0
                nodes_to_sample = num_sample
                sample_nodes = self.find_n_neighbors(adj_gnd, node_start, nodes_to_sample)

        print("sample_nodes", sample_nodes)    
        net_nodes = []
        for n in graph.get_nodes():
            if (n[0] not in sample_nodes) and (n[1]["component_type"] == ComponentEnum.NET \
                                    and n[1]["custom_features"]["one_hot_node_type"][0] != 1):
                net_nodes.append(n[0])
        print("net_nodes", net_nodes)
        sample_nodes = sample_nodes #+ net_nodes
        self.sampled_nodes = len(sample_nodes)
        if (verbose): 
            print("------------------------")
            print(f"Node to start from: {node_start}, Number of nodes to sample: {self.sampled_nodes}")
            print(f"Neighbors chosen are: {sample_nodes}")

        # actually do the sampling with the neighbors + nets node list
        graph.subsample(sample_nodes)
        self.graph: StandardGraph = graph

        #TODO fix this mapping
        map_sample = {}
        reverse = {}
        avail = threshold
        for idx, node in enumerate(self.graph.get_nodes()):
            if node[0] < threshold: 
                reverse[node[0]] = node[0]
                map_sample[node[0]] = node[0]
            else:   
                map_sample[node[0]] = avail
                reverse[avail] = node[0]
                avail += 1
        self.graph.relabel_nodes(mapping=map_sample)

       
        for n in self.graph.get_nodes():
            custom_features = self.graph.get_node_features(n[0], ["custom_features"])
            if custom_features is None:
                custom_features = {}
            custom_features["original_mapping"] = reverse[n[0]]#n[0]#reverse_orig[n[0]]
            self.graph.overwrite_node_field(n[0], "custom_features", custom_features)

        for n in list(self.graph.get_nodes()):
            if n[1]["component_type"] == ComponentEnum.NET:
                try: 
                    name=n[1]["custom_features"]["node_name"]
                except:
                    name = "NET" +str(n[0])
                
                if name=="VDD": 
                    self.PWR = n[0]
                elif name == "VSS":
                    self.GND = n[0]

        # Create and process the ground truth graph
        self.ground_truth = copy.deepcopy(self.graphs[self.dataset_index])
        self.ground_truth.relabel_nodes(mapping=mapping)
        graph_ground = copy.deepcopy(self.ground_truth)

        graph_ground_sym = copy.deepcopy(self.ground_truth)
        connected, supply_paths = cu.supply_connectivity(graph_ground_sym, filter_paths=True)
        stages, remaining, pairs, pairs_nets = cu.get_circuit_symmetry(graph_ground_sym, supply_paths=supply_paths)

        if pairs.items() or pairs_nets.items():
            self.symmetric_ground = copy.deepcopy(pairs)
            for k, v in pairs_nets.items():
                self.symmetric_ground.add(k, v)

            self.symmetric_nodes = BiDict()
            for k,v in pairs.items():
                if k in map_sample.keys() and v in map_sample.keys():
                    self.symmetric_nodes.add( map_sample[k], map_sample[v])
            for k,v in pairs_nets.items():
                if k in map_sample.keys() and v in map_sample.keys():
                    self.symmetric_nodes.add( map_sample[k], map_sample[v])

        else:
            self.symmetric_ground = None
            self.symmetric_nodes = None


        gg_conv = self.stdgraph_to_data_interface.transform_to_custom_graph(graph_ground)
        self.edge_attr = gg_conv.edge_attr 
        self.edge_index = gg_conv.edge_index

        adj = np.zeros([self.max_nodes, self.max_nodes])
        adj[:self.sampled_nodes, :self.sampled_nodes] = self.graph.get_adjacency_matrix()

        # Process the graph sample from network x into torhc geometric.
        conv = self.stdgraph_to_data_interface.transform_to_custom_graph(self.graph)
        
        self.ob_dict = {}
        self.ob_dict['adj'] = adj
        self.ob_dict['nodes'] = conv.node_type
        self.ob_dict['edge_index'] = conv.edge_index 
        self.ob_dict['edge_attr'] = conv.edge_attr
        self.ob = list(self.ob_dict.values())
        
        # Reset the counter
        self.counter = 0

        node_features_padded = torch.zeros(self.ground_nodes.shape).to(ptu.device)
        node_features_padded[:self.ob_dict['nodes'].shape[0], :self.ob_dict['nodes'].shape[1]] = \
            self.ob_dict['nodes']
        
        self.mask_obj.reset_mask(self.graph)
        return conv, graph_ground, self.graph, self.mask_obj, self.symmetric_nodes
    
    def step(self, action, domain=True):
        """
        Steps through the environment with validity checks
        :param action: is a 4 component vector 
        :return:
        """

        #old_circuit = copy.deepcopy(self.graph)
        total_nodes = self.graph.get_num_nodes() #old_circuit.number_of_nodes()
        info = {}
        reward = -2 #-1
        done = False
        import time

        ac00 = int(action[0][0])
        ac10 = int(action[1][0])
        _, _, _, _, conv = self.graph_to_obs(self.graph)
        new_graph = copy.deepcopy(self.graph)
        add_sym = False
        # RULE 1: One of the nodes we connect from and to must be existing, the second node
        #         must be within the valid range of scaffolded nodes ( + self.num_types not including 
        #        the input/output/VDD/VSS net nodes) 
        if ac00<total_nodes and ac10 < total_nodes + self.num_types-4 or \
           ac10<total_nodes and ac00 < total_nodes + self.num_types-4 :  
            # Reorder nodes into smaller index is node0, larger index is node1
            node0 = int(ac00) if ac00<ac10 else int(ac10)
            node1 = int(ac10) if ac00<ac10 else int(ac00)   

            # Adding a node + associated edge
            if node1 >= total_nodes: 
                existing_type = new_graph.get_node_features(node0,['component_type']) 
                    # ^ above returns index of one hot (0 if net, 1 if PMOS, etc)
                comp_type = RLNodeTypeEnum(node1-total_nodes)
                if comp_type == RLNodeTypeEnum.VDD or comp_type == RLNodeTypeEnum.GND or comp_type == RLNodeTypeEnum.INPUT \
                    or comp_type == RLNodeTypeEnum.OUTPUT:
                    comp_type = ComponentEnum.NET
                else: 
                    comp_type = ComponentEnum[comp_type.name]
                
                # check that action 2 is valid for the component type
                check2_type = comp_type if comp_type != ComponentEnum.NET else existing_type
                if check2_type == ComponentEnum.NFET or check2_type == ComponentEnum.PFET:
                    if int(action[2][0])<4:
                        check_act2 = True
                    else: 
                        check_act2 = False
                elif check2_type == ComponentEnum.RES or check2_type == ComponentEnum.CAP or check2_type == ComponentEnum.IND:
                    if 4<= int(action[2][0]) and int(action[2][0])<self.edge_types:
                        check_act2 = True
                    else: 
                        check_act2 = False
                else:
                    check_act2 = False
                # RULE 2: can only connect components to nets, check that this is true
                if (comp_type==ComponentEnum.NET or existing_type==ComponentEnum.NET) \
                    and (comp_type != existing_type) \
                    and check_act2: # only transistors FIXME

                    # RULE 3: One terminal of component may only be connected to one net
                    # This section collects connected terminals if the existing node is a component
                    if new_graph.get_node_features(node0,['component_type'])!=ComponentEnum.NET:
                        existing_edges = [0 for i in range(self.edge_types)]
                        # find all of the existing edges, add their attributes together to 
                        # collect all the exisitng terminal connections
                        for edge in self.graph.get_edges():
                            if edge[0] == node0 or edge[1] == node0:
                                current_attrs = new_graph.get_edge_features(edge[0], edge[1], 'one_hot_edge_attr')
                                existing_edges = [x + y for x, y in zip(current_attrs, existing_edges)]
                        unconnected_indices = [index for index, value in enumerate(existing_edges) if value == 0]
                        # if everything is connected any action is invalid
                        # for fets
                        if new_graph.get_node_features(node0,['component_type']) == ComponentEnum.NFET or \
                            new_graph.get_node_features(node0,['component_type']) == ComponentEnum.PFET:
                            check_fet = existing_edges[0:4]
                            unconnected_fet = [index for index, value in enumerate(check_fet) if value == 0]
                            if not unconnected_fet:
                                return conv, reward, done, info,self.counter, self.graph, self.mask_obj, self.symmetric_nodes

                        # for passives
                        elif new_graph.get_node_features(node0,['component_type']) == ComponentEnum.RES:
                            check_pass = existing_edges[4:]
                            unconnected_pass = [index for index, value in enumerate(check_pass) if value == 0]
                            if not unconnected_pass:
                                return conv, reward, done, info,self.counter, self.graph, self.mask_obj, self.symmetric_nodes
                        elif new_graph.get_node_features(node0,['component_type']) == ComponentEnum.CAP or \
                            new_graph.get_node_features(node0,['component_type']) == ComponentEnum.IND:
                            check_pass = existing_edges[4:6]
                            unconnected_pass = [index for index, value in enumerate(check_pass) if value == 0]
                            if not unconnected_pass:
                                return conv, reward, done, info,self.counter, self.graph, self.mask_obj, self.symmetric_nodes
                        # otherwise a new net can connect to D, G, S terminal (bulk should already be connected)
                    else:
                        # existing edges are all 0 if node is a net (nothing matters)
                        existing_edges = [0 for i in range(self.edge_types)]  

                    # Check edge type action against known connected terminals to prevent rule 3 violation
                    # everything ok, add node  
                    if existing_edges[int(action[2][0])] < 1:   

                        add_sym = int(action[4][0]) in [1, 4] and self.symmetric_ground is not None # if connecting to VDD or VSS, also add symmetry
                        if add_sym and (int(action[4][0]) == 4 and \
                            new_graph.get_node_features(node0,['component_type'])!=ComponentEnum.NET): # expansion only happens from a net
                            return conv, reward, done, info,self.counter, self.graph, self.mask_obj, self.symmetric_nodes
                        else:
                            self.add_node(node1-total_nodes, new_graph, add_sym=add_sym)

                        edge_attr = [0 for i in range(self.edge_types)]
                        edge_attr[int(action[2][0])] = 1
                        existing_node_type = new_graph.get_node_features(node0,['component_type'])
                        new_graph.add_edge(node0, total_nodes, node_1_terminal_list =[], # FIXME put std terminals here?
                                    node_2_terminal_list =[], 
                                    phys_features=None, custom_features=None)
                        new_graph.overwrite_edge_field(node0, total_nodes, "one_hot_edge_attr", edge_attr)
                        if int(action[4][0])>0:
                            if self.symmetric_nodes.in_bidict(node0):
                                node0_sym = self.symmetric_nodes.get(node0)
                            else:
                                node0_sym = node0
                            
                            if int(action[4][0]) == 1 and node0_sym != node0 and add_sym:
                                new_graph.add_edge(node0_sym, total_nodes+1, node_1_terminal_list =[], # FIXME put std terminals here?
                                            node_2_terminal_list =[], 
                                            phys_features=None, custom_features=None)
                                new_graph.overwrite_edge_field(node0_sym, total_nodes+1, "one_hot_edge_attr", edge_attr)
                            
                            elif int(action[4][0]) == 2 and node0_sym != node0 and self.symmetric_nodes.in_bidict(node0) \
                                and existing_node_type ==ComponentEnum.NET :
                                new_graph.add_edge(node0_sym, total_nodes, node_1_terminal_list =[], # FIXME put std terminals here?
                                            node_2_terminal_list =[], 
                                            phys_features=None, custom_features=None)
                                new_node_type = new_graph.get_node_features(total_nodes, ['component_type'])
                                sym_edge_attr = edge_attr.copy()
                                if new_node_type == ComponentEnum.PFET or new_node_type == ComponentEnum.NFET:
                                    if int(action[2][0]) == RLEdgeTerminalEnum.D.value:
                                        if sym_edge_attr[RLEdgeTerminalEnum.S.value] ==1:
                                            return conv, reward, done, info,self.counter, self.graph, self.mask_obj, self.symmetric_nodes
                                        sym_edge_attr[RLEdgeTerminalEnum.D.value] = 0
                                        sym_edge_attr[RLEdgeTerminalEnum.S.value] = 1
                                    elif int(action[2][0]) == RLEdgeTerminalEnum.S.value:
                                        if sym_edge_attr[RLEdgeTerminalEnum.D.value] ==1:
                                            return conv, reward, done, info,self.counter, self.graph, self.mask_obj, self.symmetric_nodes
                                        sym_edge_attr[RLEdgeTerminalEnum.S.value] = 0
                                        sym_edge_attr[RLEdgeTerminalEnum.D.value] = 1
                                    else: 
                                        return conv, reward, done, info,self.counter, self.graph, self.mask_obj, self.symmetric_nodes
                                elif new_node_type in [ComponentEnum.RES, ComponentEnum.CAP, ComponentEnum.IND]:
                                    if int(action[2][0]) == RLEdgeTerminalEnum.M.value:
                                        if sym_edge_attr[RLEdgeTerminalEnum.P.value] ==1:
                                            return conv, reward, done, info,self.counter, self.graph, self.mask_obj, self.symmetric_nodes
                                        sym_edge_attr[RLEdgeTerminalEnum.M.value] = 0
                                        sym_edge_attr[RLEdgeTerminalEnum.P.value] = 1
                                    elif int(action[2][0]) == RLEdgeTerminalEnum.P.value:
                                        if sym_edge_attr[RLEdgeTerminalEnum.M.value] ==1:
                                            return conv, reward, done, info,self.counter, self.graph, self.mask_obj, self.symmetric_nodes
                                        sym_edge_attr[RLEdgeTerminalEnum.P.value] = 0
                                        sym_edge_attr[RLEdgeTerminalEnum.M.value] = 1
                                    else: 
                                        return conv, reward, done, info,self.counter, self.graph, self.mask_obj, self.symmetric_nodes

                                new_graph.add_edge(node0_sym, total_nodes, node_1_terminal_list =[], # FIXME put std terminals here?
                                    node_2_terminal_list =[], 
                                    phys_features=None, custom_features=None)
                                new_graph.overwrite_edge_field(node0_sym, total_nodes, "one_hot_edge_attr", sym_edge_attr)
                            
                            elif int(action[4][0]) ==3 and node0_sym != node0 and existing_node_type != ComponentEnum.NET: 
                                new_graph.add_edge(node0_sym, total_nodes, node_1_terminal_list =[], # FIXME put std terminals here?
                                    node_2_terminal_list =[], 
                                    phys_features=None, custom_features=None)
                                new_graph.overwrite_edge_field(node0_sym, total_nodes, "one_hot_edge_attr", edge_attr)

                            elif node0_sym == node0 and int(action[4][0]) == 4 and \
                                 existing_node_type == ComponentEnum.NET and add_sym: 
                                new_graph.add_edge(node0, total_nodes+1, node_1_terminal_list =[], # FIXME put std terminals here?
                                    node_2_terminal_list =[], 
                                    phys_features=None, custom_features=None)
                                new_graph.overwrite_edge_field(node0, total_nodes+1, "one_hot_edge_attr", edge_attr)
                            
                            else:
                                return conv, reward, done, info,self.counter, self.graph, self.mask_obj, self.symmetric_nodes
                            

                        ## add bulk connection to supply when you add the node
                        if node1-total_nodes == 2: # if NFET
                            supply_node = self.GND
                        elif node1-total_nodes == 1:  # if PFET
                            supply_node = self.PWR
                        else:
                            supply_node = None
                        if supply_node is not None:
                            new_graph.add_edge(supply_node, total_nodes, node_1_terminal_list =[], # FIXME put std terminals here?
                                        node_2_terminal_list =[], 
                                        phys_features=None, custom_features=None)
                            if supply_node == node0:
                                edge_bulk_attr = edge_attr
                                edge_bulk_attr[0] = 1
                            else: 
                                edge_bulk_attr = [1, 0, 0, 0, 0, 0, 0]
                            new_graph.overwrite_edge_field(supply_node, total_nodes, "one_hot_edge_attr", edge_bulk_attr)

                            if int(action[4][0])>0 and self.symmetric_nodes is not None and add_sym:
                                new_graph.add_edge(supply_node, total_nodes+1, node_1_terminal_list =[], # FIXME put std terminals here?
                                            node_2_terminal_list =[], 
                                            phys_features=None, custom_features=None)
                                new_graph.overwrite_edge_field(supply_node, total_nodes+1, "one_hot_edge_attr", edge_bulk_attr)
                        
                        if node1-total_nodes == 3: # if resistor
                            new_graph.add_edge(self.PWR, total_nodes, node_1_terminal_list =[], # FIXME put std terminals here?
                                        node_2_terminal_list =[], 
                                        phys_features=None, custom_features=None)
                            if self.PWR == node0:
                                edge_bulk_attr = edge_attr
                                edge_bulk_attr[-1] = 1
                            else: 
                                edge_bulk_attr = [0, 0, 0, 0, 0, 0, 1]
                            new_graph.overwrite_edge_field(self.PWR, total_nodes, "one_hot_edge_attr", edge_bulk_attr)

                            if int(action[4][0])>0 and self.symmetric_nodes is not None and add_sym:
                                new_graph.add_edge(self.PWR, total_nodes+1, node_1_terminal_list =[], # FIXME put std terminals here?
                                            node_2_terminal_list =[], 
                                            phys_features=None, custom_features=None)
                                new_graph.overwrite_edge_field(self.PWR, total_nodes+1, "one_hot_edge_attr", edge_bulk_attr)

                    else:
                        return conv, reward, done, info,self.counter, self.graph, self.mask_obj, self.symmetric_nodes
                elif not int(action[3][0]):
                    return conv, reward, done, info,self.counter, self.graph, self.mask_obj, self.symmetric_nodes

            # Adding just an edge
            else: 
                # Rule 2 check
                existing_type = new_graph.get_node_features(node0,['component_type'])
                second_type = new_graph.get_node_features(node1,['component_type'])
                
                # check that action 2 is valid for the component type
                check2_type = second_type if second_type != ComponentEnum.NET else existing_type
                if check2_type == ComponentEnum.NFET or check2_type == ComponentEnum.PFET:
                    if int(action[2][0])<4:
                        check_act2 = True
                    else: 
                        check_act2 = False
                elif check2_type == ComponentEnum.RES or check2_type == ComponentEnum.CAP or check2_type == ComponentEnum.IND:
                    if 4<= int(action[2][0]) and int(action[2][0])<self.edge_types:
                        check_act2 = True
                    else: 
                        check_act2 = False
                else:
                    check_act2 = False

                if (existing_type==ComponentEnum.NET or second_type==ComponentEnum.NET) and (existing_type!=second_type) \
                    and check_act2:

                    # Rule 3 check
                    component_checked = node1 if existing_type==ComponentEnum.NET else node0
                    net_checked = node0 if existing_type==ComponentEnum.NET else node1
                    # Check the component does not connect anywhere else at the selected terminal
                    existing_edges = [0 for i in range(self.edge_types)]
                    edge_names = [None for i in range(self.edge_types)]
                    for edge in new_graph.get_edges():
                        if edge[0] == component_checked or edge[1] == component_checked:
                            current_attrs = new_graph.get_edge_features(edge[0], edge[1], 'one_hot_edge_attr')
                            existing_edges = [x + y for x, y in zip(current_attrs, existing_edges)]
                            # get the net node of every terminal
                            for i, val in enumerate(current_attrs): 
                                if val:
                                    edge_names[i] = edge[1] if edge[0] == component_checked  else edge[0]
                    zero_indices = [index for index, value in enumerate(existing_edges) if value == 0]

                    # for fets
                    if new_graph.get_node_features(component_checked,['component_type']) == ComponentEnum.NFET or \
                        new_graph.get_node_features(component_checked,['component_type']) == ComponentEnum.PFET:
                        check_fet = existing_edges[0:4]
                        unconnected_fet = [index for index, value in enumerate(check_fet) if value == 0]
                        if not unconnected_fet:
                            return conv, reward, done, info,self.counter, self.graph, self.mask_obj, self.symmetric_nodes

                    # for passives
                    elif new_graph.get_node_features(component_checked,['component_type']) == ComponentEnum.RES:
                        check_pass = existing_edges[4:]
                        unconnected_pass = [index for index, value in enumerate(check_pass) if value == 0]
                        if not unconnected_pass:
                            return conv, reward, done, info,self.counter, self.graph, self.mask_obj, self.symmetric_nodes
                    elif new_graph.get_node_features(component_checked,['component_type']) == ComponentEnum.CAP or \
                        new_graph.get_node_features(component_checked,['component_type']) == ComponentEnum.IND:
                        check_pass = existing_edges[4:6]
                        unconnected_pass = [index for index, value in enumerate(check_pass) if value == 0]
                        if not unconnected_pass:
                            return conv, reward, done, info,self.counter, self.graph, self.mask_obj, self.symmetric_nodes

                    # Rule 3b if a net N is already connected to the source, it can't connect to any other terminal
                    # action would therefore be invalid. 
                    if (edge_names[3] is not None) and (edge_names[3]==net_checked):
                        return conv, reward, done, info,self.counter, self.graph, self.mask_obj, self.symmetric_nodes 

                    if existing_edges[int(action[2][0])] < 1:
                        # determine the edge attribute to assign to the now valid connection
                        try: 
                            old_edge_attr = new_graph.get_edge_features(node0, node1, ['one_hot_edge_attr'])
                            old_edge_attr[int(action[2][0])] = 1
                            new_graph.overwrite_edge_field(node0, node1, "one_hot_edge_attr", old_edge_attr)
                        except:
                            old_edge_attr = [0 for i in range(self.edge_types)]
                            old_edge_attr[int(action[2][0])] = 1
                            new_graph.add_edge(node0, node1, node_1_terminal_list =[], # FIXME put std terminals here?
                                    node_2_terminal_list =[], 
                                    phys_features=None, custom_features=None)
                            new_graph.overwrite_edge_field(node0, node1, "one_hot_edge_attr", old_edge_attr)

                        if int(action[4][0])>0:    
                            if int(action[4][0]) == 1 and self.symmetric_nodes is not None and \
                                (self.symmetric_nodes.in_bidict(node0) and self.symmetric_nodes.in_bidict(node1)):
                                node0_sym = self.symmetric_nodes.get(node0)
                                node1_sym = self.symmetric_nodes.get(node1)
                                new_graph.add_edge(node0_sym, node1_sym, node_1_terminal_list =[], # FIXME put std terminals here?
                                    node_2_terminal_list =[], 
                                    phys_features=None, custom_features=None)
                                new_graph.overwrite_edge_field(node0_sym, node1_sym, "one_hot_edge_attr", old_edge_attr)
                            
                                      
                            elif int(action[4][0]) in [2,3] and self.symmetric_nodes is not None and \
                                (self.symmetric_nodes.in_bidict(node0) or self.symmetric_nodes.in_bidict(node1)):
                                
                                if self.symmetric_nodes.in_bidict(node0):
                                    sym_node = self.symmetric_nodes.get(node0)
                                    cm_node = node1
                                else:
                                    sym_node = self.symmetric_nodes.get(node1)
                                    cm_node = node0

                                cm_node_type = new_graph.get_node_features(cm_node, ['component_type'])
                                sym_node_type = new_graph.get_node_features(sym_node, ['component_type'])
                                if cm_node_type!=ComponentEnum.NET and int(action[4][0])==3:
                                    return conv, reward, done, info,self.counter, self.graph, self.mask_obj, self.symmetric_nodes

                                comp_node_type = cm_node_type if cm_node_type != ComponentEnum.NET else sym_node_type
                                comp_node = cm_node if cm_node_type != ComponentEnum.NET else sym_node
                                comp_edges = new_graph.get_node_edges(comp_node)
                                ce_edge_attr = [0 for i in range(self.edge_types)]
                                for ce in comp_edges:
                                    ce_attr = new_graph.get_edge_features(ce[0], ce[1], ['one_hot_edge_attr'])
                                    ce_edge_attr = [x + y for x, y in zip(ce_attr, ce_edge_attr)]

                                if int(action[4][0]) == 2 and cm_node_type != ComponentEnum.NET:
                                    try:
                                        sym_edge_attr = new_graph.get_edge_features(cm_node, sym_node, ['one_hot_edge_attr'])
                                    except:
                                        sym_edge_attr = [0 for i in range(self.edge_types)]

                                    if comp_node_type == ComponentEnum.PFET or comp_node_type == ComponentEnum.NFET:
                                        if int(action[2][0]) == RLEdgeTerminalEnum.D.value:
                                            if ce_edge_attr[RLEdgeTerminalEnum.S.value] ==1:
                                                return conv, reward, done, info,self.counter, self.graph, self.mask_obj, self.symmetric_nodes
                                            sym_edge_attr[RLEdgeTerminalEnum.D.value] = 0
                                            sym_edge_attr[RLEdgeTerminalEnum.S.value] = 1
                                        elif int(action[2][0]) == RLEdgeTerminalEnum.S.value:
                                            if ce_edge_attr[RLEdgeTerminalEnum.D.value] ==1:
                                                return conv, reward, done, info,self.counter, self.graph, self.mask_obj, self.symmetric_nodes
                                            sym_edge_attr[RLEdgeTerminalEnum.S.value] = 0
                                            sym_edge_attr[RLEdgeTerminalEnum.D.value] = 1
                                        else: 
                                            return conv, reward, done, info,self.counter, self.graph, self.mask_obj, self.symmetric_nodes
                                    elif comp_node_type in [ComponentEnum.RES, ComponentEnum.CAP, ComponentEnum.IND]:
                                        if int(action[2][0]) == RLEdgeTerminalEnum.M.value:
                                            if ce_edge_attr[RLEdgeTerminalEnum.P.value] ==1:
                                                return conv, reward, done, info,self.counter, self.graph, self.mask_obj, self.symmetric_nodes
                                            sym_edge_attr[RLEdgeTerminalEnum.M.value] = 0
                                            sym_edge_attr[RLEdgeTerminalEnum.P.value] = 1
                                        elif int(action[2][0]) == RLEdgeTerminalEnum.P.value:
                                            if ce_edge_attr[RLEdgeTerminalEnum.M.value] ==1:
                                                return conv, reward, done, info,self.counter, self.graph, self.mask_obj, self.symmetric_nodes
                                            sym_edge_attr[RLEdgeTerminalEnum.P.value] = 0
                                            sym_edge_attr[RLEdgeTerminalEnum.M.value] = 1
                                        else: 
                                            return conv, reward, done, info,self.counter, self.graph, self.mask_obj, self.symmetric_nodes
                                else: 
                                    if ce_edge_attr[int(action[2][0])]==1:
                                        return conv, reward, done, info,self.counter, self.graph, self.mask_obj, self.symmetric_nodes
                                    try:
                                        sym_edge_attr = new_graph.get_edge_features(cm_node, sym_node, ['one_hot_edge_attr'])
                                    except:
                                        sym_edge_attr = [0 for i in range(self.edge_types)]
                                    sym_edge_attr[int(action[2][0])] = 1
                                    
                                new_graph.add_edge(cm_node, sym_node, node_1_terminal_list =[], # FIXME put std terminals here?
                                    node_2_terminal_list =[], 
                                    phys_features=None, custom_features=None)
                                new_graph.overwrite_edge_field(cm_node, sym_node, "one_hot_edge_attr", sym_edge_attr)

                            else:
                                return conv, reward, done, info,self.counter, self.graph, self.mask_obj, self.symmetric_nodes

                    else:
                        return conv, reward, done, info, self.counter, self.graph, self.mask_obj, self.symmetric_nodes
                elif not int(action[3][0]): # failed the rule 2 check, but can continue to termination
                    return conv, reward, done, info,self.counter, self.graph, self.mask_obj, self.symmetric_nodes
        else: 
            # does not update the old graph
            return conv, reward, done, info, self.counter, self.graph, self.mask_obj, self.symmetric_nodes

        # RULE 4: All components must be completely connected.
        # Rule 4 check if the agent thinks it is done
        if int(action[3][0]):
            # iterate through each node
            for n in new_graph.get_nodes():
                # if the node type is a component (not net), then check all the terminals
                if new_graph.get_node_features(n[0],['component_type']) != ComponentEnum.NET:
                    existing_edges = [0 for i in range(self.edge_types)]
                    # checking terminals by iterating through all the edges associated with node
                    for edge in new_graph.get_edges():
                        if edge[0] == n[0] or edge[1] == n[0]:
                            current_attrs = new_graph.get_edge_features(edge[0], edge[1], 'one_hot_edge_attr')
                            existing_edges = [x + y for x, y in zip(current_attrs, existing_edges)]
                    unconnected_indices = [index for index, value in enumerate(existing_edges) if value == 0]
                    # if there are edges that are unconnected (value is 0) and agent think it is 
                    # done, the action sequence is invalid
                    # if unconnected_indices:
                    #     return conv, reward, done, info, self.counter, self.graph, self.mask_obj 

                    # for fets
                    if new_graph.get_node_features(n[0],['component_type']) == ComponentEnum.NFET or \
                        new_graph.get_node_features(n[0],['component_type']) == ComponentEnum.PFET:
                        check_fet = existing_edges[0:4]
                        unconnected_fet = [index for index, value in enumerate(check_fet) if value == 0]
                        if unconnected_fet:
                            return conv, reward, done, info,self.counter, self.graph, self.mask_obj, self.symmetric_nodes

                    # for passives
                    elif new_graph.get_node_features(n[0],['component_type']) == ComponentEnum.RES:
                        check_pass = existing_edges[4:]
                        unconnected_pass = [index for index, value in enumerate(check_pass) if value == 0]
                        if unconnected_pass:
                            return conv, reward, done, info,self.counter, self.graph, self.mask_obj, self.symmetric_nodes

                    elif new_graph.get_node_features(n[0],['component_type']) == ComponentEnum.CAP or \
                        new_graph.get_node_features(n[0],['component_type']) == ComponentEnum.IND:
                        check_pass = existing_edges[4:6]
                        unconnected_pass = [index for index, value in enumerate(check_pass) if value == 0]
                        if unconnected_pass:
                            return conv, reward, done, info,self.counter, self.graph, self.mask_obj, self.symmetric_nodes
                    # elif unconnected_indices: 
                    #     return conv, reward, done, info,self.counter, self.graph, self.mask_obj  
            # 4b there can only be one cluster (one graph, no islands)
            clusters = new_graph.get_clusters()
            if len(clusters)!= 1:
                return conv, reward, done, info, self.counter, self.graph, self.mask_obj, self.symmetric_nodes

        # All actions were valid if we get to this point
        # Increase counter, determine whether or not to terminate construction
        self.counter += 1
        if (new_graph.get_num_nodes()>=self.max_nodes) or (int(action[3][0])==1) or (self.counter >= self.max_action):
            done = True
        else:
            done = False   

        # copy over new graph with additions into the stored self.graph
        self.graph: StandardGraph = new_graph

        if add_sym:
            self.symmetric_nodes.add(total_nodes, total_nodes+1) 
        # organize graph information into format usable by gnns (torch_geometric Data)  
        node_features_padded, adj, self.ob_dict, self.ob, conv = self.graph_to_obs(new_graph)
        self.mask_obj.step_mask(new_graph, [int(a[0]) for a in action])
        
        # ------ Rewards
        if self.discriminator is None:
            # no densification
            if (done  or (self.counter  >= 10)) and domain:
                
                # iterate through each node
                for n in new_graph.get_nodes():
                    # if the node type is a component (not net), then check all the terminals
                    if new_graph.get_node_features(n[0],['component_type']) != ComponentEnum.NET:
                        existing_edges = [0 for i in range(self.edge_types)]
                        # checking terminals by iterating through all the edges associated with node
                        for edge in self.graph.get_edges():
                            if edge[0] == n[0] or edge[1] == n[0]:
                                current_attrs = new_graph.get_edge_features(edge[0], edge[1], 'one_hot_edge_attr')
                                existing_edges = [x + y for x, y in zip(current_attrs, existing_edges)]
                        unconnected_indices = [index for index, value in enumerate(existing_edges) if value == 0]

                        # for fets
                        if new_graph.get_node_features(n[0],['component_type']) == ComponentEnum.NFET or \
                            new_graph.get_node_features(n[0],['component_type']) == ComponentEnum.PFET:
                            check_fet = existing_edges[0:4]
                            unconnected_fet = [index for index, value in enumerate(check_fet) if value == 0]
                            if unconnected_fet:
                                return conv, reward, done, info,self.counter, self.graph, self.mask_obj, self.symmetric_nodes

                        # for passives
                        elif new_graph.get_node_features(n[0],['component_type']) == ComponentEnum.RES:
                            check_pass = existing_edges[4:]
                            unconnected_pass = [index for index, value in enumerate(check_pass) if value == 0]
                            if unconnected_pass:
                                return conv, reward, done, info,self.counter, self.graph, self.mask_obj, self.symmetric_nodes
                            
                        elif new_graph.get_node_features(n[0],['component_type']) == ComponentEnum.CAP or \
                            new_graph.get_node_features(n[0],['component_type']) == ComponentEnum.IND:
                            check_pass = existing_edges[4:6]
                            unconnected_pass = [index for index, value in enumerate(check_pass) if value == 0]
                            if unconnected_pass:
                                return conv, reward, done, info,self.counter, self.graph, self.mask_obj, self.symmetric_nodes
                # 4b there can only be one cluster (one graph, no islands)
                clusters = new_graph.get_clusters()
                if len(clusters)!= 1:
                    return conv, reward, done, info, self.counter, self.graph, self.mask_obj, self.symmetric_nodes

                
                g = copy.deepcopy(new_graph)
                try:
                    sup_connected, _ = cu.supply_connectivity(g, filter_paths=True)
                    component_pairs = self.symmetric_nodes.items() if self.symmetric_nodes is not None else []
                except Exception as e:
                    print("failed to check connectivity", self.counter, e)
                    component_pairs = []
                    connected = False
                    sup_connected = False
                end = time.time()

                if sup_connected: #check_half and connected:
                    if self.iteration is not None and self.traj_no is not None: 
                        itr_str = f"itr{self.iteration}_traj{self.traj_no}_step{self.counter}_"
                    else:
                        itr_str = f"step{self.counter}_"
                    custom_str =  f"run{self.gpu_id}_"+itr_str+time.strftime("%Y-%m-%d_%H-%M-%S") if self.stack_index is None \
                                        else f"run{self.gpu_id}_"+f'stack{self.stack_index}_' +itr_str+ time.strftime("%Y-%m-%d_%H-%M-%S")
 
                    result = sim_graph(g, self.bo_config, self.construction_config, symmetric_pairs=component_pairs,
                                        custom_str=custom_str, 
                                    save_directory=self.save_directory)
                    pareto_results = [result] if result is not None else []
                    domain_rew, reward_idx = domain_reward(pareto_results, self.target_dict, completion_reward=30)
 

                    if domain_rew >30:
                        reward += domain_rew+10
                        print("simulated good! ", reward, self.counter, domain_rew)
                    elif domain_rew >3:
                        reward += domain_rew
                        print("simulated! ", reward, self.counter, domain_rew)
                    else: 
                        reward += domain_rew
                        print("not good ", reward, self.counter, domain_rew)
                        done=False

                else: 
                    reward = reward - 2 #20
                    done = False
                    print("done and fail ", reward, self.counter, int(action[3][0]))


        else: 
            # discriminator reward
            prob_valid = self.discriminator([conv]).squeeze()
            
            # similarity reward
            if prob_valid >= 0.5: 
                reward = 1
            else:
                reward = -1
   

            if (done  or (self.counter >= 10)) and domain:
                # iterate through each node
                for n in new_graph.get_nodes():
                    # if the node type is a component (not net), then check all the terminals
                    if new_graph.get_node_features(n[0],['component_type']) != ComponentEnum.NET:
                        existing_edges = [0 for i in range(self.edge_types)]
                        # checking terminals by iterating through all the edges associated with node
                        for edge in self.graph.get_edges():
                            if edge[0] == n[0] or edge[1] == n[0]:
                                current_attrs = new_graph.get_edge_features(edge[0], edge[1], 'one_hot_edge_attr')
                                existing_edges = [x + y for x, y in zip(current_attrs, existing_edges)]
                        unconnected_indices = [index for index, value in enumerate(existing_edges) if value == 0]

                        # for fets
                        if new_graph.get_node_features(n[0],['component_type']) == ComponentEnum.NFET or \
                            new_graph.get_node_features(n[0],['component_type']) == ComponentEnum.PFET:
                            check_fet = existing_edges[0:4]
                            unconnected_fet = [index for index, value in enumerate(check_fet) if value == 0]
                            if unconnected_fet:
                                return conv, reward, done, info,self.counter, self.graph, self.mask_obj, self.symmetric_nodes

                        # for passives
                        elif new_graph.get_node_features(n[0],['component_type']) == ComponentEnum.RES:
                            check_pass = existing_edges[4:]
                            unconnected_pass = [index for index, value in enumerate(check_pass) if value == 0]
                            if unconnected_pass:
                                return conv, reward, done, info,self.counter, self.graph, self.mask_obj, self.symmetric_nodes

                        elif new_graph.get_node_features(n[0],['component_type']) == ComponentEnum.CAP or \
                            new_graph.get_node_features(n[0],['component_type']) == ComponentEnum.IND:
                            check_pass = existing_edges[4:6]
                            unconnected_pass = [index for index, value in enumerate(check_pass) if value == 0]
                            if unconnected_pass:
                                return conv, reward, done, info,self.counter, self.graph, self.mask_obj, self.symmetric_nodes
                # 4b there can only be one cluster (one graph, no islands)
                clusters = new_graph.get_clusters()
                if len(clusters)!= 1:
                    return conv, reward, done, info, self.counter, self.graph, self.mask_obj, self.symmetric_nodes
                
                g = copy.deepcopy(new_graph)
            
                try:
                    sup_connected, _ = cu.supply_connectivity(g, filter_paths=True)
                    component_pairs = self.symmetric_nodes.items() if self.symmetric_nodes is not None else []
                except Exception as e:
                    print("failed to check connectivity", self.counter, e)
                    component_pairs = []
                    sup_connected = False
                if sup_connected:
                    if self.iteration is not None and self.traj_no is not None: 
                        itr_str = f"itr{self.iteration}_traj{self.traj_no}_step{self.counter}_"
                    else:
                        itr_str = f"step{self.counter}_"
                    custom_str =  f"run{self.gpu_id}_"+itr_str+time.strftime("%Y-%m-%d_%H-%M-%S") if self.stack_index is None \
                                        else f"run{self.gpu_id}_"+f'stack{self.stack_index}_' +itr_str+ time.strftime("%Y-%m-%d_%H-%M-%S")
                    try:
                        result = sim_graph(g, self.bo_config, self.construction_config, symmetric_pairs=component_pairs,
                                        custom_str=custom_str, 
                                    save_directory=self.save_directory)
                        pareto_results = [result] if result is not None else []
                        domain_rew, reward_idx = domain_reward(pareto_results, self.target_dict, completion_reward=30)
                    except: 
                        domain_rew, reward_idx = 0, 0

                    if domain_rew >30:
                        reward += domain_rew+10
                        print("simulated good! ", reward, self.counter, domain_rew)
                    elif domain_rew >3:
                        reward += domain_rew
                        print("simulated! ", reward, self.counter, domain_rew)
                    else: 
                        reward += domain_rew
                        print("not good ", reward, self.counter, domain_rew)
                        done=False
                else: 
                    reward = reward - 2 
                    done = False
                    print("done and fail ", reward, self.counter, prob_valid, int(action[3][0]))
                
            elif done: 
                if prob_valid > 0.5 and (new_graph.get_num_nodes() == self.ground_truth.get_num_nodes()):  
                    reward += 5
                else: 
                    reward = reward - 40 
        return conv, reward, done, info, self.counter, new_graph, self.mask_obj, self.symmetric_nodes

    def add_node(self, node_type, graph, add_sym = False):

        type_array = np.zeros(self.num_types)
        type_array[node_type] = 1
        total_nodes = graph.get_num_nodes()
        if node_type == 0:
            custom_feat = {"one_hot_node_type": type_array}
            comp_type = ComponentEnum.NET
            phys_features = None
        else:
            if node_type == RLNodeTypeEnum.VDD.value or node_type == RLNodeTypeEnum.GND.value \
                    or node_type == RLNodeTypeEnum.INPUT.value  \
                    or node_type == RLNodeTypeEnum.OUTPUT.value :
                comp_type = ComponentEnum.NET
            else:
                comp_type = ComponentEnum[RLNodeTypeEnum(node_type).name]
            phys_features, custom_features = self.plugin.generate_default_features(comp_type, 
                                                                    'X'+str(total_nodes))
            custom_feat = {"inst_name": custom_features.inst_name, 
                               "subckt_name": custom_features.subckt_name}
        
        custom_feat["one_hot_node_type"]= type_array
        graph.add_node(comp_type, phys_features, custom_feat,
                                    override_name=total_nodes)

        if add_sym:
            graph.add_node(comp_type, phys_features, custom_feat,
                                    override_name=total_nodes+1)


    def lookup(self, spec, goal_spec):
        goal_spec = [float(e) for e in goal_spec]
        norm_spec = (spec-goal_spec)/(goal_spec+spec)
        return norm_spec
    
    def graph_to_obs(self, graph):
        new_graph_adj = graph.get_adjacency_matrix()
        adj = np.zeros([self.max_nodes, self.max_nodes])
        adj[:new_graph_adj.shape[0], :new_graph_adj.shape[1]] = new_graph_adj
        ob_dict = {}
        ob_dict['adj'] = adj

        conv = self.stdgraph_to_data_interface.transform_to_custom_graph(graph)      

        ob_dict['nodes'] = conv.node_type
        ob_dict['edge_index'] = conv.edge_index 
        ob_dict['edge_attr'] = conv.edge_attr 
        node_features_padded = torch.zeros(self.ground_nodes.shape).to(ptu.device)
        node_features_padded[:ob_dict['nodes'].shape[0], 
                             :ob_dict['nodes'].shape[1]] = ob_dict['nodes']
        ob = [ob_dict['adj'], ob_dict['nodes'],
                    ob_dict['edge_index'], ob_dict['edge_attr']]

        return node_features_padded, adj, ob_dict, ob, conv

    def create_node_dict(self, original_mapping=None):
        new_node_dict = {enum.name: None for enum in CircuitFeatures}
        new_node_dict.update({"original_mapping": original_mapping})
        return new_node_dict

    
    def get_edge_difference(self, input_graph, dataset_graph):
        expert_graph = copy.deepcopy(input_graph)

        #first remap the nodes to follow the original dataset
        og_map = {}
        reverse_map = {}
        for n in expert_graph.get_nodes():
            og_map.update({n[0]:expert_graph.get_node_features(n[0],["custom_features","original_mapping"])})
            reverse_map.update({expert_graph.get_node_features(n[0], ["custom_features","original_mapping"]):n[0]})

        expert_graph.relabel_nodes(og_map)
        expert_nodes = [n[0] for n in expert_graph.get_nodes()]
        expert_edges = []
        expert_edge_attr = []
        for edge in expert_graph.get_edges():
            expert_edges.append((edge[0], edge[1]))
            expert_edge_attr.append(edge[2]['one_hot_edge_attr'])
        
        dataset_edges = []
        dataset_edge_attr = []
        for edge in dataset_graph.get_edges():
            dataset_edges.append((edge[0], edge[1]))
            dataset_edge_attr.append(edge[2]['one_hot_edge_attr'])
        # find common and not common edges between the expert and dataset
        set1 = set(expert_edges) 
        set2 = set(dataset_edges) 
        not_common_edges = set1.symmetric_difference(set2)
        common_edges = set1.intersection(set2)
        remaining_edges = set()
        for c in common_edges: 
            idx_exp = [i for i, t in enumerate(expert_edges) if t == c][0]
            idx_data = [i for i, t in enumerate(dataset_edges) if t == c][0]
            if expert_edge_attr[idx_exp]!=dataset_edge_attr[idx_data]:
                remaining_edges.add(c)
        for nc in not_common_edges:
            if nc[0] in expert_nodes or nc[1] in expert_nodes:
                remaining_edges.add(nc)
        return og_map, reverse_map, common_edges, not_common_edges, remaining_edges

    def expert_step(self, input_graph, dataset_graph, 
                    og_map, reverse_map, common_edges, not_common_edges, remaining_edges, prev_connection= None):
        # need the original graph being sampled, starting point, length of trajectory
        # first do a copy, we'll relabel everything and then label in consecutive order at the end
        expert_graph = copy.deepcopy(input_graph)
        total_nodes = expert_graph.get_num_nodes()

        #first remap the nodes to follow the original dataset
        expert_graph.relabel_nodes(og_map)

        # get the edge attributes
        expert_edge_attr = [e[2]["one_hot_edge_attr"] for e in expert_graph.get_edges()]
        dataset_edge_attr = [e[2]["one_hot_edge_attr"] for e in dataset_graph.get_edges()]
        
        expert_graph_nodes = [n[0] for n in expert_graph.get_nodes()]
        expert_edges = []
        expert_edge_attr = []
        for edge in expert_graph.get_edges():
            expert_edges.append((edge[0], edge[1]))
            expert_edge_attr.append(edge[2]['one_hot_edge_attr'])
        
        dataset_edges = []
        dataset_edge_attr = []
        for edge in dataset_graph.get_edges():
            dataset_edges.append((edge[0], edge[1]))
            dataset_edge_attr.append(edge[2]['one_hot_edge_attr'])

        actions = [None for a in range(5)]

        # BFS on expert_graph
        start_node = 0
        visited = set()
        queue = deque([(start_node, 0)])  # Queue for BFS
        distance_edges = defaultdict(list)  # Dictionary to store node distances
        distances = {}  # Dictionary to store distances


        while queue:
            node,dist = queue.popleft()
            if node in visited or node not in expert_graph.get_nodes():
                continue
            visited.add(node)
            distances[node] = dist  # Store distance

            primary_edges = []
            supply_edges = []
            for edge in dataset_graph.get_edges():
                u, v = edge[0], edge[1]
        
                # Ignore any paths that have a bulk only connection
                one_hot_attr = dataset_graph.get_edge_features(u, v, "one_hot_edge_attr")
                if one_hot_attr != [1, 0, 0, 0, 0, 0, 0] and one_hot_attr != [0, 0, 0, 0, 0, 0, 1]:
                    if u == node and v not in visited:
                        distance_edges[dist + 1].append((u, v))  # Store edge in corresponding distance
                        queue.append((v, dist + 1))
                    elif v == node and u not in visited:
                        distance_edges[dist + 1].append((v, u))  # Store edge in corresponding distance
                        queue.append((u, dist + 1))

        redge_popped = []
        for dist, edges in distance_edges.items():
            for edge in edges:
                if edge in remaining_edges or edge[::-1] in remaining_edges:
                    one_hot_attr = dataset_graph.get_edge_features(edge[0], edge[1], "one_hot_edge_attr")
                    if one_hot_attr != [1, 0, 0, 0, 0, 0, 0] and one_hot_attr != [0, 0, 0, 0, 0, 0, 1]:
                        if edge in remaining_edges:
                            redge_popped.append(edge)
                        else:
                            redge_popped.append(edge[::-1])
            if len(redge_popped)>0:
                break
        

        redge_pruned = []
        if prev_connection is not None:
            for r in redge_popped:
                if r[0] in prev_connection or r[1] in prev_connection:
                    redge_pruned.append(r)
        if len(redge_pruned)==0:
            redge_pruned = redge_popped
        # Pick a new connection that is in the bfs path
        new_connection = random.choice(list(redge_pruned))
        if self.symmetric_ground is not None and \
            (self.symmetric_ground.in_bidict(new_connection[0]) or \
                 self.symmetric_ground.in_bidict(new_connection[1])): 

            if self.symmetric_ground.in_bidict(new_connection[0]) and \
                self.symmetric_ground.in_bidict(new_connection[1]):     
                actions[4] = 1 #symmetric connection
            else: 
                if new_connection[0] not in reverse_map.keys() or new_connection[1] not in reverse_map.keys(): # there is a new node
                    existing_node = new_connection[0] if (new_connection[1] not in reverse_map.keys()) else new_connection[1]
                    new_node = new_connection[1] if (new_connection[1] not in reverse_map.keys()) else new_connection[0]

                    if expert_graph.get_node_features(existing_node, ['component_type']) == ComponentEnum.NET:
                        if self.symmetric_ground.in_bidict(existing_node): # symmetric net
                            actions[4] = 2 # common mode reduction
                        else: # CM net, need to expand
                            actions[4] = 4 # common mode expansion
                    else: # component exists, collapse to net
                        actions[4] = 3 # collapse to net
                else:
                    sym_node = new_connection[0] if self.symmetric_ground.in_bidict(new_connection[0]) \
                                else new_connection[1]
                    cm_node = new_connection[1] if sym_node == new_connection[0] else new_connection[0]
                    if expert_graph.get_node_features(cm_node, ['component_type']) == ComponentEnum.NET:
                        actions[4] = 3 # common mode component -- net -- component
                    else: 
                        actions[4] = 2

            if self.symmetric_ground.in_bidict(new_connection[0]):
                new_connection_sym0 = self.symmetric_ground.get(new_connection[0])
            else:
                new_connection_sym0 = new_connection[0]
            if self.symmetric_ground.in_bidict(new_connection[1]):
                new_connection_sym1 = self.symmetric_ground.get(new_connection[1])
            else:
                new_connection_sym1 = new_connection[1]
            new_connection_sym = (new_connection_sym0, new_connection_sym1)

            # check if actually in dataset edges
            if new_connection_sym not in dataset_edges and new_connection_sym[::-1] not in dataset_edges:
                new_connection_sym = None
                actions[4] = 0 # no symmetry, single ended
        else:
            new_connection_sym = None
            actions[4] = 0 # no symmetry, single ended
          
        idx_data = [i for i, t in enumerate(dataset_edges) if t == new_connection][0]
        # extract the edge types that can be created from this edge connection
        data_arr = np.array(dataset_edge_attr[idx_data])
        # # order this so that the first node is always already in the circuit
        if new_connection[0] not in og_map.values():
            new_connection = (new_connection[1], new_connection[0])
            if self.symmetric_ground is not None and new_connection_sym is not None: 
                new_connection_sym = (new_connection_sym[1], new_connection_sym[0])

        # Add a new node
        if new_connection[1] not in og_map.values():
            node_dict = dataset_graph.get_node_features(new_connection[1])
            node_dict["custom_features"].update({'original_mapping': new_connection[1]})

            expert_graph.add_node(node_dict["component_type"], node_dict["phys_features"],
                                   node_dict["custom_features"],
                                    override_name=new_connection[1])
            if new_connection_sym is not None:
                expert_graph.add_node(node_dict["component_type"], node_dict["phys_features"],
                                   node_dict["custom_features"],
                                    override_name=new_connection_sym[1])
            # add new possible connections to remaining edges
            for nc in not_common_edges:
                if (nc[0] == new_connection[1] and nc[1] not in expert_graph_nodes) or \
                    (nc[1] == new_connection[1] and nc[0] not in expert_graph_nodes):
                    remaining_edges.add(nc)
                if new_connection_sym is not None:
                    if (nc[0] == new_connection_sym[1] and nc[1] not in expert_graph_nodes) or \
                        (nc[1] == new_connection_sym[1] and nc[0] not in expert_graph_nodes):
                        remaining_edges.add(nc)

            reverse_map[new_connection[1]] = total_nodes
            og_map[total_nodes] = new_connection[1]

            # have to check the og_map.values because new node could be a shared cm gate (ex.) 
            # then should not make an additional node entry
            if new_connection_sym is not None and new_connection_sym[1] not in og_map.values():
                reverse_map[new_connection_sym[1]] = total_nodes+1
                og_map[total_nodes+1] = new_connection_sym[1]

            # add to the action to indicate the correct node type
            node_type_add = np.where(np.array(dataset_graph.get_node_features(new_connection[1], 
                                                    ['custom_features','one_hot_node_type'])) == 1)[0][0]

            # Add bulk connection to power
            for n in list(expert_graph.get_nodes()):
                if n[1]["component_type"] == ComponentEnum.NET:
                    try: 
                        name=n[1]["custom_features"]["node_name"]
                    except:
                        name = "NET" +str(n[0])
                    
                    if name=="VDD": 
                        PWR = n[0]
                    elif name == "VSS":
                        GND = n[0]                   
            if node_dict["component_type"] == ComponentEnum.NFET: #NFET
                supply_node = GND
            elif node_dict["component_type"] == ComponentEnum.PFET \
                or node_dict["component_type"] == ComponentEnum.RES:
                supply_node = PWR
            else:
                supply_node = None
            if supply_node is not None:
                expert_graph.add_edge(supply_node, new_connection[1], node_1_terminal_list =[], # FIXME put std terminals here?
                                        node_2_terminal_list =[], 
                                        phys_features=None, custom_features=None)
                if new_connection_sym is not None:
                    expert_graph.add_edge(supply_node, new_connection_sym[1], node_1_terminal_list =[], # FIXME put std terminals here?
                                           node_2_terminal_list =[], 
                                           phys_features=None, custom_features=None)
                if node_dict["component_type"] == ComponentEnum.RES:
                    expert_graph.overwrite_edge_field(supply_node, new_connection[1], "one_hot_edge_attr", [0, 0, 0, 0, 0, 0, 1])
                    if new_connection_sym is not None:
                        expert_graph.overwrite_edge_field(supply_node, new_connection_sym[1], "one_hot_edge_attr", [0, 0, 0, 0, 0, 0, 1])
                else:
                    expert_graph.overwrite_edge_field(supply_node, new_connection[1], "one_hot_edge_attr", [1, 0, 0, 0, 0, 0, 0])
                    if new_connection_sym is not None:
                        expert_graph.overwrite_edge_field(supply_node, new_connection_sym[1], "one_hot_edge_attr", [1, 0, 0, 0, 0, 0, 0])
                # remove edge if bulk to power is the only other edge
                idx_data = [i for i, t in enumerate(dataset_edges) if \
                                t == (supply_node, new_connection[1]) or t==(new_connection[1],supply_node)][0]
                if [1, 0, 0, 0, 0, 0, 0] == dataset_edge_attr[idx_data] or [0, 0, 0, 0, 0, 0, 1] == dataset_edge_attr[idx_data]:
                    remaining_edges.remove((supply_node, new_connection[1]))
                    #common_edges.remove((supply_node, new_connection[1]))
                    not_common_edges.remove((supply_node, new_connection[1]))
                    if new_connection_sym is not None:
                        remaining_edges.remove((supply_node, new_connection_sym[1]))
                        not_common_edges.remove((supply_node, new_connection_sym[1]))
                    

                else:
                    # update the not_common and common edge sets
                    if (supply_node, new_connection[1]) in not_common_edges: 
                        not_common_edges.remove((supply_node, new_connection[1]))
                        if new_connection_sym is not None:
                            not_common_edges.remove((supply_node, new_connection_sym[1]))
                    common_edges.add((supply_node, new_connection[1]))
                    if new_connection_sym is not None:
                        common_edges.add((supply_node, new_connection_sym[1]))
                
                expert_edges.append((supply_node, new_connection[1]))
                if new_connection_sym is not None:
                    expert_edges.append((supply_node, new_connection_sym[1]))

                if node_dict["component_type"] == ComponentEnum.RES:
                    expert_edge_attr.append([0, 0, 0, 0, 0, 0, 1])
                    if new_connection_sym is not None:
                        expert_edge_attr.append([0, 0, 0, 0, 0, 0, 1])
                else:
                    expert_edge_attr.append([1, 0, 0, 0, 0, 0, 0])
                    if new_connection_sym is not None:
                        expert_edge_attr.append([1, 0, 0, 0, 0, 0, 0])

            actions[0] = reverse_map[new_connection[0]] 
            actions[1] = reverse_map[new_connection[1]] + node_type_add

        else:
            actions[0] = reverse_map[new_connection[0]] 
            actions[1] = reverse_map[new_connection[1]] 
        
        # pick a connection to add. If it's in common edges pick a type that has not been added
        if new_connection in common_edges:
            idx_exp = [i for i, t in enumerate(expert_edges) \
                            if (t == new_connection or t==new_connection[::-1])][0]
            expert_arr = np.array(expert_edge_attr[idx_exp])
            edge_type = random.choice(np.where(data_arr != expert_arr)[0]) 
        else:    
            edge_type = random.choice(np.where(data_arr == 1)[0]) if np.any(data_arr == 1) else None
        actions[2] = edge_type
        
        # Add the edge
        if new_connection in common_edges: # so edge exists, just need to update attribute
            idx_exp = [i for i, t in enumerate(expert_edges) \
                            if (t == new_connection or t==new_connection[::-1])][0]
            old_edge_attr = expert_edge_attr[idx_exp] 
            if old_edge_attr[edge_type] != 1:
                old_edge_attr[edge_type] = 1
                expert_graph.overwrite_edge_field(new_connection[0], new_connection[1], 
                                                   "one_hot_edge_attr", old_edge_attr)
                if new_connection_sym is not None:
                    expert_graph.overwrite_edge_field(new_connection_sym[0], new_connection_sym[1], 
                                                   "one_hot_edge_attr", old_edge_attr)                                  
                expert_edge_attr[idx_exp] = old_edge_attr
        else: 
            edge_attr = [0 for i in range(self.edge_types)]
            edge_attr[edge_type] = 1
            expert_graph.add_edge(new_connection[0], new_connection[1], node_1_terminal_list =[], # FIXME put std terminals here?
                                    node_2_terminal_list =[], 
                                    phys_features=None, custom_features=None)
            expert_graph.overwrite_edge_field(new_connection[0], new_connection[1], 
                                                   "one_hot_edge_attr", edge_attr)
            if new_connection_sym is not None:
                expert_graph.add_edge(new_connection_sym[0], new_connection_sym[1], node_1_terminal_list =[], # FIXME put std terminals here?
                                        node_2_terminal_list =[], 
                                        phys_features=None, custom_features=None)
                expert_graph.overwrite_edge_field(new_connection_sym[0], new_connection_sym[1], 
                                                   "one_hot_edge_attr", edge_attr)
            expert_edge_attr.append(edge_attr)
            idx_exp = len(expert_edge_attr)-1

            # update the not_common and common edge sets
            if new_connection in not_common_edges: 
                not_common_edges.remove((new_connection[0], new_connection[1]))
                if new_connection_sym is not None:
                    not_common_edges.remove((new_connection_sym[0], new_connection_sym[1]))
                #not_common_edges.remove((new_connection[1], new_connection[0]))
            common_edges.add(new_connection)
            if new_connection_sym is not None:
                common_edges.add(new_connection_sym)


        # Only remove edges if the attributes are equal, otherwise keep as a common edge
        idx_data = [i for i, t in enumerate(dataset_edges) if t == new_connection or t==new_connection[::-1]][0]
        if expert_edge_attr[idx_exp] == dataset_edge_attr[idx_data]:
            remaining_edges.remove(dataset_edges[idx_data])
            common_edges.remove(new_connection)
            if new_connection_sym is not None:
                idx_data_sym = [i for i, t in enumerate(dataset_edges) if t == new_connection_sym or t==new_connection_sym[::-1]][0]
                if expert_edge_attr[idx_exp] == dataset_edge_attr[idx_data_sym]:
                    remaining_edges.remove(dataset_edges[idx_data_sym])
                    common_edges.remove(new_connection_sym)

        # if subgraph == graph terminate 
        if not remaining_edges:
            done = True
            actions[3] = 1
            rew = 5 # because this is expert finished correctly
        else: 
            done = False
            actions[3] = 0
            rew = 1 # automatically reward 1 because this is expert
        # remap to consecutive
        expert_graph.relabel_nodes(reverse_map)
        # Organize graph into GNN understandable torch_geometric format
        _, adj, ob_dict, ob, conv = self.graph_to_obs(expert_graph)
        sym_rev = BiDict()
        if self.symmetric_ground is not None:
            for k,v in self.symmetric_ground.items():
                if k in reverse_map.keys() and v in reverse_map.keys():
                    sym_rev.add(reverse_map[k], reverse_map[v])
        else:
            sym_rev = None
        mask_nodes_from = self.mask_obj.get_mask_from()
        mask_nodes_to = self.mask_obj.get_mask_to( int(actions[0]))
        mask_edges = self.mask_obj.get_mask_edge( int(actions[0]), int(actions[1]))
        mask_term = copy.deepcopy(self.mask_obj.get_mask_term(  int(actions[0]), int(actions[1]), int(actions[2]), int(actions[4]), sym_rev))
        mask = [mask_nodes_from, mask_nodes_to, mask_edges, mask_term]

        self.mask_obj.step_mask(expert_graph, actions)
        info = {}
        return actions, conv, rew, done, info, expert_graph, common_edges, not_common_edges,\
                remaining_edges, og_map, reverse_map, mask, new_connection

    def get_expert_trajectory(self, num_trajectories, random_start=False, random_circuit=False, 
                              circuit_order=None, num_sample=0, sample_nodes_list=None):
        """ Returns num_trajectories of expert trajectories
        """
        # need the original graph being sampled, starting point, length of trajectory
        trajs_expert = []
        traj_start= time.time()
        co_idx = 0
        for i in range(num_trajectories):
            # make trajectory
            if circuit_order is not None and co_idx < len(circuit_order):
                if sample_nodes_list is None:
                    sample_nodes_list = [None for _ in range(len(circuit_order))]
                ob, ground_graph, graph_sampled, mask_obj, symmetries = self.reset(random_start, 
                                                                       random_circuit, circuit_order[co_idx], 
                                                                       num_sample=num_sample, sample_nodes=sample_nodes_list[co_idx])
                co_idx += 1
            elif circuit_order is not None and co_idx >= len(circuit_order):
                break
            else:
                ob, ground_graph, graph_sampled, mask_obj, symmetries = self.reset(random_start, random_circuit, num_sample=num_sample)
            obs, acs, rewards, next_obs, terminals, graphs, masks = [], [], [], [], [], [], []
            expert_graph = copy.deepcopy(graph_sampled)
            og_map, reverse_map, common_edges, not_common_edges, \
                remaining_edges = self.get_edge_difference(graph_sampled, ground_graph)
                # should always label VSS 0 and VDD 1
            prev_connection = None
            count = 0
            while True:
                ac, next_ob, rew, done, info, graph,  common_edges, \
                    not_common_edges, remaining_edges, og_map, \
                    reverse_map, mask_array, connection = self.expert_step(expert_graph, ground_graph,
                                                og_map, reverse_map, common_edges, 
                                                not_common_edges,
                                                remaining_edges, prev_connection)
                count += 1
                # record result of taking that action
                graphs.append(copy.deepcopy(graph))
                obs.append(copy.deepcopy(ob))
                acs.append(ac)
                rewards.append(rew)
                next_obs.append(next_ob)
                terminals.append(done)
                masks.append(mask_array)
                ob = next_ob
                prev_connection = connection

                expert_graph =copy.deepcopy(graph)

                if done:
                    break
            
            episode_statistics = {"l": len(rewards), "r": np.sum(rewards)} 
            if "episode" in info:
                episode_statistics.update(info["episode"])
            
            traj_dict =  {
                "observation": obs,
                "reward": np.array(rewards, dtype=np.float32),
                "action": np.array(acs, dtype=np.float32),
                "next_observation": next_obs, 
                "terminal": np.array(terminals, dtype=np.float32),
                "graph": graphs,
                "graph_sampled": graph_sampled,
                "ground_truth": ground_graph, #ground_graph,
                "mask": masks
            }

            trajs_expert.append(traj_dict)
            self.close()
        traj_end = time.time()
        print("trajectory time expert: ", traj_end-traj_start)
        return trajs_expert
    
    
class Mask():
    def __init__(self, max_nodes, num_types, edge_types, diff=True, se=False):
        self.max_nodes = max_nodes
        self.num_types = num_types-4 # special net types cant be selected
        self.edge_types = edge_types
        self.mask_nodes = np.ones((self.max_nodes,))
        self.mask_nodes_to = np.ones((self.max_nodes,))
        self.mask_edges = np.ones((self.edge_types,))
        self.mask_term = np.ones((2,))
        self.mask_sym = np.ones((5,)) # 4 symmetry options
        self.graph_nodes = 0
        self.graph = None
        self.num_mask_nodes = 0
        self.factor = 1e-20
        self.diff = diff
        self.se = se
 

    def reset_mask(self, graph):
        # take out the higher end of nodes possible
        self.mask_nodes = np.zeros((self.max_nodes,))
        self.mask_nodes_to = np.zeros((self.max_nodes,))
        self.mask_edges = np.ones((self.edge_types,))
        self.mask_term = np.ones((2,))
        self.graph_nodes = graph.get_num_nodes()
        self.graph = copy.deepcopy(graph)
        self.force_add = False # not used

        for node in graph.get_nodes():
            # if 
            # can connect to net
            if graph.get_node_features(node[0], ["component_type"]) == ComponentEnum.NET: 
                self.mask_nodes[node[0]] = 1
            # need to check if existing components are fully connected
            else: 
                existing_edges = [0 for i in range(self.edge_types)]
                for edge in graph.get_edges():
                    if edge[0] == node[0] or edge[1] == node[0]:
                        current_attrs = graph.get_edge_features(edge[0], edge[1], 'one_hot_edge_attr')
                        existing_edges = [x + y for x, y in zip(current_attrs, existing_edges)]
                zero_indices = [index for index, value in enumerate(existing_edges) if value == 0]

                # for fets
                if graph.get_node_features(node[0],['component_type']) == ComponentEnum.NFET or \
                    graph.get_node_features(node[0],['component_type']) == ComponentEnum.PFET:
                    check_fet = existing_edges[0:4]
                    unconnected_fet = [index for index, value in enumerate(check_fet) if value == 0]
                    if unconnected_fet:
                        self.mask_nodes[node[0]] = 1

                # for passives
                else:
                    if graph.get_node_features(node[0],['component_type']) == ComponentEnum.RES:
                        check_pass = existing_edges[4:]
                    else:
                        check_pass = existing_edges[4:6]
                    unconnected_pass = [index for index, value in enumerate(check_pass) if value == 0]
                    if unconnected_pass:
                        self.mask_nodes[node[0]] = 1
        
        for idx, m in enumerate(self.mask_nodes): 
            if m<0.5: 
                self.mask_nodes[idx] = self.factor
    
    
    def step_mask(self, graph, action):
        # this is copied from reset, ideally would also use action
        self.graph_nodes = graph.get_num_nodes()
        self.graph = copy.deepcopy(graph)
        self.mask_nodes = np.zeros((self.max_nodes,))
        for node in graph.get_nodes():
            # if 
            # can connect to net
            if graph.get_node_features(node[0], ["component_type"]) == ComponentEnum.NET: 
                self.mask_nodes[node[0]] = 1
            # need to check if existing components are fully connected
            else: 
                existing_edges = [0 for i in range(self.edge_types)]
                for edge in graph.get_edges():
                    if edge[0] == node[0] or edge[1] == node[0]:
                        current_attrs = graph.get_edge_features(edge[0], edge[1], 'one_hot_edge_attr')
                        existing_edges = [x + y for x, y in zip(current_attrs, existing_edges)]
                zero_indices = [index for index, value in enumerate(existing_edges) if value == 0]

                # for fets
                if graph.get_node_features(node[0],['component_type']) == ComponentEnum.NFET or \
                    graph.get_node_features(node[0],['component_type']) == ComponentEnum.PFET:
                    check_fet = existing_edges[0:4]
                    unconnected_fet = [index for index, value in enumerate(check_fet) if value == 0]
                    if unconnected_fet:
                        self.mask_nodes[node[0]] = 1

                # for passives
                else:
                    if graph.get_node_features(node[0],['component_type']) == ComponentEnum.RES:
                        check_pass = existing_edges[4:]
                    else:
                        check_pass = existing_edges[4:6]
                    unconnected_pass = [index for index, value in enumerate(check_pass) if value == 0]
                    if unconnected_pass:
                        self.mask_nodes[node[0]] = 1
        
        for idx, m in enumerate(self.mask_nodes): 
            if m<0.5: 
                self.mask_nodes[idx] = self.factor

        return self.mask_nodes
    
    def get_mask_from(self):
        # return self.mask_nodes
        # if there are no connected edges, pick the ground node
        if len(self.graph.get_edges())==0: 
            mask = np.full(self.max_nodes, self.factor)
            mask[0] = 1
            return mask
        else: 
            return self.mask_nodes
    
    def get_mask_to(self, act0, sym=None):
        
        try:
            act0 = int(act0)
            self.mask_nodes_to = np.copy(self.mask_nodes)
            self.mask_nodes_to[act0] = self.factor
            act0_type = self.graph.get_node_features(act0,["component_type"])
            
            # for the existing nodes mask out any that dont meet criteria
            if act0_type != ComponentEnum.NET:
                edges = [(e[0], e[1]) for e in self.graph.get_edges()]
                for idx, m in enumerate(self.mask_nodes):
                    if m>0.5: 
                        check_type = self.graph.get_node_features(idx, ["component_type"])
                        if check_type != ComponentEnum.NET:
                            self.mask_nodes_to[idx] = self.factor
                        # if passive, can't connect plus and minus to same net
                        else:
                            if act0_type == ComponentEnum.IND or act0_type == ComponentEnum.RES or \
                            act0_type == ComponentEnum.CAP: 
                                if ((act0, idx) in edges or (idx, act0) in edges) and\
                                self.graph.get_edge_features(act0, idx, ['one_hot_edge_attr']) != [0,0,0,0,0,0,1]:
                                    self.mask_nodes_to[idx] = self.factor
            else: #picked a net
                for idx, m in enumerate(self.mask_nodes):
                    edges = [(e[0], e[1]) for e in self.graph.get_edges()]
                    if m>0.5: 
                        check_type = self.graph.get_node_features(idx, ["component_type"])
                        if check_type == ComponentEnum.NET:
                            self.mask_nodes_to[idx] = self.factor
                        else: # if passive, can't connect plus and minus to same net
                            if check_type == ComponentEnum.IND or check_type == ComponentEnum.RES or \
                            check_type == ComponentEnum.CAP: 
                                if ((act0, idx) in edges or (idx, act0) in edges) and\
                                self.graph.get_edge_features(act0, idx, ['one_hot_edge_attr']) != [0,0,0,0,0,0,1]:
                                    self.mask_nodes_to[idx] = self.factor

            if sym is not None: 
                if sym.in_bidict(act0) and act0_type == ComponentEnum.NET and self.se==False:
                    # picked a symmetric net, and not allowed to unbalance it 
                    # must only pick a transistor with the source and drain free
                    for idx, m in enumerate(self.mask_nodes):
                        if m<0.5: 
                            continue
                        check_type = self.graph.get_node_features(idx, ["component_type"]) 
                        if not sym.in_bidict(idx):
                            if check_type not in [ComponentEnum.NFET, ComponentEnum.PFET]:
                                self.mask_nodes_to[idx] = self.factor
                            else:
                                # check if source and drain are free
                                check_edges = self.graph.get_node_edges(idx)
                                source_free = True
                                drain_free = True
                                for e in check_edges:
                                    if self.graph.get_edge_features(e[0], e[1], ['one_hot_edge_attr'])[3] == 1:
                                        source_free = False
                                        self.mask_nodes_to[idx] = self.factor
                                        break
                                    elif self.graph.get_edge_features(e[0], e[1], ['one_hot_edge_attr'])[1] == 1:
                                        drain_free = False
                                        self.mask_nodes_to[idx] = self.factor
                                        break
                elif not sym.in_bidict(act0) and act0_type != ComponentEnum.NET and self.se==False:
                    # if asymetric component
                    # cannot connect to any symmetric nets if act0 is res, cap ind
                    # can only connect to symmetric nets of source and drain are free

                    if act0_type in [ComponentEnum.NFET, ComponentEnum.PFET]:
                        # check if source and drain are free
                        check_edges = self.graph.get_node_edges(act0)
                        source_free = True
                        drain_free = True
                        for e in check_edges:
                            if self.graph.get_edge_features(e[0], e[1], ['one_hot_edge_attr'])[3] == 1:
                                source_free = False
                                break
                            elif self.graph.get_edge_features(e[0], e[1], ['one_hot_edge_attr'])[1] == 1:
                                drain_free = False
                                break
                        
                    for idx, m in enumerate(self.mask_nodes):
                        if m<0.5: 
                            continue
                        check_type = self.graph.get_node_features(idx, ["component_type"]) 
                        if sym.in_bidict(idx) and act0_type in [ComponentEnum.RES, ComponentEnum.CAP, ComponentEnum.IND]: # existing passives wouldnt be able to connect
                            self.mask_nodes_to[idx] = self.factor
                        elif sym.in_bidict(idx) and act0_type in [ComponentEnum.NFET, ComponentEnum.PFET]:
                            if not (source_free and drain_free):
                                self.mask_nodes_to[idx] = self.factor

            # for the scaffolds allow either nets or components depending
            # if picked a component in the first action could only add a net
            # only allow scaffold if all types are available
            if self.graph_nodes+self.num_types<self.max_nodes:
                if act0_type != ComponentEnum.NET: 
                    self.mask_nodes_to[self.graph_nodes] = 1
                else: # if picked a net in first action can only add components
                    self.mask_nodes_to[self.graph_nodes+1:self.graph_nodes+self.num_types] = np.ones(self.num_types-1)
            return self.mask_nodes_to
        
        except:
            return self.mask_nodes

    def get_mask_edge(self, act0, act1):
        self.mask_edges = np.ones((self.edge_types,))
        # for bulk, if added node, bulk would be connected. 
        self.mask_edges[0] = self.factor
        self.mask_edges[-1] = self.factor

        try:
            if max(act0, act1)>=self.graph_nodes:
                if max(act0,act1)-self.graph_nodes > 2: # if passive, only passive egdes available
                    self.mask_edges[1:4] = [self.factor, self.factor, self.factor]
                    return self.mask_edges
                elif max(act0,act1)-self.graph_nodes > 0: # else if fet, only fet edges available
                    self.mask_edges[4:] = [self.factor, self.factor, self.factor]
                    return self.mask_edges
                elif act0 == act1: 
                    return self.mask_edges # this action will be invalid anyways
                else:
                    component_checked = min(act0, act1)
            else:
                component_checked = act0 if \
                            (self.graph.get_node_features(act0,["component_type"]) != ComponentEnum.NET) \
                            else act1
            
            # Check where the component connects to already
            existing_edges = [0 for i in range(self.edge_types)]
            edge_names = [None for i in range(self.edge_types)]
            for edge in self.graph.get_edges():
                if edge[0] == component_checked or edge[1] == component_checked:
                    current_attrs = self.graph.get_edge_features(edge[0], edge[1], 'one_hot_edge_attr')
                    existing_edges = [x + y for x, y in zip(current_attrs, existing_edges)]
                    for i, val in enumerate(current_attrs): 
                        if val:
                            edge_names[i] = edge[1] if edge[0] == component_checked  else edge[0]
                    
                    # mask out the edges for other types by adding them to the existing edges
                    if self.graph.get_node_features(component_checked,['component_type']) == ComponentEnum.NFET or \
                        self.graph.get_node_features(component_checked,['component_type']) == ComponentEnum.PFET:
                        existing_edges[4:] = [1, 1, 1]
                    # for passives
                    elif self.graph.get_node_features(component_checked,['component_type']) == ComponentEnum.RES:
                        existing_edges[0:4] = [1, 1, 1, 1]
                    elif self.graph.get_node_features(component_checked,['component_type']) == ComponentEnum.CAP or \
                        self.graph.get_node_features(component_checked,['component_type']) == ComponentEnum.IND:
                        existing_edges[0:4] = [1, 1, 1, 1]
                        existing_edges[-1] = 1 # No bulk connection allowed
                        

            for i, e in enumerate(existing_edges): # overrides the first definition of 0 and -1 mask edges
                self.mask_edges[i] = self.factor if e==1 else 1
            # Can't connect gate or drain to source
            # if the source name is node0 or node 1, mask out drain and gate
            if self.graph.get_node_features(component_checked,['component_type']) == ComponentEnum.NFET or \
                        self.graph.get_node_features(component_checked,['component_type']) == ComponentEnum.PFET:
                if (edge_names[3] is not None) and (edge_names[3]==act0 or edge_names[3]==act1):
                    self.mask_edges[1] = self.factor
                    self.mask_edges[2] = self.factor
                # if gate connected and the same to other actions, mask out the source
                elif (edge_names[2] is not None) and (edge_names[2]==act0 or edge_names[2]==act1):
                    self.mask_edges[3] = self.factor
                # if drain connected and the same to other actions, mask out the source
                elif (edge_names[1] is not None) and (edge_names[1]==act0 or edge_names[1]==act1):
                    self.mask_edges[3] = self.factor

            return self.mask_edges
        
        except: 
            return self.mask_edges
        
    def get_mask_term(self, act0, act1, act2, act4, sym=None):  
        #self.mask_term = np.ones((2,))
        larger_node = max(act0, act1)
        smaller_node = min(act0, act1)
        #clusters = self.graph.get_clusters()
        if larger_node>=self.graph_nodes and (larger_node-self.graph_nodes)>0:
            # can't finish if you just added anything other than passive with act4==2 in diff condition
            # if that was added then can check all the other nodes for finishing
            if not(larger_node-self.graph_nodes > 2  and act4==2 and self.diff and (sym is not None and sym.in_bidict(smaller_node))):
                self.mask_term[0] = 1
                self.mask_term[1] = self.factor
                return self.mask_term

        # check that all components are fully connected
        if larger_node>=self.graph_nodes:
            check_node_list = [smaller_node]
        else:
            check_node_list = [act0 if (self.graph.get_node_features(act0,["component_type"]) != ComponentEnum.NET) else act1]

        if sym is not None and self.diff:
            if sym.in_bidict(check_node_list[0]):
                check_node_list.append(sym.get(check_node_list[0]))

        for n in self.graph.get_nodes():
            node = n[0]
            existing_edges = [0 for i in range(self.edge_types)]
            for edge in self.graph.get_node_edges(n[0]):
                current_attrs = self.graph.get_edge_features(edge[0], edge[1], 'one_hot_edge_attr')
                existing_edges = [x + y for x, y in zip(current_attrs, existing_edges)]
            ntype = self.graph.get_node_features(node,["component_type"])
            if ntype in [ComponentEnum.NFET, ComponentEnum.PFET]:
                check_fet = existing_edges[1:4] # do not need to check the bulk because that will always be connected
                unconnected_fet = [index+1 for index, value in enumerate(check_fet) if value == 0] # unconnected edges
                pass_condition = len(unconnected_fet) == 0 or (len(unconnected_fet) == 1 and (node in check_node_list) and (act2 in unconnected_fet))
                if self.diff==True and act4==2 and not pass_condition:
                    if act2 == 1:
                        act2_complement = 3
                    elif act2 == 3:
                        act2_complement = 1
                    p2 = (len(unconnected_fet) == 2 and (node in check_node_list) and (act2 in unconnected_fet) and (act2_complement in unconnected_fet)) \
                        and (sym is not None and  not sym.in_bidict(node))
                    pass_condition = pass_condition or p2
                if pass_condition:
                    self.mask_term[0] = 1
                    self.mask_term[1] = 1
                else:
                    self.mask_term[0] = 1
                    self.mask_term[1] = self.factor
                    return self.mask_term

            elif ntype in [ComponentEnum.RES, ComponentEnum.CAP, ComponentEnum.IND]:
                check_pass = existing_edges[4:-1]
                unconnected_pass = [index+4 for index, value in enumerate(check_pass) if value == 0] #+4 to match with action properly
                if len(unconnected_pass)==0 or (len(unconnected_pass) == 1 and (node in check_node_list) and (act2 in unconnected_pass)): 
                    self.mask_term[0] = 1
                    self.mask_term[1] = 1
                else:
                    self.mask_term[0] = 1
                    self.mask_term[1] = self.factor
                    return self.mask_term

        return self.mask_term

    
    def get_mask_sym(self, act0, act1, act2, sym):
        
        if max(act0, act1)>=self.graph_nodes:
            existing_node = [min(act0, act1)]
        else:
            existing_node = [act0, act1]
        
        self.mask_sym = np.full(5, self.factor)
        if self.se==True:
            self.mask_sym[0] = 1 # asymmetry possible
        
        if self.diff == True:
            if len(existing_node) == 1: 
                if sym.in_bidict(existing_node[0]):
                    self.mask_sym[1] = 1 # duplicate
                    if self.graph.get_node_features(existing_node[0],["component_type"]) == ComponentEnum.NET:
                        self.mask_sym[2] = 1 # collapse net to component
                    else:
                        self.mask_sym[3] = 1 # collapse components to net
                else:
                    if self.graph.get_node_features(existing_node[0],["component_type"]) == ComponentEnum.NET:
                        self.mask_sym[4] = 1 # Expand net to component
                    self.mask_sym[0] = 1 # single ended possible from CM node (cause its common)

            elif len(existing_node) == 2:
                if sym.in_bidict(existing_node[0]) and sym.in_bidict(existing_node[1]):
                    self.mask_sym[1] = 1 # duplicate actions on symmetric components
                elif sym.in_bidict(existing_node[0]) or sym.in_bidict(existing_node[1]):
                    sym_node = existing_node[0] if sym.in_bidict(existing_node[0]) else existing_node[1]
                    cm_node = existing_node[1] if sym.in_bidict(existing_node[0]) else existing_node[0]
                    sym_node_type = self.graph.get_node_features(sym_node,["component_type"])
                    cm_node_type = self.graph.get_node_features(cm_node,["component_type"])

                    comp_node = cm_node if cm_node_type != ComponentEnum.NET else sym_node
                    comp_edges = self.graph.get_node_edges(comp_node)
                    ce_edge_attr = [0 for i in range(self.edge_types)]
                    for ce in comp_edges:
                        ce_attr = self.graph.get_edge_features(ce[0], ce[1], ['one_hot_edge_attr'])
                        ce_edge_attr = [x + y for x, y in zip(ce_attr, ce_edge_attr)]
                        # in theory if the component is symmetric, the two components should have the same connections

                    if self.graph.get_node_features(sym_node,["component_type"]) == ComponentEnum.NET:
                        # Net -- component -- Net
                        if act2 in [1, 3, 4, 5]: #drain, source, plus, minus
                            if (ce_edge_attr[1] and act2 == 3) or (ce_edge_attr[3] and act2 == 1) or \
                                   (ce_edge_attr[4] and act2 == 5) or (ce_edge_attr[5] and act2 == 4) or ce_edge_attr[act2]:
                                self.mask_sym[2] = self.factor # already connected in that way, cant make cm connection
                            else:
                                self.mask_sym[2] = 1

                    if self.graph.get_node_features(cm_node,["component_type"]) == ComponentEnum.NET:
                        # component -- Net -- component
                        self.mask_sym[3] = 1 # "expand" out a net. same edges
                else:
                    self.mask_sym[0] = 1 # asymmetry possible
            
            if self.mask_sym.sum()<0.5:
                self.mask_sym[1] = 1 # at least allow duplicate wrong actions, then this will become an invalid action
                                     # don't want to allow all 0 mask because that's the same as not masking, and could become asymmetrical
                
        return self.mask_sym