import os, torch, logging, glob, math
import igraph as ig
import networkx as nx
import numpy as np
import torch.nn.functional as F
import matplotlib.colors as mcolors
from matplotlib import pyplot as plt
from torch_scatter import scatter
from yacs.config import CfgNode as CN
from torch_geometric.data import Data, Batch
from torch_geometric.utils import to_networkx
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import network_dict
from torch_sparse import SparseTensor
from .transform.rrwp import add_node_attr
# import pygraphviz as pgv


NODE_TYPE_OCB = {
    'R': 0,
    'C': 1,
    '+gm+':2,
    '-gm+':3,
    '+gm-':4,
    '-gm-':5,
    'sudo_in':6,
    'sudo_out':7,
    'In': 8,
    'Out':9
}

INVERSE_NODE_TYPE_OCB = {v: k for k, v in NODE_TYPE_OCB.items()}


def make_wandb_name(cfg):
    # Format dataset name.
    dataset_name = cfg.dataset.name
    if dataset_name.startswith('OGB'):
        dataset_name = dataset_name[3:]
    if dataset_name.startswith('PyG-'):
        dataset_name = dataset_name[4:]
    if dataset_name in ['GNNBenchmarkDataset', 'TUDataset']:
        # Shorten some verbose dataset naming schemes.
        dataset_name = ""
    if cfg.dataset.name != 'none':
        dataset_name += "-" if dataset_name != "" else ""
        if cfg.dataset.name == 'LocalDegreeProfile':
            dataset_name += 'LDP'
        else:
            dataset_name += cfg.dataset.name
    # Format model name.
    model_name = cfg.model.type
    if cfg.model.type in ['gnn', 'custom_gnn']:
        model_name += f".{cfg.gnn.layer_type}"
    elif cfg.model.type == 'GPSModel':
        model_name = f"GPS.{cfg.gt.layer_type}"
    model_name += f".{cfg.name_tag}" if cfg.name_tag else ""
    # Compose wandb run name.
    name = f"{cfg.experiment_name}"
    return name

def flatten_dict(metrics):
    """Flatten a list of train/val/test metrics into one dict to send to wandb.

    Args:
        metrics: List of Dicts with metrics

    Returns:
        A flat dictionary with names prefixed with "train/" , "val/" , "test/"
    """
    prefixes = ['train', 'val', 'test']
    result = {}
    for i in range(len(metrics)):
        # Take the latest metrics.
        stats = metrics[i][-1]
        result.update({f"{prefixes[i]}/{k}": v for k, v in stats.items()})
    return result

def cfg_to_dict(cfg_node, key_list=[]):
    """Convert a config node to dictionary.

    Yacs doesn't have a default function to convert the cfg object to plain
    python dict. The following function was taken from
    https://github.com/rbgirshick/yacs/issues/19
    """
    _VALID_TYPES = {tuple, list, str, int, float, bool}

    if not isinstance(cfg_node, CN):
        if type(cfg_node) not in _VALID_TYPES:
            logging.warning(f"Key {'.'.join(key_list)} with "
                            f"value {type(cfg_node)} is not "
                            f"a valid type; valid types: {_VALID_TYPES}")
        return cfg_node
    else:
        cfg_dict = dict(cfg_node)
        for k, v in cfg_dict.items():
            cfg_dict[k] = cfg_to_dict(v, key_list + [k])
        return cfg_dict
    
def remove_edges_with_attribute_value(graph, target_value):
    """
    Removes all edges from a PyTorch Geometric Data object where the edge attribute matches the target value.

    Parameters:
    - graph (Data): A PyTorch Geometric Data object containing edge attributes.
    - target_value: The value of the edge attribute to filter out.

    Returns:
    - Data: A new Data object with the specified edges removed.
    """
    # Find the indices of edges where the edge attribute does not match the target value
    mask = graph.edge_attr != target_value

    # Filter edges and edge attributes based on the mask
    filtered_edge_index = graph.edge_index[:, mask]
    filtered_edge_attr = graph.edge_attr[mask]

    # Create a new Data object with the filtered edges
    new_graph = Data(
        x=graph.x,
        edge_index=filtered_edge_index,
        edge_attr=filtered_edge_attr,
        **{key: graph[key] for key in graph.keys() if key not in ['x', 'edge_index', 'edge_attr']}
    )

    return new_graph


