import torch
import numpy as np

def get_pairwise_dist(x):
    pairwise_dist = torch.cdist(x, x)
    pairwise_dist_tri = torch.triu(pairwise_dist)
    median_dist = torch.median(torch.masked_select(pairwise_dist_tri, pairwise_dist_tri>0))
    return pairwise_dist, median_dist

def get_pairwise_dist_by_dim(x):
    pairwise_dist_by_dim = torch.vmap(lambda x: torch.cdist(x, x), in_dims=1)(x.unsqueeze(-1))
    # TODO - test timing with a chunk_size argument in vmap 
    pairwise_dist_by_dim_tri = torch.triu(pairwise_dist_by_dim)
    pairwise_dist_by_dim_tri_nan = pairwise_dist_by_dim_tri.masked_fill(
        (pairwise_dist_by_dim_tri[0]<=0).unsqueeze(0).repeat(x.shape[1],1,1),
        float("nan")
    )
    median_dist_by_dim = pairwise_dist_by_dim_tri_nan.flatten(start_dim=1).nanmedian(dim=1).values
    return pairwise_dist_by_dim, median_dist_by_dim

def rbf(pairwise_dist, h):
    return torch.exp( - pairwise_dist**2 / h )

def rbf_scale_by_dim(pairwise_dist, h, factors):
    k_rbf = torch.exp( - pairwise_dist**2 / h )
    return k_rbf.unsqueeze(2).tile(factors.shape[0]) * factors.unsqueeze(0).unsqueeze(0)

def d_rbf(pairwise_dist, x, h):
    k = rbf(pairwise_dist, h)
    return - 2 * ( torch.matmul(k, x) - torch.sum(k, dim=1).unsqueeze(1) * x ) / h

def rbf_by_dim(pairwise_dist_by_dim, h_by_dim):
    pairwise_dist_by_dim_T = pairwise_dist_by_dim.permute(*torch.arange(pairwise_dist_by_dim.dim() - 1, -1, -1))
    return torch.exp( - torch.div(pairwise_dist_by_dim_T**2, h_by_dim) )

def d_rbf_by_dim(pairwise_dist_by_dim, x, h_by_dim):
    k = torch.exp(-torch.div(pairwise_dist_by_dim, h_by_dim.unsqueeze(1).unsqueeze(1)))
    return - 2 * torch.div(
        torch.vmap(lambda x_s, k_s: torch.matmul(k_s, x_s))(x.T, k) - k.sum(dim=1) * x.T,
        h_by_dim.unsqueeze(1)
    ).T

def imq(pairwise_dist, h):
    return 1 / torch.sqrt(1 + pairwise_dist**2 / (2*h))

def d_imq(pairwise_dist, x, h):
    k3 = torch.pow(imq(pairwise_dist, h), 3)
    return - ( torch.matmul(k3, x) - torch.sum(k3, dim=1).unsqueeze(1) * x ) / (2*h)
