
import torch, torch_sparse
import numpy as np
import torch.nn.functional as F
from torch_geometric.data import Data, Batch
from torch.distributions.categorical import Categorical
from torch_sparse import SparseTensor
from scipy.optimize import linear_sum_assignment
from sklearn.metrics import accuracy_score
from .pyspice_utils import simulation
from .utils import reset_slice_dict_edges, torch_geometric_to_igraph, gym_to_igraph, add_full_rrwp, simul_outputs_to_bin_idx, load_classifier, \
    scale_x_features
from .metric_ocb import  compute_VUN, is_graph_valid, novelty_ratio, our_is_valid_circuit, unique_ratio
from .run_model_eval import to_dag
from .loader.datasets.analogenie_dataset import NAME_TO_ID_NODES, ID_TO_NAME_NODES, node2pins, NAME_TO_ID_PINS


NET_TYPES_ID = [8, 9, 10]
PIN_TYPES_ID = [6, 7]

NET_TYPES_ID_GENIE = [
    NAME_TO_ID_NODES[n] for n in ["VDD", "VSS", "VOUT", "VB", "net", "IB", "VCONT", "IOUT", "IIN", "VCM", "VREF", "IREF", "VCLK", "VRF", "VLO", "VIF", "VBB"]
]

MEAN_NODE_NUMBER = {'ocb_CktBench101': [12, 2.5], "AnaloGenie_pins" : [106, 54], "AnalogGenie": [37, 17]} # mean, std


def inference(model, euler_steps, num_samples=100, noise_e=0.1, noise_x=0.1, n_pow_e=8, n_pow_x=6, n_pow_f=2, current_t_x=0, current_t_e=0, 
              current_t_f=0, xt=None, n_nodes=None, y_guidance=None, classifier_path=None):
    '''
    Performs model inference.

    Args:
        model: The trained model
        xt: Current state of the batch (torch_geometric.data.Batch, None if denoising from prior distribution.)
        current_t: Current time value (float, should be < 1.0)
        euler_steps: Number of Euler steps to perform
        noise: Noise level for nodes and edges
        n_pow: Power for time discretization
        spec: Optional conditioning data
        classifier_path: directory path for classifier guidance / if not none, spec must be given as a list of 3 elements: [cls_gain, cls_bw, cls_pm]
    Returns:
        Updated xt after euler_steps
    '''

    # # Conditioning signal
    # c_bool = spec is not None
    cfg = model.cfg
    # Assert initial x_t is provided if needed
    if current_t_x > 0 or current_t_e > 0 or current_t_f > 0:
        assert xt is not None, 'A initial circuit must be provided if starting from t > 0.'
    if xt is not None:
        num_samples = xt.num_graphs
        xt = scale_x_features(xt.clone())

    # Init steps
    steps = 0
    t_x, t_e, t_f = current_t_x, current_t_e, current_t_f

    # Create time discretization for the remaining time
    remaining_time_x, remaining_time_e, remaining_time_f = 1.0 - current_t_x, 1.0 - current_t_e, 1.0 - current_t_f
    if max(remaining_time_x, remaining_time_e, remaining_time_f) <= 0:
        return 
    
    # Maybe load classifier for classifier guidance
    classifier = None
    if classifier_path is not None:
        classifier = load_classifier(classifier_path)
        classifier.to(cfg.device)
    
    ### Create time steps for the remaining interval

    ut = torch.linspace(0, 1, euler_steps)

    # Nodes
    t_list_normalized_x = 1 - ((1 - ut) ** n_pow_x)
    # Scale to the remaining time interval
    t_list_x = current_t_x + t_list_normalized_x * remaining_time_x
    dts_x = t_list_x.diff()

    # Edges
    t_list_normalized_e = 1 - ((1 - ut) ** n_pow_e)
    # Scale to the remaining time interval
    t_list_e = current_t_e + t_list_normalized_e * remaining_time_e
    dts_e = t_list_e.diff()

    # Features
    t_list_normalized_f = 1 - ((1 - ut) ** n_pow_f)
    # Scale to the remaining time interval
    t_list_f = current_t_f + t_list_normalized_f * remaining_time_f
    dts_f = t_list_f.diff()

    if (current_t_x == 0) or (current_t_e == 0) or (current_t_f == 0):
        # if cfg.train.noise_feat_only:
        #     xt = sample_features_from_prior(cfg, xt)
        #     num_samples = xt.num_graphs
        # else:
        xt = draw_from_prior(cfg, num_samples=num_samples, y_guidance=y_guidance, t_x=t_x, t_e=t_e, t_f=t_f, xt=xt, n_nodes=n_nodes)
        # if c_bool:
        #     xt = preprocess_batch_y(cfg, xt)
        xt.to(cfg.device)

    while (t_x < 1.0) | (t_e < 1.0) | (t_f < 1.0):

        dt_x, dt_e, dt_f = dts_x[steps], dts_e[steps], dts_f[steps] # will be 0 if t_* has reached 1
        
        xt = add_full_rrwp(xt.clone(), walk_length=cfg.posenc_RRWP.ksteps)
        # xt.t = torch.full((num_samples,), t).to(cfg.device)
        # if hasattr(cfg.gt, 'sample_separate_t') and cfg.gt.sample_separate_t:
        xt.t_x = torch.full((xt.num_nodes,), t_x).to(cfg.device)
        xt.t_f = torch.full((xt.num_nodes,), t_f).to(cfg.device)
        xt.t_e = torch.full((xt.triu_edge_index.shape[1],), t_e).to(cfg.device)

        # if cfg.dataset.get("task_type", '') == 'pin_prediction':
        #     xt.t_e = xt.t_e * xt.triu_learnable_edge_attr + xt.t_e.new_ones(len(xt.t_e)) * (1 - xt.triu_learnable_edge_attr)
        if hasattr(xt, 'learnable_x'):
            xt.t_x = xt.t_x * xt.learnable_x + xt.t_x.new_ones(len(xt.t_x)) * (1 - xt.learnable_x)
        if hasattr(xt, 'supernode_x_index'):
            xt.t_x[xt.supernode_x_index] = 1.0
            xt.t_f[xt.supernode_x_index] = 1.0
        if hasattr(xt, 'triu_learnable_edge_attr'):
            xt.t_e = xt.t_e * xt.triu_learnable_edge_attr + xt.t_e.new_ones(len(xt.t_e)) * (1 - xt.triu_learnable_edge_attr.float())
        xt.t_e = xt.t_e.repeat(2)
        
        # Model forward
        with torch.no_grad():
            logits = model(xt.clone()) # (B, D, S)
            logits_dict = {'logits': logits}

        # Unconditional forward for classifier-free guidance
        if (y_guidance is not None) and (cfg.gt.conditioning_loss == 'cfg'):
            with torch.no_grad():
                logits_uncond = model(xt.clone(), unconditional_prop=1)
                logits_dict.update({'logits_uncond': logits_uncond})

        # Denoise given current state and model output
        xt = update_xt(cfg, xt=xt, logits_dict=logits_dict, t_x=t_x, t_e=t_e, t_f=t_f, dt_x=dt_x, dt_e=dt_e, 
                       dt_f=dt_f, noise_nodes=noise_x, noise_edges=noise_e, classifier=classifier, y_guidance=y_guidance)
        
        # if (cfg.gt.node_pruning > 0) and (t + dt >= 1.0):
        #     xt = prune(cfg, xt)
        if t_x + dt_x >= 1.0:
            xt.x = xt.x.long()
        
        # Advange t only where necessary
        if remaining_time_x > 0:
            t_x = min(1, t_x + dt_x)
        if remaining_time_e > 0:
            t_e = min(1, t_e + dt_e)
        if remaining_time_f > 0:
            t_f = min(1, t_f + dt_f)

        steps += 1

    return xt


