
import torch, torch_sparse
import numpy as np
import torch.nn.functional as F
from torch_geometric.data import Data, Batch
from torch.distributions.categorical import Categorical
from torch_sparse import SparseTensor
from .pyspice_utils import simulation
from .utils import reset_slice_dict_edges, torch_geometric_to_igraph, add_full_rrwp
# from torch_geometric.graphgym.config import cfg
from .metric_ocb import  compute_VUN, is_graph_valid, novelty_ratio, our_is_valid_circuit, unique_ratio
from .run_model_eval import remove_single_connection_nodes, to_dag
from .loader.datasets.analogenie_dataset import NAME_TO_ID_NODES


NET_TYPES_ID = [8, 9, 10] # Correspond to In, Out & net for OCB TODO update AnalogGenie (mb?, at least add the new net types)
PIN_TYPES_ID = [6, 7] # TODO update AnalogGenie

NET_TYPES_ID_GENIE = [NAME_TO_ID_NODES[n] for n in ["VDD","VSS","VOUT","VB","net","IB","VCONT","IOUT","IIN","VCM","VREF","IREF","VCLK","VRF","VLO","VIF","VBB" ]]

MEAN_NODE_NUMBER = {'ocb_CktBench101': [12, 2.5], "AnaloGenie_pins" : [106,54], "AnalogGenie": [37,17]} # mean, std


def inference(model, euler_steps, num_samples=100, noise_e=0.1, noise_x=0.1, n_pow_e=8, n_pow_x=8, n_pow_f=1, current_t_x=0, current_t_e=0, 
              current_t_f=0, xt=None, cond_y=None, n_nodes=None):
    '''
    Performs model inference.

    Args:
        model: The trained model
        xt: Current state of the batch (torch_geometric.data.Batch, None if denoising from prior distribution.)
        current_t: Current time value (float, should be < 1.0)
        euler_steps: Number of Euler steps to perform
        noise: Noise level for nodes and edges
        n_pow: Power for time discretization
        cond_y: Optional conditioning data
        
    Returns:
        Updated xt after euler_steps
    '''

    # Conditioning signal goes into y_test
    c_bool = cond_y is not None
    cfg = model.cfg
    # Assert initial x_t is provided if needed
    if current_t_x > 0 or current_t_e > 0 or current_t_f > 0:
        assert xt is not None, 'A initial circuit must be provided if starting from t > 0.'
    if xt is not None:
        num_samples = xt.num_graphs

    # Init steps
    steps = 0
    t_x, t_e, t_f = current_t_x, current_t_e, current_t_f

    # Create time discretization for the remaining time
    remaining_time_x, remaining_time_e, remaining_time_f = 1.0 - current_t_x, 1.0 - current_t_e, 1.0 - current_t_f
    if max(remaining_time_x, remaining_time_e, remaining_time_f) <= 0:
        return 
    
    ### Create time steps for the remaining interval

    ut = torch.linspace(0, 1, euler_steps)

    # Nodes
    t_list_normalized_x = 1 - ((1 - ut) ** n_pow_x)
    # Scale to the remaining time interval
    t_list_x = current_t_x + t_list_normalized_x * remaining_time_x
    dts_x = t_list_x.diff()

    # Edges
    t_list_normalized_e = 1 - ((1 - ut) ** n_pow_e)
    # Scale to the remaining time interval
    t_list_e = current_t_e + t_list_normalized_e * remaining_time_e
    dts_e = t_list_e.diff()

    # Features
    t_list_normalized_f = 1 - ((1 - ut) ** n_pow_f)
    # Scale to the remaining time interval
    t_list_f = current_t_f + t_list_normalized_f * remaining_time_f
    dts_f = t_list_f.diff()

    if (current_t_x == 0) or (current_t_e == 0) or (current_t_f == 0):
        # if cfg.train.noise_feat_only:
        #     xt = sample_features_from_prior(cfg, xt)
        #     num_samples = xt.num_graphs
        # else:
        xt = draw_from_prior(cfg, num_samples=num_samples, cond_y=cond_y, t_x=t_x, t_e=t_e, t_f=t_f, xt=xt, n_nodes=n_nodes)
        if c_bool:
            xt = preprocess_batch_y(cfg, xt)
        xt.to(cfg.device)

    while (t_x < 1.0) | (t_e < 1.0) | (t_f < 1.0):

        dt_x, dt_e, dt_f = dts_x[steps], dts_e[steps], dts_f[steps] # will be 0 if t_* has reached 1
        
        xt = add_full_rrwp(xt.clone(), walk_length=cfg.posenc_RRWP.ksteps)
        # xt.t = torch.full((num_samples,), t).to(cfg.device)
        # if hasattr(cfg.gt, 'sample_separate_t') and cfg.gt.sample_separate_t:
        xt.t_x = torch.full((xt.num_nodes,), t_x).to(cfg.device)
        xt.t_f = torch.full((xt.num_nodes,), t_f).to(cfg.device)
        xt.t_e = torch.full((xt.triu_edge_index.shape[1],), t_e).to(cfg.device)

        # if cfg.dataset.get("task_type", '') == 'pin_prediction':
        #     xt.t_e = xt.t_e * xt.triu_learnable_edge_attr + xt.t_e.new_ones(len(xt.t_e)) * (1 - xt.triu_learnable_edge_attr)
        if hasattr(xt, 'learnable_x'):
            xt.t_x = xt.t_x * xt.learnable_x + xt.t_x.new_ones(len(xt.t_x)) * (1 - xt.learnable_x)
        if hasattr(xt, 'triu_learnable_edge_attr'):
            xt.t_e = xt.t_e * xt.triu_learnable_edge_attr + xt.t_e.new_ones(len(xt.t_e)) * (1 - xt.triu_learnable_edge_attr)
        xt.t_e = xt.t_e.repeat(2)
        
        # Model forward
        with torch.no_grad():
            logits = model(xt.clone()) # (B, D, S)
            logits_dict = {'logits': logits}

        # Unconditional forward for classifier-free guidance
        if c_bool:
            with torch.no_grad():
                logits_uncond = model(xt.clone(), unconditional_prop=1)
                logits_dict.update({'logits_uncond': logits_uncond})

        # Denoise given current state and model output
        xt = update_xt(cfg, xt=xt, logits_dict=logits_dict, t_x=t_x, t_e=t_e, t_f=t_f, dt_x=dt_x,
                       dt_e=dt_e, dt_f=dt_f, noise_nodes=noise_x, noise_edges=noise_e)
        
        # if (cfg.gt.node_pruning > 0) and (t + dt >= 1.0):
        #     xt = prune(cfg, xt)
        if t_x + dt_x >= 1.0:
            xt.x = xt.x.long()
        
        # Advange t only where necessary
        if remaining_time_x > 0:
            t_x = min(1, t_x + dt_x)
        if remaining_time_e > 0:
            t_e = min(1, t_e + dt_e)
        if remaining_time_f > 0:
            t_f = min(1, t_f + dt_f)

        steps += 1

    return xt


