import numpy as np
from pathlib import Path
import torch.nn as nn
import torch.nn.functional as F
import torch
from torch import Tensor
from torch.nn import Parameter
from torch_scatter import scatter_add

from torch_geometric.loader import DataLoader
from torch_geometric.data import Dataset
from torch_geometric.nn import (
    GATConv,
    GCNConv,
)
from torch_geometric.typing import Adj, OptTensor, PairTensor
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import zeros

from libs.torch_utils import to_edge_index

from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_geometric.utils import to_scipy_sparse_matrix

from torch_sparse import SparseTensor

from libs.filter import (
    band_pass_filter,
    gaussian_wks_filter,
    wks_filter,
)
from libs.normalize import compute_dense_A, exp_gcn_norm


class SpectralFilterGATConv(MessagePassing):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        heads: int,
        e0: float,
        sigma: float,
        cutoff: float,
        improved: bool = False,
        add_self_loops: bool = True,
        bias: bool = True,
        eigen_path: str = None,
        eigen_cache: bool = True,
        use_weights: bool = False,
        concat:bool = True,
        sparsifier: str = 'cutoff',
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.use_weights = use_weights
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.add_self_loops = add_self_loops
        self.improved = improved
        self.eigen_path = eigen_path
        self.eigen_cache = eigen_cache

        self.cutoff = cutoff
        self.e0 = e0
        self.sigma = sigma
        self._conv = GATConv(in_channels, out_channels, heads, bias=bias, concat=concat)
        self.reset_parameters()
        self._eidx = None
        self._ew = None
        self.sparsifier = sparsifier

    def reset_parameters(self):
        self._conv.reset_parameters()

    def filter_laplacian(self, edge_index):
        raise NotImplementedError

    def forward(
        self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None
    ) -> Tensor:
        if self._eidx is None:
            Ltilde = self.filter_laplacian(edge_index)
            Ltilde = Ltilde.to(edge_index.device)
            self._eidx, self._ew = to_edge_index(Ltilde)
            self._eidx = self._eidx.to(edge_index.device)
            if not self.use_weights:
                self._ew = None
            else:
                self._ew = self._ew.to(edge_index.device)
        if self.use_weights:
            return self._conv(x, edge_index=self._eidx, edge_attr=self._ew)
        else:
            return self._conv(x, edge_index=self._eidx)



class FilteredGATConv(SpectralFilterGATConv):
    def filter_laplacian(self, edge_index):
        return band_pass_filter(
            edge_index,
            self.e0,
            self.sigma,
            self.cutoff,
            self.eigen_path,
            self.eigen_cache,
            sparsifier=self.sparsifier,
        )


class WKSGATConv(SpectralFilterGATConv):
    def filter_laplacian(self, edge_index):
        return wks_filter(
            edge_index,
            self.e0,
            self.sigma,
            self.cutoff,
            self.eigen_path,
            self.eigen_cache,
            sparsifier=self.sparsifier,
        )


class GaussianWKSGATConv(SpectralFilterGATConv):
    def filter_laplacian(self, edge_index):
        return gaussian_wks_filter(
            edge_index,
            self.e0,
            self.sigma,
            self.cutoff,
            self.eigen_path,
            self.eigen_cache,
            sparsifier=self.sparsifier,
        )


class SpectralFilterConv(MessagePassing):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        e0: float,
        sigma: float,
        cutoff: float,
        improved: bool = False,
        add_self_loops: bool = True,
        bias: bool = True,
        eigen_path: str = None,
        eigen_cache: bool = True,
        use_weights: bool = False,
        sparsifier: str = 'cutoff',
        **kwargs,
    ):
        kwargs.setdefault("aggr", "add")
        super().__init__(**kwargs)

        self.use_weights = use_weights
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.add_self_loops = add_self_loops
        self.improved = improved
        self.eigen_path = eigen_path
        self.eigen_cache = eigen_cache
        self.lin = Linear(
            in_channels, out_channels, bias=False, weight_initializer="glorot"
        )

        self.cutoff = cutoff
        self.e0 = e0
        self.sigma = sigma
        self.sparsifier = sparsifier

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels), requires_grad=True)
        else:
            self.register_parameter("bias", None)
        self.reset_parameters()
        self._cached_L = None

    def reset_parameters(self):
        self.lin.reset_parameters()
        zeros(self.bias)

    def filter_laplacian(self, edge_index):
        raise NotImplementedError

    def forward(
        self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None
    ) -> Tensor:
        Ltilde = self._cached_L
        if Ltilde is None:
            Ltilde = self.filter_laplacian(edge_index)
            Ltilde = Ltilde.to(edge_index.device)
            if not self.use_weights:
                Ltilde, _ = to_edge_index(Ltilde)
                Ltilde = Ltilde.to(edge_index.device)
            self._cached_L = Ltilde

        x = self.lin(x)
        out = self.propagate(Ltilde, x=x)

        if self.bias is not None:
            out = out + self.bias
        return out


