import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from torch.autograd import Variable

def initialize_linear_layer_weight(layer, initial, nonlinearity):
    if initial == 'kaiming':
        init.kaiming_uniform_(layer.weight, a=0.1, nonlinearity=nonlinearity)
    elif initial == 'xavier':
        init.xavier_uniform_(layer.weight, gain=init.calculate_gain(nonlinearity, param=0.1))
    return

class AE(nn.Module):
    def __init__(self, dim, num_hidden_layers, useLeakyReLU = True, initial = 'default'):
        super(AE, self).__init__()
        [input_dim, h_dim] = dim
        self.input_dim = input_dim
        self.h_dim = h_dim
        self.num_layers = num_hidden_layers
        if useLeakyReLU:
            nonlinearity = 'leaky_relu'
        else:
            nonlinearity = 'tanh'
        
        network_layers = []
        if num_hidden_layers == 0:
            layer = nn.Linear(input_dim, input_dim)
            initialize_linear_layer_weight(layer, initial, 'linear')
            network_layers.append(layer)
        else:
            layer = nn.Linear(input_dim, h_dim)
            initialize_linear_layer_weight(layer, initial, nonlinearity)
            network_layers.append(layer)
            for i in range(num_hidden_layers):
                if useLeakyReLU:
                    network_layers.append(nn.LeakyReLU(0.1, inplace=True))
                else:
                    network_layers.append(nn.Tanh())
                if i < num_hidden_layers - 1:
                    layer = nn.Linear(h_dim, h_dim)
                    initialize_linear_layer_weight(layer, initial, nonlinearity)
                else:
                    layer = nn.Linear(h_dim, input_dim)
                    initialize_linear_layer_weight(layer, initial, 'linear')
                network_layers.append(layer)
            
        self.autoencoder = torch.nn.Sequential(*network_layers)
    
    def forward(self, x):
        return x + self.autoencoder(x)
    def calculate_loss(self, x):
        diff = x - self.forward(x)
        return torch.sum(diff*diff)
    def get_derivative(self, x, create_graph=False):
        ### calculate dr/dx value
        def unit_vectors(N, length):
            result = []
            for i in range(0, length):
                x = torch.zeros(N, length).cuda()
                x[:,i] = 1
                result.append(x)
            return result

        if not x.requires_grad:
            x.requires_grad = True
        y = x + self.autoencoder(x)
        if create_graph:
            result = [torch.autograd.grad(outputs=y, inputs=x, grad_outputs=unit, create_graph=True)[0] for unit in unit_vectors(y.size(0), y.size(1))]
        else:
            result = [torch.autograd.grad(outputs=y, inputs=x, grad_outputs=unit, retain_graph=True)[0] for unit in unit_vectors(y.size(0), y.size(1))]
        jacobian = torch.stack(result, dim=1)
        return jacobian
    
    def get_derivative_old(self, x, create_graph=False):
        ### calculate dr/dx value
        if not x.requires_grad:
            x.requires_grad = True
        y = x + self.autoencoder(x)
        grad_set = []
        x.grad = None
        for i in range(self.input_dim):
            if x.is_cuda:
                temp = torch.zeros(y.size()).cuda()
            else:
                temp = torch.zeros(y.size())
            temp[:,i] = 1
            if create_graph:
                y.backward(gradient=temp, create_graph=True)
            else:
                y.backward(gradient=temp, retain_graph=True)
            grad_set.append(x.grad)
            x.grad = None
            #temp[:,i] = 0
        return torch.stack(grad_set, dim=1)
    
class DAE(AE):
    def __init__(self, dim, num_hidden_layers, noise_std, useLeakyReLU = True, initial = 'default'):
        self.noise_std = noise_std
        super(DAE, self).__init__(dim, num_hidden_layers, useLeakyReLU, initial)
        
        
    def forward(self, x):
        if x.is_cuda:
            epsilon = Variable(torch.cuda.FloatTensor(x.size()).normal_(0.0, 
                                   self.noise_std), requires_grad = False)
        else:
            epsilon = Variable(torch.FloatTensor(x.size()).normal_(0.0, 
                                   self.noise_std), requires_grad = False)
        return x + epsilon + self.autoencoder(x + epsilon)
    
    def calculate_loss(self, x, fixed_noise = None):
        if fixed_noise is None:
            recon_corrupt = self.forward(x)
        else:
            recon_corrupt = self.clean_forward(x + fixed_noise)
        diff2 = recon_corrupt - x
        return torch.sum(diff2*diff2)
    
    def clean_forward(self, x):
        return x + self.autoencoder(x)
    
    def estimate_score(self, x):
        return (self.autoencoder(x)) / self.noise_std**2
    
    def calculate_expected_loss(self, x, Niter):
        lossSum = 0.0
        N = x.size()[0]
        with torch.no_grad():
            for i in range(Niter):
                recon_corrupt = self.forward(x)
                diff2 = recon_corrupt - x
                lossSum += torch.sum(diff2*diff2) / N
        return lossSum / Niter
    
