"""
Gated Graph Sequence Neural Network for sequence outputs
"""

import torch
import torch.nn.functional as F
from torch import nn

from dgl.nn.pytorch import GatedGraphConv, GlobalAttentionPooling


class GGSNN(nn.Module):
    def __init__(
        self,
        annotation_size,
        out_feats,
        n_steps,
        n_etypes,
        max_seq_length,
        num_cls,
    ):
        super(GGSNN, self).__init__()

        self.annotation_size = annotation_size
        self.out_feats = out_feats
        self.max_seq_length = max_seq_length

        self.ggnn = GatedGraphConv(
            in_feats=out_feats,
            out_feats=out_feats,
            n_steps=n_steps,
            n_etypes=n_etypes,
        )

        self.annotation_out_layer = nn.Linear(
            annotation_size + out_feats, annotation_size
        )

        pooling_gate_nn = nn.Linear(annotation_size + out_feats, 1)
        self.pooling = GlobalAttentionPooling(pooling_gate_nn)

        self.output_layer = nn.Linear(annotation_size + out_feats, num_cls)
        self.loss_fn = nn.CrossEntropyLoss(reduction="none")

    def forward(self, graph, seq_lengths, ground_truth=None):
        etypes = graph.edata.pop("type")
        annotation = graph.ndata.pop("annotation").float()

        assert annotation.size()[-1] == self.annotation_size

        node_num = graph.number_of_nodes()

        all_logits = []
        for _ in range(self.max_seq_length):
            zero_pad = torch.zeros(
                [node_num, self.out_feats - self.annotation_size],
                dtype=torch.float,
                device=annotation.device,
            )

            h1 = torch.cat([annotation.detach(), zero_pad], -1)
            out = self.ggnn(graph, h1, etypes)
            out = torch.cat([out, annotation], -1)
            logits = self.pooling(graph, out)
            logits = self.output_layer(logits)
            all_logits.append(logits)

            annotation = self.annotation_out_layer(out)
            annotation = F.softmax(annotation, -1)

        all_logits = torch.stack(all_logits, 1)
        preds = torch.argmax(all_logits, -1)
        if ground_truth is not None:
            loss = sequence_loss(all_logits, ground_truth, seq_lengths)
            return loss, preds
        return preds


def sequence_loss(logits, ground_truth, seq_length=None):
    def sequence_mask(length):
        max_length = logits.size(1)
        batch_size = logits.size(0)
        range_tensor = torch.arange(
            0, max_length, dtype=seq_length.dtype, device=seq_length.device
        )
        range_tensor = torch.stack([range_tensor] * batch_size, 0)

        expanded_length = torch.stack([length] * max_length, -1)
        mask = (range_tensor < expanded_length).float()
        return mask

    loss = nn.CrossEntropyLoss(reduction="none")(
        logits.permute((0, 2, 1)), ground_truth
    )

    if seq_length is None:
        loss = loss.mean()
    else:
        mask = sequence_mask(seq_length)
        loss = (loss * mask).sum(-1) / seq_length.float()
        loss = loss.mean()
    return loss