def update_xt(cfg, logits_dict, **kwargs):

    x1_probs = F.softmax(logits_dict['logits'].x, dim=-1) # (B, D, S)
    e1_probs = F.softmax(logits_dict['logits'].edge_attr, dim=-1)
    x1_probs_dict, e1_probs_dict = {'x1_probs': x1_probs}, {'e1_probs': e1_probs}
    # Mb device sizing
    if cfg.gt.get("sizing", False):
        x1_features = logits_dict['logits'].x_features#.clip(min=1e-1, max=cfg.dataset.nnode_features - 1)

    # Conditional gen
    if 'logits_uncond' in logits_dict.keys():
        x1_probs_uncond = F.softmax(logits_dict['logits_uncond'].x, dim=-1) # (B, D, S)
        e1_probs_uncond = F.softmax(logits_dict['logits_uncond'].edge_attr, dim=-1)
        # Mb device sizing
        if cfg.gt.get("sizing", False):
            x1_features_uncond = logits_dict['logits_uncond'].x_features#.clip(min=1e-1, max=cfg.dataset.nnode_features - 1)
            x1_features = cfg.gt.guidance_strength * x1_features + (1 - cfg.gt.guidance_strength) * x1_features_uncond
        if cfg.framework.type == 'vfm':
            x1_probs = cfg.gt.guidance_strength * x1_probs + (1 - cfg.gt.guidance_strength) * x1_probs_uncond
            e1_probs = cfg.gt.guidance_strength * e1_probs + (1 - cfg.gt.guidance_strength) * e1_probs_uncond
        else:
            x1_probs_dict.update({'x1_probs_uncond': x1_probs_uncond})
            e1_probs_dict.update({'e1_probs_uncond': e1_probs_uncond})

    kwargs.update({'x1_features': x1_features if cfg.gt.get("sizing", False) else None})

    # if cfg.train.noise_feat_only:
    #     return update_xt_sizing(**kwargs)
    if cfg.framework.type == 'vfm':
        kwargs.update({'logits': logits_dict['logits'], 'x1_probs': x1_probs, 'e1_probs': e1_probs})
        return update_xt_vfm(cfg, **kwargs)
    elif cfg.train.prior == 'marginal':
        kwargs.update({'x1_probs_dict': x1_probs_dict, 'e1_probs_dict': e1_probs_dict})
        return update_xt_marginal(cfg, **kwargs)
    

def update_xt_features(cfg, xt, x1_features, t, dt, classifier=None, y_guidance=None):
    if t == 1:
        return xt.x_features
    else:
        t_f = xt.t_f
        non_one_indices = t_f != 1
        features = xt.x_features
        # new_features = xt.x_features[non_one_indices] * (1 - dt / (1 - t[non_one_indices, None])) + \
        #     dt * x1_features[non_one_indices] / (1 - t[non_one_indices, None])

        # Maybe classifier guidance
        if classifier is not None:
            with torch.enable_grad():
                classifier.zero_grad()
                cls_inpt = xt.clone()
                cls_inpt.x_features.requires_grad_(True)
                # Classifier forward
                pred_xt, _ = classifier(cls_inpt.clone())
                pred_xt_probs = torch.log(F.softmax(pred_xt, dim=-1))
                # Compute mean probabilities then backpropagate gradients
                pred_xt_probs = torch.stack(
                    [pred_xt_probs[range(len(pred_xt)), i, y_guidance[range(len(pred_xt)), i]] \
                        for i in range(y_guidance.shape[1])], dim=1)
                # pred_xt_probs = pred_xt_probs[range(len(pred_xt)), 0, y_guidance[range(len(pred_xt)), 0]]
                pred_xt_probs = pred_xt_probs.mean(dim=-1, keepdim=True)
                mean = pred_xt_probs.mean()
                mean.backward()
                # Finally, update features in the direction of the gradient of log p(y|x_f)
                x1_features = x1_features + cfg.gt.guidance_strength_features * cls_inpt.x_features.grad

        new_features = xt.x_features[non_one_indices] + dt * x1_features[non_one_indices]
        
        features[non_one_indices] = new_features

        # If it's the last step then map back to the original data range [0; 100]
        if t >= 1.0 - dt:
            features = (features * 50 + 50).int().clip(min=1, max=100)

        return features


def update_xt_sizing(xt, x1_features, t, dt, **kwargs):
    xt_feats = update_xt_features(xt, x1_features, t, dt)
    xt.x[:, 1:] = xt_feats
    return xt


def sample_features_from_prior(cfg, xt):
    xt.x = xt.x.float()
    x_features_0 = torch.rand((len(xt.x), cfg.dataset.node_features_dim - 1), device=xt.x.device) * 2 - 1 #cfg.dataset.nnode_features
    xt.x[:, 1:] = x_features_0#.clip(min=1e-1)
    return xt


def classifier_free_guidance_continuous(y_cond, y_uncond, guidance_scale=3.0):
    return y_uncond + guidance_scale * (y_cond - y_uncond)


def update_xt_marginal(cfg, xt, x1_probs_dict, e1_probs_dict, x1_features, t_x, t_e, t_f, dt_x,
                       dt_e, dt_f, noise_nodes, noise_edges, **kwargs):

    # Nodes type
    if t_x < 1:
        x1 = sample_nodes_marginal(cfg, xt, x1_probs_dict, t_x, dt_x, noise_nodes, **kwargs)
        xt.x = x1[:, None].float()

    # Device sizes
    if t_f < 1:
        xt_feats = update_xt_features(cfg, xt, x1_features, t_f, dt_f, **kwargs) if cfg.gt.get("sizing", False) else xt.x.new_zeros((len(xt.x), 1))
        xt.x_features = xt_feats

    # Edges
    if t_e < 1:
        xt_edge_idx, xt_edge_attr = sample_edges_marginal(cfg, xt, e1_probs_dict, t_e, dt_e, noise_edges)

        # Prune edges between nodes of the same type - almost useless in practice, the model has learned to prune these edges already
        # if t_e >= 1.0 - dt_e:
        #     node_type = x1.new_zeros(x1.shape)
        #     for net_id in NET_TYPES_ID:
        #         node_type = node_type + (x1 == net_id).int()
        #     for net_id in PIN_TYPES_ID:
        #         node_type = node_type + 2 * (x1 == net_id).int()
        #     i_type = node_type[xt_edge_idx[0, :]]
        #     o_type = node_type[xt_edge_idx[1, :]]
        #     xt_edge_attr[i_type == o_type] = 0

        keep_edges = torch.nonzero(xt_edge_attr > 0)[:, 0]

        xt_edge_idx = xt_edge_idx[:, keep_edges]
        xt_edge_attr = xt_edge_attr[keep_edges]


        xt_edge_idx = torch.cat([xt_edge_idx, torch.flip(xt_edge_idx, dims=[0])], dim=1)
        # Sort edge indices
        sorted_e, sorted_idx = xt_edge_idx.sort()
        xt_edge_idx = torch.stack([sorted_e[0], xt_edge_idx[1, sorted_idx[0]]])

        xt_edge_attr = torch.cat([xt_edge_attr, xt_edge_attr], dim=0)

        xt.edge_index = xt_edge_idx
        xt.edge_attr = xt_edge_attr
        xt = reset_slice_dict_edges(xt)

    return xt


