import torch 
import numpy as np
import networkx as nx

from torch import Tensor
from torch.types import Device

# Utils without specific semantics


def create_attention_mask(num_nodes: int, device: Device) -> Tensor:
    """TODO: add doscstring. What do I do?"""
    # TODO: make mask boolean to avoid warning
    seq_length = num_nodes ** 2
    attn_shape = (1, seq_length, seq_length)
    mask = 1 - torch.tril(torch.ones(attn_shape)).type(torch.uint8)
    mask = mask.to(device)
    return mask


def num_dags(output: Tensor, batch_size: int, num_nodes: int) -> int:
    """TODO: add doscstring. What do I do?"""
    num_dags = 0
    for batch_idx in range(batch_size):
        adj_curr = output[batch_idx].reshape(num_nodes, num_nodes).detach().cpu().numpy()
        graph = nx.from_numpy_array(adj_curr, create_using=nx.DiGraph)
        num_dags += nx.is_directed_acyclic_graph(graph)
    return num_dags


def average_out_degree(output: Tensor, batch_size: int, num_nodes: int) -> float:
    """TODO: add doscstring. What do I do?"""
    total_adj_degree = 0
    for batch_idx in range(batch_size):
        adj_curr = output[batch_idx].reshape(num_nodes, num_nodes).detach().cpu().numpy()
        avg_adj_node_degree = np.mean(np.sum(adj_curr, axis=1))
        total_adj_degree += avg_adj_node_degree
    return total_adj_degree / batch_size