import torch
import torch.nn as nn
import scipy as scp

class Projection(nn.Module):
    def __init__(self, manifold=None):
        super(Projection, self).__init__()
        self.manifold = manifold.lower()
        if self.manifold is None:
            raise  ValueError('Manifold needs to be provided')
        if not (self.manifold == 'stiefel' or self.manifold == 'sphere' or self.manifold == 'grassmann'
                or self.manifold == 'qconv' or self.manifold == 'spd' or self.manifold == 'euc'):
            raise ValueError('Manifold needs to be stiefel, grassmann or spd or euc or sphere')

    def forward(self, param_val=None, euclidean_grad=None):
        if euclidean_grad.dim() == 2 and param_val.dim() == 3:
            euclidean_grad = euclidean_grad.unsqueeze(dim = 2)
        assert ( param_val.size() == euclidean_grad.size() )
        n, p = euclidean_grad.size(0), euclidean_grad.size(1)#
        flag = 0

        if self.manifold == 'sphere':
            val            = torch.sum( param_val * euclidean_grad )
            projected_grad = euclidean_grad - param_val * val
            return projected_grad

        elif self.manifold == 'qconv':
            eta = euclidean_grad.clone()
            if param_val.dim() == 3:
                assert (param_val.size(2) == 1)
                param_val = param_val.squeeze()
                eta       = eta.squeeze()

            L    = param_val.clone()
            LtL  = torch.mm(L.t(), L)
            SS   = torch.mm(LtL, LtL)
            RHS  = torch.mm(eta.t(), torch.mm(L, LtL)) - torch.mm(torch.mm(LtL, L.t()), eta)
            SS1  = SS.clone().detach()
            RHS1 = RHS.clone().detach()
            if SS.is_cuda:
                SS1  = SS.cpu()
                RHS1 = RHS1.cpu()

            Omega          = scp.linalg.solve_sylvester(SS1.numpy(), SS1.numpy(), RHS1.numpy())
            Omega          = torch.from_numpy(Omega).type_as(SS)
            projected_grad = eta + torch.mm(L, Omega.t())
            return projected_grad

        else:
            if euclidean_grad.dim()==2:
                flag            = 1
                k               = 1
                param_val       = param_val.unsqueeze(dim=2)
                euclidean_grad  = euclidean_grad.unsqueeze(dim=2)
            else:
                k = euclidean_grad.size(2)

            projected_grad  = torch.zeros(n, p, k)
            if self.manifold == 'stiefel':
                X  = param_val
                U  = euclidean_grad
                if X.is_cuda:
                    if not U.is_cuda:
                        U = U.cuda()
                Xt = X.transpose(dim0=0, dim1=1)
                for i in range(k):
                    XtU =  torch.mm(Xt[:, :, i], U[:, :, i])
                    symXtU = 0.5 * ( XtU + XtU.t() )
                    projected_grad[:, :, i] = U[:, :, i] - torch.mm(X[:, :, i], symXtU)
            elif self.manifold == 'grassmann':
                X  = param_val
                U  = euclidean_grad
                Xt = torch.transpose(X, dim0=0, dim1=1)
                for i in range(k):
                    XtU = torch.mm( Xt[:, :, i], U[:, :, i] )
                    projected_grad[:, :, i] = U[:, :, i] - torch.mm( X[:, :, i], XtU)
            elif self.manifold == 'spd':
                assert (n==p)
                for i in range(0, k):
                    grad_vec                = euclidean_grad[:, :, i]
                    projected_grad[:, :, i] = 0.5*(grad_vec + grad_vec.t())
            elif self.manifold == 'euc':
                projected_grad = euclidean_grad

            else:
                raise ValueError('Manifold needs to be provided')

            if flag==1:
                projected_grad = projected_grad.squeeze()
            return projected_grad