def update_xt_vfm(cfg, xt, logits, x1_probs, e1_probs, x1_features, t_x, t_e, dt_x, dt_e, **kwargs): # TODO update w/ sep t_x / t_e / t_f

    # Nodes
    if t_x + dt_x < 1.0:
        xt.xt_logits = xt.xt_logits * (1 - dt_x / (1 - t_x)) + dt_x * x1_probs / (1 - t_x)
    else:
        # if t = 1, probabilities are (re-)computed from the output from the conditional model only;
        # --> this is specific to the VFM framework which outputs categorical probabilities instead of an actual
        # velocity field, which no longer verify the simplex condition after applying the classifier-free guidance combination
        x1_probs = F.softmax(logits.x, dim=-1) # (B, D, S)
        xt.x[:, 0] = Categorical(probs=x1_probs).sample((1,)).squeeze()

    # Device sizes
    xt_feats = update_xt_features(xt, x1_features, t, dt) if cfg.dataset.node_features_dim > 1 else xt.x.new_zeros((len(xt.x), 1))
    xt.x[:, 1:] = xt_feats

    # Edges
    if t + dt < 1.0:
        xt_edge_idx = xt.edge_index[:, ::2]
        xt_edge_idx = torch.stack([xt_edge_idx.min(dim=0)[0], xt_edge_idx.max(dim=0)[0]])
        xt_edge_attr = xt.edge_attr[::2]
        xt_edge_idx, xt_edge_attr = torch_sparse.coalesce(
            torch.cat([xt_edge_idx, logits.edge_index], dim=1),
            torch.cat([xt_edge_attr * (1 - dt / (1 - t)), dt * e1_probs / (1 - t)], dim=0),
            xt.num_nodes, xt.num_nodes,
            op="sum"
        )
    else:
        e1_probs = F.softmax(logits.edge_attr, dim=-1)
        xt_edge_idx = logits.edge_index
        xt_edge_attr = Categorical(probs=e1_probs).sample((1,)).squeeze()
        
        # Prune edges between nodes of the same type
        node_type = xt.x.new_zeros(xt.x[:, 0].shape)
        for net_id in NET_TYPES_ID:
            node_type = node_type + (xt.x[:, 0] == net_id).int()
        for net_id in PIN_TYPES_ID:
            node_type = node_type + 2 * (xt.x[:, 0] == net_id).int()
        i_type = node_type[xt_edge_idx[0, :]]
        o_type = node_type[xt_edge_idx[1, :]]
        xt_edge_attr[i_type == o_type] = 0

        xt_edge_idx = xt_edge_idx[:, xt_edge_attr > 0]
        xt_edge_attr = xt_edge_attr[xt_edge_attr > 0]

    # Flip and concat with the same ordering as when drawing from p0
    xt_edge_idx = xt_edge_idx.repeat_interleave(2, dim=1)
    xt_edge_idx[:, ::2] = torch.flip(xt_edge_idx[:, ::2], dims=[0])
    xt_edge_attr = xt_edge_attr.repeat_interleave(2, dim=0)

    xt.edge_index = xt_edge_idx
    xt.edge_attr = xt_edge_attr
    xt = reset_slice_dict_edges(xt)

    return xt    


def draw_from_prior(cfg, **kwargs):
    
    if cfg.framework.type == 'vfm':
        return generate_random_graph_logits(cfg, **kwargs)
    elif cfg.train.prior == 'masked':
        return generate_masked_graph(cfg, **kwargs)
    elif cfg.train.prior == 'marginal':
        return generate_marginal_graph(cfg, **kwargs)


def generate_marginal_graph(cfg, num_samples, y_guidance, t_x, t_e, t_f, xt, n_nodes):
    
    # Generate a random number of nodes
    if n_nodes is None:
        mu_nnodes, std_nnodes = MEAN_NODE_NUMBER[cfg.dataset.name]
        min_nnodes, max_nnodes = mu_nnodes - int(2 * std_nnodes), mu_nnodes + int(2 * std_nnodes)
        nnodes = (np.random.normal(size=num_samples) * std_nnodes + mu_nnodes).astype(int).clip(min=min_nnodes, max=max_nnodes)
    elif type(n_nodes) == int:
        nnodes = [n_nodes] * num_samples
    else:
        # Then it must be a list of length num_samples
        nnodes = n_nodes
    batch_list = []
    for i in range(num_samples):
        
        data_dict = {}
        # Nodes
        if t_x == 0:
            # If xt is provided then sample only additional node - for a full denoising then xt must be None or t_x > 0.
            num_nodes = nnodes[i] if xt is None else nnodes[i] - xt.get_example(i).num_nodes # will throw an error if num_nodes if too small.
            probs = torch.tensor(cfg.node_type_pmf).to(torch.device(cfg.device))
            categorical_dist = Categorical(probs=probs)
            x = categorical_dist.sample((num_nodes, 1)).float()
            if xt is not None:
                x_i = xt.get_example(i)
                anchor_x = torch.cat([x_i.x.new_ones(len(x_i.x),), x_i.x.new_zeros(len(x),)])
                data_dict.update({'learnable_x': 1 - anchor_x})
                x = torch.cat([x_i.x, x], dim=0)
                num_nodes = len(x)
            if cfg.gt.get('conditional_gen', False) and cfg.gt.get('supernode', False):
                if cfg.gnn.n_spec == 1:
                    mult = torch.tensor([0, 0, 0])
                    mult[cfg.gnn.spec_dim] = 1
                else:
                    mult = torch.tensor([1, cfg.gnn.n_bins, cfg.gnn.n_bins ** 2])
                supernode_type = (y_guidance[[i]] * mult).sum(dim=1, keepdim=True).to(cfg.device)
                x = torch.cat([x, supernode_type])
                data_dict.update({'supernode_x_index': torch.tensor(num_nodes)})
                num_nodes += 1
        else:
            x_i = xt.get_example(i)
            num_nodes = x_i.num_nodes
            x = x_i.x
        data_dict.update({'x': x})

        # Features
        if t_f == 0:
            if cfg.gt.get("sizing", False):
                x_features = torch.rand((num_nodes, 1), device=cfg.device) * 2 - 1 # Range [-1; 1]
                if cfg.gt.get('conditional_gen', False) and cfg.gt.get('supernode', False):
                    x_features[-1] = 0 # Put supernode features to zero.
                # x_features = x_features.clip(min=1e-1)
            else:
                x_features = None
        else:
            x_i = xt.get_example(i)
            x_features = x_i.x_features
        data_dict.update({'x_features': x_features})

        # Edges
        if t_e == 0:
            row, col = torch.triu_indices(num_nodes, num_nodes, offset=1)
            all_connections = torch.stack((row, col), dim=0)
            x0_edge_attr = (torch.rand(all_connections.size(1)) < cfg.edge_ratio).long()
            if not x0_edge_attr.any(): # Ensure not all edges are zero (messes with slice_dict)
                x0_edge_attr[np.random.randint(len(x0_edge_attr))] = 1

            # If xt is provided then two use cases: circuit completion if new nodes were added, \
            # in which case existing edges are preserved, else link prediction.
            if (xt is not None) and (xt.get_example(i).num_nodes < num_nodes):
                x_i = xt.get_example(i)
                directed_index = x_i.edge_index[0] < x_i.edge_index[1]
                all_connections, anchor_e = torch_sparse.coalesce(
                    torch.cat([all_connections.to(x_i.edge_index.device), x_i.edge_index[:, directed_index]], dim=1),
                    torch.cat([x_i.edge_index.new_zeros(len(x0_edge_attr)), x_i.edge_attr[directed_index]], dim=0),
                    num_nodes, num_nodes,
                    op="max"
                )
                x0_edge_attr = (anchor_e + x0_edge_attr.to(x_i.edge_index.device)).clamp(max=1)
                data_dict.update({'triu_learnable_edge_attr': 1 - anchor_e})

            # Suppress 0 edges, then flip
            x0_edge_idx = all_connections[:, x0_edge_attr > 0].clone()
            x0_edge_attr = x0_edge_attr[x0_edge_attr > 0]

            x0_edge_idx = torch.cat([x0_edge_idx, torch.flip(x0_edge_idx, dims=[0])], dim=1)
            x0_edge_attr = torch.cat([x0_edge_attr, x0_edge_attr], dim=0)
        else:
            x_i = xt.get_example(i)
            x0_edge_idx= x_i.edge_index
            x0_edge_attr= x_i.edge_attr
            all_connections= x_i.triu_edge_index
        data_dict.update({'edge_index': x0_edge_idx, 'edge_attr': x0_edge_attr, 'triu_edge_index': all_connections})
        if y_guidance is not None:
            data_dict.update({'y': y_guidance[[i]]})

        if cfg.gt.get('conditional_gen', False) and cfg.gt.get('supernode', False): # TODO currently not compatible w/ completion
            data_dict.update({'triu_learnable_edge_attr': all_connections[1] != (num_nodes - 1)})

        # Create a graph data object
        graph = Data(**data_dict)
        batch_list.append(graph)

    data_batch = Batch.from_data_list(batch_list)
    
    return data_batch


