"""
Similarity Measures for Knowledge Distillation

This module implements various similarity measures for comparing representations
between teacher and student models during knowledge distillation.

Available measures:
- Linear: Procrustes alignment with SVD-based projection as well as nuclear norm based alignment
- CKA: Centered Kernel Alignment with linear or Gaussian kernel
- Euclidean: Mean Squared Error with dimension matching using a linear projection
- Energy: Energy distance metric
"""

import torch
from torch import Tensor
from torch.nn.functional import pad, mse_loss
import random
from typing import Literal, Tuple, Optional, List


def gaussian_kernel(X: torch.Tensor, Y: torch.Tensor, sigma: float = None) -> torch.Tensor:
    """
    Compute Gaussian kernel matrix between two sets of vectors.
    
    Args:
        X: First set of vectors (n_samples, n_features)
        Y: Second set of vectors (m_samples, n_features)
        sigma: Kernel bandwidth. If None, uses median distance
        
    Returns:
        Kernel matrix of shape (n_samples, m_samples)
    """
    X_norm_sq = (X ** 2).sum(dim=1).view(-1, 1)  # (n, 1)
    Y_norm_sq = (Y ** 2).sum(dim=1).view(1, -1)  # (1, m)
    dist_sq = X_norm_sq + Y_norm_sq - 2 * X @ Y.T
    
    if not sigma:
        sigma = torch.median(dist_sq) / 2
        
    K = torch.exp(-dist_sq / torch.sqrt(sigma))
    return K


class MSE_w_padding(torch.nn.Module):
    """
    Mean Squared Error with padding support for dimension matching.
    
    This class handles cases where the feature dimensions of two representations
    don't match by applying zero padding or other dimension matching strategies.
    """

    def __init__(self, dim_matching='zero_pad', reduction="mean"):
        """
        Initialize MSE with padding.
        
        Args:
            dim_matching: Strategy for handling dimension mismatches
                         ('zero_pad', 'pca', 'none')
            reduction: Reduction method ('mean', 'sum', 'none')
        """
        super(MSE_w_padding, self).__init__()
        self.dim_matching = dim_matching
        self.reduction = reduction
        
    def forward(self, X, Y):
        """
        Compute MSE between X and Y with dimension matching.
        
        Args:
            X: First tensor (batch_size, seq_len, features)
            Y: Second tensor (batch_size, seq_len, features)
            
        Returns:
            MSE loss value
        """
        if X.shape[:-1] != Y.shape[:-1] or X.ndim != 3 or Y.ndim != 3:
            raise ValueError(
                'Expected 3D input matrices to match in all dimensions but last. '
                f'But got {X.shape} and {Y.shape} instead.'
            )

        if X.shape[-1] != Y.shape[-1]:
            if self.dim_matching is None or self.dim_matching == 'none':
                raise ValueError(
                    f'Expected same dimension matrices got instead {X.shape} and {Y.shape}. '
                    f'Set dim_matching or change matrix dimensions.'
                )
            elif self.dim_matching == 'zero_pad':
                size_diff = Y.shape[-1] - X.shape[-1]
                if size_diff < 0:
                    raise ValueError(
                        f'With `zero_pad` dimension matching expected X dimension to be smaller than Y. '
                        f'But got {X.shape} and {Y.shape} instead.'
                    )
                X = pad(X, (0, size_diff))
            elif self.dim_matching == 'pca':
                raise NotImplementedError
            else:
                raise ValueError(f'Unrecognized dimension matching {self.dim_matching}')

        return mse_loss(X, Y, reduction=self.reduction)
        

