"""Global isotropy and local anisotropy metrics."""

import dataclasses
from typing import Dict, List, Optional, Tuple

import torch
import torch.nn.functional as F
import numpy as np


@dataclasses.dataclass
class GlobalIsotropyMetrics:
    """Metrics for global isotropy of embeddings."""

    eigenvalues: np.ndarray  # Sorted eigenvalues (descending)
    effective_dim: float  # (sum(lambda))^2 / sum(lambda^2)
    isotropy_score: float  # effective_dim / d
    participation_ratio: float  # 1 / sum((lambda_i / sum(lambda))^2)
    cosine_sim_mean: float  # Mean pairwise cosine similarity
    cosine_sim_std: float  # Std of pairwise cosine similarity
    embed_dim: int  # Embedding dimension


@dataclasses.dataclass
class LocalAnisotropyMetrics:
    """Metrics for local anisotropy of neighborhoods."""

    mean_anisotropy: float  # Mean of lambda_1 / sum(lambda) across neighborhoods
    std_anisotropy: float  # Std of anisotropy across neighborhoods
    anisotropy_values: np.ndarray  # Per-sample anisotropy values
    mean_top2_ratio: float  # Mean of (lambda_1 + lambda_2) / sum(lambda)
    mean_effective_dim: float  # Mean effective dimensionality of neighborhoods


class OnlineCovarianceEstimator:
    """Online estimation of covariance matrix using Welford's algorithm."""

    def __init__(self, dim: int, device: torch.device):
        self.dim = dim
        self.device = device
        self.n = 0
        self.mean = torch.zeros(dim, device=device, dtype=torch.float64)
        self.M2 = torch.zeros(dim, dim, device=device, dtype=torch.float64)

    def update(self, x: torch.Tensor):
        """Update with a batch of samples. x: (batch, dim)"""
        x = x.to(self.device, dtype=torch.float64)
        batch_size = x.shape[0]

        for i in range(batch_size):
            self.n += 1
            delta = x[i] - self.mean
            self.mean = self.mean + delta / self.n
            delta2 = x[i] - self.mean
            self.M2 = self.M2 + torch.outer(delta, delta2)

    def update_batch(self, x: torch.Tensor):
        """Batch update (more efficient for large batches)."""
        x = x.to(self.device, dtype=torch.float64)
        batch_size = x.shape[0]

        if self.n == 0:
            self.mean = x.mean(dim=0)
            self.n = batch_size
            centered = x - self.mean
            self.M2 = centered.T @ centered
        else:
            new_mean = x.mean(dim=0)
            new_n = self.n + batch_size
            delta = new_mean - self.mean
            self.mean = self.mean + delta * batch_size / new_n
            centered_old = x - new_mean
            self.M2 = self.M2 + centered_old.T @ centered_old + torch.outer(delta, delta) * self.n * batch_size / new_n
            self.n = new_n

    def get_covariance(self) -> torch.Tensor:
        """Return the covariance matrix."""
        if self.n < 2:
            return torch.zeros(self.dim, self.dim, device=self.device)
        return self.M2 / (self.n - 1)

    def get_mean(self) -> torch.Tensor:
        """Return the mean."""
        return self.mean.float()


def compute_eigenvalues(cov: torch.Tensor) -> np.ndarray:
    """Compute eigenvalues of covariance matrix, sorted descending."""
    # Use numpy for numerical stability
    cov_np = cov.cpu().numpy()
    eigenvalues = np.linalg.eigvalsh(cov_np)
    eigenvalues = np.sort(eigenvalues)[::-1]  # Sort descending
    eigenvalues = np.maximum(eigenvalues, 0)  # Ensure non-negative
    return eigenvalues


