import numpy as np
import torch
import torch.nn as nn
from gae2 import AE, DAE, RCAE, initialize_linear_layer_weight
from sph_n_ambient import *
from sph_n_DataUtil import getPosJacobianFromPos_torch, getNormalizedJacobianFromPos_torch
from torch_batch_svd import svd

def pinverse(mat, eps=1e-10):
    if mat.shape[1] < 32 or mat.shape[2] < 32:
        u, s, v = svd(mat)
    else:
        # this can break computation graph... check if the gradient of this function is required...
        mat = mat.cpu()
        try:
            u, s, v = torch.svd(mat)
        except:                     # torch.svd may have convergence issues for GPU and CPU.
            #print(s)
            u, s, v = torch.svd(mat + 1e-4*mat.mean()*torch.rand(mat.shape))

        u = u.cuda()
        s = s.cuda()
        v = v.cuda()
    s_inv = 1./s
    s_inv[s<eps] = 0.
    return torch.bmm(torch.bmm(v, torch.diag_embed(s_inv)), u.permute(0,2,1))
    
class GDAE_sph_ambient(DAE):
    ### self.autoencoder: R^(n+1) -> R^(n+1)
    ### r: S^n -> S^n (project from R^(n+1) to S^n)
    def __init__(self, dim, num_hidden_layers, noise_std, useLeakyReLU = True, initial = 'default'):
        super(GDAE_sph_ambient, self).__init__(dim, num_hidden_layers, noise_std, useLeakyReLU, initial)
        
    def forward(self, x):
        if x.is_cuda:
            epsilon = torch.cuda.FloatTensor(x.size()).normal_(0.0, self.noise_std)
        else:
            epsilon = torch.FloatTensor(x.size()).normal_(0.0, self.noise_std)
        epsilon = project_to_tangentSpace(x, epsilon)
        x_tilde = exponential_map(x, epsilon)
        return project_to_sphere(x_tilde + self.autoencoder(x_tilde))
    
    def calculate_loss(self, x, fixed_noise = None):
        if fixed_noise is None:
            recon_corrupt = self.forward(x)
        else:
            epsilon = project_to_tangentSpace(x, fixed_noise)
            x_tilde = exponential_map(x, epsilon)
            recon_corrupt = self.clean_forward(x_tilde)
        #dist = distance(recon_corrupt, x)
        #return torch.sum(dist*dist)
        return torch.sum(acos_square(torch.sum(recon_corrupt*x, axis=1)))
    
    def clean_forward(self, x):
        return project_to_sphere(x + self.autoencoder(x))
    
    def estimate_score(self, x, inCoord = False, dx_dxth = None):
        # return Log_x(r(x)) (score in the ambient space... coordinate transform may be required for comparison to other methods)
        r = self.clean_forward(x)
        log_x_r = logarithm_map(x, r)
        if inCoord:
            if dx_dxth is None:
                dx_dxth = getPosJacobianFromPos_torch(x, eps=1e-6)
            return torch.bmm(dx_dxth.permute(0,2,1), log_x_r.view(-1,self.input_dim,1)).view(-1,self.input_dim-1)  / self.noise_std**2
        return log_x_r / self.noise_std**2
    
    def calculate_expected_loss(self, x, Niter):
        lossSum = 0.0
        N = x.size()[0]
        with torch.no_grad():
            for i in range(Niter):
                recon_corrupt = self.forward(x)
                #dist = distance(recon_corrupt, x)
                #lossSum += torch.sum(dist*dist) / N
                lossSum += torch.sum(acos_square(torch.sum(recon_corrupt*x, axis=1))) / N
        return lossSum / Niter
    
    def get_derivative(self, x, create_graph=False, inCoord = False, dx_dxth = None):
        ### calculate dr/dx value
        def unit_vectors(N, length):
            result = []
            for i in range(0, length):
                x = torch.zeros(N, length).cuda()
                x[:,i] = 1
                result.append(x)
            return result

        if not x.requires_grad:
            x.requires_grad = True
        y = self.clean_forward(x)
        if create_graph:
            result = [torch.autograd.grad(outputs=y, inputs=x, grad_outputs=unit, create_graph=True)[0] for unit in unit_vectors(y.size(0), y.size(1))]
        else:
            result = [torch.autograd.grad(outputs=y, inputs=x, grad_outputs=unit, retain_graph=True)[0] for unit in unit_vectors(y.size(0), y.size(1))]
        dr_dx = torch.stack(result, dim=1)
        if inCoord:
            if dx_dxth is None:
                dx_dxth = getPosJacobianFromPos_torch(x, eps=1e-6)
            dr_drth = getPosJacobianFromPos_torch(y, eps=1e-6)
            drth_dr = pinverse(dr_drth, eps = 1e-3)     # use quite a large eps for numerical stability...
            #drth_dr = torch.pinverse(dr_drth)    # this may be slow...
            return torch.bmm(torch.bmm(drth_dr, dr_dx), dx_dxth)
        return dr_dx
        