class LinearMeasure(torch.nn.Module):
    """
    Linear similarity measure using SVD-based alignment.
    
    This measure computes linear alignment between representations using
    singular value decomposition to find optimal projections.
    """
    
    def __init__(self,
                 alpha=1, 
                 center_columns=True, 
                 dim_matching='zero_pad', 
                 svd_grad=False, 
                 reduction='mean', 
                 no_svd=True, 
                 approx=False):
        """
        Initialize linear similarity measure.
        
        Args:
            alpha: Weight parameter for the measure
            center_columns: Whether to center the columns
            dim_matching: Strategy for dimension matching
            svd_grad: Whether to compute gradients through SVD
            reduction: Reduction method ('mean', 'sum', 'none')
            no_svd: Whether to skip SVD for single samples
            approx: Whether to use approximation method
        """
        super(LinearMeasure, self).__init__()
        self.register_buffer('alpha', torch.tensor(alpha))
        assert dim_matching in [None, 'none', 'zero_pad', 'pca']
        self.dim_matching = dim_matching
        self.center_columns = center_columns
        self.svd_grad = svd_grad
        self.reduction = reduction
        self.no_svd = no_svd
        self.approx = approx

    def partial_fit(self, X: Tensor, center_col_idx=1) -> Tuple[Tensor, Tensor]:
        """
        Compute mean centered columns for whitening transform.
        
        Args:
            X: Input tensor
            center_col_idx: Dimension to center along
            
        Returns:
            Tuple of (mean, centered_X)
        """
        if self.center_columns:
            mx = torch.mean(X, dim=center_col_idx, keepdim=True)
        else:
            mx = torch.zeros(X.shape[2], dtype=X.dtype, device=X.device)
        wx = X - mx

        return mx, wx

    def create_sim_matrix(self, X: Tensor):
        """
        Create similarity matrix from flattened tensor.
        
        Args:
            X: Input tensor
            
        Returns:
            Similarity matrix
        """
        X = torch.flatten(X, end_dim=-2)
        sim_m = X @ X.T
        return sim_m

    def fit(self, X: Tensor, Y: Tensor) -> Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor]]:
        """
        Fit the linear transformation using SVD.
        
        Args:
            X: First tensor
            Y: Second tensor
            
        Returns:
            Tuple of transformation parameters for X and Y
        """
        mx, wx = self.partial_fit(X)
        my, wy = self.partial_fit(Y)
        
        if self.svd_grad:
            wxy = torch.bmm(wx.transpose(1, 2), wy)
            U, _, Vt = torch.linalg.svd(wxy, driver="gesvd", full_matrices=False)
        else:
            with torch.no_grad():
                wxy = torch.bmm(wx.transpose(1, 2), wy)
                U, _, Vt = torch.linalg.svd(wxy, driver="gesvd", full_matrices=False)
                
        wx = U
        wy = Vt.transpose(1, 2)
        return (mx, wx), (my, wy)

    def project(self, X: Tensor, m: Tensor, w: Tensor):
        """
        Project tensor using learned transformation.
        
        Args:
            X: Input tensor
            m: Mean vector
            w: Projection matrix
            
        Returns:
            Projected tensor
        """
        if self.center_columns:
            return torch.bmm((X - m), w)
        else:
            return torch.bmm(X, w)

    def forward(self, X: Tensor, Y: Tensor):
        """
        Compute linear similarity between X and Y.
        
        Args:
            X: First tensor (batch_size, seq_len, features)
            Y: Second tensor (batch_size, seq_len, features)
            
        Returns:
            Similarity measure value
        """
        if X.shape[:-1] != Y.shape[:-1] or X.ndim != 3 or Y.ndim != 3:
            raise ValueError(
                'Expected 3D input matrices to match in all dimensions but last. '
                f'But got {X.shape} and {Y.shape} instead.'
            )
            
        if self.approx:
            # Use approximation method
            mx, wx = self.partial_fit(X, center_col_idx=0)
            my, wy = self.partial_fit(Y, center_col_idx=0)

            wx = torch.flatten(wx, end_dim=-2)
            wy = torch.flatten(wy, end_dim=-2)

            X_flat = torch.flatten(X, end_dim=-2)
            Y_flat = torch.flatten(Y, end_dim=-2)
            
            K_x = X_flat.T @ wx
            K_y = Y_flat.T @ wy

            x_fro = torch.trace(K_x)
            y_fro = torch.trace(K_y)

            sq_trace = torch.norm(X_flat.T @ wy, p="nuc")

            return x_fro + y_fro - 2 * sq_trace
        else:
            # Handle dimension matching
            if X.shape[-1] != Y.shape[-1]:
                if self.dim_matching is None or self.dim_matching == 'none':
                    raise ValueError(
                        f'Expected same dimension matrices got instead {X.shape} and {Y.shape}. '
                        f'Set dim_matching or change matrix dimensions.'
                    )
                elif self.dim_matching == 'zero_pad':
                    size_diff = Y.shape[-1] - X.shape[-1]
                    if size_diff < 0:
                        raise ValueError(
                            f'With `zero_pad` dimension matching expected X dimension to be smaller than Y. '
                            f'But got {X.shape} and {Y.shape} instead.'
                        )
                    X = pad(X, (0, size_diff))
                elif self.dim_matching == 'pca':
                    raise NotImplementedError
                else:
                    raise ValueError(f'Unrecognized dimension matching {self.dim_matching}')

            # Compute similarity
            if self.no_svd and X.shape[1] == 1:
                # Simple norm-based similarity for single samples
                mx, wx = self.partial_fit(X)
                my, wy = self.partial_fit(Y)
    
                x_norm = torch.linalg.norm(wx, dim=(1, 2))
                y_norm = torch.linalg.norm(wy, dim=(1, 2))
    
                norms = (x_norm - y_norm) ** 2
            else:
                # Use SVD-based alignment
                X_params, Y_params = self.fit(X, Y)
                norms = torch.linalg.norm(
                    self.project(X, *X_params) - self.project(Y, *Y_params), 
                    ord="fro", 
                    dim=(1, 2)
                )

        # Apply reduction
        if self.reduction == 'mean':
            return norms.mean()
        elif self.reduction == 'sum':
            return norms.sum()
        elif self.reduction == 'none' or self.reduction is None:
            return norms
        else:
            raise ValueError(f'Unrecognized reduction {self.reduction}')


