import torch

class CKA(object):
    def __init__(self, device: torch.device):
        self.device = device

    def centering(self, K: torch.Tensor) -> torch.Tensor:
        if K.dtype != torch.float32:
            K = K.to(dtype = torch.float32)
        n = K.size(0)
        I = torch.eye(n, device=self.device)
        H = I - torch.ones((n, n), device=self.device) / n
        return H @ K @ H

    def linear_HSIC(self, X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
        Lx = X @ X.T
        Ly = Y @ Y.T
        return torch.sum(self.centering(Lx) * self.centering(Ly))

    def linear_CKA(self, X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
        hsic = self.linear_HSIC(X, Y)
        var1 = torch.sqrt(self.linear_HSIC(X, X))
        var2 = torch.sqrt(self.linear_HSIC(Y, Y))
        return hsic / (var1 * var2 + 1e-12)
    
class Procrustes(object):
    def __init__(self, device):
        self.device = device

    def svd_nuclear_norm(self, X, Y):
        Ux_r, Sx_r = torch.linalg.svd(X, full_matrices=False)[:2]
        Uy_r, Sy_r = torch.linalg.svd(Y, full_matrices=False)[:2]

        Sx_diag = torch.diag_embed(Sx_r)  # (bs x bs)
        Sy_diag = torch.diag_embed(Sy_r)  # (bs x bs)

        middle = Ux_r.transpose(0, 1).matmul(Uy_r)  # (bs x bs)
        A = Sx_diag @ middle @ Sy_diag               # (bs x bs)

        svals_A = torch.linalg.svdvals(A)
        nuc_norm = torch.sum(svals_A)

        return nuc_norm

    def orthogonal_procrustes_distance(self, X, Y, normalize=True):
        if normalize:
            X = X - X.mean()
            X = X / torch.norm(X, p='fro')
            Y = Y - Y.mean()
            Y = Y / torch.norm(Y, p='fro')
        xy_nuc = self.svd_nuclear_norm(X, Y)
        d = 2 - 2 * xy_nuc
        return d

    def orthogonal_procrustes_similarity(self, X, Y, normalize = True):
        d = self.orthogonal_procrustes_distance(X, Y, normalize)
        sim = 1.0 - d if normalize else 2.0 - d
        return sim
    
class DifferentiableKNN(object):
    def __init__(self, device, k, temperature = 500, metric = 'cosine', return_distributions = False):
        self.device = device
        self.k = k
        self.temperature = temperature
        self.metric = metric
        self.return_distributions = return_distributions

    def soft_topk(self, logits, k, temperature):
        b, n = logits.shape
        topk_vals, topk_idx = torch.topk(logits, k, dim=1)

        if temperature == 0.0:
            probs = torch.zeros_like(logits)
            probs.scatter_(1, topk_idx, 1.0 / k)
            return probs
        scaled = topk_vals / temperature
        scaled = scaled - scaled.max(dim=1, keepdim=True).values
        weights = scaled.exp()
        weights /= weights.sum(dim=1, keepdim=True)

        probs = torch.zeros_like(logits)
        probs.scatter_(1, topk_idx, weights)
        return probs

    def soft_knn_alignment_topk(self, f, g):
        assert f.shape[0] == g.shape[0]
        b = f.size(0)
        device = f.device

        def _pairwise_sq_dists(x):
            xx = (x**2).sum(dim=1, keepdim=True)
            return xx + xx.T - 2 * x @ x.T

        if self.metric == 'euclidean':
            S_f = -_pairwise_sq_dists(f)
            S_g = -_pairwise_sq_dists(g)
        elif self.metric == 'cosine':
            S_f = (f @ f.T)
            S_g = (g @ g.T)
        else:
            raise ValueError('metric must be \'euclidean\' or \'cosine\'')
        diag_mask = torch.eye(b, device=device, dtype=torch.bool)
        S_f.masked_fill_(diag_mask, -1e8)
        S_g.masked_fill_(diag_mask, -1e8)

        P = self.soft_topk(S_f, self.k, self.temperature)
        Q = self.soft_topk(S_g, self.k, self.temperature)
        m_soft_per_anchor = (P * Q).sum(dim=1)
        m_soft = m_soft_per_anchor.mean()
        return (m_soft, P, Q) if self.return_distributions else m_soft