import numpy as np
import torch


def kernel_distance(feats1, feats2, num_subsets=100, max_subset_size=1000):
    if isinstance(feats1, torch.Tensor):
        feats1 = feats1.cpu().numpy()
    if isinstance(feats2, torch.Tensor):
        feats2 = feats2.cpu().numpy()

    n = feats1.shape[1]
    m = min(min(feats1.shape[0], feats2.shape[0]), max_subset_size)
    t = 0
    for _subset_idx in range(num_subsets):
        x = feats2[np.random.choice(feats2.shape[0], m, replace=False)]
        y = feats1[np.random.choice(feats1.shape[0], m, replace=False)]
        a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3
        b = (x @ y.T / n + 1) ** 3
        t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m
    kid = t / num_subsets / m
    return float(kid)
