'''Implementation of kernel functions.'''

import torch
import numpy as np


def euclidean_distances(samples, centers, squared=True):
    samples_norm = torch.sum(samples**2, dim=1, keepdim=True)
    if samples is centers:
        centers_norm = samples_norm
    else:
        centers_norm = torch.sum(centers**2, dim=1, keepdim=True)
    centers_norm = torch.reshape(centers_norm, (1, -1))

    distances = samples.mm(torch.t(centers))
    distances.mul_(-2)
    distances.add_(samples_norm)
    distances.add_(centers_norm)
    #print(centers_norm.size(), samples_norm.size(), distances.size())
    if not squared:
        distances.clamp_(min=0)
        distances.sqrt_()

    return distances

def euclidean_distances_M(samples, centers, M, squared=True):

    samples_norm = (samples @ M)  * samples
    samples_norm = torch.sum(samples_norm, dim=1, keepdim=True)

    if samples is centers:
        centers_norm = samples_norm
    else:
        centers_norm = (centers @ M) * centers
        centers_norm = torch.sum(centers_norm, dim=1, keepdim=True)

    centers_norm = torch.reshape(centers_norm, (1, -1))

    distances = samples.mm(M @ torch.t(centers))
    distances.mul_(-2)
    distances.add_(samples_norm)
    distances.add_(centers_norm)

    if not squared:
        distances.clamp_(min=0)
        distances.sqrt_()

    return distances

def clamped_euclidean_distances_M(samples, centers, M, squared=True, clamp=1e-10):
    """ euclidean distances where those smaller than the value of clamp are set to 0. This is differentiable even when squared = False"""
    samples_norm = (samples @ M)  * samples
    samples_norm = torch.sum(samples_norm, dim=1, keepdim=True)

    if samples is centers:
        centers_norm = samples_norm
    else:
        centers_norm = (centers @ M) * centers
        centers_norm = torch.sum(centers_norm, dim=1, keepdim=True)

    centers_norm = torch.reshape(centers_norm, (1, -1))

    distances = samples.mm(M @ torch.t(centers))
    distances.mul_(-2)
    distances.add_(samples_norm)
    distances.add_(centers_norm)
    device = distances.device
    distances = torch.where(distances < clamp, torch.zeros(1,device=device), distances)
    if not squared:
        distances = torch.where(distances < clamp, torch.zeros(1,device=device), torch.sqrt(distances))
    return distances


##############
#### GAUSSIAN
##############

def gaussian(samples, centers, bandwidth):
    '''Gaussian kernel.

    Args:
        samples: of shape (n_sample, n_feature).
        centers: of shape (n_center, n_feature).
        bandwidth: kernel bandwidth.

    Returns:
        kernel matrix of shape (n_sample, n_center).
    '''
    assert bandwidth > 0
    kernel_mat = euclidean_distances(samples, centers)
    kernel_mat.clamp_(min=0)
    gamma = 1. / (2 * bandwidth ** 2)
    kernel_mat.mul_(-gamma)
    kernel_mat.exp_()

    #print(samples.size(), centers.size(),
    #      kernel_mat.size())
    return kernel_mat

def gaussian_M(samples, centers, bandwidth, M):
    assert bandwidth > 0
    kernel_mat = euclidean_distances_M(samples, centers, M)
    kernel_mat.clamp_(min=0)
    gamma = 1. / (2 * bandwidth ** 2)
    kernel_mat.mul_(-gamma)
    kernel_mat.exp_()

    return kernel_mat

def gaussian_inner_M(samples, centers, bandwidth, M):
    assert bandwidth > 0
    kernel_mat = samples @ M @ centers.T
    gamma = 1. / (2 * bandwidth ** 2)
    kernel_mat.mul_(gamma)
    kernel_mat.exp_()

    return kernel_mat