def update_xt(cfg, logits_dict, **kwargs):

    x1_probs = F.softmax(logits_dict['logits'].x, dim=-1) # (B, D, S)
    e1_probs = F.softmax(logits_dict['logits'].edge_attr, dim=-1)
    x1_probs_dict, e1_probs_dict = {'x1_probs': x1_probs}, {'e1_probs': e1_probs}
    # Mb device sizing
    if (cfg.dataset.node_features_dim > 1) or cfg.gt.get("sizing", False):
        x1_features = logits_dict['logits'].x_features.clip(min=1e-1, max=cfg.dataset.nnode_features - 1)

    # Conditional gen
    if 'logits_uncond' in logits_dict.keys():
        x1_probs_uncond = F.softmax(logits_dict['logits_uncond'].x, dim=-1) # (B, D, S)
        e1_probs_uncond = F.softmax(logits_dict['logits_uncond'].edge_attr, dim=-1)
        # Mb device sizing
        if (cfg.dataset.node_features_dim > 1) or cfg.gt.get("sizing", False):
            x1_features_uncond = logits_dict['logits_uncond'].x_features.clip(min=0.0, max=cfg.dataset.nnode_features - 1)
            x1_features = cfg.gt.guidance_strength * x1_features + (1 - cfg.gt.guidance_strength) * x1_features_uncond
        if cfg.framework.type == 'vfm':
            x1_probs = cfg.gt.guidance_strength * x1_probs + (1 - cfg.gt.guidance_strength) * x1_probs_uncond
            e1_probs = cfg.gt.guidance_strength * e1_probs + (1 - cfg.gt.guidance_strength) * e1_probs_uncond
        else:
            x1_probs_dict.update({'x1_probs_uncond': x1_probs_uncond})
            e1_probs_dict.update({'e1_probs_uncond': e1_probs_uncond})

    kwargs.update({'x1_features': x1_features if ((cfg.dataset.node_features_dim > 1) or cfg.gt.get("sizing", False)) else None})

    # if cfg.train.noise_feat_only:
    #     return update_xt_sizing(**kwargs)
    if cfg.framework.type == 'vfm':
        kwargs.update({'logits': logits_dict['logits'], 'x1_probs': x1_probs, 'e1_probs': e1_probs})
        return update_xt_vfm(cfg, **kwargs)
    elif cfg.train.prior == 'marginal':
        kwargs.update({'x1_probs_dict': x1_probs_dict, 'e1_probs_dict': e1_probs_dict})
        return update_xt_marginal(cfg, **kwargs)
    

def update_xt_features(xt, x1_features, t, dt):
    if t == 1:
        return xt.x_features
    else:
        t = xt.t_f
        non_one_indices = t != 1
        features = xt.x_features
        new_features = xt.x_features[non_one_indices] * (1 - dt / (1 - t[non_one_indices, None])) + \
            dt * x1_features[non_one_indices] / (1 - t[non_one_indices, None])
        features[non_one_indices] = new_features

        return features


