import torch

from utils.config import cfg
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import remove_self_loops

class FullyAdjacentTransform(BaseTransform):
    def __init__(self):
        super(FullyAdjacentTransform).__init__()

    def __call__(self, data):
        if not cfg.use_fa:
            return data

        num_nodes = data.num_nodes
        if cfg.dataset.name in ['COLLAB', 'REDDIT-BINARY', 'IMDB-BINARY'] and data.x is None:
            data.x = torch.ones((num_nodes, 1))

        all_nodes = torch.arange(0, num_nodes)
        fully_edge_index =  torch.cartesian_prod(all_nodes, all_nodes).T
        
        # This will work fine in the batching as long as the attribute name contains edge_index
        fully_edge_index_no_sl, _ = remove_self_loops(fully_edge_index)

        data.expander_edge_index = fully_edge_index_no_sl

        if hasattr(data, 'edge_attr') and data.edge_attr is not None:
            data.expander_edge_attr = torch.zeros(
                (data.expander_edge_index.shape[1], data.edge_attr.shape[1]), 
                dtype=data.edge_attr.dtype
            )

        return data
