import networkx as nx
import torch
from torch import Tensor

def dodiscover_shd(pred: Tensor, target: Tensor, double_for_anticausal: bool = False):
    diff = torch.abs(target - pred)
    if double_for_anticausal:
        return torch.sum(diff)
    else:
        diff = diff + diff.T
        diff[diff > 1] = 1  # Ignoring the double edges.
        return torch.sum(diff) / 2
    

def dodiscover_shd(pred: Tensor, target: Tensor, double_for_anticausal: bool = False):
    diff = torch.abs(target - pred)
    if double_for_anticausal:
        return torch.sum(diff)
    else:
        diff = diff + diff.T
        diff[diff > 1] = 1  # Ignoring the double edges.
        return torch.sum(diff) / 2
    

def average_shd(output, target, batch_size: int) -> float:
    tot_shd = 0
    for i in range(batch_size):
        tot_shd += dodiscover_shd(pred=output[i], target=target[i], double_for_anticausal=False)
    return tot_shd/batch_size


def num_dags(output: Tensor, batch_size: int, num_nodes: int) -> float:
    num_dags = 0
    for batch_idx in range(batch_size):
        adj_curr = output[batch_idx].reshape(num_nodes, num_nodes).detach().cpu().numpy()
        graph = nx.from_numpy_array(adj_curr, create_using=nx.DiGraph)
        num_dags += nx.is_directed_acyclic_graph(graph)
    return float(num_dags)


def average_out_degree(output: Tensor, batch_size: int, num_nodes: int) -> float:
    total_adj_degree = 0
    for batch_idx in range(batch_size):
        adj_curr = output[batch_idx].reshape(num_nodes, num_nodes)
        avg_adj_node_degree = torch.mean(torch.sum(adj_curr.detach().cpu(), dim=1))
        total_adj_degree += avg_adj_node_degree
    return total_adj_degree / batch_size
