import torch
from torch_geometric.data import Data

from federatedscope.core.mlp import MLP
from federatedscope.gfl.model.gcn import GCN_Net
from federatedscope.gfl.model.sage import SAGE_Net
from federatedscope.gfl.model.gat import GAT_Net
from federatedscope.gfl.model.gin import GIN_Net
from federatedscope.gfl.model.gpr import GPR_Net


class GNN_Net_Link(torch.nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 hidden=64,
                 max_depth=2,
                 dropout=.0,
                 gnn='gcn',
                 layers=2):
        r"""GNN model with LinkPredictor for link prediction tasks.

        Arguments:
            in_channels (int): input channels.
            out_channels (int): output channels.
            hidden (int): hidden dim for all modules.
            max_depth (int): number of layers for gnn.
            dropout (float): dropout probability.
            gnn (str): name of gnn type, use ("gcn" or "gin").
            layers (int): number of layers for LinkPredictor.

        """
        super(GNN_Net_Link, self).__init__()
        self.dropout = dropout

        # GNN layer
        if gnn == 'gcn':
            self.gnn = GCN_Net(in_channels=in_channels,
                               out_channels=hidden,
                               hidden=hidden,
                               max_depth=max_depth,
                               dropout=dropout)
        elif gnn == 'sage':
            self.gnn = SAGE_Net(in_channels=in_channels,
                                out_channels=hidden,
                                hidden=hidden,
                                max_depth=max_depth,
                                dropout=dropout)
        elif gnn == 'gat':
            self.gnn = GAT_Net(in_channels=in_channels,
                               out_channels=hidden,
                               hidden=hidden,
                               max_depth=max_depth,
                               dropout=dropout)
        elif gnn == 'gin':
            self.gnn = GIN_Net(in_channels=in_channels,
                               out_channels=hidden,
                               hidden=hidden,
                               max_depth=max_depth,
                               dropout=dropout)
        elif gnn == 'gpr':
            self.gnn = GPR_Net(in_channels=in_channels,
                               out_channels=hidden,
                               hidden=hidden,
                               K=max_depth,
                               dropout=dropout)
        else:
            raise ValueError(f'Unsupported gnn type: {gnn}.')

        dim_list = [hidden for _ in range(layers)]
        self.output = MLP([hidden] + dim_list + [out_channels],
                          batch_norm=True)

    def forward(self, data):
        if isinstance(data, Data):
            x, edge_index = data.x, data.edge_index
        elif isinstance(data, tuple):
            x, edge_index = data
        else:
            raise TypeError('Unsupported data type!')

        x = self.gnn((x, edge_index))
        return x

    def link_predictor(self, x, edge_index):
        x = x[edge_index[0]] * x[edge_index[1]]
        x = self.output(x)
        return x
