
import numpy as np
import torch
import torch.nn as nn

import manifolds
import layers.sym_layers as sym_layers



class Encoder(nn.Module):
    """
    Encoder abstract class.
    """

    def __init__(self, c):
        super(Encoder, self).__init__()
        self.c = c

    def encode(self, x, adj):
        if self.encode_graph:
            input = (x, adj)
            output, _ = self.layers.forward(input)
        else:
            output = self.layers.forward(x)
        return output


class SPDGCN(Encoder):
    """
    SPDGCN.
    """

    def __init__(self, c, args):
        super(SPDGCN, self).__init__(c)
        self.manifold = getattr(manifolds, args.manifold)()
        assert args.num_layers > 1
        #self.curvatures.append(self.c)
        self.feat_dim = args.feat_dim
        self.mat_dim = args.dim
        self.mat_feat = round(np.sqrt(self.feat_dim*2))
        self.dim = self.mat_dim * (self.mat_dim + 1) // 2
        self.weight = nn.Linear(self.feat_dim, self.dim)
        torch.nn.init.xavier_uniform_(self.weight.weight)
        torch.nn.init.zeros_(self.weight.bias)
        self.dropout = nn.Dropout(args.dropout)
        self.act = nn.ReLU()
        spd_layers = []
        spd_layers.append(
            sym_layers.SPDGraphConvolution(
                self.manifold, self.mat_dim, self.mat_dim, args
            )
        )
        for i in range(args.num_layers-1):
            spd_layers.append(
                sym_layers.SPDGraphConvolution(
                    self.manifold, self.mat_dim, self.mat_dim, args
                )
            )
        self.layers = nn.Sequential(*spd_layers)
        self.encode_graph = True

    def encode(self, x, adj):
        return super(SPDGCN, self).encode(x, adj)

    def vec2cho2spd(self, u: torch.Tensor) -> torch.Tensor:

        node_feats = self.weight(self.dropout(u))
        pos_node_feats = node_feats
        triu_indices = torch.triu_indices(self.mat_dim, self.mat_dim)

        L = torch.zeros((len(node_feats), self.mat_dim, self.mat_dim),
                                device=node_feats.device, dtype=node_feats.dtype)

        L[:, triu_indices[1], triu_indices[0]] = pos_node_feats

        spd_mat = L @ L.transpose(-1,-2)

        return spd_mat
