import math
from typing import Union, Tuple, Optional
from torch_geometric.typing import PairTensor, Adj, OptTensor, OptPairTensor

import torch
from torch import Tensor
import torch.nn.functional as F
from torch_sparse import SparseTensor, matmul

from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.utils import softmax, subgraph, remove_self_loops, add_self_loops

import pdb

class SparseAttention(MessagePassing):
    _alpha: OptTensor

    def __init__(
        self,
        in_channels: Union[int, Tuple[int, int]],
        out_channels: int,
        num_heads: int = 1,
        qkv_bias: bool = False,
        attn_drop: float = 0.,
        proj_drop: float = 0.,
        edge_dim: Optional[int] = None,
        **kwargs,
    ):
        kwargs.setdefault('aggr', 'add')
        super(SparseAttention, self).__init__(node_dim=0, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_heads = num_heads
        self.attn_drop = attn_drop
        self.edge_dim = edge_dim
        self._alpha = None

        assert self.out_channels%self.num_heads==0
        if isinstance(in_channels, int):
            in_channels = (in_channels, in_channels)

        self.lin_key = Linear(in_channels[0], out_channels, bias=qkv_bias)
        self.lin_query = Linear(in_channels[1], out_channels, bias=qkv_bias)
        self.lin_value = Linear(in_channels[0], out_channels, bias=qkv_bias)
        # self.proj = Linear(in)
        self.lin_edge = self.register_parameter('lin_edge', None)
        self.proj = Linear(out_channels, out_channels)
        self.proj_drop = torch.nn.Dropout(proj_drop)


        self.reset_parameters()

    def reset_parameters(self):
        self.lin_key.reset_parameters()
        self.lin_query.reset_parameters()
        self.lin_value.reset_parameters()
        if self.edge_dim:
            self.lin_edge.reset_parameters()
        self.proj.reset_parameters()
        
    def forward(self, x, key, edge_index: Adj,
                edge_attr: OptTensor = None, return_attention_weights=None):
        r"""
        Args:
            return_attention_weights (bool, optional): If set to :obj:`True`,
                will additionally return the tuple
                :obj:`(edge_index, attention_weights)`, holding the computed
                attention weights for each edge. (default: :obj:`None`)
        """
        H, C = self.num_heads, self.out_channels//self.num_heads
        
        query = self.lin_query(x).view(-1, H, C)
        key = self.lin_key(key).view(-1, H, C)
        value = self.lin_value(key).view(-1, H, C)
        # pdb.set_trace()
        # temp = subgraph(query_idx, edge_index)[0]
        # propagate_type: (query: Tensor, key:Tensor, value: Tensor, edge_attr: OptTensor) # noqa
        out = self.propagate(edge_index, query=query, key=key, value=value,
                             edge_attr=edge_attr, size=None)

        alpha = self._alpha
        self._alpha = None
        # pdb.set_trace()
        out = out.reshape(out.size(0), -1)
        out = self.proj(out)
        out = self.proj_drop(out)


        if isinstance(return_attention_weights, bool):
            assert alpha is not None
            if isinstance(edge_index, Tensor):
                return out, (edge_index, alpha)
            elif isinstance(edge_index, SparseTensor):
                return out, edge_index.set_value(alpha, layout='coo')
        else:
            return out

    def message(self, query_i: Tensor, key_j: Tensor, value_j: Tensor,
                edge_attr: OptTensor, index: Tensor, ptr: OptTensor,
                size_i: Optional[int]):

        if self.lin_edge is not None:
            assert edge_attr is not None
            edge_attr = self.lin_edge(edge_attr).view(-1, self.num_heads,
                                                      self.out_channels)
            key_j += edge_attr

        alpha = (query_i * key_j).sum(dim=-1) / math.sqrt(self.out_channels)
        alpha = softmax(alpha, index, ptr, size_i)
        self._alpha = alpha
        alpha = F.dropout(alpha, p=self.attn_drop, training=self.training)

        out = value_j
        if edge_attr is not None:
            out += edge_attr

        out *= alpha.view(-1, self.num_heads, 1)
        return out

    def __repr__(self):
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels}, num_heads={self.num_heads})')
        