class CKA(torch.nn.Module):
    """
    Centered Kernel Alignment (CKA) similarity measure.
    
    CKA measures the similarity between two representations by computing
    the alignment of their kernel matrices.
    """
    
    def __init__(self, dim_matching='zero_pad', reduction='mean', kernel="linear", 
                 similarity_token_strategy="flatten", biased=False):
        """
        Initialize CKA similarity measure.
        
        Args:
            dim_matching: Strategy for dimension matching
            reduction: Reduction method ('mean', 'sum', 'none')
            kernel: Kernel type ('linear', 'gaussian')
            similarity_token_strategy: Strategy for token handling ('flatten', 'random')
            biased: Whether to use biased CKA
        """
        super(CKA, self).__init__()
        assert dim_matching in [None, 'none', 'zero_pad', 'pca']
        self.dim_matching = dim_matching
        self.reduction = reduction
        self.kernel = kernel
        self.similarity_token_strategy = similarity_token_strategy
        self.random_tokens = None
        self.biased = biased

    def generate_random_token_index(self, token_size, selected_size=10):
        """
        Generate random token indices for sampling.
        
        Args:
            token_size: Total number of tokens
            selected_size: Number of tokens to sample
        """
        self.random_tokens = random.sample(range(token_size), selected_size)
    
    def create_sim_matrix(self, X: Tensor):
        """
        Create similarity matrix from input tensor.
        
        Args:
            X: Input tensor
            
        Returns:
            Similarity matrix
        """
        if self.similarity_token_strategy == "flatten":
            X = torch.flatten(X, end_dim=-2)
        elif self.similarity_token_strategy == "random":
            if not self.random_tokens:
                self.generate_random_token_index(X.shape[1])
            X = torch.flatten(X[:, self.random_tokens, :], end_dim=-2)
            
        # Create similarity matrix
        if self.kernel == "linear":
            sim_m = X @ X.T
        elif self.kernel == "gaussian":
            sim_m = gaussian_kernel(X, X)
            
        return sim_m

    def make_diag_zero(self, sim_m):
        """
        Set diagonal elements of similarity matrix to zero.
        
        Args:
            sim_m: Similarity matrix
            
        Returns:
            Similarity matrix with zero diagonal
        """
        diagonal_mask = torch.eye(sim_m.shape[-1], dtype=torch.bool).unsqueeze(0).to(sim_m.device)

        # Handle batch dimension
        if len(sim_m.shape) > 2:
            sim_m = sim_m.masked_fill(diagonal_mask, 0)
        else:
            sim_m = sim_m.masked_fill(diagonal_mask[0], 0)
            
        return sim_m

    def HSIC(self, K, L):
        """
        Compute Hilbert-Schmidt Independence Criterion.
        
        Args:
            K: First similarity matrix
            L: Second similarity matrix
            
        Returns:
            HSIC value
        """
        n = K.shape[-1]
        
        if self.biased:
            # Biased HSIC
            device = K.device
            H = torch.eye(n, device=device) - (1/n) * torch.ones((n, 1), device=device) @ torch.ones((1, n), device=device)
            kl = K @ H @ L @ H
            return (1/n**2) * kl.diagonal().sum()
        else:
            # Unbiased HSIC
            K = self.make_diag_zero(K)
            L = self.make_diag_zero(L)
            
            if len(K.shape) == 3:
                # TODO: Implement batched CKA
                pass
            else:
                kl = K @ L
                trace = kl.diagonal().sum()
                middle = (K.sum() * L.sum()) / ((n-1) * (n-2))
                last = -2 * kl.sum() / (n-2)
                return torch.abs((trace + middle + last) / (n * (n-3)))

    def forward(self, X: Tensor, Y: Tensor):
        """
        Compute CKA similarity between X and Y.
        
        Args:
            X: First tensor (batch_size, seq_len, features)
            Y: Second tensor (batch_size, seq_len, features)
            
        Returns:
            CKA similarity value
        """
        # Create similarity matrices
        X_sim_matrix = self.create_sim_matrix(X)
        Y_sim_matrix = self.create_sim_matrix(Y)

        # Compute HSIC values
        self_hsic_x = self.HSIC(X_sim_matrix, X_sim_matrix)
        self_hsic_y = self.HSIC(Y_sim_matrix, Y_sim_matrix)
        cross_hsic = self.HSIC(X_sim_matrix, Y_sim_matrix)
        
        # Compute CKA
        batched_cka = cross_hsic / (torch.sqrt(self_hsic_x) * torch.sqrt(self_hsic_y))
            
        # Apply reduction
        if self.reduction == 'mean':
            return batched_cka.mean()
        elif self.reduction == 'sum':
            return batched_cka.sum()
        elif self.reduction == 'none' or self.reduction is None:
            return batched_cka
        else:
            raise ValueError(f'Unrecognized reduction {self.reduction}')


