import torch
from typing import Optional
from torch import Tensor
from torch_geometric.nn.conv import MessagePassing
from greatx.nn.layers import Sequential
from torch_geometric.typing import (
    Adj,
    OptTensor,
    SparseTensor,
)
from torch_geometric.nn.conv.gcn_conv import gcn_norm


class ProjConv(MessagePassing):
    def __init__(self, improved: bool = False, cached: bool = False,
                 add_self_loops: bool = True, normalize: bool = True,
                 **kwargs):

        kwargs.setdefault('aggr', 'add')
        super().__init__(**kwargs)
        self.improved = improved
        self.cached = cached
        self.add_self_loops = add_self_loops
        self.normalize = normalize

        self._cached_edge_index = None
        self._cached_adj_t = None
        self.reset_parameters()

    def reset_parameters(self):
        self._cached_edge_index = None
        self._cached_adj_t = None


    def forward(self, x: Tensor, edge_index: Adj,
                edge_weight: OptTensor = None) -> Tensor:
        if self.normalize:
            if isinstance(edge_index, Tensor):
                edge_index, edge_weight = gcn_norm(  # yapf: disable
                    edge_index, edge_weight, x.size(self.node_dim),
                    self.improved, self.add_self_loops, self.flow, x.dtype)

            elif isinstance(edge_index, SparseTensor):
                edge_index = gcn_norm(  # yapf: disable
                    edge_index, edge_weight, x.size(self.node_dim),
                    self.improved, self.add_self_loops, self.flow, x.dtype)
        # propagate_type: (x: Tensor, edge_weight: OptTensor)
        out = self.propagate(edge_index, x=x, edge_weight=edge_weight,
                             size=None)
        return out

    def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:
        return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j

    def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
        return spmm(adj_t, x, reduce=self.aggr)


class Propagate(torch.nn.Module):
    r"Propagation Model, similar to GCN but without linear transformation"
    
    def __init__(self, norm: bool = True, nlayers = 2):
        super().__init__()
        conv = []
        for i in range(nlayers):
            conv.append(ProjConv(normalize=norm))
        self.conv = Sequential(*conv)

    def reset_parameters(self):
        self.conv.reset_parameters()

    def forward(self, x, edge_index, edge_weight=None):
        return self.conv(x, edge_index, edge_weight)

