from typing import Optional

import torch
import torch.nn as nn
from torch.nn import functional as F

from typing import (
    Union, Optional, Dict, List,
    Any, Callable, Final, Tuple
)
from torch_geometric.nn.resolver import (
    activation_resolver)

from torch_geometric.utils import to_dense_adj

from .utils import (
    inverse_sigmoid,
    remove_flipped_duplicates,
    filter_edges,
)


class GraphBuilder(nn.Module):
    def __init__(self,
                 prior_edge_index: torch.Tensor,
                 prior_edge_probs: torch.Tensor,
                 n_nodes: Optional[int] = None,
                 *args, **kwargs):
        """
        Args:
            init_edge_index (Tensor): Initial edge index.
            symmetric (bool): Symmetrize adjacency matrix.
        """
        super(GraphBuilder, self).__init__()
        self.shape = (n_nodes, n_nodes)

        prior_edge_score = inverse_sigmoid(prior_edge_probs)        
        self.edge_values = torch.nn.Parameter(prior_edge_score)
        self.register_buffer("edge_index", prior_edge_index)
        
    
    def forward(self, *args, **kwargs):
        return self.edge_index, torch.sigmoid(self.edge_values)

