from typing import Tuple, Optional
import os.path as osp
import numpy as np
import scipy.sparse as sp
import torch.nn as nn
import torch.nn.functional as F
from .utils import MixedDropout
from torch_geometric.nn.dense.linear import Linear
import torch
from torch import Tensor
from torch.nn import Parameter

from torch_geometric.nn.inits import zeros
from torch_geometric.typing import OptTensor


class ReLUWithIdx(nn.Module):
    def forward(
            self, t_and_idx: Tuple[torch.Tensor, torch.LongTensor]
    ) -> Tuple[torch.Tensor, torch.LongTensor]:
        t, idx = t_and_idx
        return F.relu(t), idx


def calc_laplacian(adj_matrix: sp.spmatrix) -> sp.spmatrix:
    degree_vec = np.asarray(np.sum(adj_matrix, axis=1)).flatten()
    deg = sp.diags(degree_vec)
    return deg - adj_matrix


def calc_norm(mat) -> float:
    return np.linalg.norm(mat, 2)


def calc_resolvent_exact(
        t_matrix: sp.spmatrix,
        omega: float = 1,
        singular_value_normalization: bool = False,
        normalizing_factor: float = 1.0
) -> np.ndarray:
    nnodes = t_matrix.shape[0]
    t_matrix = calc_laplacian(t_matrix).toarray()
   
    if singular_value_normalization is True:
        t_matrix = t_matrix / calc_norm(t_matrix)

    t_matrix = t_matrix / normalizing_factor

    return np.linalg.inv(t_matrix - omega * np.eye(nnodes))


class ResolventPropagate(nn.Module):
    def __init__(
            self,
            adj_matrix: sp.spmatrix,
            omega: float = -1.,
            drop_prob: float = None,
            singular_value_normalization: bool = False,
            normalizing_factor: float = 1.0, 
            resolvent_mat = None):
        super().__init__()

        if resolvent_mat is None:
            resolvent_mat = calc_resolvent_exact(
                t_matrix=adj_matrix,
                omega=omega,
                singular_value_normalization=singular_value_normalization,
                normalizing_factor=normalizing_factor)
        self.register_buffer('mat', None)
        self.mat = torch.from_numpy(resolvent_mat).to(
            dtype=torch.get_default_dtype())
  
        if drop_prob is None or drop_prob == 0:
            self.dropout = lambda x: x
        else:
            self.dropout = MixedDropout(drop_prob)

    def forward(self, predictions: torch.FloatTensor):
        return self.dropout(self.mat) @ predictions


class ResolventConvLayerViaMatMul(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        zero_order: bool,
        K_minus: int,
        adj_matrix: sp.spmatrix,
        omega=-1.,
        normalizing_factor_minus=1,
        singular_value_normalization=False,
        drop_prob: float = 0.0,
        bias: bool = True,
        batch: OptTensor = None,
        resolvent_path: Optional[str] = None,
        **kwargs,
    ):
        assert zero_order in (True, False)
        assert K_minus > 0

        super().__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.zero_order = zero_order
        self.K_minus = K_minus
        
        self.lin_zero = Linear(
            in_channels, out_channels, bias=False, weight_initializer='glorot')

        self.lins_minus = torch.nn.ModuleList([
            Linear(
                in_channels, out_channels, bias=False,
                weight_initializer='glorot') for _ in range(K_minus)
        ])

        if bias:
            self.bias = Parameter(torch.zeros(out_channels))
        else:
            self.register_parameter('bias', None)


        if resolvent_path is None:
            self.propagate = ResolventPropagate(
                adj_matrix,
                omega=omega,
                drop_prob=drop_prob,
                singular_value_normalization=singular_value_normalization,
                normalizing_factor=normalizing_factor_minus)
        else:
            if osp.exists(resolvent_path):
                resolvent_mat = np.load(resolvent_path)
                self.propagate = ResolventPropagate(
                    adj_matrix,
                    omega=omega,
                    drop_prob=drop_prob,
                    singular_value_normalization=singular_value_normalization,
                    normalizing_factor=normalizing_factor_minus,
                    resolvent_mat=resolvent_mat)
            else:
                self.propagate = ResolventPropagate(
                    adj_matrix,
                    omega=omega,
                    drop_prob=drop_prob,
                    singular_value_normalization=singular_value_normalization,
                    normalizing_factor=normalizing_factor_minus)
                resolvent_mat = self.propagate.mat
                np.save(resolvent_path, resolvent_mat)




        self.reset_parameters()

    def reset_parameters(self):
        for lin in self.lins_minus:
            lin.reset_parameters()
        self.lin_zero.reset_parameters()
        zeros(self.bias)

    def forward(self, x: Tensor) -> Tensor:

        if self.zero_order is True:
            out = self.lin_zero.forward(x)
        else:
            out = torch.zeros_like(self.lin_zero(x))

        if self.K_minus > 0:
            tx_1 = x
            for lin in self.lins_minus:
                tx_2 = self.propagate(tx_1)
                out = out + lin.forward(tx_2)
                tx_1 = tx_2
       
        if self.bias is not None:
            out = out + self.bias

        return out