def update_xt_sizing(xt, x1_features, t, dt, **kwargs):
    xt_feats = update_xt_features(xt, x1_features, t, dt)
    xt.x[:, 1:] = xt_feats
    return xt


def sample_features_from_prior(cfg, xt):
    xt.x = xt.x.float()
    x_features_0 = torch.rand((len(xt.x), cfg.dataset.node_features_dim - 1), device=xt.x.device) * cfg.dataset.nnode_features
    xt.x[:, 1:] = x_features_0.clip(min=1e-1)
    return xt


def classifier_free_guidance_continuous(y_cond, y_uncond, guidance_scale=3.0):
    return y_uncond + guidance_scale * (y_cond - y_uncond)


def update_xt_marginal(cfg, xt, x1_probs_dict, e1_probs_dict, x1_features, t_x, t_e, t_f, dt_x,
                       dt_e, dt_f, noise_nodes, noise_edges, **kwargs):

    # Nodes type
    if t_x < 1:
        x1 = sample_nodes_marginal(xt, x1_probs_dict, t_x, dt_x, cfg.dataset.nnode_types, noise_nodes, num_classes=cfg.dataset.nnode_types, pmf=cfg.node_type_pmf,
                                guidance_strength=cfg.gt.guidance_strength, forbid_states=[] if (cfg.dataset.get("use_pins", False) or cfg.dataset.name == "AnalogGenie") else [6, 7])
        xt.x = x1[:, None].float()

    # Device sizes
    if t_f < 1:
        xt_feats = update_xt_features(xt, x1_features, t_f, dt_f) if (cfg.dataset.node_features_dim > 1) or cfg.gt.get("sizing", False) \
            else xt.x.new_zeros((len(xt.x), 1))
        # x = torch.cat([x1[:, None].float(), xt_feats], dim=1)
        xt.x_features = xt_feats

    # Edges
    if t_e < 1:
        xt_edge_idx, xt_edge_attr = sample_edges_marginal(cfg, xt, e1_probs_dict, t_e, dt_e, noise_edges)

        # Prune edges between nodes of the same type - almost useless in practice, the model has learned to prune these edges already
        # if t_e >= 1.0 - dt_e:
        #     node_type = x1.new_zeros(x1.shape)
        #     for net_id in NET_TYPES_ID:
        #         node_type = node_type + (x1 == net_id).int()
        #     for net_id in PIN_TYPES_ID:
        #         node_type = node_type + 2 * (x1 == net_id).int()
        #     i_type = node_type[xt_edge_idx[0, :]]
        #     o_type = node_type[xt_edge_idx[1, :]]
        #     xt_edge_attr[i_type == o_type] = 0

        keep_edges = torch.nonzero(xt_edge_attr > 0)[:, 0]

        xt_edge_idx = xt_edge_idx[:, keep_edges]
        xt_edge_attr = xt_edge_attr[keep_edges]


        xt_edge_idx = torch.cat([xt_edge_idx, torch.flip(xt_edge_idx, dims=[0])], dim=1)
        # Sort edge indices
        sorted_e, sorted_idx = xt_edge_idx.sort()
        xt_edge_idx = torch.stack([sorted_e[0], xt_edge_idx[1, sorted_idx[0]]])

        xt_edge_attr = torch.cat([xt_edge_attr, xt_edge_attr], dim=0)

        xt.edge_index = xt_edge_idx
        xt.edge_attr = xt_edge_attr
        xt = reset_slice_dict_edges(xt)

    return xt


