
import math
import torch
import torch.nn as nn
from torch.nn.modules.module import Module
from geoopt import ManifoldParameter, Stiefel
from manifolds import SPDManifold


class SPDGraphConvolution(nn.Module):
    """
    SPD graph convolution layer.
    """

    def __init__(self, manifold, in_features, out_features, args, use_bias=False, dropout=0.2, use_att=False, local_agg=False, nonlin=None):
        super(SPDGraphConvolution, self).__init__()
        self.linear = StiefelLinear(manifold, in_features, out_features, use_bias, dropout, nonlin=nonlin, device=args.device)
        self.agg = LogCholeskyAgg(manifold, out_features, dropout, use_att, local_agg)
        # self.hyp_act = HypAct(manifold, c_in, c_out, act)

    def forward(self, input):
        x, adj = input
        h = x
        h = self.linear(h)
        h = self.agg(h, adj)
        #h = self.hyp_act.forward(h)

        output = h, adj
        return output

class CholeskyLinear(nn.Module):
    def __init__(self,
                 manifold,
                 in_features,
                 out_features,
                 bias=True,
                 dropout=0.1,
                 scale=10,
                 fixscale=False,
                 nonlin=None,
                 device=None):
        super().__init__()
        self.nonlin = nonlin
        self.in_features = int(in_features*(in_features+1)/2)
        self.out_features = int(out_features*(out_features+1)/2)
        self.bias = bias
        self.irows, self.icols = torch.triu_indices(in_features,in_features,device=device)
        self.orows, self.ocols = torch.triu_indices(out_features, out_features, device=device)
        self.weight = nn.Linear(
            self.in_features, self.out_features, bias=bias)
        self.reset_parameters()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        if self.nonlin is not None:
            x = self.nonlin(x)
        try:
            C = torch.linalg.cholesky(x, upper=False)
        except RuntimeError:
            x = SPDManifold.clamping(x)
            C = torch.linalg.cholesky(x, upper=False)
        nf = C[:, self.icols, self.irows]
        nf = self.weight(self.dropout(nf))
        dim = max(self.orows)+1
        L = torch.zeros((len(x), dim, dim),
                        device=x.device, dtype=x.dtype)
        L[:, self.ocols, self.orows] = nf
        return L @ L.transpose(-1,-2)

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.out_features)
        step = self.in_features
        nn.init.uniform_(self.weight.weight, -stdv, stdv)
        with torch.no_grad():
            for idx in range(0, self.in_features, step):
                self.weight.weight[:, idx] = 0
        if self.bias:
            nn.init.constant_(self.weight.bias, 0)

class StiefelLinear(nn.Module):
    def __init__(self,
                 manifold,
                 in_features,
                 out_features,
                 bias=True,
                 dropout=0.1,
                 scale=10,
                 fixscale=False,
                 nonlin=None,
                 device=None):
        super().__init__()
        self.nonlin = nonlin
        self.in_features = in_features
        self.out_features = out_features
        stiefel_manifold = Stiefel()
        assert in_features >= out_features
        w_init = stiefel_manifold.random((in_features,out_features))
        self.weight = ManifoldParameter(
            w_init, manifold=stiefel_manifold
        )

    def forward(self, x):
        if self.nonlin is not None:
            x = self.nonlin(x)

        out = self.weight.t() @ x @ self.weight

        return out

class LogCholeskyAgg(Module):
    """
    Log-Cholesky Mean aggregation layer.
    """

    def __init__(self, manifold, in_features, dropout, use_att, local_agg):
        super(LogCholeskyAgg, self).__init__()
        self.manifold = manifold

        self.in_features = in_features
        self.dropout = dropout
        self.local_agg = local_agg
        self.use_att = use_att
        if self.use_att:
            # self.att = DenseAtt(in_features, dropout)
            self.key_linear = StiefelLinear(manifold, in_features, in_features)
            self.query_linear = StiefelLinear(manifold, in_features, in_features)
            self.bias = nn.Parameter(torch.zeros(()) + 20)
            self.scale = nn.Parameter(torch.zeros(()) + math.sqrt(in_features))

    def forward(self, x, adj):
        num_nodes, n, _ = x.size()

        try:
            C = torch.linalg.cholesky(x, upper=False)
        except RuntimeError:
            x = SPDManifold.clamping(x)
            C = torch.linalg.cholesky(x, upper=False)
        R = torch.tril(C,diagonal=-1).reshape(num_nodes, -1)

        D = torch.log(torch.diagonal(C, dim1=1, dim2=2)).reshape(num_nodes, -1)

        agg_R = torch.spmm(adj, R).reshape(num_nodes, n, n)
        agg_D = torch.spmm(adj, D)

        L = agg_R + torch.diag_embed(torch.exp(agg_D))
        L = torch.clamp(L, min=1e-6)
        return L @ L.transpose(-1,-2)


    def attention(self, x, adj):
        pass



