import torch as th
import torch.nn as nn
import torch.nn.functional as F

import dgl


class DGLRoutingLayer(nn.Module):
    def __init__(self, in_nodes, out_nodes, f_size, batch_size=0, device="cpu"):
        super(DGLRoutingLayer, self).__init__()
        self.batch_size = batch_size
        self.g = init_graph(in_nodes, out_nodes, f_size, device=device)
        self.in_nodes = in_nodes
        self.out_nodes = out_nodes
        self.in_indx = list(range(in_nodes))
        self.out_indx = list(range(in_nodes, in_nodes + out_nodes))
        self.device = device

    def forward(self, u_hat, routing_num=1):
        self.g.edata["u_hat"] = u_hat
        batch_size = self.batch_size

        # step 2 (line 5)
        def cap_message(edges):
            if batch_size:
                return {"m": edges.data["c"].unsqueeze(1) * edges.data["u_hat"]}
            else:
                return {"m": edges.data["c"] * edges.data["u_hat"]}

        def cap_reduce(nodes):
            return {"s": th.sum(nodes.mailbox["m"], dim=1)}

        for r in range(routing_num):
            # step 1 (line 4): normalize over out edges
            edges_b = self.g.edata["b"].view(self.in_nodes, self.out_nodes)
            self.g.edata["c"] = F.softmax(edges_b, dim=1).view(-1, 1)

            # Execute step 1 & 2
            self.g.update_all(message_func=cap_message, reduce_func=cap_reduce)

            # step 3 (line 6)
            if self.batch_size:
                self.g.nodes[self.out_indx].data["v"] = squash(
                    self.g.nodes[self.out_indx].data["s"], dim=2
                )
            else:
                self.g.nodes[self.out_indx].data["v"] = squash(
                    self.g.nodes[self.out_indx].data["s"], dim=1
                )

            # step 4 (line 7)
            v = th.cat(
                [self.g.nodes[self.out_indx].data["v"]] * self.in_nodes, dim=0
            )
            if self.batch_size:
                self.g.edata["b"] = self.g.edata["b"] + (
                    self.g.edata["u_hat"] * v
                ).mean(dim=1).sum(dim=1, keepdim=True)
            else:
                self.g.edata["b"] = self.g.edata["b"] + (
                    self.g.edata["u_hat"] * v
                ).sum(dim=1, keepdim=True)


def squash(s, dim=1):
    sq = th.sum(s**2, dim=dim, keepdim=True)
    s_norm = th.sqrt(sq)
    s = (sq / (1.0 + sq)) * (s / s_norm)
    return s


def init_graph(in_nodes, out_nodes, f_size, device="cpu"):
    g = dgl.DGLGraph()
    g.set_n_initializer(dgl.frame.zero_initializer)
    all_nodes = in_nodes + out_nodes
    g.add_nodes(all_nodes)
    in_indx = list(range(in_nodes))
    out_indx = list(range(in_nodes, in_nodes + out_nodes))
    # add edges use edge broadcasting
    for u in in_indx:
        g.add_edges(u, out_indx)

    g = g.to(device)
    g.edata["b"] = th.zeros(in_nodes * out_nodes, 1).to(device)
    return g