class GRCAE_sph_ambient(RCAE):
    ### self.autoencoder: R^(n+1) -> R^(n+1)
    ### r: S^n -> S^n (project from R^(n+1) to S^n)
    def __init__(self, dim, num_hidden_layers, noise_std, useLeakyReLU = True, initial = 'default'):
        super(GRCAE_sph_ambient, self).__init__(dim, num_hidden_layers, noise_std, useLeakyReLU, initial)
        
    def forward(self, x):
        return project_to_sphere(x + self.autoencoder(x))
    
    def calculate_loss(self, x, separate = False):
        recon = self.forward(x)
        dr_dx = self.get_derivative(x, create_graph = True)
        dr_dx_multipy_x = torch.matmul(dr_dx, x.unsqueeze(-1))
        if separate:
            return torch.sum(acos_square(torch.sum(recon*x, axis=1))), (torch.sum(dr_dx**2) - torch.sum(dr_dx_multipy_x**2))
        return torch.sum(acos_square(torch.sum(recon*x, axis=1))) + self.noise_std**2 * (torch.sum(dr_dx**2) - torch.sum(dr_dx_multipy_x**2))
    
    def estimate_score(self, x, inCoord = False, dx_dxth = None):
        # return Log_x(r(x)) (score in the ambient space... coordinate transform may be required for comparison to other methods)
        r = self.forward(x)
        log_x_r = logarithm_map(x, r)
        if inCoord:
            if dx_dxth is None:
                dx_dxth = getPosJacobianFromPos_torch(x, eps=1e-6)
            return torch.bmm(dx_dxth.permute(0,2,1), log_x_r.view(-1,self.input_dim,1)).view(-1,self.input_dim-1)  / self.noise_std**2
        return log_x_r / self.noise_std**2
    
    def get_derivative(self, x, create_graph=False, inCoord = False, dx_dxth = None):
        ### calculate dr/dx value
        def unit_vectors(N, length):
            result = []
            for i in range(0, length):
                x = torch.zeros(N, length).cuda()
                x[:,i] = 1
                result.append(x)
            return result

        if not x.requires_grad:
            x.requires_grad = True
        y = self.forward(x)
        if create_graph:
            result = [torch.autograd.grad(outputs=y, inputs=x, grad_outputs=unit, create_graph=True)[0] for unit in unit_vectors(y.size(0), y.size(1))]
        else:
            result = [torch.autograd.grad(outputs=y, inputs=x, grad_outputs=unit, retain_graph=True)[0] for unit in unit_vectors(y.size(0), y.size(1))]
        dr_dx = torch.stack(result, dim=1)
        if inCoord:
            if dx_dxth is None:
                dx_dxth = getPosJacobianFromPos_torch(x, eps=1e-6)
            dr_drth = getPosJacobianFromPos_torch(y, eps=1e-6)
            drth_dr = pinverse(dr_drth, eps = 1e-3)     # use quite a large eps for numerical stability...
            #drth_dr = torch.pinverse(dr_drth)    # this may be slow...
            return torch.bmm(torch.bmm(drth_dr, dr_dx), dx_dxth)
        return dr_dx
        
################ for timer experiments ################
class GAE_vanilla(AE):
    ### self.autoencoder: R^(n+1) -> R^(n+1)
    ### r: S^n -> S^n (project from R^(n+1) to S^n)
    def __init__(self, dim, num_hidden_layers, useLeakyReLU = True, initial = 'default'):
        super(GAE_vanilla, self).__init__(dim, num_hidden_layers, useLeakyReLU, initial)
        
    def forward(self, x):
        return project_to_sphere(x + self.autoencoder(x))
    
    def calculate_loss(self, x):
        diff = x - self.forward(x)
        return torch.sum(diff*diff)

class GAE_sph_ambient_geodesiconly(AE):
    ### self.autoencoder: R^(n+1) -> R^(n+1)
    ### r: S^n -> S^n (project from R^(n+1) to S^n)
    def __init__(self, dim, num_hidden_layers, useLeakyReLU = True, initial = 'default'):
        super(GAE_sph_ambient_geodesiconly, self).__init__(dim, num_hidden_layers, useLeakyReLU, initial)
        
    def forward(self, x):
        return project_to_sphere(x + self.autoencoder(x))
    
    def calculate_loss(self, x):
        return torch.sum(acos_square(torch.sum(self.forward(x)*x, axis=1)))

