"""
Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks
https://arxiv.org/abs/1503.00075
"""
import itertools
import time

import mxnet as mx
import networkx as nx
import numpy as np
from mxnet import gluon

import dgl


class _TreeLSTMCellNodeFunc(gluon.HybridBlock):
    def hybrid_forward(self, F, iou, b_iou, c):
        iou = F.broadcast_add(iou, b_iou)
        i, o, u = iou.split(num_outputs=3, axis=1)
        i, o, u = i.sigmoid(), o.sigmoid(), u.tanh()
        c = i * u + c
        h = o * c.tanh()

        return h, c


class _TreeLSTMCellReduceFunc(gluon.HybridBlock):
    def __init__(self, U_iou, U_f):
        super(_TreeLSTMCellReduceFunc, self).__init__()
        self.U_iou = U_iou
        self.U_f = U_f

    def hybrid_forward(self, F, h, c):
        h_cat = h.reshape((0, -1))
        f = self.U_f(h_cat).sigmoid().reshape_like(h)
        c = (f * c).sum(axis=1)
        iou = self.U_iou(h_cat)
        return iou, c


class _TreeLSTMCell(gluon.HybridBlock):
    def __init__(self, h_size):
        super(_TreeLSTMCell, self).__init__()
        self._apply_node_func = _TreeLSTMCellNodeFunc()
        self.b_iou = self.params.get(
            "bias", shape=(1, 3 * h_size), init="zeros"
        )

    def message_func(self, edges):
        return {"h": edges.src["h"], "c": edges.src["c"]}

    def apply_node_func(self, nodes):
        iou = nodes.data["iou"]
        b_iou, c = self.b_iou.data(iou.context), nodes.data["c"]
        h, c = self._apply_node_func(iou, b_iou, c)
        return {"h": h, "c": c}


class TreeLSTMCell(_TreeLSTMCell):
    def __init__(self, x_size, h_size):
        super(TreeLSTMCell, self).__init__(h_size)
        self._reduce_func = _TreeLSTMCellReduceFunc(
            gluon.nn.Dense(3 * h_size, use_bias=False),
            gluon.nn.Dense(2 * h_size),
        )
        self.W_iou = gluon.nn.Dense(3 * h_size, use_bias=False)

    def reduce_func(self, nodes):
        h, c = nodes.mailbox["h"], nodes.mailbox["c"]
        iou, c = self._reduce_func(h, c)
        return {"iou": iou, "c": c}


class ChildSumTreeLSTMCell(_TreeLSTMCell):
    def __init__(self, x_size, h_size):
        super(ChildSumTreeLSTMCell, self).__init__()
        self.W_iou = gluon.nn.Dense(3 * h_size, use_bias=False)
        self.U_iou = gluon.nn.Dense(3 * h_size, use_bias=False)
        self.U_f = gluon.nn.Dense(h_size)

    def reduce_func(self, nodes):
        h_tild = nodes.mailbox["h"].sum(axis=1)
        f = self.U_f(nodes.mailbox["h"]).sigmoid()
        c = (f * nodes.mailbox["c"]).sum(axis=1)
        return {"iou": self.U_iou(h_tild), "c": c}


class TreeLSTM(gluon.nn.Block):
    def __init__(
        self,
        num_vocabs,
        x_size,
        h_size,
        num_classes,
        dropout,
        cell_type="nary",
        pretrained_emb=None,
        ctx=None,
    ):
        super(TreeLSTM, self).__init__()
        self.x_size = x_size
        self.embedding = gluon.nn.Embedding(num_vocabs, x_size)
        if pretrained_emb is not None:
            print("Using glove")
            self.embedding.initialize(ctx=ctx)
            self.embedding.weight.set_data(pretrained_emb)
        self.dropout = gluon.nn.Dropout(dropout)
        self.linear = gluon.nn.Dense(num_classes)
        cell = TreeLSTMCell if cell_type == "nary" else ChildSumTreeLSTMCell
        self.cell = cell(x_size, h_size)
        self.ctx = ctx

    def forward(self, batch, h, c):
        """Compute tree-lstm prediction given a batch.
        Parameters
        ----------
        batch : dgl.data.SSTBatch
            The data batch.
        h : Tensor
            Initial hidden state.
        c : Tensor
            Initial cell state.
        Returns
        -------
        logits : Tensor
            The prediction of each node.
        """
        g = batch.graph
        g = g.to(self.ctx)
        # feed embedding
        embeds = self.embedding(batch.wordid * batch.mask)
        wiou = self.cell.W_iou(self.dropout(embeds))
        g.ndata["iou"] = wiou * batch.mask.expand_dims(-1).astype(wiou.dtype)
        g.ndata["h"] = h
        g.ndata["c"] = c
        # propagate
        dgl.prop_nodes_topo(
            g,
            message_func=self.cell.message_func,
            reduce_func=self.cell.reduce_func,
            apply_node_func=self.cell.apply_node_func,
        )
        # compute logits
        h = self.dropout(g.ndata.pop("h"))
        logits = self.linear(h)
        return logits
