import torch
import torch.nn.functional as F
import torchsort


def locality_reg(key_states, sample_cnt: int = 1024, sort_strength=1.0):
    """
    Locality regularizer
    :param key_states: (bsz, num_heads, k_len, head_dim)
    :param sample_cnt: number of samples
    :param sort_strength: soft_rank regularization strength
    :return: Spearman's rank correlation coefficients (bsz, num_heads).
    -1 means the keys are sorted in the reverse order of the closeness, 1 means the keys are sorted in the same order.
    """
    bsz, num_heads, k_len, head_dim = key_states.size()

    # (bsz, num_heads, sample_cnt)
    i = torch.randint(0, k_len, (bsz, num_heads, sample_cnt), device=key_states.device)
    j = torch.randint(0, k_len, (bsz, num_heads, sample_cnt), device=key_states.device)

    # (bsz, num_heads, sample_cnt, head_dim)
    ki = torch.gather(key_states, 2, i.unsqueeze(-1).expand(-1, -1, -1, head_dim))
    kj = torch.gather(key_states, 2, i.unsqueeze(-1).expand(-1, -1, -1, head_dim))

    # (bsz, num_heads, sample_cnt)
    sim = F.cosine_similarity(ki, kj, dim=-1)
    closeness = torch.exp(-(i - j).float() ** 2)

    result = spearmanr(
        sim.reshape(-1, sample_cnt),
        closeness.reshape(-1, sample_cnt),
        regularization_strength=sort_strength,
    ).reshape(bsz, num_heads)

    return result


def spearmanr(pred, target, **kw):
    pred = torchsort.soft_rank(pred, **kw)
    target = torchsort.soft_rank(target, **kw)
    pred = pred - pred.mean(dim=-1, keepdim=True)
    pred = pred / pred.norm(dim=-1, keepdim=True)
    target = target - target.mean(dim=-1, keepdim=True)
    target = target / target.norm(dim=-1, keepdim=True)
    return (pred * target).sum(dim=-1)