def gaussian_M_wagop(X, L, sol, M, diag_only=False):
    # X is n x d
    # sol is c x n (= alpha.T)
    # M is d x d
    device = M.device
    U = torch.eye(M.shape[0]).to(device)
    U.requires_grad = True
    K = gaussian_M(X, X @ U, L, M)
    val = torch.trace(sol @ K @ sol.T) / sol.shape[0]
    val.backward()
    wagopsqrt = U.grad
    U = U.detach()
    wagop = wagopsqrt.T @ wagopsqrt
    if diag_only:
        wagop = torch.diag(wagop)
    return wagop

def gaussian_M_agop(X, L, sol, M, diag_only=False):
    K = gaussian_M(X,X,L,M)
    B = X @ M
    G1 = torch.einsum('il,ij,cj->cil',B,K,sol)
    G2 = torch.einsum('jl,ij,cj->cil',B,K,sol)
    G = (-G1 + G2) / L**2
    agop = torch.einsum('ikj,ikl->jl', G, G)
    agop = agop / G.shape[0]
    if diag_only:
        agop = torch.diag(agop)
    return agop

###############
#### LAPLACIAN
###############

def laplacian(samples, centers, bandwidth):
    '''Laplacian kernel.

    Args:
        samples: of shape (n_sample, n_feature).
        centers: of shape (n_center, n_feature).
        bandwidth: kernel bandwidth.

    Returns:
        kernel matrix of shape (n_sample, n_center).
    '''
    assert bandwidth > 0
    kernel_mat = euclidean_distances(samples, centers, squared=False)
    kernel_mat.clamp_(min=0)
    gamma = 1. / bandwidth
    kernel_mat.mul_(-gamma)
    kernel_mat.exp_()
    return kernel_mat

def laplacian_M(samples, centers, bandwidth, M):
    assert bandwidth > 0
    kernel_mat = euclidean_distances_M(samples, centers, M, squared=False)
    kernel_mat.clamp_(min=0)
    gamma = 1. / bandwidth
    kernel_mat.mul_(-gamma)
    kernel_mat.exp_()
    return kernel_mat

def clamped_laplacian_M(samples, centers, bandwidth, M,clamp=1e-10):
    """Differentiable, unlike laplacian_M. Derivatives around K(x,x) are set to 0."""
    assert bandwidth > 0
    kernel_mat = clamped_euclidean_distances_M(samples, centers, M, squared=False,clamp=clamp)
    kernel_mat.clamp_(min=0)
    gamma = 1. / bandwidth
    kernel_mat.mul_(-gamma)
    kernel_mat.exp_()
    return kernel_mat

def clamped_laplacian_M_wagop(X, L, sol, M, diag_only=False, clamp=1e-10):
    # X is n x d
    # sol is c x n (= alpha.T)
    # M is d x d
    d = M.shape[0]
    device = M.device
    U = torch.eye(d).to(device)
    U.requires_grad = True
    K = clamped_laplacian_M(X, X @ U, L, M,clamp=clamp)
    val = torch.trace(sol @ K @ sol.T) / sol.shape[0]
    val.backward()
    wagopsqrt = U.grad
    U = U.detach()
    M = wagopsqrt.T @ wagopsqrt # or wagopsqrt.T @ wagopsqrt
    if diag_only:
        M = torch.diag(M)
    return M



def laplacian_M_agop(X, L, sol, M):

    num_samples = 20000
    indices = np.random.randint(len(X), size=num_samples)

    #"""
    if len(X) > len(indices):
        x = X[indices, :]
    else:
        x = X

    K = laplacian_M(X, x, L, M)
    device = K.device

    dist = euclidean_distances_M(X, x, M, squared=False)
    dist = torch.where(dist < 1e-10, torch.zeros(1).float().to(device), dist)

    K = K/dist
    K[K == float("Inf")] = 0.

    a1 = sol.T.float()
    n, d = X.shape
    n, c = a1.shape
    m, d = x.shape

    a1 = a1.reshape(n, c, 1)
    X1 = (X @ M).reshape(n, 1, d)
    step1 = a1 @ X1
    del a1, X1
    step1 = step1.reshape(-1, c*d)

    step2 = K.T @ step1
    del step1

    step2 = step2.reshape(-1, c, d)

    a2 = sol.float()
    step3 = (a2 @ K).T

    del K, a2

    step3 = step3.reshape(m, c, 1)
    x1 = (x @ M).reshape(m, 1, d)
    step3 = step3 @ x1

    G = (step2 - step3) * -1/L

    M = torch.einsum('ikj,ikl->jl', G, G)
    M = M / len(G)
    return M