class EnergyMetric(torch.nn.Module):
    """
    Energy distance metric for comparing representations.
    
    This metric computes the energy distance between two sets of representations
    using iterative optimization to find optimal alignments.
    """

    def __init__(self, n_iter=100, tol=1e-6, dim_matching='zero_pad', reduction='mean'):
        """
        Initialize energy metric.
        
        Args:
            n_iter: Number of iterations for optimization
            tol: Tolerance for convergence
            dim_matching: Strategy for dimension matching
            reduction: Reduction method ('mean', 'sum', 'none')
        """
        super(EnergyMetric, self).__init__()
        self.n_iter = n_iter
        self.tol = torch.tensor(tol)
        assert dim_matching in [None, 'none', 'zero_pad', 'pca']
        self.dim_matching = dim_matching
        self.reduction = reduction

    @torch.no_grad()
    def fit(self, X: torch.Tensor, Y: torch.Tensor):
        """
        Fit the energy metric by finding optimal alignment.
        
        Args:
            X: First tensor
            Y: Second tensor
            
        Returns:
            Tuple of (weights, transformation, losses, X_prod, Y_prod)
        """
        n_x = X.shape[2]
        n_y = Y.shape[2]

        # Create all pairwise combinations
        X = X.repeat_interleave(n_y, dim=2).flatten(start_dim=1, end_dim=2)
        Y = Y.tile(dims=(1, 1, n_x, 1)).flatten(start_dim=1, end_dim=2)

        if X.shape[1] != Y.shape[1]:
            raise ValueError(f"After permutation got {X.shape} and {Y.shape}")

        # Initialize weights
        w = torch.ones(X.shape[0], X.shape[1])

        # Track losses
        batch_loss = [torch.mean(torch.linalg.norm(X - Y, dim=-1), dim=-1)]
        
        # Iterative optimization
        for i in range(self.n_iter):
            T = self.get_orth_matrix(w[:, :, None] * X, w[:, :, None] * Y)
            iter_result = torch.linalg.norm(X - torch.bmm(Y, T), dim=-1)
            batch_loss.append(torch.mean(iter_result, dim=-1))
            w = 1 / torch.maximum(torch.sqrt(iter_result), self.tol)

        return w, T, batch_loss, X, Y

    def get_orth_matrix(self, X: torch.Tensor, Y: torch.Tensor):
        """
        Get orthogonal transformation matrix.
        
        Args:
            X: First tensor
            Y: Second tensor
            
        Returns:
            Orthogonal transformation matrix
        """
        U, _, Vt = torch.linalg.svd(torch.bmm(X.transpose(1, 2), Y))
        return torch.bmm(Vt.transpose(1, 2), U.transpose(1, 2))

    def get_dist_energy(self, X: torch.Tensor):
        """
        Compute distance energy for a single tensor.
        
        Args:
            X: Input tensor
            
        Returns:
            Distance energy value
        """
        n = X.shape[2]
        combs = torch.combinations(torch.arange(n))
        X1 = torch.flatten(X[:, :, combs[:, 0], :], start_dim=1, end_dim=2)
        X2 = torch.flatten(X[:, :, combs[:, 1], :], start_dim=1, end_dim=2)

        return torch.mean(torch.linalg.norm(X2 - X1, dim=-1), dim=-1)

    def forward(self, X: torch.Tensor, Y: torch.Tensor):
        """
        Compute energy distance between X and Y.
        
        Expected tensors to be of the form batch x class x repeats x activations
        
        Args:
            X: First tensor (batch_size, num_classes, num_repeats, features)
            Y: Second tensor (batch_size, num_classes, num_repeats, features)
            
        Returns:
            Energy distance value
        """
        if X.shape[:-2] != Y.shape[:-2] or X.ndim != 4 or Y.ndim != 4:
            raise ValueError(
                'Expected 4D input matrices to match in all dimensions but last two. '
                f'But got {X.shape} and {Y.shape} instead.'
            )

        # Handle dimension matching
        if X.shape[-1] != Y.shape[-1]:
            if self.dim_matching is None or self.dim_matching == 'none':
                raise ValueError(
                    f'Expected same dimension matrices got instead {X.shape} and {Y.shape}. '
                    f'Set dim_matching or change matrix dimensions.'
                )
            elif self.dim_matching == 'zero_pad':
                size_diff = Y.shape[-1] - X.shape[-1]
                if size_diff < 0:
                    raise ValueError(
                        f'With `zero_pad` dimension matching expected X dimension to be smaller than Y. '
                        f'But got {X.shape} and {Y.shape} instead.'
                    )
                X = pad(X, (0, size_diff))
            elif self.dim_matching == 'pca':
                raise NotImplementedError
            else:
                raise ValueError(f'Unrecognized dimension matching {self.dim_matching}')

        # Fit the energy metric
        w, T, fit_loss, X_prod, Y_prod = self.fit(X, Y)

        # Compute energy distances
        e_xx = self.get_dist_energy(X)
        e_yy = self.get_dist_energy(Y)
        Y_proj = torch.bmm(Y_prod, T)
        e_xy = torch.mean(torch.linalg.norm(X_prod - Y_proj, dim=-1), dim=-1)

        # Compute final energy distance
        norms = torch.sqrt(torch.nn.functional.relu(e_xy - 0.5 * (e_xx + e_yy)))

        # Apply reduction
        if self.reduction == 'mean':
            return norms.mean()
        elif self.reduction == 'sum':
            return norms.sum()
        elif self.reduction == 'none' or self.reduction is None:
            return norms
        else:
            raise ValueError(f'Unrecognized reduction {self.reduction}')
