"""Graph encoders."""
import torch.nn as nn
import manifolds
import layers.hyp_layers as hyp_layers


class Encoder(nn.Module):
    """
    Encoder abstract class.
    """

    def __init__(self, c):
        super(Encoder, self).__init__()
        self.c = c

    def encode(self, x, adj):
        if self.encode_graph:
            input = (x, adj)
            output, _ = self.layers.forward(input)
        else:
            output = self.layers.forward(x)
        return output

class HGCN(Encoder):
    """
    Hyperbolic-GCN.
    """

    def __init__(self, c, args):
        super(HGCN, self).__init__(c)
        self.manifold = getattr(manifolds, args.manifold)()
        assert args.ht_layer_num > 1
        # n_layers = hidden layers + input layer
        args.ht_layer_num = args.ht_layer_num + 1
        self.n_layers = args.ht_layer_num
        dims, acts, self.curvatures = hyp_layers.get_dim_act_curv(args)
        self.curvatures.append(self.c)
        self.layers = nn.ModuleList()
        for i in range(len(dims) - 1):
            c_in, c_out = self.curvatures[i], self.curvatures[i + 1]
            in_dim, out_dim = dims[i], dims[i + 1]
            act = acts[i]
            self.layers.append(
                    hyp_layers.HyperbolicGraphConvolution(
                            self.manifold, in_dim, out_dim, c_in, c_out, args.dropout, act, 1, 0, 0
                    )
            )
        self.encode_graph = True

    def encode(self, x, adj, middle=False):
        x_tan = self.manifold.proj_tan0(x, self.curvatures[0])
        x_hyp = self.manifold.expmap0(x_tan, c=self.curvatures[0])
        x_hyp = self.manifold.proj(x_hyp, c=self.curvatures[0])
        input = (x_hyp, adj)
        middle_feats = []
        for i in range(self.n_layers-1):
            output, _, m = self.layers[i](input)
            # middle embeddings
            middle_feats.append(m)
            input = (output, adj)
        if middle:
            return output, middle_feats
        return output