def laplacian_M_get_grads(X, L, sol, M):
    """ G of dimension n x c x d, where G[i,j,:] is df_c(x_i)/dx in R^d """
    x = X

    K = laplacian_M(X, x, L, M)
    device = K.device

    dist = euclidean_distances_M(X, x, M, squared=False)
    dist = torch.where(dist < 1e-10, torch.zeros(1).float().to(device), dist)

    K = K/dist
    K[K == float("Inf")] = 0.

    a1 = sol.T.float()
    n, d = X.shape
    n, c = a1.shape
    m, d = x.shape

    a1 = a1.reshape(n, c, 1)
    X1 = (X @ M).reshape(n, 1, d)
    step1 = a1 @ X1
    del a1, X1
    step1 = step1.reshape(-1, c*d)

    step2 = K.T @ step1
    del step1

    step2 = step2.reshape(-1, c, d)

    a2 = sol.float()
    step3 = (a2 @ K).T

    del K, a2

    step3 = step3.reshape(m, c, 1)
    x1 = (x @ M).reshape(m, 1, d)
    step3 = step3 @ x1

    G = (step2 - step3) * -1/L
    return G


###############
#### QUADRATIC
###############


def quadratic_kernel_L_M(p1, p2, L, M):
    # L parameter is bandwidth, and probably does not do anything
    return torch.square((p1 @ M) @ p2.T / L)

# def quadratic_kernel_M(p1, p2, M):
#     # p1, p2 are n x d
#     return quadratic_kernel_L_M(p1, p2, 1.0, M)

# def quadratic_kernel_M_wagop(X, sol, M, diag_only=False):    # X is n x d
#     return quadratic_kernel_L_M_wagop(X, 1.0, sol, M, diag_only)

# def quadratic_kernel_L_M_lgop(X, L, sol, M, diag_only=False):   
#     # sol is c x n (= alpha.T)
#     # M is d x d
#     device = M.device
#     d = M.shape[0]
#     U = torch.eye(d).to(device)
#     U.requires_grad = True
#     K = quadratic_kernel_L_M(X, X @ U, L, M)
#     val = torch.trace(sol @ K @ sol.T) / sol.shape[0]
#     val.backward()
#     lgop = U.grad
#     U = U.detach()
#     del U
#     # print(lgop)
#     if diag_only:
#         lgop = torch.diag(lgop)
#     return lgop

# def quadratic_kernel_L_M_wagop(X, L, sol, M, diag_only=False):   
#     lgop = quadratic_kernel_L_M_lgop(X, L, sol, M, diag_only)
#     wagop = lgop @ M
#     if diag_only:
#         wagop = torch.diag(wagop)
#     return wagop


###############
#### DISPERSAL
###############

def dispersal(samples, centers, bandwidth, gamma):
    '''Dispersal kernel.

    Args:
        samples: of shape (n_sample, n_feature).
        centers: of shape (n_center, n_feature).
        bandwidth: kernel bandwidth.
        gamma: dispersal factor.

    Returns:
        kernel matrix of shape (n_sample, n_center).
    '''
    assert bandwidth > 0
    kernel_mat = euclidean_distances(samples, centers)
    kernel_mat.pow_(gamma / 2.)
    kernel_mat.mul_(-1. / bandwidth)
    kernel_mat.exp_()
    return kernel_mat