def generate_masked_graph(cfg, num_samples):

    nnodes = (np.random.normal(size=num_samples) * 2 + 10).astype(int).clip(min=4, max=16)
    batch_list = []
    for i in range(num_samples):
        # Generate a random number of nodes between 8 and 14
        num_nodes = nnodes[i]
        
        # Generate all possible edges
        row, col = torch.meshgrid(torch.arange(num_nodes), torch.arange(num_nodes))
        mask = row != col
        row, col = row[mask], col[mask]
        
        # Combine row and col into edges
        edges = torch.stack([row, col], dim=0)
        # Create an edge attribute tensor initialized to the specified value
        edge_attributes = torch.full((edges.size(1),), cfg.dataset.nedge_types - 1)
        
        # Create a full masked node feature tensor
        x = torch.full((num_nodes, cfg.dataset.node_features_dim), cfg.dataset.nnode_features - 1)
        x[:, 0] = cfg.dataset.nnode_types - 1

        # Finally, add upper triangular connections (useful for edge remasking)
        row, col = torch.triu_indices(num_nodes, num_nodes, offset=1)
        all_connections = torch.stack((row, col), dim=0)

        # Create a graph data object
        graph = Data(
            x=x,  # Node features filled with the masked value
            edge_index=edges,                      # Edge indices
            edge_attr=edge_attributes,             # Edge attributes
            triu_edge_index=all_connections
        )
        batch_list.append(graph)

    data_batch = Batch.from_data_list(batch_list)
    
    return data_batch


def generate_random_graph_logits(cfg, num_samples, cond_y, t_x, t_e, xt):

    # Generate a random number of nodes
    mu_nnodes, std_nnodes = MEAN_NODE_NUMBER[cfg.dataset.name]
    if cfg.dataset.get("use_pins", False) and 'ocb' in cfg.dataset.name:
        mu_nnodes += 7
        std_nnodes += 1
    min_nnodes, max_nnodes = mu_nnodes - int(2 * std_nnodes), mu_nnodes + int(2 * std_nnodes)
    nnodes = (np.random.normal(size=num_samples) * std_nnodes + mu_nnodes).astype(int).clip(min=min_nnodes, max=max_nnodes)
    batch_list = []
    for i in range(num_samples):

        if t_x == 0:
            num_nodes = nnodes[i]

            # Nodes
            x0 = torch.randn((num_nodes, cfg.dataset.nnode_types)) * 2 - 1

            # Device types and sizes
            x_node = torch.zeros(num_nodes, 1).float()
            # x_feature = torch.randint(low=1, high=cfg.dataset.nnode_features, 
            #                           size=(len(x_node), cfg.dataset.node_features_dim - 1))
            x_feature = torch.rand((len(x_node), cfg.dataset.node_features_dim - 1), device=cfg.device) * cfg.dataset.nnode_features
            x = torch.cat([x_node, x_feature.clip(min=1e-1)], dim=1)
        else:
            x_i = xt.get_example(i)
            x = x_i.x
            num_nodes = x_i.num_nodes
            x0 = x_i.xt_logits

        if t_e == 0:
            # Edges
            row, col = torch.triu_indices(num_nodes, num_nodes, offset=1)
            all_connections = torch.stack((row, col), dim=0)
            x0_edge_attr = torch.randn((all_connections.size(1), 2)) * 2 - 1
            
            # Flip and concat - it is important to preserve the ordering when adding later with model prediction
            x0_edge_idx = all_connections.repeat_interleave(2, dim=1)
            x0_edge_idx[:, ::2] = torch.flip(x0_edge_idx[:, ::2], dims=[0])
            x0_edge_attr = x0_edge_attr.repeat_interleave(2, dim=0)
        else:
            x_i = xt.get_example(i)
            x0_edge_idx = x_i.edge_index
            x0_edge_attr = x_i.edge_attr
            all_connections = x_i.triu_edge_index

        # Create a graph data object
        graph = Data(
            x=x, xt_logits=x0, edge_index=x0_edge_idx,
            edge_attr=x0_edge_attr, triu_edge_index=all_connections, y=cond_y[[i]] if cond_y is not None else None
        )
        batch_list.append(graph)

    data_batch = Batch.from_data_list(batch_list)

    return data_batch

