"""
CKA (Centered Kernel Alignment) Similarity Analysis

Implements CKA similarity measurement to assess representation preservation
(Table 4 in paper).
"""

import torch
import torch.nn as nn
import numpy as np
from typing import Dict, List, Tuple
from tqdm import tqdm


class CKAAnalyzer:
    """
    Compute CKA similarity between models.

    Used to measure representation preservation (Section 5.4, Table 4).
    """

    def __init__(self, device: str = "cuda"):
        """
        Args:
            device: Device for computations
        """
        self.device = device

    @staticmethod
    def centering(K: torch.Tensor) -> torch.Tensor:
        """
        Center the kernel matrix.

        Args:
            K: Kernel matrix (n x n)

        Returns:
            Centered kernel matrix
        """
        n = K.shape[0]
        unit = torch.ones(n, n, device=K.device)
        I = torch.eye(n, device=K.device)
        H = I - unit / n

        return torch.matmul(torch.matmul(H, K), H)

    @staticmethod
    def linear_CKA(X: torch.Tensor, Y: torch.Tensor) -> float:
        """
        Compute linear CKA similarity between two feature matrices.

        CKA(X, Y) = ||Y^T X||_F^2 / (||X^T X||_F * ||Y^T Y||_F)

        Args:
            X: Feature matrix (n x d1)
            Y: Feature matrix (n x d2)

        Returns:
            CKA similarity score [0, 1]
        """
        # Compute kernel matrices
        X = X - X.mean(dim=0, keepdim=True)
        Y = Y - Y.mean(dim=0, keepdim=True)

        # CKA formula
        XtX = torch.matmul(X.t(), X)
        YtY = torch.matmul(Y.t(), Y)
        YtX = torch.matmul(Y.t(), X)

        hsic = torch.norm(YtX, p='fro') ** 2
        normalization = torch.norm(XtX, p='fro') * torch.norm(YtY, p='fro')

        return (hsic / normalization).item()

    @staticmethod
    def rbf_CKA(X: torch.Tensor, Y: torch.Tensor, sigma: float = None) -> float:
        """
        Compute RBF kernel CKA similarity.

        Args:
            X: Feature matrix (n x d1)
            Y: Feature matrix (n x d2)
            sigma: RBF bandwidth (if None, use median heuristic)

        Returns:
            CKA similarity score [0, 1]
        """
        def rbf_kernel(X, sigma):
            # Compute pairwise distances
            XX = torch.matmul(X, X.t())
            X_sqnorms = torch.diag(XX)
            X_L2 = -2 * XX + X_sqnorms.unsqueeze(1) + X_sqnorms.unsqueeze(0)

            if sigma is None:
                # Median heuristic
                sigma = torch.median(X_L2[X_L2 > 0])

            return torch.exp(-X_L2 / (2 * sigma ** 2))

        K = rbf_kernel(X, sigma)
        L = rbf_kernel(Y, sigma)

        # Center kernels
        K = CKAAnalyzer.centering(K)
        L = CKAAnalyzer.centering(L)

        # CKA
        hsic = torch.sum(K * L)
        normalization = torch.sqrt(torch.sum(K * K)) * torch.sqrt(torch.sum(L * L))

        return (hsic / normalization).item()

    @torch.no_grad()
    def extract_layer_representations(
        self,
        model: nn.Module,
        dataloader,
        layer_indices: List[int],
        max_samples: int = 1000,
    ) -> Dict[int, torch.Tensor]:
        """
        Extract intermediate representations from specified layers.

        Args:
            model: The model
            dataloader: Data loader
            layer_indices: List of layer indices to extract
            max_samples: Maximum number of samples

        Returns:
            Dictionary mapping layer index to feature matrix (n_samples x d)
        """
        model.eval()

        # Hooks to capture activations
        activations = {idx: [] for idx in layer_indices}

        def get_activation(layer_idx):
            def hook(module, input, output):
                # Store activation (handle tuples)
                act = output[0] if isinstance(output, tuple) else output
                # Take mean over sequence dimension
                if len(act.shape) == 3:  # (batch, seq, hidden)
                    act = act.mean(dim=1)
                activations[layer_idx].append(act.detach().cpu())
            return hook

        # Register hooks
        hooks = []
        for name, module in model.named_modules():
            layer_idx = self._extract_layer_index(name)
            if layer_idx in layer_indices:
                hooks.append(module.register_forward_hook(get_activation(layer_idx)))

        # Forward pass
        num_samples = 0
        for batch in tqdm(dataloader, desc="Extracting representations"):
            if num_samples >= max_samples:
                break

            # Move batch to device
            batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                    for k, v in batch.items()}

            model(**batch)
            num_samples += batch['input_ids'].shape[0]

        # Remove hooks
        for hook in hooks:
            hook.remove()

        # Concatenate activations
        layer_features = {}
        for layer_idx in layer_indices:
            if len(activations[layer_idx]) > 0:
                layer_features[layer_idx] = torch.cat(activations[layer_idx], dim=0)[:max_samples]
            else:
                print(f"Warning: No activations captured for layer {layer_idx}")

        return layer_features

    def compute_model_cka(
        self,
        model1: nn.Module,
        model2: nn.Module,
        dataloader,
        layer_ranges: Dict[str, List[int]],
        max_samples: int = 1000,
        method: str = "linear"
    ) -> Dict[str, float]:
        """
        Compute CKA similarity between two models for different layer ranges.

        Reproduces Table 4: CKA Similarity with SFT Base Model.

        Args:
            model1: First model (e.g., trained model)
            model2: Second model (e.g., SFT base)
            dataloader: Data loader
            layer_ranges: Dict with keys like 'Layers 0-15', 'Layers 16-31'
            max_samples: Number of samples for CKA computation
            method: 'linear' or 'rbf'

        Returns:
            Dictionary mapping range name to CKA score
        """
        results = {}

        for range_name, layer_indices in layer_ranges.items():
            print(f"\nComputing CKA for {range_name}...")

            # Extract representations from both models
            features1 = self.extract_layer_representations(
                model1, dataloader, layer_indices, max_samples
            )
            features2 = self.extract_layer_representations(
                model2, dataloader, layer_indices, max_samples
            )

            # Compute average CKA across layers in range
            cka_scores = []
            for layer_idx in layer_indices:
                if layer_idx in features1 and layer_idx in features2:
                    X = features1[layer_idx].to(self.device)
                    Y = features2[layer_idx].to(self.device)

                    if method == "linear":
                        cka = self.linear_CKA(X, Y)
                    else:
                        cka = self.rbf_CKA(X, Y)

                    cka_scores.append(cka)
                    print(f"  Layer {layer_idx}: CKA = {cka:.4f}")

            avg_cka = np.mean(cka_scores) if len(cka_scores) > 0 else 0.0
            results[range_name] = avg_cka
            print(f"  Average CKA for {range_name}: {avg_cka:.4f}")

        return results

    def _extract_layer_index(self, module_name: str) -> int:
        """Extract layer index from module name."""
        parts = module_name.split('.')
        for i, part in enumerate(parts):
            if part in ['layers', 'h', 'blocks']:
                if i + 1 < len(parts) and parts[i + 1].isdigit():
                    return int(parts[i + 1])
        return None