def update_xt_vfm(cfg, xt, logits, x1_probs, e1_probs, x1_features, t_x, t_e, dt_x, dt_e, **kwargs): # TODO update w/ sep t_x / t_e / t_f

    # Nodes
    if t_x + dt_x < 1.0:
        xt.xt_logits = xt.xt_logits * (1 - dt_x / (1 - t_x)) + dt_x * x1_probs / (1 - t_x)
    else:
        # if t = 1, probabilities are (re-)computed from the output from the conditional model only;
        # --> this is specific to the VFM framework which outputs categorical probabilities instead of an actual
        # velocity field, which no longer verify the simplex condition after applying the classifier-free guidance combination
        x1_probs = F.softmax(logits.x, dim=-1) # (B, D, S)
        xt.x[:, 0] = Categorical(probs=x1_probs).sample((1,)).squeeze()

    # Device sizes
    xt_feats = update_xt_features(xt, x1_features, t, dt) if cfg.dataset.node_features_dim > 1 else xt.x.new_zeros((len(xt.x), 1))
    xt.x[:, 1:] = xt_feats

    # Edges
    if t + dt < 1.0:
        xt_edge_idx = xt.edge_index[:, ::2]
        xt_edge_idx = torch.stack([xt_edge_idx.min(dim=0)[0], xt_edge_idx.max(dim=0)[0]])
        xt_edge_attr = xt.edge_attr[::2]
        xt_edge_idx, xt_edge_attr = torch_sparse.coalesce(
            torch.cat([xt_edge_idx, logits.edge_index], dim=1),
            torch.cat([xt_edge_attr * (1 - dt / (1 - t)), dt * e1_probs / (1 - t)], dim=0),
            xt.num_nodes, xt.num_nodes,
            op="sum"
        )
    else:
        e1_probs = F.softmax(logits.edge_attr, dim=-1)
        xt_edge_idx = logits.edge_index
        xt_edge_attr = Categorical(probs=e1_probs).sample((1,)).squeeze()
        
        # Prune edges between nodes of the same type
        node_type = xt.x.new_zeros(xt.x[:, 0].shape)
        for net_id in NET_TYPES_ID:
            node_type = node_type + (xt.x[:, 0] == net_id).int()
        for net_id in PIN_TYPES_ID:
            node_type = node_type + 2 * (xt.x[:, 0] == net_id).int()
        i_type = node_type[xt_edge_idx[0, :]]
        o_type = node_type[xt_edge_idx[1, :]]
        xt_edge_attr[i_type == o_type] = 0

        xt_edge_idx = xt_edge_idx[:, xt_edge_attr > 0]
        xt_edge_attr = xt_edge_attr[xt_edge_attr > 0]

    # Flip and concat with the same ordering as when drawing from p0
    xt_edge_idx = xt_edge_idx.repeat_interleave(2, dim=1)
    xt_edge_idx[:, ::2] = torch.flip(xt_edge_idx[:, ::2], dims=[0])
    xt_edge_attr = xt_edge_attr.repeat_interleave(2, dim=0)

    xt.edge_index = xt_edge_idx
    xt.edge_attr = xt_edge_attr
    xt = reset_slice_dict_edges(xt)

    return xt    


def draw_from_prior(cfg, **kwargs):
    
    if cfg.framework.type == 'vfm':
        return generate_random_graph_logits(cfg, **kwargs)
    elif cfg.train.prior == 'masked':
        return generate_masked_graph(cfg, **kwargs)
    elif cfg.train.prior == 'marginal':
        return generate_marginal_graph(cfg, **kwargs)


def generate_marginal_graph(cfg, num_samples, cond_y, t_x, t_e, t_f, xt, n_nodes):
    
    # Generate a random number of nodes
    if n_nodes is None:
        mu_nnodes, std_nnodes = MEAN_NODE_NUMBER[cfg.dataset.name]
        min_nnodes, max_nnodes = mu_nnodes - int(2 * std_nnodes), mu_nnodes + int(2 * std_nnodes)
        nnodes = (np.random.normal(size=num_samples) * std_nnodes + mu_nnodes).astype(int).clip(min=min_nnodes, max=max_nnodes)
    elif type(n_nodes) == int:
        nnodes = [n_nodes] * num_samples
    else:
        # Then it must be a list of length num_samples
        nnodes = n_nodes
    batch_list = []
    for i in range(num_samples):
        
        data_dict = {}
        # Nodes
        if t_x == 0:
            # If xt is provided then sample only additional node - for a full denoising then xt must be None or t_x > 0.
            num_nodes = nnodes[i] if xt is None else nnodes[i] - xt.get_example(i).num_nodes # will throw an error if num_nodes if too small.
            probs = torch.tensor(cfg.node_type_pmf).to(torch.device(cfg.device))
            categorical_dist = Categorical(probs=probs)
            x = categorical_dist.sample((num_nodes, 1)).float()
            if xt is not None:
                x_i = xt.get_example(i)
                anchor_x = torch.cat([x_i.x.new_ones(len(x_i.x),), x_i.x.new_zeros(len(x),)])
                data_dict.update({'learnable_x': 1 - anchor_x})
                x = torch.cat([x_i.x, x], dim=0)
                num_nodes = len(x)
        else:
            x_i = xt.get_example(i)
            num_nodes = x_i.num_nodes
            x = x_i.x
        data_dict.update({'x': x})

        # Features
        if t_f == 0:
            if (cfg.dataset.node_features_dim > 1) or cfg.gt.get("sizing", False):
                x_features = torch.rand((num_nodes, 1), device=cfg.device) * cfg.dataset.nnode_features
                x_features = x_features.clip(min=1e-1)
            else:
                x_features = None
        else:
            x_i = xt.get_example(i)
            x_features = x_i.x_features
        data_dict.update({'x_features': x_features})

        # Edges
        if t_e == 0:
            row, col = torch.triu_indices(num_nodes, num_nodes, offset=1)
            all_connections = torch.stack((row, col), dim=0)
            x0_edge_attr = (torch.rand(all_connections.size(1)) < cfg.edge_ratio).long()
            if not x0_edge_attr.any(): # Ensure not all edges are zero (messes with slice_dict)
                x0_edge_attr[np.random.randint(len(x0_edge_attr))] = 1

            # If xt is provided then two use cases: circuit completion if new nodes were added, \
            # in which case existing edges are preserved, else link prediction.
            if (xt is not None) and (xt.get_example(i).num_nodes < num_nodes):
                x_i = xt.get_example(i)
                directed_index = x_i.edge_index[0] < x_i.edge_index[1]
                all_connections, anchor_e = torch_sparse.coalesce(
                    torch.cat([all_connections.to(x_i.edge_index.device), x_i.edge_index[:, directed_index]], dim=1),
                    torch.cat([x_i.edge_index.new_zeros(len(x0_edge_attr)), x_i.edge_attr[directed_index]], dim=0),
                    num_nodes, num_nodes,
                    op="max"
                )
                x0_edge_attr = (anchor_e + x0_edge_attr.to(x_i.edge_index.device)).clamp(max=1)
                data_dict.update({'triu_learnable_edge_attr': 1 - anchor_e})

            # Suppress 0 edges, then flip
            x0_edge_idx = all_connections[:, x0_edge_attr > 0].clone()
            x0_edge_attr = x0_edge_attr[x0_edge_attr > 0]

            x0_edge_idx = torch.cat([x0_edge_idx, torch.flip(x0_edge_idx, dims=[0])], dim=1)
            x0_edge_attr = torch.cat([x0_edge_attr, x0_edge_attr], dim=0)
        else:
            x_i = xt.get_example(i)
            x0_edge_idx= x_i.edge_index
            x0_edge_attr= x_i.edge_attr
            all_connections= x_i.triu_edge_index
        data_dict.update({'edge_index': x0_edge_idx, 'edge_attr': x0_edge_attr, 'triu_edge_index': all_connections})
        if cond_y is not None:
            data_dict.update({'y': cond_y[[i]]})

        # Create a graph data object
        graph = Data(**data_dict)
        batch_list.append(graph)

    data_batch = Batch.from_data_list(batch_list)
    
    return data_batch