def compute_node_rate_matrix_marginal(xt_oh, x1_probs, t, S, noise, pmf, forbid_states=[6, 7]):

    mj = torch.tensor(pmf).repeat(len(x1_probs), 1).to(xt_oh.device)
    mxt = (xt_oh * mj).sum(dim=1)[:, None]
    pxt = (xt_oh * x1_probs).sum(dim=1)[:, None]

    pj_fact = (1 - mj + mxt) / (S * (1 - t) * mxt.clamp(min=1e-4)) + noise * (t + (1 - t) * mxt) / ((1 - t) * mj.clamp(min=1e-4))
    not_pj_not_pxt_fact = (mxt - mj).clamp(min=0) / (S * (1 - t) * mxt.clamp(min=1e-4))
    
    # Marginal rate matrix
    rm = pj_fact * x1_probs + noise * pxt + not_pj_not_pxt_fact * (1 - x1_probs - pxt)
    rm[:, forbid_states] = 0
    # Handle stationary transitions
    rm = rm * (1 - xt_oh) - xt_oh * (rm * (1 - xt_oh)).sum(dim=1)[:, None]

    return rm


def sample_nodes_marginal(cfg, xt, x1_probs_dict, t, dt, noise, pmin=0, classifier=None, y_guidance=None):

    S = cfg.dataset.nnode_types
    num_classes = cfg.dataset.nnode_types
    pmf = cfg.node_type_pmf,
    guidance_strength = cfg.gt.guidance_strength
    forbid_states = [] if (cfg.dataset.get("use_pins", False) or cfg.dataset.name == "AnalogGenie") else [6, 7]


    if t >= 1.0 - dt:
        noise = 0

    t = xt.t_x
    non_one_indices = t != 1

    x = xt.x[:, 0].long()

    xt_oh = F.one_hot(x[non_one_indices], num_classes=num_classes)
    rm = compute_node_rate_matrix_marginal(xt_oh, x1_probs_dict['x1_probs'][non_one_indices], t[non_one_indices, None], 
                                           S, noise, pmf, forbid_states)
    # Unconditional rate matrix in case of classifier-free guidance
    if 'x1_probs_uncond' in x1_probs_dict.keys():
        rm_uncond = compute_node_rate_matrix_marginal(xt_oh, x1_probs_dict['x1_probs_uncond'][non_one_indices], 
                                                      t[non_one_indices, None], S, noise, pmf, forbid_states)
        rm = (rm.clamp(min=0) ** guidance_strength) * (rm_uncond.clamp(min=1e-5) ** (1 - guidance_strength))
        rm = rm * (1 - xt_oh) - xt_oh * (rm * (1 - xt_oh)).sum(dim=1)[:, None]
    elif classifier is not None:
        rm = get_guided_rates(cfg, xt, xt_oh, classifier, rm, y_guidance=y_guidance, forbid_states=forbid_states)

    # Ensure all probs are positive and sum to one by adjusting dt if necessary
    dt_min = (1 - pmin) / (rm * (1 - xt_oh)).sum(dim=1)[:, None]
    adjusted_dt = torch.min(torch.cat([dt_min, dt_min.new_full(dt_min.shape, dt)], dim=1), dim=1)[0]

    probs = xt_oh + rm * adjusted_dt[:, None]
    categorical_dist = Categorical(probs=probs)
    samples = categorical_dist.sample((1,)).squeeze()

    x[non_one_indices] = samples

    return x


def compute_edge_rate_matrix_marginal(mj, mxt, edge_attr_oh, e1_probs, t, noise):

    pxt = (edge_attr_oh * e1_probs).sum(dim=1)[:, None]

    pj_fact = (1 - mj + mxt) / (2 * (1 - t) * mxt) + noise * (t + (1 - t) * mxt) / ((1 - t) * mj.clamp(min=1e-4))
    not_pj_not_pxt_fact = (mxt - mj).clamp(min=0) / (2 * (1 - t) * mxt)
    
    # Marginal rate matrix
    rm = pj_fact * e1_probs + noise * pxt + not_pj_not_pxt_fact * (1 - e1_probs - pxt)
    # Handle stationary transitions
    rm = rm * (1 - edge_attr_oh) - edge_attr_oh * (rm * (1 - edge_attr_oh)).sum(dim=1)[:, None]

    return rm


def sample_edges_marginal(cfg, xt, e1_probs_dict, t, dt, noise, pmin=0):

    if t >= 1.0 - dt:
        noise = 0

    # Densify directed edges
    one_way_edge_idx = xt.edge_index[0] < xt.edge_index[1]
    xt_edge_idx_all, xt_edge_attr_all = torch_sparse.coalesce(
        torch.cat([xt.edge_index[:, one_way_edge_idx], xt.triu_edge_index], dim=1),
        torch.cat([xt.edge_attr[one_way_edge_idx], xt.edge_attr.new_zeros(xt.triu_edge_index.size(1))], dim=0),
        xt.num_nodes, xt.num_nodes,
        op="max"
    )
    
    # Do not denoise edges whose time index is already 1
    t = xt.t_e[:int(0.5 * len(xt.t_e))]
    non_one_indices = t != 1
    xt_edge_attr = xt_edge_attr_all[non_one_indices]
    
    edge_attr_oh = F.one_hot(xt_edge_attr, num_classes=2)
    mj = torch.tensor([1 - cfg.edge_ratio, cfg.edge_ratio]).repeat(len(xt_edge_attr), 1).to(cfg.device)
    mxt = (edge_attr_oh * mj).sum(dim=1)[:, None]

    rm = compute_edge_rate_matrix_marginal(mj, mxt, edge_attr_oh, e1_probs_dict['e1_probs'][non_one_indices], 
                                           t[non_one_indices, None], noise)
    # Unconditional rate matrix in case of classifier-free guidance
    if 'e1_probs_uncond' in e1_probs_dict.keys():
        rm_uncond = compute_edge_rate_matrix_marginal(mj, mxt, edge_attr_oh, e1_probs_dict['e1_probs_uncond'][non_one_indices], 
                                                      t[non_one_indices, None], noise)
        rm = (rm.clamp(min=0) ** cfg.gt.guidance_strength) * (rm_uncond.clamp(min=1e-5) ** (1 - cfg.gt.guidance_strength))
        rm = rm * (1 - edge_attr_oh) - edge_attr_oh * (rm * (1 - edge_attr_oh)).sum(dim=1)[:, None]

    # Ensure all probs are positive and sum to one by adjusting dt if necessary
    dt_min = (1 - pmin) / (rm * (1 - edge_attr_oh)).sum(dim=1)[:, None]
    adjusted_dt = torch.min(torch.cat([dt_min, dt_min.new_full(dt_min.shape, dt)], dim=1), dim=1)[0]

    probs = edge_attr_oh + rm * adjusted_dt[:, None]

    categorical_dist = Categorical(probs=probs)
    samples = categorical_dist.sample((1,)).squeeze()
    # samples = ensure_minimal_edge_count(samples, xt)

    # Edges at t=1 are unchanged, others are replaced by new ones samples from the cat dist. 
    xt_edge_attr_all[non_one_indices] = samples

    return xt_edge_idx_all, xt_edge_attr_all
    # return xt_edge_idx, samples


