'''Implementation of kernel functions.'''

import torch
import numpy as np


def euclidean_distances(samples, centers, squared=True):
    samples_norm = torch.sum(samples**2, dim=1, keepdim=True)
    if samples is centers:
        centers_norm = samples_norm
    else:
        centers_norm = torch.sum(centers**2, dim=1, keepdim=True)
    centers_norm = torch.reshape(centers_norm, (1, -1))

    distances = samples.mm(torch.t(centers))
    distances.mul_(-2)
    distances.add_(samples_norm)
    distances.add_(centers_norm)
    #print(centers_norm.size(), samples_norm.size(), distances.size())
    if not squared:
        distances.clamp_(min=0)
        distances.sqrt_()

    return distances

def euclidean_distances_M(samples, centers, M, squared=True):

    samples_norm = (samples @ M)  * samples
    samples_norm = torch.sum(samples_norm, dim=1, keepdim=True)

    if samples is centers:
        centers_norm = samples_norm
    else:
        centers_norm = (centers @ M) * centers
        centers_norm = torch.sum(centers_norm, dim=1, keepdim=True)

    centers_norm = torch.reshape(centers_norm, (1, -1))

    distances = samples.mm(M @ torch.t(centers))
    distances.mul_(-2)
    distances.add_(samples_norm)
    distances.add_(centers_norm)

    if not squared:
        distances.clamp_(min=0)
        distances.sqrt_()

    return distances

def clamped_euclidean_distances_M(samples, centers, M, squared=True, clamp=1e-10):
    """ euclidean distances where those smaller than the value of clamp are set to 0. This is differentiable even when squared = False"""
    samples_norm = (samples @ M)  * samples
    samples_norm = torch.sum(samples_norm, dim=1, keepdim=True)

    if samples is centers:
        centers_norm = samples_norm
    else:
        centers_norm = (centers @ M) * centers
        centers_norm = torch.sum(centers_norm, dim=1, keepdim=True)

    centers_norm = torch.reshape(centers_norm, (1, -1))

    distances = samples.mm(M @ torch.t(centers))
    distances.mul_(-2)
    distances.add_(samples_norm)
    distances.add_(centers_norm)
    device = distances.device
    distances = torch.where(distances < clamp, torch.zeros(1,device=device), distances)
    if not squared:
        distances = torch.where(distances < clamp, torch.zeros(1,device=device), torch.sqrt(distances))
    return distances


##############
#### GAUSSIAN
##############

def gaussian(samples, centers, bandwidth):
    '''Gaussian kernel.

    Args:
        samples: of shape (n_sample, n_feature).
        centers: of shape (n_center, n_feature).
        bandwidth: kernel bandwidth.

    Returns:
        kernel matrix of shape (n_sample, n_center).
    '''
    assert bandwidth > 0
    kernel_mat = euclidean_distances(samples, centers)
    kernel_mat.clamp_(min=0)
    gamma = 1. / (2 * bandwidth ** 2)
    kernel_mat.mul_(-gamma)
    kernel_mat.exp_()

    #print(samples.size(), centers.size(),
    #      kernel_mat.size())
    return kernel_mat

def gaussian_M(samples, centers, bandwidth, M):
    assert bandwidth > 0
    kernel_mat = euclidean_distances_M(samples, centers, M)
    kernel_mat.clamp_(min=0)
    gamma = 1. / (2 * bandwidth ** 2)
    kernel_mat.mul_(-gamma)
    kernel_mat.exp_()

    return kernel_mat

def gaussian_M_wagop(X, L, sol, M, diag_only=False):
    # X is n x d
    # sol is c x n (= alpha.T)
    # M is d x d
    device = M.device
    U = torch.eye(M.shape[0]).to(device)
    U.requires_grad = True
    K = gaussian_M(X, X @ U, L, M)
    val = torch.trace(sol @ K @ sol.T) / sol.shape[0]
    val.backward()
    wagopsqrt = U.grad
    U = U.detach()
    wagop = wagopsqrt.T @ wagopsqrt
    if diag_only:
        wagop = torch.diag(wagop)
    return wagop

def gaussian_M_agop(X, L, sol, M, diag_only=False):
    K = gaussian_M(X,X,L,M)
    B = X @ M
    G1 = torch.einsum('il,ij,cj->cil',B,K,sol)
    G2 = torch.einsum('jl,ij,cj->cil',B,K,sol)
    G = (-G1 + G2) / L**2
    agop = torch.einsum('ikj,ikl->jl', G, G)
    agop = agop / G.shape[0]
    if diag_only:
        agop = torch.diag(agop)
    return agop

###############
#### LAPLACIAN
###############

def laplacian(samples, centers, bandwidth):
    '''Laplacian kernel.

    Args:
        samples: of shape (n_sample, n_feature).
        centers: of shape (n_center, n_feature).
        bandwidth: kernel bandwidth.

    Returns:
        kernel matrix of shape (n_sample, n_center).
    '''
    assert bandwidth > 0
    kernel_mat = euclidean_distances(samples, centers, squared=False)
    kernel_mat.clamp_(min=0)
    gamma = 1. / bandwidth
    kernel_mat.mul_(-gamma)
    kernel_mat.exp_()
    return kernel_mat

