import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Parameter
from torch_sparse import SparseTensor, set_diag

from torch_geometric.nn import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.utils import add_self_loops, remove_self_loops, softmax


def glorot(value):
    stdv = math.sqrt(6.0 / (value.size(-2) + value.size(-1)))
    value.data.uniform_(-stdv, stdv)


class EFATLinear(MessagePassing):
    def __init__(
            self,
            in_channels,
            out_channels,
            heads=1,
            negative_slope=0.2,
            dropout=0.0,
            add_self_loops=True,
            **kwargs
    ):
        kwargs.setdefault('aggr', 'add')
        super().__init__(node_dim=0, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.negative_slope = negative_slope
        self.dropout = dropout
        self.add_self_loops = add_self_loops

        self.fun = nn.Linear(in_channels, out_channels)
#         self.att = Parameter(torch.Tensor(1, heads, 2 * out_channels))
        self.att = Linear(2 * out_channels * heads, out_channels * heads, bias=True, weight_initializer='glorot')
#         self.bias = Parameter(torch.Tensor(out_channels))
        self._alpha = None

        self.reset_parameters()

    def reset_parameters(self):
#         glorot(self.att)
        self.att.reset_parameters()
        self.fun.reset_parameters()

    def forward(self, x, edge_index, return_attention_weights=None):
        H, C = self.heads, self.out_channels

        # (N, C)
        x_l = self.fun(x).repeat(1,H) # (N, H * C)
#         x_l = x_l.view(-1, H, C) # (N, H, C)
        x_r = x_l

        if self.add_self_loops:
            num_nodes = x_l.size(0)
            num_nodes = min(x_l.size(0), x_r.size(0))
            edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)

        out = self.propagate(edge_index, x=(x_l, x_r), size=None)

        # Expand the Head-Channel dimension and then average over heads
        out = out.view(-1, H, C, *out.size()[2:]).mean(dim=1)
#         out = out.mean(dim=1)

        alpha = self._alpha
        self._alpha = None

        if isinstance(return_attention_weights, bool):
            assert alpha is not None
            if isinstance(edge_index, Tensor):
                return out, (edge_index, alpha)
            elif isinstance(edge_index, SparseTensor):
                return out, edge_index.set_value(alpha, layout='coo')
        else:
            return out

    def message(self, x_j, x_i, index, ptr, size_i):
        # Concatentate the two nodes along Channel dimension
        x = torch.cat((x_i, x_j), dim=-1)
        x = F.leaky_relu(x, self.negative_slope)

#         alpha = (x * self.att).sum(dim=-1)
        alpha = self.att(torch.moveaxis(x, 1, -1))
        alpha = torch.moveaxis(alpha, -1, 1)
        pre_alpha = alpha
        alpha = softmax(alpha, index, ptr, size_i)
        self._alpha = alpha
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)

#         return x_j * alpha.unsqueeze(-1)
        return x_j * alpha

    def __repr__(self):
        return (f'{self.__class__.__name__}({self.fun.__class__.__name__}, '
                f'{self.in_channels}, {self.out_channels}, heads={self.heads})')