'''
-------------------------------- RETRACTION ----------------------------------------------------------------------------
'''
class Retraction(nn.Module):
    def __init__(self, manifold=None, parameter=None, config=None):
        super(Retraction, self).__init__()
        self.manifold = manifold.lower()
        if self.manifold is None:
            raise  ValueError('Manifold needs to be provided')
        if not (self.manifold == 'stiefel' or self.manifold == 'sphere' or self.manifold == 'grassmann'
                or self.manifold == 'qconv' or self.manifold == 'spd' or self.manifold == 'euc'):
            raise ValueError('Manifold needs to be stiefel, grassmann or spd or euc or sphere')

    def forward(self, param_val, projected_grad, mul_fac = None):
        if projected_grad.dim() == 2 and param_val.dim() == 3:
            projected_grad = projected_grad.unsqueeze(dim = 2)
        assert (param_val.size() == projected_grad.size())
        n, p  = projected_grad.size(0), projected_grad.size(1)
        flag = 0
        if self.manifold == 'sphere':
            if mul_fac == None:
                mul_fac = 1.0
            val             = param_val + projected_grad * mul_fac
            retracted_point = val / torch.norm(val)
            return retracted_point

        elif self.manifold == 'qconv':
            if mul_fac == None:
                mul_fac = 1.0
            val             = param_val + projected_grad * mul_fac
            retracted_point = val
            return retracted_point

        else:
            if projected_grad.dim() == 2:
                flag            = 1
                k               = 1
                param_val       = param_val.unsqueeze(dim=2)
                projected_grad  = projected_grad.unsqueeze(dim=2)
            else:
                k = projected_grad.size(2)

            retracted_point = torch.zeros(n, p, k)
            if self.manifold == 'stiefel':
                if mul_fac == None:
                    mul_fac = 1.0
                Y = param_val + mul_fac * projected_grad.type_as(param_val)
                for i in range(0, k):
                    Q, R        = torch.linalg.qr( Y[:, :, i] )
                    Y[:, :, i]  = torch.mm(Q, torch.diag(torch.sign(torch.sign(torch.diag(R)) + 0.5)))
                retracted_point = Y

            elif self.manifold == 'grassmann':
                if mul_fac == None:
                    mul_fac = 1.0
                Y = param_val + mul_fac * projected_grad
                for i in range(0, k):
                    u, s, v     = torch.svd( Y[:, :, i] )
                    Y[:, :, i]  = torch.mm( u, v.t() )
                retracted_point = Y

            elif self.manifold == 'spd':
                if mul_fac is not None:
                    grad_val    = mul_fac * projected_grad
                else:
                    grad_val    = projected_grad
                assert (n==p)
                for i in range(0, k):
                    param_vec                = param_val[:, :, i]
                    grad_vec                 = grad_val[:, :, i]
                    W                        = param_vec + grad_vec + torch.mm(0.5 * grad_vec, torch.mm(param_vec, torch.inverse( grad_vec )) )
                    retracted_point[:, :, i] = 0.5 * W * W.t()

            elif self.manifold == 'euc':
                if mul_fac == None:
                    mul_fac = 1.0
                retracted_point = param_val + mul_fac * projected_grad

            else:
                raise ValueError('Manifold needs to be provided')

            if flag == 1:
                retracted_point = retracted_point.squeeze()
            return retracted_point
'''
-------------------------------------- PARALLEL TRANSPORT --------------------------------------------------------------
'''
class Parallel_Transport(nn.Module):
    def __init__(self, manifold=None):
        super(Parallel_Transport, self).__init__()
        self.manifold = manifold.lower()
        if self.manifold is None:
            raise  ValueError('Manifold needs to be provided')
        if not (self.manifold == 'stiefel' or self.manifold == 'sphere' or self.manifold == 'grassmann'
                or self.manifold == 'qconv' or self.manifold == 'spd' or self.manifold == 'euc'):
            raise  ValueError('Manifold needs to be stiefel, grassmann or spd or euc or sphere')

    def forward(self, param_val1=None, param_val0=None, projected_grad1 = None ):
        if self.manifold == 'spd' or self.manifold == 'euc':
            parallel_grad = projected_grad1
        else:
            proj            = Projection(self.manifold)
            parallel_grad   = proj( param_val=param_val0, euclidean_grad=projected_grad1 )

        return parallel_grad