def get_guided_rates(cfg, xt, xt_oh, classifier, rm, y_guidance, forbid_states):
    """
    Implements (exact) classifier guidance for discrete flows.
    """

    node_types = torch.tensor([e for e in np.arange(cfg.dataset.nnode_types) if e not in forbid_states])

    # Classifier forward - all single-node transitions p(y|x_{t+dt},t)
    cls_input = []
    for i in range(len(xt)):
        g = xt.get_example(i).clone()

        t_x_i, t_x_o = xt._slice_dict['x'][i].item(), xt._slice_dict['x'][i + 1].item()
        t_x = xt.t_x[t_x_i: t_x_o]
        t_f = xt.t_f[t_x_i: t_x_o]
        t_e_i, t_e_o = xt._slice_dict['triu_edge_index'][i].item(), xt._slice_dict['triu_edge_index'][i + 1].item()
        t_e = xt.t_e[2 * t_e_i: 2 * t_e_o]
        g = add_full_rrwp(g.clone(), walk_length=cfg.posenc_RRWP.ksteps)

        num_nodes = g.num_nodes
        
        for node_idx in range(num_nodes):
            rep = g.x.repeat(len(node_types), 1)
            modif_idx = node_idx + np.arange(len(node_types)) * num_nodes
            rep[modif_idx] = node_types[:, None].float().to(cfg.device)

            for ntype_idx, _ in enumerate(node_types):
                new_g = g.clone()
                new_g.x = rep[ntype_idx * num_nodes: (ntype_idx + 1) * num_nodes, :]
                new_g.t_x = t_x
                new_g.t_f = t_f
                new_g.t_e = t_e
                for attr in ['rrwp_index', 'rrwp_val', 'rrwp', 'log_deg', 'deg']:
                    new_g[attr] = g[attr]
                new_g['graph_idx'] = torch.tensor(i).to(cfg.device)
                cls_input.append(new_g)
        
    cls_input = Batch.from_data_list(cls_input)

    # Broadcast y_guidance over graphs
    all_spec = y_guidance[cls_input.graph_idx.cpu()]

    with torch.no_grad():
        pred, _ = classifier(cls_input.clone())
    pred_probs = F.softmax(pred, dim=-1)
    pred_probs = torch.stack([pred_probs[range(len(pred)), i, all_spec[range(len(pred)), i]] for i in range(all_spec.shape[1])], dim=1)
    pred_probs = pred_probs.mean(dim=-1, keepdim=True)

    # Classifier forward - y_guidance probabilities if remaining in the same state p(y|x_t,t)
    with torch.no_grad():
        pred_xt, _ = classifier(xt.clone())
    pred_xt_probs = F.softmax(pred_xt, dim=-1)
    pred_xt_probs = torch.stack([pred_xt_probs[range(len(pred_xt)), i, y_guidance[range(len(pred_xt)), i]] for i in range(y_guidance.shape[1])], dim=1)
    # pred_xt_probs = pred_xt_probs[range(len(pred_xt)), 0, y_guidance[range(len(pred_xt)), 0]]
    pred_xt_probs = pred_xt_probs.mean(dim=-1, keepdim=True)

    # Compute the likelihood ratio
    prob_ratio = pred_probs / pred_xt_probs[cls_input.graph_idx.cpu()]

    # Add missing transitions (forbidden states)
    rm_guidance = rm.new_ones(rm.shape)
    rm_guidance[:, node_types] = prob_ratio.view(-1, len(node_types))

    # Finally, bias the rate matrix towards higher posterior y probabilities
    guided_rm = (rm_guidance ** cfg.gt.guidance_strength) * rm
    guided_rm = guided_rm * (1 - xt_oh) - xt_oh * (guided_rm * (1 - xt_oh)).sum(dim=1)[:, None]
    return guided_rm


def compute_valids(batch, cfg, train_loader=None):
    
    valid_circuits, valid_graphs, has_isolated_nodes, valid_sim = 0, 0, 0, 0
    generated_graphs, simulation_out, valid_specs = [], [], []
    compute_accuracy = cfg.gt.conditional_gen and (cfg.gt.conditioning_loss == 'cfg')
    for i in range(batch.num_graphs):  # Number of graphs in the batch
        graph = batch.get_example(i).to('cpu').clone()
        i_graph = torch_geometric_to_igraph(graph)
        # i_graph = gym_to_igraph(graph)
        # If supernodes
        if cfg.gt.get('conditional_gen', False) and cfg.gt.get('supernode', False):
            i_graph.delete_vertices(len(i_graph.vs['type']) - 1)
        if cfg.gt.get('conditional_gen', False) and (not cfg.gt.get('sizing', False)):
            i_graph.vs['feat'] = (np.random.rand(len(i_graph.vs['type'])) * 100 + 1).astype(int).tolist()

        # Check if the graph can be simulated
        val_sim, sim_out = is_valid_sim(i_graph, with_pins=cfg.dataset.get("use_pins", False))
        valid_sim += val_sim
        if sim_out[0] is not None:
            simulation_out.append(sim_out)
            if graph.get('y', None) is not None:
                valid_specs.append(graph.y.cpu().numpy())
    
        # Get the degree of each node (number of connections)
        degrees = i_graph.degree()
        
        # Check if any node has a degree of 0 (isolated)
        has_isolated_nodes = has_isolated_nodes + any(degree == 0 for degree in degrees)
    
        valid_graph = is_graph_valid(i_graph)
        if valid_graph:
            valid_graphs += 1
            i_graph.valid_graph = True
        else:
            i_graph.valid_graph = False

        valid_circuit = our_is_valid_circuit(i_graph)
        if valid_circuit:
            valid_circuits += 1
            i_graph.valid_circuit = True
        else:
            i_graph.valid_circuit = False

        generated_graphs.append(i_graph)

    out_dict = dict()
    if compute_accuracy:
        out_dict.update({'accuracy': 0.0})
        if len(simulation_out) > 1:
            simulation_out_bins = simul_outputs_to_bin_idx(np.array(simulation_out), nbins=cfg.gnn.n_bins)
            accuracy = np.mean(np.array([accuracy_score(simulation_out_bins[:, i].cpu().numpy(), np.array(valid_specs)[:, 0, i]) \
                                         for i in range(simulation_out_bins.shape[1])]))
            # accuracy_gain = accuracy_score(simulation_out_bins[:, 0].tolist(), np.array(valid_specs)[:, 0, 0])
            # accuracy_bw = accuracy_score(simulation_out_bins[:, 1].tolist(), np.array(valid_specs)[:, 0, 1])
            reformat = lambda x: round(float(x), cfg.round)
            out_dict.update({'accuracy': reformat(accuracy)})

    uniqueness = unique_ratio(generated_graphs)
    if train_loader:
        novelty = novelty_ratio(generated_graphs, train_loader)
        VUN = compute_VUN(generated_graphs)
        out_dict.update({'valid_circuits': valid_circuits / (i + 1), 'valid_graphs': valid_graphs / (i + 1), 
                'valid_sim': valid_sim / (i + 1), 'uniqueness': uniqueness, 'novelty': novelty, 'VUN': VUN})
    else:
        out_dict.update({'valid_circuits':valid_circuits / (i + 1), 'valid_graphs': valid_graphs / (i + 1), 
                'valid_sim': valid_sim / (i + 1), 'uniqueness': uniqueness})
        
    return out_dict


def is_valid_sim(graph, with_pins):
    
    out_graph = graph.copy()
    try:
        # Sanity check
        sim = simulation(to_dag(out_graph), with_pins=with_pins)
        # sim = simulation(out_graph, features=features, with_pins=with_pins)
        return 1, (np.round(float(sim.gain[0] / 100), 3), np.round(float(sim.ugw / 1e9), 3), np.round(float(sim.pm / 90), 3))
    except:
        return 0, (None, None, None)