class GDAE_sph_ambient_expmaponly(DAE):
    ### self.autoencoder: R^(n+1) -> R^(n+1)
    ### r: S^n -> S^n (project from R^(n+1) to S^n)
    def __init__(self, dim, num_hidden_layers, noise_std, useLeakyReLU = True, initial = 'default'):
        super(GDAE_sph_ambient_expmaponly, self).__init__(dim, num_hidden_layers, noise_std, useLeakyReLU, initial)
        
    def forward(self, x):
        if x.is_cuda:
            epsilon = torch.cuda.FloatTensor(x.size()).normal_(0.0, self.noise_std)
        else:
            epsilon = torch.FloatTensor(x.size()).normal_(0.0, self.noise_std)
        epsilon = project_to_tangentSpace(x, epsilon)
        x_tilde = exponential_map(x, epsilon)
        return project_to_sphere(x_tilde + self.autoencoder(x_tilde))
    
    def calculate_loss(self, x):
        diff = x - self.forward(x)
        return torch.sum(diff*diff)

class GRCAE_sph_ambient_contractiveonly(RCAE):
    ### self.autoencoder: R^(n+1) -> R^(n+1)
    ### r: S^n -> S^n (project from R^(n+1) to S^n)
    def __init__(self, dim, num_hidden_layers, noise_std, useLeakyReLU = True, initial = 'default'):
        super(GRCAE_sph_ambient_contractiveonly, self).__init__(dim, num_hidden_layers, noise_std, useLeakyReLU, initial)
        
    def forward(self, x):
        return project_to_sphere(x + self.autoencoder(x))
    
    def calculate_loss(self, x, separate = False):
        recon = self.forward(x)
        dr_dx = self.get_derivative(x, create_graph = True)
        dr_dx_multipy_x = torch.matmul(dr_dx, x.unsqueeze(-1))
        diff = x - recon
        if separate:
            return torch.sum(diff*diff), (torch.sum(dr_dx**2) - torch.sum(dr_dx_multipy_x**2))
        return torch.sum(diff*diff) + self.noise_std**2 * (torch.sum(dr_dx**2) - torch.sum(dr_dx_multipy_x**2))    
##############################################
        
################ previous models ################
class GDAE_sph_tangentSpaceProj(GDAE_sph_ambient):
    ### self.autoencoder: R^(n+1) -> R^(n+1)
    ### r: S^n -> S^n (project from R^(n+1) to T_xS^n to get tangent vector; then apply exponential map from T_xS^n to S^n)
    def __init__(self, dim, num_hidden_layers, noise_std, useLeakyReLU = True, initial = 'default'):
        super(GDAE_sph_tangentSpaceProj, self).__init__(dim, num_hidden_layers, noise_std, useLeakyReLU, initial)
        
    def forward(self, x):
        if x.is_cuda:
            epsilon = torch.cuda.FloatTensor(x.size()).normal_(0.0, self.noise_std)
        else:
            epsilon = torch.FloatTensor(x.size()).normal_(0.0, self.noise_std)
        epsilon = project_to_tangentSpace(x, epsilon)
        x_tilde = exponential_map(x, epsilon)
        v = project_to_tangentSpace(x_tilde, self.autoencoder(x_tilde))
        return exponential_map(x_tilde, v)
    
    def clean_forward(self, x):
        v = project_to_tangentSpace(x, self.autoencoder(x))
        return exponential_map(x, v)

class GDAE_sph_tangentSpace(GDAE_sph_ambient):
    ### self.autoencoder: R^(n+1) -> R^n
    ### r: S^n -> S^n (apply spherical coordinate (normalized) Jacobian basis (in R^(n+1 x n)) to R^n to get tangent vector in T_xS^n; then apply exponential map from T_xS^n to S^n)
    def __init__(self, dim, num_hidden_layers, noise_std, useLeakyReLU = True, initial = 'default'):
        super(GDAE_sph_tangentSpace, self).__init__(dim, num_hidden_layers, noise_std, useLeakyReLU, initial)
        
        network_layers = list(self.autoencoder.children())[:-1]
        layer = nn.Linear(self.h_dim, self.input_dim - 1)
        initialize_linear_layer_weight(layer, initial, 'linear')
        network_layers.append(layer)
        self.autoencoder = torch.nn.Sequential(*network_layers)
        
    def forward(self, x):
        if x.is_cuda:
            epsilon = torch.cuda.FloatTensor(x.size()).normal_(0.0, self.noise_std)
        else:
            epsilon = torch.FloatTensor(x.size()).normal_(0.0, self.noise_std)
        epsilon = project_to_tangentSpace(x, epsilon)
        x_tilde = exponential_map(x, epsilon)
        basis = getNormalizedJacobianFromPos_torch(x_tilde).permute(0,2,1)
        v = torch.bmm(self.autoencoder(x_tilde).view(-1, 1, self.input_dim-1), basis).view(-1, self.input_dim)
        return exponential_map(x_tilde, v)
    
    def clean_forward(self, x):
        basis = getNormalizedJacobianFromPos_torch(x).permute(0,2,1)
        v = torch.bmm(self.autoencoder(x).view(-1, 1, self.input_dim-1), basis).view(-1, self.input_dim)
        return exponential_map(x, v)
#################################################