import networkx as nx
import torch
from torch import Tensor


def average_shd(output: Tensor, target: Tensor, batch_size: int) -> float:
    common_edges = output.eq(target)
    errors = ~common_edges
    adj_shd = errors.sum() / batch_size
    return adj_shd


def num_dags(output: Tensor, batch_size: int, num_nodes: int) -> float:
    num_dags = 0
    for batch_idx in range(batch_size):
        adj_curr = output[batch_idx].reshape(num_nodes, num_nodes).detach().cpu().numpy()
        graph = nx.from_numpy_array(adj_curr, create_using=nx.DiGraph)
        num_dags += nx.is_directed_acyclic_graph(graph)
    return float(num_dags)


def average_out_degree(output: Tensor, batch_size: int, num_nodes: int) -> float:
    total_adj_degree = 0
    for batch_idx in range(batch_size):
        adj_curr = output[batch_idx].reshape(num_nodes, num_nodes)
        avg_adj_node_degree = torch.mean(torch.sum(adj_curr.detach().cpu(), dim=1))
        total_adj_degree += avg_adj_node_degree
    return total_adj_degree / batch_size
