"""Graph decoders."""
import manifolds
import torch.nn as nn
from layers.layers import Linear


class Decoder(nn.Module):
    """
    Decoder abstract class for node classification tasks.
    """

    def __init__(self, c):
        super(Decoder, self).__init__()
        self.c = c

    def decode(self, x, adj):
        if self.decode_adj:
            input = (x, adj)
            probs, _ = self.cls.forward(input)
        else:
            probs = self.cls.forward(x)
        return probs


class LinearDecoder(Decoder):
    """
    MLP Decoder for Hyperbolic node classification models.
    """

    def __init__(self, c, args):
        super(LinearDecoder, self).__init__(c)
        self.manifold = getattr(manifolds, args.manifold)()
        self.input_dim = args.ht_hidden_num
        self.output_dim = args.n_classes
        self.bias = 1
        self.cls = Linear(self.input_dim, self.output_dim, args.dropout, lambda x: x, self.bias)
        self.decode_adj = False

    def decode(self, x, adj):
        # 将最后一层的输出放到切空间中，这里的维度还是中间层的维度
        h = self.manifold.proj_tan0(self.manifold.logmap0(x, c=self.c), c=self.c)
        probs = self.cls.forward(h)
        return probs

    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}, c={}'.format(
                self.input_dim, self.output_dim, self.bias, self.c
        )


model2decoder = {
    'HGCN': LinearDecoder,
}

