# PyTorch related imports
import torch
from torch.nn import Parameter

from .helper import *
from .message_passing import MessagePassing


class CompGCNConv(MessagePassing):
    def __init__(
        self,
        in_channels,
        out_channels,
        num_rels,
        act=lambda x: x,
        dropout=0.0,
        use_bias=True,
        opn="corr",
    ):
        super(self.__class__, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_rels = num_rels
        self.act = act
        self.device = None

        self.w_loop = get_param((in_channels, out_channels))
        self.w_in = get_param((in_channels, out_channels))
        self.w_out = get_param((in_channels, out_channels))
        self.w_rel = get_param((in_channels, out_channels))
        self.loop_rel = get_param((1, in_channels))

        self.drop = torch.nn.Dropout(dropout)
        self.use_bias = use_bias
        self.opn = opn
        self.bn = torch.nn.BatchNorm1d(out_channels)

        if self.use_bias:
            self.register_parameter("bias", Parameter(torch.zeros(out_channels)))

    def forward(self, x, edge_index, edge_type, rel_embed):
        if self.device is None:
            self.device = edge_index.device

        rel_embed = torch.cat([rel_embed, self.loop_rel], dim=0)
        num_edges = edge_index.size(1) // 2
        num_ent = x.size(0)

        # split into (s,r,o) and (o,r^-1,s) edges
        in_index, out_index = edge_index[:, :num_edges], edge_index[:, num_edges:]
        in_type, out_type = edge_type[:num_edges], edge_type[num_edges:]

        loop_index = torch.stack([torch.arange(num_ent), torch.arange(num_ent)]).to(
            self.device,
        )
        loop_type = torch.full((num_ent,), rel_embed.size(0) - 1, dtype=torch.long).to(
            self.device,
        )

        in_norm = self.compute_norm(in_index, num_ent)  # some weights for (s,r,o)
        out_norm = self.compute_norm(out_index, num_ent)  # some weights for (o,r^-1,s)

        in_res = self.propagate(
            "add",
            in_index,
            x=x,
            edge_type=in_type,
            rel_embed=rel_embed,
            edge_norm=in_norm,
            mode="in",
        )
        loop_res = self.propagate(
            "add",
            loop_index,
            x=x,
            edge_type=loop_type,
            rel_embed=rel_embed,
            edge_norm=None,
            mode="loop",
        )
        out_res = self.propagate(
            "add",
            out_index,
            x=x,
            edge_type=out_type,
            rel_embed=rel_embed,
            edge_norm=out_norm,
            mode="out",
        )
        out = self.drop(in_res) * (1 / 3) + self.drop(out_res) * (1 / 3) + loop_res * (1 / 3)

        if self.use_bias:
            out = out + self.bias
        out = self.bn(out)

        return self.act(out), torch.matmul(rel_embed, self.w_rel)[
            :-1
        ]  # Ignoring the self loop inserted

    def rel_transform(self, ent_embed, rel_embed):
        if self.opn == "corr":
            trans_embed = ccorr(ent_embed, rel_embed)
        elif self.opn == "sub":
            trans_embed = ent_embed - rel_embed
        elif self.opn == "mult":
            trans_embed = ent_embed * rel_embed
        else:
            raise NotImplementedError

        return trans_embed

    def message(self, x_j, edge_type, rel_embed, edge_norm, mode):
        weight = getattr(self, f"w_{mode}")
        rel_emb = torch.index_select(rel_embed, 0, edge_type)
        xj_rel = self.rel_transform(x_j, rel_emb)
        out = torch.mm(xj_rel, weight)

        return out if edge_norm is None else out * edge_norm.view(-1, 1)

    def update(self, aggr_out):
        return aggr_out

    def compute_norm(self, edge_index, num_ent):
        row, col = edge_index  # head, tail
        edge_weight = torch.ones_like(row).float()
        deg = torch.zeros(num_ent, dtype=edge_weight.dtype, device=edge_weight.device)
        deg.scatter_add_(0, row, edge_weight)  # Summing number of weights of the edges
        deg_inv = deg.pow(-0.5)  # D^{-0.5}
        deg_inv[deg_inv == float("inf")] = 0
        norm = deg_inv[row] * edge_weight * deg_inv[col]  # D^{-0.5}

        return norm

    def __repr__(self):
        return f"{self.__class__.__name__}({self.in_channels}, {self.out_channels}, num_rels={self.num_rels})"
