import torch.nn as nn
import torch.nn.functional as F
import torch
from torch import Tensor
from torch.nn import Parameter
from torch_scatter import scatter_add

from torch_geometric.loader import DataLoader
from torch_geometric.data import Dataset
from torch_geometric.nn import GCNConv
from torch_geometric.typing import Adj, OptTensor, PairTensor
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import zeros
from torch_geometric.utils import (
    add_remaining_self_loops,
    to_dense_adj,
)
from libs.torch_utils import to_torch_coo_tensor

from torch_geometric.utils.num_nodes import maybe_num_nodes

from torch_sparse import SparseTensor
from scipy import sparse


# @torch.jit._overload
# def exp_gcn_norm(
# edge_index,
# edge_weight=None,
# lam=0.1,
# num_nodes=None,
# improved=False,
# add_self_loops=True,
# flow="source_to_target",
# dtype=None,
# ):
# # type: (Tensor, OptTensor, Optional[int], bool, bool, str, Optional[int]) -> OptPairTensor  # noqa
# pass


# @torch.jit._overload
# def exp_gcn_norm(
# edge_index,
# edge_weight=None,
# lam=0.1,
# num_nodes=None,
# improved=False,
# add_self_loops=True,
# flow="source_to_target",
# dtype=None,
# ):
# # type: (SparseTensor, OptTensor, Optional[int], bool, bool, str, Optional[int]) -> SparseTensor  # noqa
# pass
def fill_diag(adj_t: SparseTensor, fill_value: float):
    """
    Fills the diagonal of a sparse tensor with a given value.
    """
    return adj_t.fill_diag(fill_value)



def compute_dense_A(
    edge_index,
    edge_weight=None,
    num_nodes=None,
    improved=False,
    add_self_loops=True,
    flow="source_to_target",
    dtype=None,
):
    fill_value = 2.0 if improved else 1.0
    is_sparse_tensor = False
    # Get our data into the right format
    if isinstance(edge_index, SparseTensor):
        is_sparse_tensor = True
        adj_t = edge_index
        if not adj_t.has_value():
            adj_t = adj_t.fill_value(1.0, dtype=dtype)
        if add_self_loops:
            adj_t = fill_diag(adj_t, fill_value)
        edge_index = adj_t._indices()
        edge_weight = adj_t._values()
    else:
        assert flow in ["source_to_target", "target_to_source"]
        num_nodes = maybe_num_nodes(edge_index, num_nodes)
        if edge_weight is None:
            edge_weight = torch.ones(
                (edge_index.size(1),), dtype=dtype, device=edge_index.device
            )
        if add_self_loops:
            edge_index, tmp_edge_weight = add_remaining_self_loops(
                edge_index, edge_weight, fill_value, num_nodes
            )
            assert tmp_edge_weight is not None
            edge_weight = tmp_edge_weight

    # Convert to dense and then normalize exponentially
    A = to_dense_adj(edge_index=edge_index, edge_attr=edge_weight).squeeze()
    return A, is_sparse_tensor


def exp_gcn_norm(
    A,
    lam,
    theta0,
    theta1,
    num_nodes=None,
    improved=False,
    add_self_loops=True,
    flow="source_to_target",
    dtype=None,
    is_sparse_tensor=False,
):
    """
    I'm converting to a dense matrix first. This isn't optimal, but is fine
    for initial experimentation.
    """

    expA = torch.matrix_exp(lam * A)
    expD = torch.diag(torch.sum(expA, axis=0).pow_(-0.5))
    expD.masked_fill_(expD == float("inf"), 0)
    expA = expD @ expA @ expD
    expA = theta0 * torch.eye(num_nodes).to(expA.get_device()) + theta1 * expA
    spExpA = expA.to_sparse()
    (edge_index, edge_weight) = spExpA._indices(), spExpA._values()

    if is_sparse_tensor:
        adj_t = to_torch_coo_tensor(edge_index, edge_weight, size=num_nodes)
        return adj_t, None
    else:
        return edge_index, edge_weight