class NormalizedAggr(MessagePassing):
    _cached_edge_index: Optional[Tuple[Tensor, Tensor]]
    _cached_adj_t: Optional[SparseTensor]

    def __init__(self, improved = False, cached = False,
                 add_self_loops = True, normalize = True,
                 **kwargs):

        kwargs.setdefault('aggr', 'add')
        super().__init__(**kwargs)

        self.improved = improved
        self.cached = cached
        self.add_self_loops = add_self_loops
        self.normalize = normalize

        self._cached_edge_index = None
        self._cached_adj_t = None

    def reset_parameters(self):
        self._cached_edge_index = None
        self._cached_adj_t = None

    def forward(self, x, edge_index, edge_weight= None):
        if self.normalize:
            if isinstance(edge_index, Tensor):
                cache = self._cached_edge_index
                if cache is None:
                    edge_index, edge_weight = gcn_norm(  # yapf: disable
                        edge_index, edge_weight, x.size(self.node_dim),
                        self.improved, self.add_self_loops)
                    if self.cached:
                        self._cached_edge_index = (edge_index, edge_weight)
                else:
                    edge_index, edge_weight = cache[0], cache[1]

            elif isinstance(edge_index, SparseTensor):
                cache = self._cached_adj_t
                if cache is None:
                    edge_index = gcn_norm(  # yapf: disable
                        edge_index, edge_weight, x.size(self.node_dim),
                        self.improved, self.add_self_loops)
                    if self.cached:
                        self._cached_adj_t = edge_index
                else:
                    edge_index = cache
        # propagate_type: (x: Tensor, edge_weight: OptTensor)
        out = self.propagate(edge_index, x=x, edge_weight=edge_weight,
                             size=None)

        return out

    def message(self, x_j, edge_weight):
        return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j

    def message_and_aggregate(self, adj_t, x):
        return matmul(adj_t, x, reduce=self.aggr)        


class Meanprop(MessagePassing):
    def __init__(self, add_self_loops=False, **kwargs):
        kwargs.setdefault('aggr', 'mean')
        super().__init__(**kwargs)

        self.add_self_loops = add_self_loops

    def forward(self, x, edge_index, size=None):
        
        if isinstance(x, Tensor):
            x = (x, x)

        if self.add_self_loops==True:
            edge_index, _ = remove_self_loops(edge_index)
            edge_index, _ = add_self_loops(edge_index)
            
        # propagate_type: (x: OptPairTensor)
        out = self.propagate(edge_index, x=x, size=size)

        return out

    def message(self, x_j: Tensor):
        return x_j

    def message_and_aggregate(self, adj_t: SparseTensor,
                              x: OptPairTensor):
        adj_t = adj_t.set_value(None, layout=None)
        return matmul(adj_t, x[0], reduce=self.aggr)

class Interpolate(MessagePassing):
    def __init__(self, add_self_loops=False, **kwargs):
        super().__init__(**kwargs)

        self.add_self_loops = add_self_loops

    def forward(self, x, edge_index, edge_attr=None, size=None, return_attention_weights=None):
        
        if isinstance(x, Tensor):
            x = (x, x)

        if self.add_self_loops:
            edge_index, _ = remove_self_loops(edge_index)
            edge_index, _ = add_self_loops(edge_index)
            
        # propagate_type: (x: OptPairTensor)
        out = self.propagate(edge_index, x=x, alpha=alpha, edge_attr=edge_attr, size=size)

        return out

    def message(self, x_j):
        return x_j

    def aggregate(self, inputs, index, ptr=None, dim_size=None):
        return 
    def message_and_aggregate(self, adj_t: SparseTensor,
                              x: OptPairTensor):
        adj_t = adj_t.set_value(None, layout=None)
        return matmul(adj_t, x[0], reduce=self.aggr)