class GaussianWKSConv(SpectralFilterConv):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def filter_laplacian(self, edge_index):
        return gaussian_wks_filter(
            edge_index,
            self.e0,
            self.sigma,
            self.cutoff,
            self.eigen_path,
            self.eigen_cache,
            sparsifier=self.sparsifier,
        )


class WKSConv(SpectralFilterConv):
    def filter_laplacian(self, edge_index):
        return wks_filter(
            edge_index,
            self.e0,
            self.sigma,
            self.cutoff,
            self.eigen_path,
            self.eigen_cache,
            sparsifier=self.sparsifier,
        )


class FilteredConv(SpectralFilterConv):
    def filter_laplacian(self, edge_index):
        return band_pass_filter(
            edge_index,
            self.e0,
            self.sigma,
            self.cutoff,
            self.eigen_path,
            self.eigen_cache,
            sparsifier=self.sparsifier,
        )


class ExpConv(MessagePassing):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        local_temp: bool = False,
        improved: bool = False,
        cached: bool = False,
        add_self_loops: bool = True,
        normalize: bool = True,
        bias: bool = True,
        num_nodes: int = -1,
        **kwargs,
    ):
        kwargs.setdefault("aggr", "add")
        super().__init__(**kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.cached = cached
        self.add_self_loops = add_self_loops
        self.improved = False
        self.normalize = normalize
        # Assert that norm_type is symmetric, left, or right. Add an enum?
        self._cached_adj_t = None
        self._cached_edge_index = None
        self.lin = Linear(
            in_channels, out_channels, bias=False, weight_initializer="glorot"
        )

        if local_temp:
            if num_nodes == -1:
                raise ValueError(
                    "Number of nodes wasn't provided, can't initialize local temperatures"
                )
            self.lam = Parameter(torch.rand(num_nodes))
        else:
            self.lam = Parameter(torch.rand(1))

        self.theta0 = Parameter(torch.rand(1))
        self.theta1 = Parameter(torch.rand(1))

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter("bias", None)

        self.reset_parameters()
        self.A = None

    def reset_parameters(self):
        self.lin.reset_parameters()
        zeros(self.bias)
        self._cached_edge_index = None
        self._cached_adj_t = None

    def forward(
        self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None
    ) -> Tensor:
        # Normalization stuff
        if self.normalize:
            if self.A is None:
                self.A, self.is_sparse_tensor = compute_dense_A(
                    edge_index,
                    edge_weight,
                    x.size(self.node_dim),
                    self.improved,
                    self.add_self_loops,
                    self.flow,
                    x.dtype,
                )

            if isinstance(edge_index, Tensor):
                cache = self._cached_edge_index
                if cache is None:
                    edge_index, edge_weight = exp_gcn_norm(  # yapf: disable
                        self.A,
                        self.lam,
                        self.theta0,
                        self.theta1,
                        x.size(self.node_dim),
                        self.improved,
                        self.add_self_loops,
                        self.flow,
                        x.dtype,
                        self.is_sparse_tensor,
                    )
                    if self.cached:
                        self._cached_edge_index = (edge_index, edge_weight)
                else:
                    edge_index, edge_weight = cache[0], cache[1]

            elif isinstance(edge_index, SparseTensor):
                cache = self._cached_adj_t
                if cache is None:
                    edge_index = gcn_norm(  # yapf: disable
                        edge_index,
                        edge_weight,
                        x.size(self.node_dim),
                        self.improved,
                        self.add_self_loops,
                        self.flow,
                        x.dtype,
                    )
                    if self.cached:
                        self._cached_adj_t = edge_index
                else:
                    edge_index = cache

        x = self.lin(x)

        # propagate_type: (x: Tensor, edge_weight: OptTensor)
        out = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=None)

        if self.bias is not None:
            out = out + self.bias
        return out

    def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:
        return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j

    def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
        return torch.matmul(adj_t, x, reduce=self.aggr)