import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from typing import Tuple
from argparse import Namespace
from torch_geometric.utils import degree
from torch_geometric.transforms import BaseTransform
import itertools

class FullyConnected(BaseTransform):
    def __call__(self, data: Data) -> Data:
        nodes = [i for i in range(data.num_nodes)]
        data['edge_index'] = torch.tensor([list(pair) for pair in itertools.product(nodes, nodes)]).long().transpose(0, 1)
        return data