def compute_global_isotropy(
    embeddings: torch.Tensor,
    n_pairs_for_cosine: int = 10000,
) -> GlobalIsotropyMetrics:
    """
    Compute global isotropy metrics from embeddings.

    Args:
        embeddings: (N, D) tensor of embeddings
        n_pairs_for_cosine: Number of random pairs for cosine similarity distribution

    Returns:
        GlobalIsotropyMetrics dataclass
    """
    N, D = embeddings.shape
    device = embeddings.device

    # Compute covariance matrix
    mean = embeddings.mean(dim=0)
    centered = embeddings - mean
    cov = (centered.T @ centered) / (N - 1)

    # Compute eigenvalues
    eigenvalues = compute_eigenvalues(cov.float())

    # Effective dimensionality
    sum_eig = eigenvalues.sum()
    sum_eig_sq = (eigenvalues ** 2).sum()
    if sum_eig_sq > 0:
        effective_dim = (sum_eig ** 2) / sum_eig_sq
    else:
        effective_dim = 0.0

    # Isotropy score
    isotropy_score = effective_dim / D

    # Participation ratio
    if sum_eig > 0:
        normalized_eig = eigenvalues / sum_eig
        participation_ratio = 1.0 / (normalized_eig ** 2).sum()
    else:
        participation_ratio = 0.0

    # Cosine similarity distribution
    n_pairs = min(n_pairs_for_cosine, N * (N - 1) // 2)
    if n_pairs > 0 and N > 1:
        idx1 = torch.randint(0, N, (n_pairs,), device=device)
        idx2 = torch.randint(0, N, (n_pairs,), device=device)
        # Ensure different indices
        idx2 = (idx1 + 1 + torch.randint(0, N - 1, (n_pairs,), device=device)) % N
        cos_sims = F.cosine_similarity(embeddings[idx1], embeddings[idx2], dim=-1)
        cosine_sim_mean = cos_sims.mean().item()
        cosine_sim_std = cos_sims.std().item()
    else:
        cosine_sim_mean = 0.0
        cosine_sim_std = 0.0

    return GlobalIsotropyMetrics(
        eigenvalues=eigenvalues,
        effective_dim=effective_dim,
        isotropy_score=isotropy_score,
        participation_ratio=participation_ratio,
        cosine_sim_mean=cosine_sim_mean,
        cosine_sim_std=cosine_sim_std,
        embed_dim=D,
    )


def compute_local_anisotropy_single(
    neighborhood: torch.Tensor,
    debug: bool = False,
) -> Tuple[float, float, float]:
    """
    Compute local anisotropy for a single neighborhood.

    Args:
        neighborhood: (K, D) tensor of neighborhood embeddings
        debug: Print debug info

    Returns:
        anisotropy: lambda_1 / sum(lambda)
        top2_ratio: (lambda_1 + lambda_2) / sum(lambda)
        effective_dim: local effective dimensionality
    """
    K, D = neighborhood.shape

    if K < 2:
        return 0.0, 0.0, float(D)

    if debug:
        print(f"          [debug] neighborhood shape: {neighborhood.shape}, device: {neighborhood.device}", flush=True)

    # Center the neighborhood
    mean = neighborhood.mean(dim=0)
    centered = neighborhood - mean

    # Use Gram matrix (K×K) instead of covariance (D×D) when K < D
    # Eigenvalues of X @ X.T / (K-1) equal eigenvalues of X.T @ X / (K-1)
    # but K×K is much smaller than D×D when K=16 and D=512
    if K < D:
        gram = (centered @ centered.T) / (K - 1)  # K×K matrix
        if debug:
            print(f"          [debug] gram shape: {gram.shape}, computing eigenvalues...", flush=True)
        eigenvalues = compute_eigenvalues(gram.float())
    else:
        cov = (centered.T @ centered) / (K - 1)  # D×D matrix
        if debug:
            print(f"          [debug] cov shape: {cov.shape}, computing eigenvalues...", flush=True)
        eigenvalues = compute_eigenvalues(cov.float())

    if debug:
        print(f"          [debug] eigenvalues computed", flush=True)

    sum_eig = eigenvalues.sum()
    if sum_eig < 1e-10:
        return 0.0, 0.0, float(D)

    # Anisotropy: concentration in top eigenvalue
    anisotropy = eigenvalues[0] / sum_eig

    # Top-2 ratio
    top2_ratio = (eigenvalues[0] + eigenvalues[1]) / sum_eig if len(eigenvalues) > 1 else anisotropy

    # Local effective dimensionality
    sum_eig_sq = (eigenvalues ** 2).sum()
    if sum_eig_sq > 0:
        effective_dim = (sum_eig ** 2) / sum_eig_sq
    else:
        effective_dim = float(D)

    return anisotropy, top2_ratio, effective_dim


def compute_local_anisotropy_batch(
    anchor_embeddings: torch.Tensor,
    candidate_embeddings: torch.Tensor,
    k: int = 16,
    batch_size: int = 100,
    verbose: bool = True,
) -> LocalAnisotropyMetrics:
    """
    Compute local anisotropy for semantic neighborhoods.

    For each anchor, find top-K candidates by cosine similarity,
    then compute local anisotropy of that neighborhood.

    Args:
        anchor_embeddings: (N, D) tensor of anchor embeddings (e.g., images)
        candidate_embeddings: (M, D) tensor of candidate embeddings (e.g., texts)
        k: Number of neighbors to use
        batch_size: Batch size for processing
        verbose: Print progress

    Returns:
        LocalAnisotropyMetrics dataclass
    """
    N = anchor_embeddings.shape[0]
    device = anchor_embeddings.device

    anisotropy_values = []
    top2_ratios = []
    effective_dims = []

    # Process in batches
    for start in range(0, N, batch_size):
        if verbose and start % 100 == 0:
            print(f"        Semantic batch {start}/{N}...", flush=True)
        end = min(start + batch_size, N)
        batch_anchors = anchor_embeddings[start:end]

        # Compute similarities: (batch, M)
        sims = batch_anchors @ candidate_embeddings.T

        # Get top-K indices
        _, top_indices = torch.topk(sims, k=min(k, sims.shape[1]), dim=1)

        # Compute local anisotropy for each anchor
        for i in range(batch_anchors.shape[0]):
            neighborhood = candidate_embeddings[top_indices[i]]
            # Debug first iteration
            aniso, top2, eff_dim = compute_local_anisotropy_single(
                neighborhood, debug=(start == 0 and i == 0)
            )
            anisotropy_values.append(aniso)
            top2_ratios.append(top2)
            effective_dims.append(eff_dim)

    anisotropy_values = np.array(anisotropy_values)
    top2_ratios = np.array(top2_ratios)
    effective_dims = np.array(effective_dims)

    return LocalAnisotropyMetrics(
        mean_anisotropy=anisotropy_values.mean(),
        std_anisotropy=anisotropy_values.std(),
        anisotropy_values=anisotropy_values,
        mean_top2_ratio=top2_ratios.mean(),
        mean_effective_dim=effective_dims.mean(),
    )


def compute_local_anisotropy_random(
    candidate_embeddings: torch.Tensor,
    n_samples: int = 1000,
    k: int = 16,
    verbose: bool = True,
) -> LocalAnisotropyMetrics:
    """
    Compute local anisotropy for random neighborhoods (baseline).

    Args:
        candidate_embeddings: (M, D) tensor of embeddings
        n_samples: Number of random neighborhoods to sample
        k: Neighborhood size
        verbose: Print progress

    Returns:
        LocalAnisotropyMetrics dataclass
    """
    M = candidate_embeddings.shape[0]
    device = candidate_embeddings.device

    anisotropy_values = []
    top2_ratios = []
    effective_dims = []

    for i in range(n_samples):
        if verbose and i % 100 == 0:
            print(f"        Random neighborhood {i}/{n_samples}...", flush=True)
        # Random indices
        indices = torch.randperm(M, device=device)[:k]
        neighborhood = candidate_embeddings[indices]

        # Debug first iteration to diagnose hangs
        aniso, top2, eff_dim = compute_local_anisotropy_single(neighborhood, debug=(i == 0))
        anisotropy_values.append(aniso)
        top2_ratios.append(top2)
        effective_dims.append(eff_dim)

    anisotropy_values = np.array(anisotropy_values)
    top2_ratios = np.array(top2_ratios)
    effective_dims = np.array(effective_dims)

    return LocalAnisotropyMetrics(
        mean_anisotropy=anisotropy_values.mean(),
        std_anisotropy=anisotropy_values.std(),
        anisotropy_values=anisotropy_values,
        mean_top2_ratio=top2_ratios.mean(),
        mean_effective_dim=effective_dims.mean(),
    )


def compute_local_anisotropy_adversarial(
    candidate_sets: List[torch.Tensor],
    verbose: bool = True,
) -> LocalAnisotropyMetrics:
    """
    Compute local anisotropy for adversarial candidate sets.

    Args:
        candidate_sets: List of (K, D) tensors, each representing a candidate set
                       from adversarial benchmarks (e.g., SugarCrepe, Winoground)
        verbose: Print progress

    Returns:
        LocalAnisotropyMetrics dataclass
    """
    anisotropy_values = []
    top2_ratios = []
    effective_dims = []
    n_sets = len(candidate_sets)

    for i, candidates in enumerate(candidate_sets):
        if verbose and i % 100 == 0:
            print(f"        Adversarial {i}/{n_sets}...", flush=True)
        aniso, top2, eff_dim = compute_local_anisotropy_single(candidates)
        anisotropy_values.append(aniso)
        top2_ratios.append(top2)
        effective_dims.append(eff_dim)

    anisotropy_values = np.array(anisotropy_values)
    top2_ratios = np.array(top2_ratios)
    effective_dims = np.array(effective_dims)

    return LocalAnisotropyMetrics(
        mean_anisotropy=anisotropy_values.mean(),
        std_anisotropy=anisotropy_values.std(),
        anisotropy_values=anisotropy_values,
        mean_top2_ratio=top2_ratios.mean(),
        mean_effective_dim=effective_dims.mean(),
    )