def gym_to_igraph(graphgym_graph):
    
    # Extract edges from edge_index
    edges = [e for e in list(zip(graphgym_graph.edge_index[0].tolist(), graphgym_graph.edge_index[1].tolist())) if e[0] < e[1]]
    # Create the igraph graph
    g = ig.Graph(n=len(graphgym_graph.x), edges=edges, directed=True)
    g.vs['type'] = graphgym_graph.get('x')[:, 0].cpu().numpy()
    if hasattr(graphgym_graph, 'x_features'):
        g.vs['feat'] = graphgym_graph.get('x_features')[:, 0].cpu().numpy()
    return g


def torch_geometric_to_igraph(data, directed=True):
    """
    Converts a PyTorch Geometric Data object to an igraph object.

    Parameters:
        data (torch_geometric.data.Data): The PyTorch Geometric Data object.
        directed (bool): Whether the resulting igraph graph should be directed.

    Returns:
        ig.Graph: The corresponding igraph object.
    """
    # Extract edges from edge_index
    edges = list(zip(data.edge_index[0].tolist(), data.edge_index[1].tolist()))
    
    # Create the igraph graph
    g = ig.Graph(n=len(data.x), edges=edges, directed=directed)
    
    # Add node features (optional)
    if hasattr(data, 'x') and data.x is not None:
        for i, feature in enumerate(data.x):
            if isinstance(feature.tolist(), int):
                g.vs[i]["type"] = feature
            else:
                g.vs[i]["type"] = int(feature.tolist()[0])
                if len(feature) > 1:
                    g.vs[i]["feat"] = int(feature.tolist()[1])
        x_features = data.get('x_features', None)
        if x_features is not None:
            for i, feature in enumerate(x_features):
                g.vs[i]["feat"] = int(feature.clamp(min=1.0).item())

    # Add edge attributes (optional)
    if hasattr(data, 'edge_attr') and data.edge_attr is not None:
        g.es["edge_attr"] = data.edge_attr.tolist()

    return g


def networkx_to_igraph(nx_graph):
    """
    Converts a NetworkX graph to an iGraph graph, preserving all node and edge attributes.

    :param nx_graph: A NetworkX graph (nx.Graph or nx.DiGraph).
    :return: An equivalent iGraph graph.
    """
    # Determine if the graph is directed
    ig_graph = ig.Graph(directed=nx_graph.is_directed())

    # Add nodes with attributes
    node_map = {}  # Map NetworkX node ID -> iGraph node index
    for idx, (node, attrs) in enumerate(nx_graph.nodes(data=True)):
        ig_graph.add_vertex(name=node, **attrs)
        node_map[node] = idx  # Store mapping

    # Add edges with attributes
    for u, v, attrs in nx_graph.edges(data=True):
        ig_graph.add_edge(node_map[u], node_map[v], **attrs)

    return ig_graph


def igraph_to_networkx(ig_graph):
    """
    Converts an iGraph graph to a NetworkX graph, preserving all node and edge attributes dynamically.

    :param ig_graph: An iGraph graph (igraph.Graph).
    :return: An equivalent NetworkX graph.
    """
    # Determine if the graph is directed
    nx_graph = nx.DiGraph() if ig_graph.is_directed() else nx.Graph()

    # Add nodes dynamically with all attributes
    for v in ig_graph.vs:
        node_id = v.index  # Default to index
        if "type" in v.attributes():  # Prefer "type" if it exists
            node_id = v["type"]
        nx_graph.add_node(node_id, **v.attributes())  # Add all attributes

    # Add edges dynamically with all attributes
    for e in ig_graph.es:
        source = ig_graph.vs[e.source]["type"] if "type" in ig_graph.vs.attributes() else e.source
        target = ig_graph.vs[e.target]["type"] if "type" in ig_graph.vs.attributes() else e.target
        nx_graph.add_edge(source, target, **e.attributes())  # Add all attributes

    return nx_graph

def plot_pyg_graph(graph, save_path=None):
    G = to_networkx(graph, to_undirected=False)
    pos = nx.spring_layout(G)
    plt.figure(figsize=(8, 8))
    node_labels = {i: f'{INVERSE_NODE_TYPE_OCB[graph.x[i].tolist()[0]]}' for i in range(graph.num_nodes)}
    nx.draw(G, pos, with_labels=True, labels=node_labels, node_color=graph.x[:, 0].squeeze().cpu().numpy(), cmap=plt.get_cmap('viridis'))
    
    if save_path:
        plt.savefig(save_path)
    else:
        return plt.gcf()
    
    plt.show()

