
import manifolds
import torch
import torch.nn as nn

from geoopt import ManifoldParameter


class Decoder(nn.Module):


    def __init__(self, c):
        super(Decoder, self).__init__()
        self.c = c

    def decode(self, x, adj):
        if self.decode_adj:
            input = (x, adj)
            probs, _ = self.cls.forward(input)
        else:
            probs = self.cls.forward(x)
        return probs


class SPDMLRDecoder(Decoder):
    def __init__(self, c, args):
        super(SPDMLRDecoder, self).__init__(c)
        self.manifold = getattr(manifolds, args.manifold)(dims=args.dim)
        self.rows, self.cols = torch.triu_indices(args.dim,args.dim,device=args.device)
        self.dim_mat = args.dim
        self.cls = torch.nn.Parameter(torch.Tensor(args.n_classes, self.dim_mat, self.dim_mat))
        self.bias = ManifoldParameter(torch.eye(self.dim_mat).repeat(args.n_classes,1,1), manifold=self.manifold)
        self.dropout = nn.Dropout(args.dropout)
        nn.init.xavier_uniform_(self.cls)

    def decode(self, x, adj):
        try:
            xL = torch.linalg.cholesky(x, upper=False).unsqueeze(1).expand(-1,self.cls.size(0),-1,-1)
        except RuntimeError:
            x = self.manifold.clamping(x)
            xL = torch.linalg.cholesky(x, upper=False).unsqueeze(1).expand(-1,self.cls.size(0),-1,-1)
        W = self.sym(self.cls).unsqueeze(0).expand(x.size(0), -1, -1, -1)
        PL = torch.linalg.cholesky(self.bias, upper=False).unsqueeze(0).expand(x.size(0), -1, -1, -1)
        PL_i = torch.inverse(PL)
        PWP = PL_i @ W @ PL_i.transpose(-1,-2)
        W_p = PL @ (torch.tril(PWP, diagonal=-1) + 0.5*torch.diag_embed(torch.diagonal(PWP, dim1=-2, dim2=-1)))

        A = -torch.tril(PL, diagonal=-1) + torch.tril(xL, diagonal=-1)+\
            torch.diag_embed(torch.log(1. / torch.diagonal(PL, dim1=-2, dim2=-1) * torch.diagonal(xL, dim1=-2, dim2=-1)))
        B = torch.tril(W_p, diagonal=-1) + \
            torch.diag_embed(1./ torch.diagonal(PL, dim1=-2, dim2=-1) * torch.diagonal(W_p, dim1=-2, dim2=-1))
        f_inner_product = torch.sum(A * B, dim=(-2,-1))

        return f_inner_product

    def sym(self, x):
        return 0.5 * (x.transpose(-1,-2) + x)

class SPDLinearDecoder(Decoder):
    def __init__(self, c, args):
        super(SPDLinearDecoder, self).__init__(c)
        self.manifold = getattr(manifolds, args.manifold)(dims=args.dim)
        self.rows, self.cols = torch.triu_indices(args.dim,args.dim,device=args.device)
        self.proj = nn.Linear(int(args.dim*(args.dim+1)/2), args.n_classes)
        self.dropout = nn.Dropout(args.dropout)

    def decode(self, x, adj):

        node_feats = x[:,self.rows,self.cols]
        return self.proj(node_feats)

class EucMLRDecoder(Decoder):
    def __init__(self, c, args):
        super(EucMLRDecoder, self).__init__(c)
        self.manifold = getattr(manifolds, args.manifold)(dims=args.dim)
        self.dim_mat = args.dim
        self.cls = torch.nn.Parameter(torch.Tensor(args.n_classes, self.dim_mat, self.dim_mat))
        self.bias = torch.nn.Parameter(torch.eye(self.dim_mat).repeat(args.n_classes,1,1))
        self.dropout = nn.Dropout(args.dropout)
        nn.init.xavier_uniform_(self.cls)

    def decode(self, x, adj):
        try:
            xL = torch.linalg.cholesky(x, upper=False).unsqueeze(1).expand(-1,self.cls.size(0),-1,-1)
        except RuntimeError:
            x = self.manifold.clamping(x)
            xL = torch.linalg.cholesky(x, upper=False).unsqueeze(1).expand(-1,self.cls.size(0),-1,-1)
        W = self.cls.unsqueeze(0).expand(x.size(0), -1, -1, -1)
        PL = self.bias.unsqueeze(0).expand(x.size(0), -1, -1, -1)

        f_inner_product = torch.sum(xL * W + PL, dim=(-2,-1))

        return f_inner_product


model2decoder = {
    'SPDGCN': SPDMLRDecoder
}

# SPDMLRDecoder SPDLinearDecoder EucMLRDecoder

