from typing import Type, Callable
import torch
import torch.nn as nn

from .gcn import GCNConv


class FakeBilin(nn.Module):
    def __init__(self, edge_feat_size: int, device: torch.device):
        super().__init__()
        self.bilin_weights = nn.Parameter(torch.empty(edge_feat_size))
        self.reset()
        self.to(device)

    def reset(self):
        nn.init.ones_(self.bilin_weights)

    def forward(self, x_j, norm):
        return (x_j.T * torch.mv(norm, self.bilin_weights)).T


class Net(nn.Module):
    def __init__(
            self, node_feat_size: int, edge_feat_size: int,
            emb_size: int, nlayers: int, aggregate: Type[nn.Module],
            device: torch.device, dropout: float = 0.5,
            act_fn: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
            use_linear: bool = True, bilin_type: str = 'full',
            gcn_aggr: str = 'add', agg_degree_mean: bool = False):
        super().__init__()
        self.emb_size = emb_size
        self.device = device
        self.act_fn = act_fn
        self.aggregate = aggregate

        if bilin_type == 'full':
            self.bilin = nn.Bilinear(emb_size, edge_feat_size, emb_size, bias=False)
        elif bilin_type == 'scale':
            self.bilin = FakeBilin(edge_feat_size, device)
        elif bilin_type == 'none':
            self.bilin = None

        # Input layer
        self.layers = [GCNConv(node_feat_size, emb_size, act_fn, dropout, self.bilin,
                               use_linear=use_linear, aggr=gcn_aggr, agg_mean=False)]

        # Hidden & output layers
        for _ in range(nlayers - 1):
            self.layers.append(GCNConv(emb_size, emb_size, act_fn, dropout, self.bilin,
                                       use_linear=use_linear, aggr=gcn_aggr,
                                       agg_mean=agg_degree_mean))

        # Add fcs-list to module
        for i in range(nlayers):
            self.add_module(f"layer_{i}", self.layers[i])

        self.to(self.device)

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        if not self.bilin:
            edge_attr = None

        node_embeddings = []
        for ilayer, layer in enumerate(self.layers):
            y, x = layer(x, edge_index, edge_weight=edge_attr)
            if ilayer == 0:
                node_embeddings.append(y)
            x = x + y
            node_embeddings.append(x)

        return self.aggregate(node_embeddings)
