from typing import Optional

import torch
from torch import Tensor
from torch.nn import Parameter

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import zeros
from torch_geometric.typing import OptTensor
from torch_geometric.utils import get_laplacian
from .utils import calc_norm, get_resolvent

from typing import Optional, Tuple

class HoloConv(MessagePassing):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        K_plus: int,
        K_minus: int,
        edge_index: Tensor,
        omega = -1.,
        L_singular_value_normalization = False,
        R_singular_value_normalization = False,
        normalizing_factor_plus = 1,
        normalizing_factor_minus = 1,
        bias: bool = True,
        edge_weight = None,
        batch: OptTensor = None,
        **kwargs,
    ):
        kwargs.setdefault('aggr', 'add')
        super().__init__(**kwargs)

        assert K_plus + K_minus > 0
        
        # print(edge_index)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.L_singular_value_normalization = L_singular_value_normalization
        self.K_plus = K_plus
        self.K_minus = K_minus


        if K_plus > 0:
            self.lins_plus = torch.nn.ModuleList([
                Linear(in_channels, out_channels, bias=False,
                    weight_initializer='glorot') for _ in range(K_plus)
            ])
        if K_minus > 0:
            self.lins_minus = torch.nn.ModuleList([
                Linear(in_channels, out_channels, bias=False,
                    weight_initializer='glorot') for _ in range(K_minus)
            ])

        if bias:
            self.bias = Parameter(torch.zeros(out_channels))
        else:
            self.register_parameter('bias', None)

        ############# Compute Laplacian: ###################
        L_edges, L_norm = self.__norm__(
            edge_index,
            max(edge_index[0])+1, #### This probably doesn't play nice with batches on multi-graph objectives. original was # x.size(self.node_dim),
            edge_weight,
            batch=batch,
        )
       
        L_unnormalized_edges = L_edges
        L_unnormalized_weights= L_norm
        self.register_buffer('edges', None)
        self.edges = L_edges

        if L_singular_value_normalization == True:
            L_norm = L_norm/calc_norm(self.edges, L_norm)
        if normalizing_factor_plus is not None:
            L_norm = L_norm/normalizing_factor_plus
        self.register_buffer('norm', None)
        self.norm = L_norm





        ########## Compute Resolvent: ######################
        R_edges, R_norm = self.get_resolvent(L_unnormalized_edges, L_unnormalized_weights, omega = omega, R_singular_value_normalization=R_singular_value_normalization, normalizing_factor=normalizing_factor_minus)
        self.register_buffer('R_edges', None)
        self.R_edges = R_edges
        self.register_buffer('R_norm', None)
        self.R_norm = R_norm





        
        ########## Reset Parameters: #################
        self.reset_parameters()





    def reset_parameters(self):
        # super().reset_parameters()
        if self.K_plus > 0:
            for lin in self.lins_plus:
                lin.reset_parameters()
        if self.K_minus > 0:
            for lin in self.lins_minus:
                lin.reset_parameters()
        
        
        
        
        zeros(self.bias)
    
    def get_resolvent(
        self,
        L_unnormalized_edges,
        L_unnormalized_weights,
        omega = -1.,
        R_singular_value_normalization: Optional[bool] = False,
        dtype: Optional[torch.dtype] = None,
        num_nodes: Optional[int] = None,
        normalizing_factor: Optional[float] = None,
    ) -> Tuple[Tensor, OptTensor]:
        # assert omega < 0
        # print(f'These are the edges {L_unnormalized_edges}')
        # print(f' These are the weights {L_unnormalized_weights}')
        # L_indices = get_laplacian(edge_index, edge_weight)
        # print((L_unnormalized_edges.max()+1).tolist())
        # print(L_unnormalized_weights.size()/2)
        L =  torch.sparse_coo_tensor(L_unnormalized_edges, L_unnormalized_weights, size = [(L_unnormalized_edges.max()+1).tolist(),(L_unnormalized_edges.max()+1).tolist()])
        L = L.to_dense()
        # print(L.size())   
        
        if R_singular_value_normalization == True: 
            
            L = L/calc_norm(L_unnormalized_edges, L_unnormalized_weights)
            
        if normalizing_factor is not None:
            L = L/normalizing_factor
        
        identity = torch.eye(L.size(0), device = L_unnormalized_edges.device)
        # T = L - self.omega*identity
        
        T = L - omega*identity
        R = torch.linalg.inv(T).to_sparse()

        Redge_index, Redge_attr = R.indices(), R.values()
        return Redge_index, Redge_attr
    


    def __norm__(
        self,
        edge_index: Tensor,
        num_nodes: Optional[int],
        edge_weight: OptTensor,

        dtype: Optional[int] = None,
        batch: OptTensor = None,
    ):
        
        L_edge_index, L_edge_weight = get_laplacian(edge_index, edge_weight,
                                                None, dtype,
                                                num_nodes)
       
        # assert L_edge_weight is not None
        # L_edge_weight.masked_fill_(L_edge_weight == float('inf'), 0)
        # loop_mask = L_edge_index[0] == L_edge_index[1]
        # L_edge_weight[loop_mask] -= 1
        return L_edge_index, L_edge_weight




    def forward(
        self,
        x: Tensor,
        edge_index: Tensor,
        # edge_weight: OptTensor = None,
        # batch: OptTensor = None,
    ) -> Tensor:
        
        assert self.K_plus + self.K_minus > 0

        
        if self.K_plus > 0:    
            Tx_0 = x
            Tx_1 = x  # Dummy.
        
        
            out = self.lins_plus[0](Tx_0)
            

            # propagate_type: (x: Tensor, norm: Tensor)
            if len(self.lins_plus) > 1:
                Tx_1 = self.propagate(self.edges, x=x, norm=self.norm, size=None)
                out = out + self.lins_plus[1](Tx_1)

            for lin in self.lins_plus[2:]:
                Tx_2 = self.propagate(self.edges, x=Tx_1, norm=self.norm, size=None)
                # Tx_2 = 2. * Tx_2 - Tx_0
                out = out + lin.forward(Tx_2)
                Tx_0, Tx_1 = Tx_1, Tx_2



            if self.K_minus > 0:
            
                Tx_1 = x
                for lin in self.lins_minus:
                    Tx_2 = self.propagate(self.R_edges, x=Tx_1, norm=self.R_norm, size=None)
                    # Tx_2 = 2. * Tx_2 - Tx_0
                    out = out + lin.forward(Tx_2)
                    Tx_1 = Tx_2



        else:
            if self.K_minus > 0:
            
                Tx_1 = x
                Tx_2 = self.propagate(self.R_edges, x=Tx_1, norm=self.R_norm, size=None)
                out = self.lins_minus[0].forward(Tx_2)
                Tx_1 = Tx_2

                for lin in self.lins_minus[1:]:
                    Tx_2 = self.propagate(self.R_edges, x=Tx_1, norm=self.R_norm, size=None)
                    # Tx_2 = 2. * Tx_2 - Tx_0
                    out = out + lin.forward(Tx_2)
                    Tx_1 = Tx_2

       


        if self.bias is not None:
            out = out + self.bias







            
        return out
    







    def message(self, x_j: Tensor, norm: Tensor) -> Tensor:
        return norm.view(-1, 1) * x_j

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels}, K_plus={len(self.lins_plus)}, '
                # f'singular_value_normalization={self.singular_value_normalization})'
                )