def generate_masked_graph(cfg, num_samples):

    nnodes = (np.random.normal(size=num_samples) * 2 + 10).astype(int).clip(min=4, max=16)
    batch_list = []
    for i in range(num_samples):
        # Generate a random number of nodes between 8 and 14
        num_nodes = nnodes[i]
        
        # Generate all possible edges
        row, col = torch.meshgrid(torch.arange(num_nodes), torch.arange(num_nodes))
        mask = row != col
        row, col = row[mask], col[mask]
        
        # Combine row and col into edges
        edges = torch.stack([row, col], dim=0)
        # Create an edge attribute tensor initialized to the specified value
        edge_attributes = torch.full((edges.size(1),), cfg.dataset.nedge_types - 1)
        
        # Create a full masked node feature tensor
        x = torch.full((num_nodes, cfg.dataset.node_features_dim), cfg.dataset.nnode_features - 1)
        x[:, 0] = cfg.dataset.nnode_types - 1

        # Finally, add upper triangular connections (useful for edge remasking)
        row, col = torch.triu_indices(num_nodes, num_nodes, offset=1)
        all_connections = torch.stack((row, col), dim=0)

        # Create a graph data object
        graph = Data(
            x=x,  # Node features filled with the masked value
            edge_index=edges,                      # Edge indices
            edge_attr=edge_attributes,             # Edge attributes
            triu_edge_index=all_connections
        )
        batch_list.append(graph)

    data_batch = Batch.from_data_list(batch_list)
    
    return data_batch


def generate_random_graph_logits(cfg, num_samples, cond_y, t_x, t_e, xt):

    # Generate a random number of nodes
    mu_nnodes, std_nnodes = MEAN_NODE_NUMBER[cfg.dataset.name]
    if cfg.dataset.get("use_pins", False) and 'ocb' in cfg.dataset.name:
        mu_nnodes += 7
        std_nnodes += 1
    min_nnodes, max_nnodes = mu_nnodes - int(2 * std_nnodes), mu_nnodes + int(2 * std_nnodes)
    nnodes = (np.random.normal(size=num_samples) * std_nnodes + mu_nnodes).astype(int).clip(min=min_nnodes, max=max_nnodes)
    batch_list = []
    for i in range(num_samples):

        if t_x == 0:
            num_nodes = nnodes[i]

            # Nodes
            x0 = torch.randn((num_nodes, cfg.dataset.nnode_types)) * 2 + 1

            # Device types and sizes
            x_node = torch.zeros(num_nodes, 1).float()
            # x_feature = torch.randint(low=1, high=cfg.dataset.nnode_features, 
            #                           size=(len(x_node), cfg.dataset.node_features_dim - 1))
            x_feature = torch.rand((len(x_node), cfg.dataset.node_features_dim - 1), device=cfg.device) * cfg.dataset.nnode_features
            x = torch.cat([x_node, x_feature.clip(min=1e-1)], dim=1)
        else:
            x_i = xt.get_example(i)
            x = x_i.x
            num_nodes = x_i.num_nodes
            x0 = x_i.xt_logits

        if t_e == 0:
            # Edges
            row, col = torch.triu_indices(num_nodes, num_nodes, offset=1)
            all_connections = torch.stack((row, col), dim=0)
            x0_edge_attr = torch.randn((all_connections.size(1), 2)) * 2 + 1
            
            # Flip and concat - it is important to preserve the ordering when adding later with model prediction
            x0_edge_idx = all_connections.repeat_interleave(2, dim=1)
            x0_edge_idx[:, ::2] = torch.flip(x0_edge_idx[:, ::2], dims=[0])
            x0_edge_attr = x0_edge_attr.repeat_interleave(2, dim=0)
        else:
            x_i = xt.get_example(i)
            x0_edge_idx = x_i.edge_index
            x0_edge_attr = x_i.edge_attr
            all_connections = x_i.triu_edge_index

        # Create a graph data object
        graph = Data(
            x=x, xt_logits=x0, edge_index=x0_edge_idx,
            edge_attr=x0_edge_attr, triu_edge_index=all_connections, y=cond_y[[i]] if cond_y is not None else None
        )
        batch_list.append(graph)

    data_batch = Batch.from_data_list(batch_list)

    return data_batch

