import torch
import torch.nn as nn

from torch import Tensor
from typing import Union
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import reset
from torch_geometric.typing import OptPairTensor, OptTensor, Size
from torch_geometric.utils import scatter

from utils import create_activation


class ACM_GIN(MessagePassing):
    def __init__(
        self,
        nn_lowpass: torch.nn.Module,
        nn_highpass: torch.nn.Module,
        nn_fullpass: torch.nn.Module,
        nn_lowpass_proj: torch.nn.Module,
        nn_highpass_proj: torch.nn.Module,
        nn_fullpass_proj: torch.nn.Module,
        nn_mix: torch.nn.Module,
        T: float = 3.0,
        **kwargs,
    ):
        kwargs.setdefault("aggr", "add")
        super().__init__(**kwargs)
        self.nn_lowpass = nn_lowpass
        self.nn_highpass = nn_highpass
        self.nn_fullpass = nn_fullpass
        self.nn_lowpass_proj = nn_lowpass_proj
        self.nn_highpass_proj = nn_highpass_proj
        self.nn_fullpass_proj = nn_fullpass_proj
        self.nn_mix = nn_mix
        self.sigmoid = torch.nn.Sigmoid()
        self.softmax = torch.nn.Softmax(dim=1)
        self.T = T
        self.reset_parameters()

    def reset_parameters(self):
        reset(self.nn_lowpass)
        reset(self.nn_highpass)
        reset(self.nn_fullpass)
        reset(self.nn_lowpass_proj)
        reset(self.nn_highpass_proj)
        reset(self.nn_fullpass_proj)
        reset(self.nn_mix)

    def forward(
        self,
        x: Union[Tensor, OptPairTensor],
        edge_index: Tensor,
        edge_weight: OptTensor = None,
        size: Size = None,
    ) -> Tensor:

        if isinstance(x, Tensor):
            x: OptPairTensor = (x, x)

        # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)
        out = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=size)

        deg = scatter(edge_weight, edge_index[1], 0, out.size(0), reduce="sum")
        deg_inv = 1.0 / deg
        deg_inv.masked_fill_(deg_inv == float("inf"), 0)
        out = deg_inv.view(-1, 1) * out

        x_r = x[1]
        if x_r is not None:
            out_lowpass = (x_r + out) / 2.0
            out_highpass = (x_r - out) / 2.0

        # compute embeddings for each filter
        out_lowpass = self.nn_lowpass(out_lowpass)
        out_highpass = self.nn_highpass(out_highpass)
        out_fullpass = self.nn_fullpass(x_r)
        # compute importance weights per filter
        alpha_lowpass = self.sigmoid(self.nn_lowpass_proj(out_lowpass))
        alpha_highpass = self.sigmoid(self.nn_highpass_proj(out_highpass))
        alpha_fullpass = self.sigmoid(self.nn_fullpass_proj(out_fullpass))
        alpha_cat = torch.concat([alpha_lowpass, alpha_highpass, alpha_fullpass], dim=1)
        alpha_cat = self.softmax(self.nn_mix(alpha_cat / self.T))

        out = alpha_cat[:, 0].view(-1, 1) * out_lowpass
        out = out + alpha_cat[:, 1].view(-1, 1) * out_highpass
        out = out + alpha_cat[:, 2].view(-1, 1) * out_fullpass

        return out

    def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor:
        return edge_weight.view(-1, 1) * x_j

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(nn={self.nn})"


class ACM_GIN_model(nn.Module):
    """ """

    def __init__(
        self, in_dim, out_dim, num_layers, hidden_dim, batchnorm, activation="relu"
    ):
        super(ACM_GIN_model, self).__init__()
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.gnn_batchnorm = batchnorm
        self.out_dim = out_dim

        self.ACM_convs = nn.ModuleList()
        self.nns_lowpass = nn.ModuleList()
        self.nns_highpass = nn.ModuleList()
        self.nns_fullpass = nn.ModuleList()
        self.nns_lowpass_proj = nn.ModuleList()
        self.nns_highpass_proj = nn.ModuleList()
        self.nns_fullpass_proj = nn.ModuleList()
        self.nns_mix = nn.ModuleList()

        self.activation = create_activation(activation)

        for i in range(self.num_layers):
            # projection modules to compute importance weights
            for channel_proj_module in [
                self.nns_lowpass_proj,
                self.nns_highpass_proj,
                self.nns_fullpass_proj,
            ]:
                if i == self.num_layers - 1:
                    channel_proj_module.append(nn.Linear(self.out_dim, 1))
                else:
                    channel_proj_module.append(nn.Linear(self.hidden_dim, 1))
            # weights mixing module as attention mechanism
            self.nns_mix.append(nn.Linear(3, 3))

            # GIN embedding scheme per channel
            if i == 0:
                local_input_dim = in_dim
            else:
                local_input_dim = self.hidden_dim

            if i == self.num_layers - 1:
                local_out_dim = self.out_dim
            else:
                local_out_dim = self.hidden_dim

            for channel_module in [
                self.nns_lowpass,
                self.nns_highpass,
                self.nns_fullpass,
            ]:
                if self.gnn_batchnorm:
                    sequential = nn.Sequential(
                        nn.Linear(local_input_dim, self.hidden_dim),
                        nn.BatchNorm1d(self.hidden_dim),
                        self.activation,
                        nn.Linear(self.hidden_dim, local_out_dim),
                        nn.BatchNorm1d(local_out_dim),
                        self.activation,
                    )
                else:
                    sequential = nn.Sequential(
                        nn.Linear(local_input_dim, self.hidden_dim),
                        self.activation,
                        nn.Linear(self.hidden_dim, local_out_dim),
                        self.activation,
                    )

                channel_module.append(sequential)

            self.ACM_convs.append(
                ACM_GIN(
                    nn_lowpass=self.nns_lowpass[i],
                    nn_highpass=self.nns_highpass[i],
                    nn_fullpass=self.nns_fullpass[i],
                    nn_lowpass_proj=self.nns_lowpass_proj[i],
                    nn_highpass_proj=self.nns_highpass_proj[i],
                    nn_fullpass_proj=self.nns_fullpass_proj[i],
                    nn_mix=self.nns_mix[i],
                )
            )

    def reset_parameters(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                m.reset_parameters()
            elif isinstance(m, nn.BatchNorm1d):
                m.reset_parameters()

    def forward(self, x, edge_index, edge_attr, return_hidden=False):
        outs = []
        for i in range(self.num_layers):
            x = self.ACM_convs[i](x=x, edge_index=edge_index, edge_weight=edge_attr)
            outs.append(x)
        if return_hidden:
            return x, outs
        else:
            return x


if __name__ == "__main__":
    acm_gin = ACM_GIN_model(46, 46, 2, 256, True)
    print(sum(p.numel() for p in acm_gin.parameters() if p.requires_grad))
    print("")