def laplacian_M(samples, centers, bandwidth, M):
    assert bandwidth > 0
    kernel_mat = euclidean_distances_M(samples, centers, M, squared=False)
    kernel_mat.clamp_(min=0)
    gamma = 1. / bandwidth
    kernel_mat.mul_(-gamma)
    kernel_mat.exp_()
    return kernel_mat

def clamped_laplacian_M(samples, centers, bandwidth, M,clamp=1e-10):
    """Differentiable, unlike laplacian_M. Derivatives around K(x,x) are set to 0."""
    assert bandwidth > 0
    kernel_mat = clamped_euclidean_distances_M(samples, centers, M, squared=False,clamp=clamp)
    kernel_mat.clamp_(min=0)
    gamma = 1. / bandwidth
    kernel_mat.mul_(-gamma)
    kernel_mat.exp_()
    return kernel_mat

def clamped_laplacian_M_wagop(X, L, sol, M, diag_only=False, clamp=1e-10):
    # X is n x d
    # sol is c x n (= alpha.T)
    # M is d x d
    d = M.shape[0]
    device = M.device
    U = torch.eye(d).to(device)
    U.requires_grad = True
    K = clamped_laplacian_M(X, X @ U, L, M,clamp=clamp)
    val = torch.trace(sol @ K @ sol.T) / sol.shape[0]
    val.backward()
    wagopsqrt = U.grad
    U = U.detach()
    M = wagopsqrt.T @ wagopsqrt # (FACT)(FACT^T)
    if diag_only:
        M = torch.diag(M)
    return M

def clamped_laplacian_M_fact(X, L, sol, M, diag_only=False, clamp=1e-10):
    # X is n x d
    # sol is c x n (= alpha.T)
    # M is d x d
    d = M.shape[0]
    device = M.device
    U = torch.eye(d).to(device)
    U.requires_grad = True
    K = clamped_laplacian_M(X, X @ U, L, M,clamp=clamp)
    val = torch.trace(sol @ K @ sol.T) / sol.shape[0]
    val.backward()
    wagopsqrt = U.grad
    fact = wagopsqrt.T
    
    U = U.detach()
    assert(not diag_only)
    return fact



def laplacian_M_agop(X, L, sol, M):

    num_samples = 20000
    indices = np.random.randint(len(X), size=num_samples)

    #"""
    if len(X) > len(indices):
        x = X[indices, :]
    else:
        x = X

    K = laplacian_M(X, x, L, M)
    device = K.device

    dist = euclidean_distances_M(X, x, M, squared=False)
    dist = torch.where(dist < 1e-10, torch.zeros(1).float().to(device), dist)

    K = K/dist
    K[K == float("Inf")] = 0.

    a1 = sol.T.float()
    n, d = X.shape
    n, c = a1.shape
    m, d = x.shape

    a1 = a1.reshape(n, c, 1)
    X1 = (X @ M).reshape(n, 1, d)
    step1 = a1 @ X1
    del a1, X1
    step1 = step1.reshape(-1, c*d)

    step2 = K.T @ step1
    del step1

    step2 = step2.reshape(-1, c, d)

    a2 = sol.float()
    step3 = (a2 @ K).T

    del K, a2

    step3 = step3.reshape(m, c, 1)
    x1 = (x @ M).reshape(m, 1, d)
    step3 = step3 @ x1

    G = (step2 - step3) * -1/L

    M = torch.einsum('ikj,ikl->jl', G, G)
    M = M / len(G)
    return M


###############
#### QUADRATIC
###############

def quadratic_kernel_M(p1, p2, M):
    # p1, p2 are n x d
    return torch.square((p1 @ M) @ p2.T)

def quadratic_kernel_M_wagop(X, sol, M, diag_only=False):    # X is n x d
    # sol is c x n (= alpha.T)
    # M is d x d
    device = M.device
    d = M.shape[0]
    U = torch.eye(d).to(device)
    U.requires_grad = True
    K = quadratic_kernel_M_wagop(X, X @ U, M)
    val = torch.trace(sol @ K @ sol.T) / sol.shape[0]
    val.backward()
    wagopsqrt = U.grad
    U = U.detach()
    M = wagopsqrt.T @ wagopsqrt
    if diag_only:
        M = torch.diag(M)
    return M


###############
#### DISPERSAL
###############

def dispersal(samples, centers, bandwidth, gamma):
    '''Dispersal kernel.

    Args:
        samples: of shape (n_sample, n_feature).
        centers: of shape (n_center, n_feature).
        bandwidth: kernel bandwidth.
        gamma: dispersal factor.

    Returns:
        kernel matrix of shape (n_sample, n_center).
    '''
    assert bandwidth > 0
    kernel_mat = euclidean_distances(samples, centers)
    kernel_mat.pow_(gamma / 2.)
    kernel_mat.mul_(-1. / bandwidth)
    kernel_mat.exp_()
    return kernel_mat