def compute_node_rate_matrix_marginal(xt_oh, x1_probs, t, S, noise, pmf, forbid_states=[6, 7]):

    mj = torch.tensor(pmf).repeat(len(x1_probs), 1).to(xt_oh.device)
    mxt = (xt_oh * mj).sum(dim=1)[:, None]
    pxt = (xt_oh * x1_probs).sum(dim=1)[:, None]

    pj_fact = (1 - mj + mxt) / (S * (1 - t) * mxt.clamp(min=1e-4)) + noise * (t + (1 - t) * mxt) / ((1 - t) * mj.clamp(min=1e-4))
    not_pj_not_pxt_fact = (mxt - mj).clamp(min=0) / (S * (1 - t) * mxt.clamp(min=1e-4))
    
    # Marginal rate matrix
    rm = pj_fact * x1_probs + noise * pxt + not_pj_not_pxt_fact * (1 - x1_probs - pxt)
    rm[:, forbid_states] = 0
    # Handle stationary transitions
    rm = rm * (1 - xt_oh) - xt_oh * (rm * (1 - xt_oh)).sum(dim=1)[:, None]

    return rm


def sample_nodes_marginal(xt, x1_probs_dict, t, dt, S, noise, num_classes, pmf, forbid_states=[6, 7], pmin=0, guidance_strength=2):

    if t >= 1.0 - dt:
        noise = 0

    t = xt.t_x
    non_one_indices = t != 1

    x = xt.x[:, 0].long()

    xt_oh = F.one_hot(x[non_one_indices], num_classes=num_classes)
    rm = compute_node_rate_matrix_marginal(xt_oh, x1_probs_dict['x1_probs'][non_one_indices], t[non_one_indices, None], 
                                           S, noise, pmf, forbid_states)
    # Unconditional rate matrix in case of classifier-free guidance
    if 'x1_probs_uncond' in x1_probs_dict.keys():
        rm_uncond = compute_node_rate_matrix_marginal(xt_oh, x1_probs_dict['x1_probs_uncond'][non_one_indices], 
                                                      t[non_one_indices, None], S, noise, pmf, forbid_states)
        rm = (rm.clamp(min=0) ** guidance_strength) * (rm_uncond.clamp(min=1e-5) ** (1 - guidance_strength))
        rm = rm * (1 - xt_oh) - xt_oh * (rm * (1 - xt_oh)).sum(dim=1)[:, None]

    # Ensure all probs are positive and sum to one by adjusting dt if necessary
    dt_min = (1 - pmin) / (rm * (1 - xt_oh)).sum(dim=1)[:, None]
    adjusted_dt = torch.min(torch.cat([dt_min, dt_min.new_full(dt_min.shape, dt)], dim=1), dim=1)[0]

    probs = xt_oh + rm * adjusted_dt[:, None]
    categorical_dist = Categorical(probs=probs)
    samples = categorical_dist.sample((1,)).squeeze()

    x[non_one_indices] = samples

    return x


def compute_edge_rate_matrix_marginal(mj, mxt, edge_attr_oh, e1_probs, t, noise):

    pxt = (edge_attr_oh * e1_probs).sum(dim=1)[:, None]

    pj_fact = (1 - mj + mxt) / (2 * (1 - t) * mxt) + noise * (t + (1 - t) * mxt) / ((1 - t) * mj.clamp(min=1e-4))
    not_pj_not_pxt_fact = (mxt - mj).clamp(min=0) / (2 * (1 - t) * mxt)
    
    # Marginal rate matrix
    rm = pj_fact * e1_probs + noise * pxt + not_pj_not_pxt_fact * (1 - e1_probs - pxt)
    # Handle stationary transitions
    rm = rm * (1 - edge_attr_oh) - edge_attr_oh * (rm * (1 - edge_attr_oh)).sum(dim=1)[:, None]

    return rm