class GDAE(DAE):
    def __init__(self, dim, num_hidden_layers, noise_std, diagonal_metric = False, 
                 metricSqrtFunc = None, useLeakyReLU = True, initial = 'default', corruptedMetric = False):
        self.diagonal_metric = diagonal_metric
        self.metric_sqrt_func = metricSqrtFunc
        self.corrupted_metric = corruptedMetric
        super(GDAE, self).__init__(dim, num_hidden_layers, noise_std, useLeakyReLU, initial)
        
    def forward(self, x, metricInv_sqrt):
        if self.diagonal_metric:
            if x.is_cuda:
                epsilon = Variable(
                    torch.cuda.FloatTensor(x.size()[0],x.size()[1]).normal_(0.0, 
                                   self.noise_std), requires_grad = False) * metricInv_sqrt
            else:
                epsilon = Variable(
                    torch.FloatTensor(x.size()[0],x.size()[1]).normal_(0.0, 
                                   self.noise_std), requires_grad = False) * metricInv_sqrt
        else:
            if x.is_cuda:
                epsilon = torch.bmm(Variable(
                    torch.cuda.FloatTensor(x.size()[0],1,x.size()[1]).normal_(0.0, 
                                   self.noise_std), requires_grad = False), 
                                metricInv_sqrt).view(x.size()[0],x.size()[1])
                
            else:
                epsilon = torch.bmm(Variable(
                    torch.FloatTensor(x.size()[0],1,x.size()[1]).normal_(0.0, 
                                   self.noise_std), requires_grad = False),
                                metricInv_sqrt).view(x.size()[0],x.size()[1])
        #print(epsilon.size())
        #print(metricInv_sqrt.size())
        return x + epsilon + self.autoencoder(x + epsilon)
    
    def calculate_loss(self, x, metricInv_sqrt, metric_sqrt, fixed_noise = None):
        if fixed_noise is None:
            recon_corrupt = self.forward(x, metricInv_sqrt)
        else:
            recon_corrupt = self.clean_forward(x + fixed_noise)
        if self.metric_sqrt_func is not None:
            if self.corrupted_metric:
                metric_sqrt = self.metric_sqrt_func(recon_corrupt)
            if not self.corrupted_metric:
                recon = self.clean_forward(x)
                metric_sqrt = self.metric_sqrt_func(recon)
        if self.diagonal_metric:
            diff2 = (x - recon_corrupt) * metric_sqrt
        else:
            diff2 = torch.bmm((x - recon_corrupt).unsqueeze(1), metric_sqrt)
        return torch.sum(diff2*diff2)
    
    def estimate_score(self, x, metric):
        recon = x + self.autoencoder(x)
        diff = recon - x
        if self.diagonal_metric:
            score_est = (diff * metric) / self.noise_std**2
        else:
            score_est = torch.bmm(diff.unsqueeze(1), metric) \
            / self.noise_std**2
        return score_est
        
    def calculate_expected_loss(self, x, metricInv_sqrt, metric_sqrt, Niter):
        lossSum = 0.0
        N = x.size()[0]
        with torch.no_grad():
            for i in range(Niter):
                recon_corrupt = self.forward(x, metricInv_sqrt)
                if self.metric_sqrt_func is not None:
                    if self.corrupted_metric:
                        metric_sqrt = self.metric_sqrt_func(recon_corrupt)
                    if not self.corrupted_metric:
                        recon = self.clean_forward(x)
                        metric_sqrt = self.metric_sqrt_func(recon)
                if self.diagonal_metric:
                    diff2 = (x - recon_corrupt) * metric_sqrt
                else:
                    diff2 = torch.bmm((x - recon_corrupt).unsqueeze(1), 
                                  metric_sqrt)
                lossSum += torch.sum(diff2*diff2) / N
        return lossSum / Niter
        
        
class RCAE(AE):
    def __init__(self, dim, num_hidden_layers, noise_std, useLeakyReLU = True, initial = 'default'):
        self.noise_std = noise_std
        super(RCAE, self).__init__(dim, num_hidden_layers, useLeakyReLU, initial)
    
    def calculate_loss(self, x, separate = False):
        diff = self.autoencoder(x)
        dr_dx = self.get_derivative(x, create_graph = True)
        if separate:
            return torch.sum(diff*diff), torch.sum(dr_dx**2)
        return torch.sum(diff*diff) + self.noise_std**2 * torch.sum(dr_dx**2)
    
    def estimate_score(self, x):
        return (self.autoencoder(x)) / self.noise_std**2
    
class GRCAE(RCAE):
    def __init__(self, dim, num_hidden_layers, noise_std, diagonal_metric = False, 
                 metricSqrtFunc = None, useLeakyReLU = True, initial = 'default'):
        self.diagonal_metric = diagonal_metric
        self.metric_sqrt_func = metricSqrtFunc
        super(GRCAE, self).__init__(dim, num_hidden_layers, noise_std, useLeakyReLU, initial)
    
    def calculate_loss(self, x, metricInv_sqrt, metric_sqrt, separate = False):
        recon = self.forward(x)
        if self.metric_sqrt_func is not None:
            metric_sqrt_r = self.metric_sqrt_func(recon)
        else:
            metric_sqrt_r = metric_sqrt
        
        dr_dx = self.get_derivative(x, create_graph = True)
        if self.diagonal_metric:
            diff2 = (x - recon) * metric_sqrt
            temp = metric_sqrt_r.unsqueeze(-1) * dr_dx * metricInv_sqrt.unsqueeze(1)
        else:
            # metric_sqrt_r and metricInv_sqrt should be symmetric...
            diff2 = torch.bmm((x - recon).unsqueeze(1), metric_sqrt)
            temp = torch.matmul(metric_sqrt_r, torch.matmul(dr_dx, metricInv_sqrt))
        if separate:
            return torch.sum(diff2*diff2), torch.sum(temp*temp)
        return torch.sum(diff2*diff2) + self.noise_std**2 * torch.sum(temp*temp)
    
    def estimate_score(self, x, metric):
        diff = self.autoencoder(x)
        if self.diagonal_metric:
            score_est = (diff * metric) / self.noise_std**2
        else:
            score_est = torch.bmm(diff.unsqueeze(1), metric) \
            / self.noise_std**2
        return score_est