def get_simulation_outputs(graph):
    try:
        sim = simulation(to_dag(graph.copy()), 'default')
        return (np.round(float(sim.gain[0] / 100), 3), np.round(float(sim.ugw / 1e9), 3), np.round(float(sim.pm / 90), 3))
    except:
        return (0.0, 0.0, 0.0)


def compute_accuracy(batch, classifier, cfg):

    # Transform batch.x to one hot
    batch.x = F.one_hot(batch.x[:, 0], num_classes=cfg.dataset.nnode_types).float()
    batch.edge_attr = batch.edge_attr[:, None].float()

    # Add RRWP & update batch t_x and t_e
    rrwp_batch = add_full_rrwp(batch.clone(), walk_length=cfg.posenc_RRWP.ksteps)

    # Spec prediction 
    pred, gt = classifier(rrwp_batch.clone())

    argmax = pred.argmax(dim=-1)
    accuracy = np.mean(np.array([accuracy_score(argmax[:, i].cpu().numpy(), gt[:, i].cpu().numpy()) for i in range(argmax.shape[1])]))
    reformat = lambda x: round(float(x), cfg.round)
    return {
        'accuracy': reformat(accuracy)
    }


# def eval_simulation(batch, y): # Do we still need this?
#     for i in range(batch.num_graphs):  # Number of graphs in the batch
#         graph = batch.get_example(i).to('cpu')
#         i_graph = torch_geometric_to_igraph(graph)
#         sim = simulation(i_graph, y[i].numpy())
#         gain = sim.gain[0]
        

def eval_inference(model, num_samples, euler_steps, y_guidance=None, n_nodes=None, xt=None,
                   current_t_x=0, current_t_e=0):
    model.eval()
    with torch.no_grad():
        denoised_batch = inference(model, num_samples=num_samples, euler_steps=euler_steps, noise_e=0, noise_x=0, 
                                   n_pow_e=6, n_pow_x=4, n_pow_f=4, y_guidance=y_guidance, n_nodes=n_nodes,
                                   xt=xt, current_t_x=current_t_x, current_t_e=current_t_e) # --> uniqueness depends on num_samples!
    out_dict = compute_valids(denoised_batch, cfg=model.cfg)

    return out_dict


def eval_inference_sizing(model, batch, euler_steps=20, n_pow=1):
    model.eval()
    with torch.no_grad():
        denoised_batch = inference(model, euler_steps=euler_steps, n_pow=n_pow, xt=batch)
    sim_out = compute_valids(denoised_batch, cfg=model.cfg)
    return sim_out


def preprocess_batch_y(cfg, batch):
    """
    Encodes y features of a batch of circuits using RBF functions. The centroids of the RBFs are retrieved from the config.
    """

    features = batch.y / torch.tensor(cfg.y_std)
    n_feats = features.shape[-1]

    # Flatten and repeat for distance calculation
    flat_centroids = torch.cat([torch.tensor(feat_centroids).repeat(len(features)) for feat_centroids in cfg.kmeans_centroids])
    flat_y_normed = torch.cat([torch.tensor(features[:, i]).repeat_interleave(cfg.gt.n_rbf_centroids) for i in range(n_feats)])
    distances = torch.abs(flat_y_normed - flat_centroids).view(n_feats, len(features), cfg.gt.n_rbf_centroids)

    # Gaussian kernels with temperature coefficients
    temp = torch.tensor([0.5, 1, 0.1])
    unnorm_rbf = torch.exp(-distances * 0.5 / temp[:, None, None])
    rbf = unnorm_rbf / unnorm_rbf.sum(dim=-1, keepdim=True)

    batch.c_init = rbf[[cfg.gt.conditional_dim]].transpose(0, 1).float()

    return batch


def prune(cfg, batch):
    
    updated_graphs = []

    for i in range(batch.num_graphs):
        graph = batch.get_example(i)

        # Isolated nodes
        adj = SparseTensor.from_edge_index(graph.edge_index, None, sparse_sizes=(graph.num_nodes, graph.num_nodes))
        keep_nodes = adj.sum(dim=0) > 0

        # Prunable node type
        if cfg.gt.node_pruning == 2:
            keep_nodes = keep_nodes & (graph.x[:, 0] != cfg.dataset.nnode_types - 1)

        # Ensures there's at least one node per graph, even if it's the "prunable" class
        if not keep_nodes.any():
            keep_nodes[0] = True

        # Discard edges connecting isolated / prunable nodes and re-index
        new_edges = adj[keep_nodes, keep_nodes]
        row, col, edge_attr = new_edges.t().coo()
        edge_index = torch.stack([row, col], dim=0)

        graph.x = graph.x[keep_nodes]
        graph.edge_index = edge_index
        graph.edge_attr = edge_attr
        updated_graphs.append(graph)

    batch = Batch.from_data_list(updated_graphs)
    
    return batch


def prune_end_nodes_batch(batch):
    
    updated_graphs = []

    for i in range(batch.num_graphs):
        graph = batch.get_example(i)
        updated_graphs.append(prune_end_nodes(graph.clone()))

    batch = Batch.from_data_list(updated_graphs)
    
    return batch


def prune_end_nodes(graph):

    # Nodes that have only 1 neighbor
    adj = SparseTensor.from_edge_index(graph.edge_index, None, sparse_sizes=(graph.num_nodes, graph.num_nodes))
    keep_nodes = adj.sum(dim=0) > 1

    # Keep In / Out nodes
    keep_nodes = keep_nodes | (graph.x[:, 0] == 8) | (graph.x[:, 0] == 9)

    # Ensures there's at least one node per graph, even if it's the "prunable" class
    if not keep_nodes.any():
        keep_nodes[0] = True

    # Discard edges connecting prunable nodes and re-index
    new_edges = adj[keep_nodes, keep_nodes]
    row, col, edge_attr = new_edges.t().coo()
    edge_index = torch.stack([row, col], dim=0)

    graph.x = graph.x[keep_nodes]
    graph.edge_index = edge_index
    graph.edge_attr = edge_attr
    
    return graph