def sample_edges_marginal(cfg, xt, e1_probs_dict, t, dt, noise, pmin=0):

    if t >= 1.0 - dt:
        noise = 0

    # Densify directed edges
    one_way_edge_idx = xt.edge_index[0] < xt.edge_index[1]
    xt_edge_idx_all, xt_edge_attr_all = torch_sparse.coalesce(
        torch.cat([xt.edge_index[:, one_way_edge_idx], xt.triu_edge_index], dim=1),
        torch.cat([xt.edge_attr[one_way_edge_idx], xt.edge_attr.new_zeros(xt.triu_edge_index.size(1))], dim=0),
        xt.num_nodes, xt.num_nodes,
        op="max"
    )
    
    # Do not denoise edges whose time index is already 1
    t = xt.t_e[:int(0.5 * len(xt.t_e))]
    non_one_indices = t != 1
    xt_edge_attr = xt_edge_attr_all[non_one_indices]
    
    edge_attr_oh = F.one_hot(xt_edge_attr, num_classes=2)
    mj = torch.tensor([1 - cfg.edge_ratio, cfg.edge_ratio]).repeat(len(xt_edge_attr), 1).to(cfg.device)
    mxt = (edge_attr_oh * mj).sum(dim=1)[:, None]

    rm = compute_edge_rate_matrix_marginal(mj, mxt, edge_attr_oh, e1_probs_dict['e1_probs'][non_one_indices], 
                                           t[non_one_indices, None], noise)
    # Unconditional rate matrix in case of classifier-free guidance
    if 'e1_probs_uncond' in e1_probs_dict.keys():
        rm_uncond = compute_edge_rate_matrix_marginal(mj, mxt, edge_attr_oh, e1_probs_dict['e1_probs_uncond'][non_one_indices], 
                                                      t[non_one_indices, None], noise)
        rm = (rm.clamp(min=0) ** cfg.gt.guidance_strength) * (rm_uncond.clamp(min=1e-5) ** (1 - cfg.gt.guidance_strength))
        rm = rm * (1 - edge_attr_oh) - edge_attr_oh * (rm * (1 - edge_attr_oh)).sum(dim=1)[:, None]

    # Ensure all probs are positive and sum to one by adjusting dt if necessary
    dt_min = (1 - pmin) / (rm * (1 - edge_attr_oh)).sum(dim=1)[:, None]
    adjusted_dt = torch.min(torch.cat([dt_min, dt_min.new_full(dt_min.shape, dt)], dim=1), dim=1)[0]

    probs = edge_attr_oh + rm * adjusted_dt[:, None]

    categorical_dist = Categorical(probs=probs)
    samples = categorical_dist.sample((1,)).squeeze()
    # samples = ensure_minimal_edge_count(samples, xt)

    # Edges at t=1 are unchanged, others are replaced by new ones samples from the cat dist. 
    xt_edge_attr_all[non_one_indices] = samples

    return xt_edge_idx_all, xt_edge_attr_all
    # return xt_edge_idx, samples


def compute_valids(batch, cfg, train_loader=None):
    
    valid_circuits, valid_graphs, has_isolated_nodes, valid_sim = 0, 0, 0, 0
    generated_graphs, simulation_out = [], []
    for i in range(batch.num_graphs):  # Number of graphs in the batch
        graph = batch.get_example(i).to('cpu')
        i_graph = torch_geometric_to_igraph(graph)

        # if cfg.train.noise_feat_only:
        #     simulation_out.append(get_simulation_outputs(i_graph))
        #     continue

        # Check if the graph can be simulated
        valid_sim += is_valid_sim(i_graph, cfg.dataset.get("use_pins", False))
    
        # Get the degree of each node (number of connections)
        degrees = i_graph.degree()
        
        # Check if any node has a degree of 0 (isolated)
        has_isolated_nodes = has_isolated_nodes + any(degree == 0 for degree in degrees)
    
        valid_graph = is_graph_valid(i_graph)
        if valid_graph:
            valid_graphs += 1
            i_graph.valid_graph = True
        else:
            i_graph.valid_graph = False

        valid_circuit = our_is_valid_circuit(i_graph)
        if valid_circuit:
            valid_circuits += 1
            i_graph.valid_circuit = True
        else:
            i_graph.valid_circuit = False

        generated_graphs.append(i_graph)

    # if cfg.train.noise_feat_only:
    #     return np.array(simulation_out).mean(axis=0)

    uniqueness = unique_ratio(generated_graphs)
    if train_loader:
        novelty = novelty_ratio(generated_graphs, train_loader)
        VUN = compute_VUN(generated_graphs)
        return valid_circuits / (i + 1), valid_graphs / (i + 1), valid_sim / (i + 1), uniqueness, novelty, VUN
    else:
        return valid_circuits / (i + 1), valid_graphs / (i + 1), valid_sim / (i + 1), uniqueness


def is_valid_sim(graph, with_pins):
    
    out_graph = graph.copy()
    try:
        # Sanity check
        sim = simulation(to_dag(out_graph), 'default', with_pins=with_pins)
        return 1
    except:
        return 0


def get_simulation_outputs(graph):
    try:
        sim = simulation(to_dag(graph.copy()), 'default')
        return (np.round(float(sim.gain[0] / 100), 3), np.round(float(sim.pm / 90), 3), np.round(float(sim.ugw / 1e9), 3))
    except:
        return (0.0, 0.0, 0.0)


