import geotorch
import math
import random
import torch
import torch.nn as nn

from .functions import ImplicitFunction
from ..tools.utils import projection_norm_inf

"""
------------------
Implicit Graph
------------------
Implicit graph layer allowing several parametrizations.

    linModule:
        cayley - W = \sigmoid(\mu) * (I-\skew[C]) @ \inv(I+\skew[C]) @ D(\rho) {D diagonal with entries either -1,1 with P(-1)=\rho}
        expm - W = \sigmoid(\mu) * \expm[C] {expm is the geotorch orthogonal B @ \expm( \inv[B] @ A ) }
        frob - W = \sigmoid(\mu) * CC' / ( || CC'||_\F + \eps )
        proj - W { ||W||_\inf < 1 / \lambda_\pf [A] }
        symm - W = .5 * (1 - \exp[\mu]) * I - CC'

------------------
Reference
------------------
https://github.com/SwiftieH/IGNN

"""
class ImplicitGraph(nn.Module):
    def __init__(self, in_features, out_features, num_node,
                kappa=0.99, b_direct=False, mu=None, linModule='proj', device='cpu', rho=0, **kwargs):
        super(ImplicitGraph, self).__init__()
        self.p = in_features
        self.m = out_features
        self.n = num_node
        self.k = kappa # if set kappa=None, projection will be disabled at forward feeding. 
        self.b_direct = b_direct

        # Added Parametrizations
        stdv = 1. / math.sqrt(self.m)
        self.linModule = linModule
        if (mu is None) and (linModule is not None):
                self.mu = nn.Parameter(torch.ones(1,dtype=torch.float,device=device))
        else: self.mu = torch.tensor(mu,dtype=torch.float,device=device)

        if self.linModule == 'expm':
            self.C = nn.Linear(self.m, self.m,bias=False)
            geotorch.orthogonal(self.C, 'weight')
            self.C.weight = self.C.parametrizations.weight[0].sample('torus')

        elif self.linModule == 'cayley':
            diag = [1]*int((1-rho)*self.m) + [-1]*int(rho*self.m)# [-1,1] values with %of -1s set by rho
            if len(diag)<self.m: 
                diag += [1]
            random.shuffle(diag) #random shuffling
            self.D = torch.diag(torch.tensor(diag,device=device,dtype=torch.float))
            self.C = nn.Linear(self.m,self.m,bias=False,dtype=torch.float)
            geotorch.skew(self.C, 'weight')

        else:
            # self.C = nn.Parameter(.78*torch.eye(self.m, dtype=torch.float))
            self.C = nn.Parameter(torch.FloatTensor(self.m, self.m))
            self.C.data.uniform_(-stdv, stdv)

        self.I = torch.eye(self.m, dtype=torch.float,device=device)
        self.Omega_1 = nn.Parameter(torch.FloatTensor(self.m, self.p))
        self.Omega_1.data.uniform_(-stdv, stdv)
        self.func = ImplicitFunction(**kwargs)
        pass


    def forward(self, X_0, A, U, phi, A_rho=1.0, fw_mitr=300, bw_mitr=300, A_orig=None):
        support_1 = torch.spmm(torch.transpose(U, 0, 1), self.Omega_1.T).T
        support_1 = torch.spmm(torch.transpose(A, 0, 1), support_1.T).T
        b_Omega = support_1
        return self.func(self.get_W(A_rho), X_0, A if A_orig is None else A_orig, b_Omega, phi, fw_mitr, bw_mitr)


    def get_W(self,A_rho):
        if self.linModule == 'cayley':
            return torch.sigmoid(self.mu)*(self.I - self.C.weight) @ torch.inverse(self.I+self.C.weight)#@self.D

        elif self.linModule == 'expm':
            return torch.sigmoid(self.mu)*self.C.weight

        elif self.linModule == 'frob':
            return torch.sigmoid(self.mu)* self.C @ self.C.t() / torch.linalg.norm(self.C@self.C.t()+1e-6)

        elif self.linModule == 'proj':
            if self.k != 0: # when self.k = 0, A_rho is not required
                self.C = projection_norm_inf(self.C, kappa=self.k/A_rho)
            return self.C

        elif self.linModule == 'symm':
            return (1-torch.sigmoid(self.mu))/2 * self.I - self.C @ self.C.t()

        else:
            raise NotImplementedError