def preprocess_pins_topo_model_out(g):
    '''
    Reproduces the steps of preprocess_pins and preprocess_edges from the master loader. In particular, (1) map node types to the (larger) pin-level
    node type dictionnary, (2) add edges between parent nodes (e.g. central N/PMOS) and pin nodes, (3) add placeholder edges between pin nodes
    and parent nodes' neighboring nodes, (4), suppress parent-neighbors edges.
    '''

    infer_pin_pred = True
    if not np.array([ID_TO_NAME_NODES[t] in node2pins.keys() for t in g.vs['type']]).any():
        infer_pin_pred = False 

    edge_index = torch.tensor(g.get_edgelist(), dtype=torch.long).t().contiguous()
    edge_attr = torch.ones((edge_index.size(1),))
    
    learnable_edge_index = []
    x_to_append = []
    num_nodes = len(g.vs['type'])
    parents_to_neighbors = []
    
    
    for parent_node, t in enumerate(g.vs['type']):
        
        parent_node_typename = ID_TO_NAME_NODES[t]
        
        if parent_node_typename in node2pins.keys():
            
            # n_neighbors = len(node2pins[parent_node_typename])
            pin_types = torch.tensor([NAME_TO_ID_PINS[tname] for tname in node2pins[parent_node_typename]])
            pin_indices = torch.from_numpy(np.arange(num_nodes, num_nodes + len(pin_types)))
            num_nodes += len(pin_indices)
    
            # Append to graph.x
            x_to_append.append(pin_types[:, None])
    
            # Add edges from parent node to pins and from pins to neighbors
            parent_to_pins = torch.stack([torch.full((len(pin_indices),), fill_value=parent_node), pin_indices])
            unique_neighbors = torch.tensor(np.unique(g.neighbors(parent_node)))
    
            # Connect pins to all neighbors in the prior --> can be challenged
            pins_to_all_neighbors = torch.stack([unique_neighbors.repeat_interleave(len(pin_indices)), pin_indices.repeat(len(unique_neighbors))])
            new_edges = torch.cat([parent_to_pins, pins_to_all_neighbors], dim=1)
            
            # Edges between parent node and neighbors will be suppressed
            parents_to_neighbors.append(torch.stack([torch.tensor(parent_node).repeat(len(unique_neighbors)), 
                                                     unique_neighbors]))
    
            edge_index = torch.cat([edge_index, new_edges], dim=1)
            edge_attr = torch.cat([edge_attr, edge_attr.new_ones(new_edges.size(1))])
    
            ## Add learnable edges argument to the graph
            # learnable_edge_index.append(pins_to_all_neighbors)
            learnable_edge_index.append([pin_indices.tolist(), unique_neighbors.tolist()])    
    
    # Then, out of the for loop
    # if len(learnable_edge_index) > 0:
    #     learnable_edge_index = torch.cat(learnable_edge_index, dim=1)
    # else:
    #     learnable_edge_index = torch.tensor([])
    
    # Suppress edges between parent node and neighbors
    if len(parents_to_neighbors) > 0:
        parents_to_neighbors = torch.cat(parents_to_neighbors, dim=1)
        nnode_coalesce = (edge_index.max() + 1).item()
        new_idx, new_attr = torch_sparse.coalesce(
            torch.cat([edge_index, torch.cat([parents_to_neighbors, parents_to_neighbors.flip(dims=[0])], dim=1)], dim=1),
            torch.cat([edge_attr, torch.zeros(2 * parents_to_neighbors.size(1))], dim=0),
            nnode_coalesce, nnode_coalesce,
            op="min"
        )
        keep_indices = (new_attr > 0) & (new_idx[0, :] < new_idx[1, :])
        edge_index = new_idx[:, keep_indices]
        edge_attr = new_attr[keep_indices]

    # Flip edges
    edge_index = torch.cat([edge_index, torch.flip(edge_index, dims=[0])], dim=1)
    edge_attr = torch.cat([edge_attr, edge_attr], dim=0)

    # Add triu edge index
    row, col = torch.triu_indices(num_nodes, num_nodes, offset=1)
    all_connections = torch.stack((row, col), dim=0)

    # # Finally add triu_learnable_edge_attr
    # if len(learnable_edge_index) > 0:
    #     adj_triu_size, n_learnable_edges = all_connections.size(1), learnable_edge_index.size(1)
    #     triu_edge_index, triu_learnable_edge_attr = torch_sparse.coalesce(
    #         torch.cat([all_connections, learnable_edge_index], dim=1), 
    #         torch.cat([torch.zeros(adj_triu_size), torch.ones(n_learnable_edges)]), num_nodes, num_nodes,
    #         op="max"
    #     )
    # else:
    #     triu_learnable_edge_attr = torch.tensor([])
    
    # Update x
    x = torch.tensor([NAME_TO_ID_PINS[ID_TO_NAME_NODES[t]] for t in g.vs['type']])[:, None]
    if len(x_to_append) > 0:
        x = torch.cat([x, torch.cat(x_to_append, dim=0)], dim=0)
    
            
    # data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr.long(), learnable_edge_index=learnable_edge_index, triu_edge_index=all_connections,
    #            triu_learnable_edge_attr=triu_learnable_edge_attr)
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr.long(), pin_neighbors=learnable_edge_index, triu_edge_index=all_connections)

    return data, infer_pin_pred


def hungarian_assignment(edge_proba):
    '''
    Takes as input a matrix of shape n_pins x n_neighbors that gives the assigment probabilities of a central node's pins to its neighbors.
    The assignment problem supposes to link each pin to exactly one neighbor, and each neighbor to at least one pin.
    Return a list of tuples (pin_id, neighbor_id).
    '''

    N, M = edge_proba.shape
    max_rep = N - M

    assert max_rep >= 0, 'Too many neighboring nodes.'

    # Apply constraints: 1 neighbor per pin, > 1 pin per neighbor
    repeats = [torch.ones(M)]
    for i in range(max_rep):
        repeats = [r + F.one_hot(torch.tensor(i), num_classes=M) for r in repeats for i in range(M)]
    repeats = set([tuple(r.numpy()) for r in repeats])
    repeats = [torch.tensor(r) for r in repeats]

    # Perform assignment for each entry in the repeats list: gives the optimal repetition of neighbor nodes
    labels = torch.arange(M)
    c, assignments = [], []
    for r in repeats:
        r_labels = labels.repeat_interleave(r.long())
        cost_matrix = -edge_proba.repeat_interleave(r.long().to(edge_proba.device), dim=1)
        row_ind, col_ind = linear_sum_assignment(cost_matrix.cpu().numpy())
        c.append(cost_matrix[row_ind, col_ind].sum())
        assignments.append([(i, r_labels[o]) for (i, o) in zip(row_ind, col_ind)])
    
    best_assignment = assignments[torch.argmin(torch.tensor(c))]

    return best_assignment


def analyze_graph(graph):

    g = graph.as_undirected()

    # Check connectivity
    is_connected = g.is_connected()
    
    # Check VSS presence
    # has_vss = "VSS" in graph.vs["type"]
    has_vss = 0 in g.vs["type"]
    
    # Check pin connections
    pin_connection_valid = check_pin_connections(g)
    
    # # Check for isolated nodes
    # has_isolated_nodes = check_isolated_nodes(graph)
    
    # Check if graph meets all validity criteria
    is_valid = has_vss and is_connected and pin_connection_valid # and not has_isolated_nodes
    
    return is_valid


def check_pin_connections(graph):
    """Check if all nodes respect their pin constraints"""

    n_pins_dict = {"C" : 2, "R" : 2, "L" : 2}
    n_pins_dict.update({k: len(v) for (k, v) in node2pins.items()})

    types = [ID_TO_NAME_NODES[t] for t in graph.vs['type']]

    for i, t in enumerate(types):
    # for node in graph.vs:
    #     if node["type"] in node2pins:
        if t in n_pins_dict.keys():

            max_n_pins = n_pins_dict[t]
            neighbors = graph.neighbors(i)
            # neighbors_type = [types[n] for n in neighbors]
            # connection2discard = 0
            # for nt in neighbors_type:
            #     if nt in node2pins:
            #         connection2discard +=1

            if len(neighbors) > max_n_pins:# - connection2discard:
                return False
    return True

def check_isolated_nodes(graph):
    """Check if graph has isolated nodes (degree 0)"""
    degrees = graph.degree()
    return any(d == 0 for d in degrees)