# def eval_simulation(batch, y): # Do we still need this?
#     for i in range(batch.num_graphs):  # Number of graphs in the batch
#         graph = batch.get_example(i).to('cpu')
#         i_graph = torch_geometric_to_igraph(graph)
#         sim = simulation(i_graph, y[i].numpy())
#         gain = sim.gain[0]
        

def eval_inference(model, num_samples, euler_steps, noise, n_pow_e, n_pow_x, cond_y=None):
    model.eval()
    with torch.no_grad():
        denoised_batch = inference(model, num_samples=num_samples, euler_steps=euler_steps, 
                                   noise_e=noise, noise_x=noise, n_pow_e=n_pow_e, n_pow_x=n_pow_x, cond_y=cond_y) # --> uniqueness depends on num_samples!
    valid_circuits, valid_graphs, valid_sim, uniqueness = compute_valids(denoised_batch, cfg=model.cfg)
    return {'valid_circuits': valid_circuits, 'valid_graphs': valid_graphs, 'valid_sim': valid_sim, 'uniqueness': uniqueness}


def eval_inference_sizing(model, batch, euler_steps=20, n_pow=1):
    model.eval()
    with torch.no_grad():
        denoised_batch = inference(model, euler_steps=euler_steps, n_pow=n_pow, xt=batch)
    sim_out = compute_valids(denoised_batch, cfg=model.cfg)
    return sim_out


def preprocess_batch_y(cfg, batch):
    """
    Encodes y features of a batch of circuits using RBF functions. The centroids of the RBFs are retrieved from the config.
    """

    features = batch.y / torch.tensor(cfg.y_std)
    n_feats = features.shape[-1]

    # Flatten and repeat for distance calculation
    flat_centroids = torch.cat([torch.tensor(feat_centroids).repeat(len(features)) for feat_centroids in cfg.kmeans_centroids])
    flat_y_normed = torch.cat([torch.tensor(features[:, i]).repeat_interleave(cfg.gt.n_rbf_centroids) for i in range(n_feats)])
    distances = torch.abs(flat_y_normed - flat_centroids).view(n_feats, len(features), cfg.gt.n_rbf_centroids)

    # Gaussian kernels with temperature coefficients
    temp = torch.tensor([0.5, 1, 0.1])
    unnorm_rbf = torch.exp(-distances * 0.5 / temp[:, None, None])
    rbf = unnorm_rbf / unnorm_rbf.sum(dim=-1, keepdim=True)

    batch.c_init = rbf[[cfg.gt.conditional_dim]].transpose(0, 1).float()

    return batch


def prune(cfg, batch):
    
    updated_graphs = []

    for i in range(batch.num_graphs):
        graph = batch.get_example(i)

        # Isolated nodes
        adj = SparseTensor.from_edge_index(graph.edge_index, None, sparse_sizes=(graph.num_nodes, graph.num_nodes))
        keep_nodes = adj.sum(dim=0) > 0

        # Prunable node type
        if cfg.gt.node_pruning == 2:
            keep_nodes = keep_nodes & (graph.x[:, 0] != cfg.dataset.nnode_types - 1)

        # Ensures there's at least one node per graph, even if it's the "prunable" class
        if not keep_nodes.any():
            keep_nodes[0] = True

        # Discard edges connecting isolated / prunable nodes and re-index
        new_edges = adj[keep_nodes, keep_nodes]
        row, col, edge_attr = new_edges.t().coo()
        edge_index = torch.stack([row, col], dim=0)

        graph.x = graph.x[keep_nodes]
        graph.edge_index = edge_index
        graph.edge_attr = edge_attr
        updated_graphs.append(graph)

    batch = Batch.from_data_list(updated_graphs)
    
    return batch


def prune_end_nodes_batch(batch):
    
    updated_graphs = []

    for i in range(batch.num_graphs):
        graph = batch.get_example(i)
        updated_graphs.append(prune_end_nodes(graph.clone()))

    batch = Batch.from_data_list(updated_graphs)
    
    return batch


def prune_end_nodes(graph):

    # Nodes that have only 1 neighbor
    adj = SparseTensor.from_edge_index(graph.edge_index, None, sparse_sizes=(graph.num_nodes, graph.num_nodes))
    keep_nodes = adj.sum(dim=0) > 1

    # Keep In / Out nodes
    keep_nodes = keep_nodes | (graph.x[:, 0] == 8) | (graph.x[:, 0] == 9)

    # Ensures there's at least one node per graph, even if it's the "prunable" class
    if not keep_nodes.any():
        keep_nodes[0] = True

    # Discard edges connecting prunable nodes and re-index
    new_edges = adj[keep_nodes, keep_nodes]
    row, col, edge_attr = new_edges.t().coo()
    edge_index = torch.stack([row, col], dim=0)

    graph.x = graph.x[keep_nodes]
    graph.edge_index = edge_index
    graph.edge_attr = edge_attr
    
    return graph