def plot_nx_graph(graph,save_path):
    pos = nx.spring_layout(graph)
    plt.figure(figsize=(8, 8))
    node_labels = {i: f'{graph.nodes[i]["type"]}' for i in graph.nodes}
    nx.draw(graph, pos, with_labels=True, labels=node_labels, node_color=[graph.nodes[i]["type"] for i in graph.nodes], cmap=plt.get_cmap('viridis'))
    plt.savefig(save_path)
    plt.show()

def plot_igraph(graph, path=None ,layout="auto", vertex_size=20, vertex_color="skyblue", edge_color="black", edge_width=1):
    """
    Plots an igraph.Graph object using matplotlib.
    
    Parameters:
        graph (igraph.Graph): The graph to be plotted.
        layout (str): Layout type for node positioning (default: "auto").
        vertex_size (int): Size of the nodes.
        vertex_color (str): Color of the nodes.
        edge_color (str): Color of the edges.
        edge_width (int): Width of the edges.
    """
    fig, ax = plt.subplots(figsize=(8, 6))
    
    # Compute layout
    layout_positions = graph.layout(layout)
    
    # Plot graph
    ig.plot(
        graph, 
        target=ax, 
        layout=layout_positions,
        vertex_size=vertex_size,
        vertex_color=vertex_color,
        edge_color=edge_color,
        edge_width=edge_width,
        vertex_label=[v.index for v in graph.vs]  # Label nodes with index
    )

    plt.savefig(path)
    plt.close()

    
    
def add_node(graph, node_id, label, shape='box', style='filled'):
    if label == 8:  
        label = 'input'
        color = 'orchid'
    elif label == 9:
        label = 'output'
        color = 'pink'
    elif label == 0:
        label = 'R'
        color = 'yellow'
    elif label == 1:
        label = 'C'
        color = 'lawngreen'
    elif label == 2:
        label = '+gm+'
        color = 'cyan'
    elif label == 3:
        label = '-gm+'
        color = 'lightblue'
    elif label == 4:
        label = '+gm-'
        color = 'deepskyblue'
    elif label == 5:
        label = '-gm-'
        color = 'dodgerblue'
    elif label == 6:
        label = 'gm_i'
        color = 'aliceblue'
    elif label == 7:
        label = 'gm_o'
        color = 'aliceblue'
    elif label == 10:
        label = 'n'
        color = 'silver'
    else:
        label = ''
        color = 'aliceblue'
    #label = f"{label}\n({node_id})"
    label = f"{label}"
    graph.add_node(
            node_id, label=label, color='black', fillcolor=color,
            shape=shape, style=style, fontsize=24)


def draw_ckt(g, path, backbone=False):
    
    graph = pgv.AGraph(directed=True, strict=True, fontname='Helvetica', arrowtype='open')
    if g is None:
        add_node(graph, 0, 0)
        graph.layout(prog='dot')
        graph.draw(path)
        return
    for idx in range(g.vcount()):
        add_node(graph, idx, g.vs[idx]['type'])
    for idx in range(g.vcount()):
        for node in g.get_adjlist(ig.IN)[idx]:
            if node == idx-1 and backbone:
                graph.add_edge(node, idx, weight=1)
            else:
                graph.add_edge(node, idx, weight=0)
    graph.layout(prog='dot')
    graph.draw(path)

