import torch
import torch.nn as nn

from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import zeros
from asdfghjkl.operations import FullBias

from .utils import (
    add_self_loop_adj_sparse,
    get_deg_adj_sparse,
    normalize_adj,
    normalize_adj_sparse
)


class GraphSAGEConv(nn.Module):
    def __init__(self, in_channels: int, out_channels: int,
                 bias: bool = True):
        super(GraphSAGEConv, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.lin1 = nn.Linear(in_channels, out_channels, bias=bias)
        self.lin2 = nn.Linear(in_channels, out_channels, bias=False)
        self.reset_parameters()
            
    def reset_parameters(self):
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, x, adj):
        # mean aggregation
        if adj.is_sparse:
            adj_hat = add_self_loop_adj_sparse(adj)
            neighbor_sum = torch.sparse.mm(adj_hat, x)
            deg = get_deg_adj_sparse(adj_hat).clamp(min=1).unsqueeze(1)
        else:
            neighbor_sum = adj @ x
            deg = adj.sum(dim=1, keepdim=True)
            deg = deg.clamp(min=1)
        neighbor_feats = neighbor_sum / deg

        # transform
        h_neigh = self.lin1(neighbor_feats)
        h_self = self.lin2(x)
        return h_neigh + h_self


class GCNConv(nn.Module):
    def __init__(self, in_channels: int, out_channels: int,
                 bias: bool = True):
        super(GCNConv, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.lin = nn.Linear(
            in_channels, out_channels, bias=False)
        if bias:
            self._bias = FullBias(out_channels)
        else:
            self.register_parameter('_bias', None)
        self.reset_parameters()
            
    def reset_parameters(self):
        nn.init.xavier_uniform_(self.lin.weight)
        if self._bias is not None:
            self._bias.reset_parameters()

    def forward(self, x, adj):
        if adj.is_sparse:
            out = normalize_adj_sparse(adj) @ self.lin(x)
        else:
            out = normalize_adj(adj) @ self.lin(x)
        if self._bias is not None:
            out = self._bias(out)
        return out
