import torch.nn as nn
from torch_geometric.nn.dense.linear import Linear
import torch
from torch import Tensor
from torch.nn import Parameter

from torch_geometric.nn.inits import zeros
from torch_geometric.typing import OptTensor


class ResolventConvLayerViaMatMul_molecules(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        zero_order: bool,
        K_minus: int,
        bias: bool = True,
    ):
        assert zero_order in (True, False)

        super().__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.zero_order = zero_order
        self.K_minus = K_minus
        

        self.lin_zero = Linear(
            in_channels, out_channels, bias=False, weight_initializer='glorot')

        self.lins_minus = torch.nn.ModuleList([
            Linear(
                in_channels, out_channels, bias=False,
                weight_initializer='glorot') for _ in range(K_minus)
        ])
        
        if bias:
            self.bias = Parameter(torch.zeros(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        for lin in self.lins_minus:
            lin.reset_parameters()
        self.lin_zero.reset_parameters()
        zeros(self.bias)

    def forward(self, x: Tensor, 
                edge_index: Tensor,
                edge_attr: Tensor = None,
                ) -> Tensor:
        resolvent_mat_sparse = torch.sparse_coo_tensor(indices=edge_index, values=edge_attr)
        resolvent_mat = resolvent_mat_sparse.to_dense()
        if self.zero_order is True:
            out = self.lin_zero(x)
        else:
            out = torch.zeros_like(self.lin_zero(x))

        if self.K_minus > 0:
            tx_1 = x
            for lin in self.lins_minus:
                tx_2 = resolvent_mat @ tx_1
                out = out + lin.forward(tx_2)
                tx_1 = tx_2
       
        if self.bias is not None:
            out = out + self.bias

        return out