def visualize_and_save_graphs(graph, output_path, layout="fruchterman_reingold", 
                             vertex_size=20, edge_width=1, use_vertex_types=True,
                             show_type_labels=True, vertex_property="type"):
    """
    Visualize an igraph object and save the visualizations by vertex type.
    
    Parameters:
    -----------
    graph : igraph.Graph
        The input graph to visualize
    output_path : str
        Directory path where visualizations should be saved
    layout : str, optional
        The layout algorithm to use (default: "fruchterman_reingold")
    vertex_size : int, optional
        Size of vertices in the visualization (default: 20)
    edge_width : int, optional
        Width of edges in the visualization (default: 1)
    use_vertex_types : bool, optional
        Whether to use vertex types for coloring and subgraphs (default: True)
    show_type_labels : bool, optional
        Whether to display the type as a label above each node (default: True)
    vertex_property : str, optional
        The vertex attribute to use for coloring and labeling (default: "type")
    
    Returns:
    --------
    None
    """
    # # Create the output directory if it doesn't exist
    # if not os.path.exists(output_path):
    #     os.makedirs(output_path)
    #     print(f"Created directory: {output_path}")
    
    # Check if the graph has the specified vertex attribute
    if use_vertex_types and vertex_property not in graph.vs.attributes():
        print(f"Warning: Graph does not have '{vertex_property}' vertex attribute. Using a single visualization.")
        use_vertex_types = False
    
    # Calculate layout for the full graph (to keep consistent positions)
    if layout == "fruchterman_reingold":
        layout_coords = graph.layout_fruchterman_reingold()
    elif layout == "kamada_kawai":
        layout_coords = graph.layout_kamada_kawai()
    elif layout == "circle":
        layout_coords = graph.layout_circle()
    else:
        layout_coords = graph.layout_auto()
    
    # Color mapping function - using named colors instead of HSV conversion
    def get_color_palette(n):
        # Use built-in named colors from matplotlib
        all_colors = list(mcolors.CSS4_COLORS.keys())
        # Remove very light or white-ish colors
        good_colors = [c for c in all_colors if not any(x in c for x in 
                     ['white', 'snow', 'ivory', 'beige', 'linen', 'mint', 'azure', 'aliceblue'])]
        # Use a selection of colors spaced throughout the list
        step = max(1, len(good_colors) // n)
        return [good_colors[i*step % len(good_colors)] for i in range(n)]
    
    # Save visualization of the full graph
    plt.figure(figsize=(12, 10))
    
    if use_vertex_types:
        # Get unique values from the specified property
        property_values = sorted(set(graph.vs[vertex_property]))
        color_palette = get_color_palette(len(property_values))
        value_to_color = {v: color_palette[i] for i, v in enumerate(property_values)}
        
        # Color vertices by property value
        vertex_colors = [value_to_color[v[vertex_property]] for v in graph.vs]
        
        # Set vertex labels to property values if requested
        if show_type_labels:
            graph.vs["label"] = graph.vs[vertex_property]
        
        # Draw the full graph
        ig.plot(
            graph,
            target=plt.gca(),
            layout=layout_coords,
            vertex_size=vertex_size,
            vertex_color=vertex_colors,
            vertex_label=graph.vs[vertex_property] if show_type_labels else None,
            vertex_label_dist=-1.5,  # Position label slightly above the node
            vertex_label_size=8,     # Adjust size of label text
            vertex_label_angle=0,    # Keep text horizontal
            edge_width=edge_width,
            edge_color="gray",
        )
        
        # Save the full graph
        plt.title("Full Graph (colored by vertex type)")
        plt.savefig(output_path, dpi=300, bbox_inches="tight")
        plt.close()
        
    else:
        # If not using types, just visualize the whole graph with one color
        ig.plot(
            graph,
            target=plt.gca(),
            layout=layout_coords,
            vertex_size=vertex_size,
            vertex_color="lightblue",
            edge_width=edge_width,
            edge_color="gray",
        )
        
        plt.title("Graph Visualization")
        plt.savefig(os.path.join(output_path, "graph.png"), dpi=300, bbox_inches="tight")
        plt.close()
    
    print(f"All visualizations saved to {output_path}")
    
# def reset_slice_dict_nodes(batch, keep_nodes):
#     """
#     Recompute the space used by each sample in the batched x attribute
#     """

#     if len(keep_nodes) == len(batch.batch):
#         return batch
    
#     slice_dict = getattr(batch, '_slice_dict').copy()
#     batch_size = batch.batch.max().item() + 1
#     one = batch.batch.new_ones(len(keep_nodes))
#     counts = scatter(one, batch.batch[keep_nodes], dim=0, dim_size=batch_size, reduce='add')
#     # Make sure there is at least one node per graph, be it a pruned node
#     keep_nodes = torch.cat([
#         keep_nodes, slice_dict['x'][torch.nonzero(counts == 0)[:, 0].cpu().numpy()].to(keep_nodes.device)
#     ])
#     counts = counts.clamp(min=1)

#     new_slice_node_idx = torch.cat([slice_dict['x'].new_zeros(1,), counts.cumsum(0).to(slice_dict['x'].device)])

#     slice_dict['x'] = new_slice_node_idx
#     setattr(batch, '_slice_dict', slice_dict)

#     return batch, keep_nodes


def reset_slice_dict_edges(batch):
    """
    Recompute the space used by each sample in the batched edge_index attribute. This supposes that edge are sorted 
    according to their index.
    """

    x_edge_idx = batch.edge_index
    edge_batch = batch.batch[x_edge_idx[0]]
    edge_batch_idx, counts = torch.unique(edge_batch, return_counts=True)
    slice_dict = getattr(batch, '_slice_dict').copy()

    # There should be as many elements in the new slice dict as there are graphs in the batch, hence repeat indices to account
    # for missing samples, i.e. where all nodes are disconnected (this might happen during first denoising steps)
    repeats = torch.cat([torch.cat([edge_batch_idx.new_full((1,), -1), edge_batch_idx]).diff(), edge_batch_idx.new_full((1,), -1)])
    repeats[-1] = batch.num_graphs - edge_batch_idx[-1]

    new_slice_edge_idx = torch.cat([slice_dict['edge_index'].new_zeros(1,), counts.cumsum(0).to(slice_dict['edge_index'].device)])
    new_slice_edge_idx = new_slice_edge_idx.repeat_interleave(repeats.to(slice_dict['edge_index'].device)) 

    slice_dict['edge_index'] = new_slice_edge_idx
    slice_dict['edge_attr'] = new_slice_edge_idx
    setattr(batch, '_slice_dict', slice_dict)

    return batch


def ensure_minimal_edge_count(edge_attr, xt):
    '''
    Search for samples for which all nodes are isolated and randomly sample one edge. This has practically no impact on
    performances (as it concerns only the first denoising steps) but ensures that no error appears in the slice_dict. 
    Params:
        edge_attr: The edge_attr tensor for all upper triangular elements of the adjacency matrix
        xt: Batched data
    Returns:
        Updated version of edge_attr with at least one non-zero edge per sample. 
    '''

    batch_edges = torch.arange(len(xt))
    # Equivalent to batch.batch but for triu_edge_index (space occupied by each sample in the batch)
    lengths = xt._slice_dict['triu_edge_index'].diff()
    # Count how many edges are predicted for each sample
    batch_edges = batch_edges.repeat_interleave(lengths)
    num_edges_per_sample = scatter(edge_attr, batch_edges.to(cfg.device), dim=0, dim_size=len(xt), reduce='add')
    # If no edge is predicted, then pick a random index in the `edge_attr` tensor for each of such samples and replace
    # it with an edge
    sample_indices = torch.nonzero(num_edges_per_sample == 0)[:, 0]
    nonzero_edge_indices = [(xt._slice_dict['triu_edge_index'][i] + np.random.randint(lengths[i])).item() for i in sample_indices]
    edge_attr[nonzero_edge_indices] = 1

    return edge_attr


@torch.no_grad()
def add_full_rrwp(data, walk_length):

    device = data.edge_index.device
    attr_name = "rrwp"
    num_nodes = data.num_nodes
    edge_index, edge_weight = data.edge_index, data.edge_weight
    if cfg.framework.type == 'vfm':
        edge_weight = F.softmax(data.edge_attr, dim=-1)[:, 0]
 
    adj = SparseTensor.from_edge_index(edge_index, edge_weight,
                                       sparse_sizes=(num_nodes, num_nodes),
                                       )

    # Compute D^{-1} A:
    deg = adj.sum(dim=1)
    if cfg.posenc_RRWP.enable:
        deg_inv = 1.0 / adj.sum(dim=1)
        deg_inv[deg_inv == float('inf')] = 0
        adj = adj * deg_inv.view(-1, 1)
        adj = adj.to_dense()

        i = 1
        pe_list = [torch.eye(num_nodes, dtype=torch.float).to(device)]

        out = adj
        pe_list.append(adj.to(device))

        if walk_length > 2:
            for j in range(i + 1, walk_length):
                out = out @ adj
                pe_list.append(out)

        pe = torch.stack(pe_list, dim=-1) # n x n x k

        abs_pe = pe.diagonal().transpose(0, 1) # n x k

        rel_pe = SparseTensor.from_dense(pe, has_value=True)
        rel_pe_row, rel_pe_col, rel_pe_val = rel_pe.coo()
        # rel_pe_idx = torch.stack([rel_pe_row, rel_pe_col], dim=0)
        rel_pe_idx = torch.stack([rel_pe_col, rel_pe_row], dim=0)
        # the framework of GRIT performing right-mul while adj is row-normalized, 
        #                 need to switch the order or row and col.
        #    note: both can work but the current version is more reasonable.

        if (cfg.framework.type == 'defog') and (cfg.posenc_RRWP.spse):
            data = add_spse(data, walk_length)
        else:
            data = add_node_attr(data, rel_pe_idx, attr_name=f"{attr_name}_index")
            data = add_node_attr(data, rel_pe_val, attr_name=f"{attr_name}_val")
        
        data = add_node_attr(data, abs_pe, attr_name=attr_name)

    data.log_deg = torch.log(deg + 1)
    data.deg = deg.type(torch.long)

    return data


def add_spse(data, walk_length, count=False):

    device = data.edge_index.device
    graphs = []
    for i in range(len(data)):
        graph_i = data.get_example(i).clone()
        nx_i = to_networkx(graph_i)
        indices, val = get_spse_emb(nx_i, walk_length, count)
        graph_i.rrwp_index = indices.to(device)
        graph_i.rrwp_val = val.to(device)
        graphs.append(graph_i)
        
    return Batch.from_data_list(graphs)  


def get_spse_emb(nx_graph, walk_length, count):
    
    nodes = list(nx_graph.nodes)
    
    indices, val = [], []
    for i in nodes:
        j = 0
        while j < i:
            paths = list(nx.all_simple_paths(nx_graph, source=j, target=i, cutoff=walk_length))
            l = [len(p) - 1 for p in paths]
            idx, counts = np.unique(l, return_counts=True)
            spse_emb = np.zeros(walk_length)
            if count:
                spse_emb[(idx - 1).astype(np.long)] = counts
            else:
                spse_emb[(idx - 1).astype(np.long)] = 0.5 * np.log(1 + counts)
            indices.extend([(i, j), (j, i)])
            val.extend([spse_emb, spse_emb])
            j += 1

    return torch.tensor(indices).T, torch.tensor(np.array(val)).to(torch.float)


def embed_1D_scalar(t, dim, max_period):
    """
    Create sinusoidal timestep embeddings.
    :param t: a 1-D Tensor of N indices, one per batch element.
                        These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an (N, D) Tensor of positional embeddings.
    """
    # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
    half = dim // 2
    freqs = torch.exp(
        -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
    ).to(device=t.device)
    args = t[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    if dim % 2:
        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    return embedding


def nodes_per_batch_sample(batch):
    # Faster than torch bincount
    batch_size = batch.batch.max().item() + 1
    one = batch.batch.new_ones(batch.batch.size(0))
    nodes_for_graph = scatter(one, batch.batch, dim=0, dim_size=batch_size, reduce='add')
    return nodes_for_graph


#### For AnalogGenie topology model

node2pins = {
    "PM": ["PM_D", "PM_G", "PM_S", "PM_B"],
    "NM": ["NM_D", "NM_G", "NM_S", "NM_B"],
    "NPN": ["NPN_C", "NPN_B", "NPN_E"],
    "PNP": ["PNP_C", "PNP_B", "PNP_E"],
    "DIO": ["DIO_P", "DIO_N"],
    "XOR": ["XOR_A", "XOR_B", "XOR_VDD", "XOR_VSS", "XOR_Y"],
    "PFD": ["PFD_A", "PFD_B", "PFD_QA", "PFD_QB", "PFD_VDD", "PFD_VSS"],
    "INVERTER": ["INVERTER_A", "INVERTER_Q", "INVERTER_VDD", "INVERTER_VSS"],
    "TRANSMISSION_GATE": ["TRANSMISSION_GATE_A", "TRANSMISSION_GATE_B", "TRANSMISSION_GATE_C", "TRANSMISSION_GATE_VDD", "TRANSMISSION_GATE_VSS"],
}

n_pins_dict = {"C" : 2, "R" : 2, "L" : 2}
n_pins_dict.update({k: len(v) for (k, v) in node2pins.items()})

def analyze_graph(graph):

    g = graph.as_undirected()

    # Check connectivity
    is_connected = g.is_connected()
    
    # Check VSS presence
    # has_vss = "VSS" in graph.vs["type"]
    has_vss = 0 in g.vs["type"]
    
    # Check pin connections
    pin_connection_valid = check_pin_connections(g)
    
    # # Check for isolated nodes
    # has_isolated_nodes = check_isolated_nodes(graph)
    
    # Check if graph meets all validity criteria
    is_valid = has_vss and is_connected and pin_connection_valid # and not has_isolated_nodes
    
    return is_valid


def check_pin_connections(graph):
    """Check if all nodes respect their pin constraints"""

    types = [ID_TO_NAME_NODES[t] for t in graph.vs['type']]

    for i, t in enumerate(types):
    # for node in graph.vs:
    #     if node["type"] in node2pins:
        if t in n_pins_dict.keys():

            max_n_pins = n_pins_dict[t]
            neighbors = graph.neighbors(i)
            # neighbors_type = [types[n] for n in neighbors]
            # connection2discard = 0
            # for nt in neighbors_type:
            #     if nt in node2pins:
            #         connection2discard +=1

            if len(neighbors) > max_n_pins:# - connection2discard:
                return False
    return True

def check_isolated_nodes(graph):
    """Check if graph has isolated nodes (degree 0)"""
    degrees = graph.degree()
    return any(d == 0 for d in degrees)


def simul_outputs_to_bin_idx(simul_out, nbins=4):
    '''
    Takes a numpy array of dimension (n_samples x 3) as input and returns a discretized torch tensor of the corresponding categories
    '''

    # Bins - those are pre-computed from the dataset deciles
    if nbins == 4:
        bins_gain = np.array([0.986, 1.649, 2.07, 2.396])
        bins_ugw = np.array([1.259,  5.754, 13.183, 32.359])
        bins_pm = np.array([1.128, 2.026, 3., 4.])
    else:
        bins_gain = np.array([0.613, 0.903, 1.253, 1.478, 1.641, 1.848, 1.988, 2.138, 2.258, 2.396])
        bins_ugw = np.array([ 0.275, 0.776, 1.905, 3.715, 5.754, 8.318, 11.22, 15.136, 21.38, 32.359])
        bins_pm = np.array([0.661, 1.002, 1.363, 1.836, 2.036, 2.431, 2.884, 3.068, 3.604, 4.])

    cat_gain, cat_ugw, cat_pm = np.digitize(simul_out[:, 0], bins_gain, right=True), \
        np.digitize(simul_out[:, 1], bins_ugw, right=True), \
        np.digitize(simul_out[:, 2], bins_pm, right=True)
    cat_sizes = torch.from_numpy(np.stack([cat_gain, cat_ugw, cat_pm], axis=1)).clamp(max=len(bins_pm) - 1)

    return cat_sizes


def load_classifier(classifier_path):

    # if not (cfg.gt.conditional_gen and cfg.train.use_classifier):# (cfg.gt.conditioning_loss == 'cg')):
    #     return None

    # classifier_path = cfg.train.classifier_path

    # Load classifier configuration file
    cfg_classifier = CN()
    cfg_filep = os.path.join(classifier_path, 'config.yaml')
    cfg_classifier.set_new_allowed(True)
    cfg_classifier.merge_from_file(cfg_filep)

    # Create classifier and load weights
    classifier = create_model_from_cfg(cfg_classifier)
    ckpt_path = glob.glob(classifier_path + '/*.ckpt')[0]
    checkpoint = torch.load(ckpt_path)
    classifier.load_state_dict(checkpoint['model_state'])
    classifier.eval()
    # # Freeze parameters
    # for p in classifier.parameters():
    #     p.requires_grad_(False)
    # classifier.to(torch.device(cfg.device))
    print(f"Loaded checkpoint from {ckpt_path}.")

    return classifier


class GraphGymModule(torch.nn.Module):
    '''
    Wrapping class to mimic GraphGymModule from PyGym
    '''
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)


def create_model_from_cfg(model_cfg):
    r"""Create model from an existing configuration file.

    Args:
    cfg: The config dictionnary
        
    """
    dim_in = model_cfg.gt.dim_hidden
    dim_out = model_cfg.share.dim_out
    model = network_dict[model_cfg.model.type](dim_in=dim_in, dim_out=dim_out, model_cfg=model_cfg)
    
    return GraphGymModule(model)


def scale_x_features(batch, shift=50):
    batch.x_features = (batch.x_features.float() - shift) / shift
    return batch