"""Vector analysis utilities for Q/K distributions and visualization.

Includes Mahalanobis distance computation, Gaussian fitting, and PCA-2D plots
adapted into a reusable module.
"""

from __future__ import annotations

import numpy as np
import torch
import matplotlib.pyplot as plt


def fit_gaussian_from_K(K: torch.Tensor, shrink: float = 1e-3, diagonal: bool = True, eps: float = 1e-6):
    K = K.to(torch.float32)
    mu = K.mean(dim=0, keepdim=True)
    X = K - mu
    if diagonal:
        var = X.var(dim=0, unbiased=True)
        inv_std = 1.0 / torch.sqrt(var + eps)
        return mu, inv_std, None
    else:
        N, d = X.shape
        cov = (X.T @ X) / (N - 1 + 1e-9)
        tr = torch.trace(cov) / d
        cov = (1 - shrink) * cov + shrink * tr * torch.eye(d, device=K.device, dtype=K.dtype)
        chol = torch.linalg.cholesky(cov)
        return mu, None, chol


def mahalanobis(X: torch.Tensor, mu: torch.Tensor, inv_std=None, chol=None):
    X = X.to(torch.float32)
    mu = mu.to(torch.float32)
    D = X - mu
    if inv_std is not None:
        Z = D * inv_std
        md2 = (Z * Z).sum(dim=-1)
    else:
        z = torch.cholesky_solve(D.T, chol)
        md2 = (D.T * z).sum(dim=0)
    return torch.sqrt(torch.clamp(md2, min=0.0))


def dist_QK_hist_from_states(
    key_states: torch.Tensor,
    query_states: torch.Tensor,
    q_reduce: str = "mean",
    diagonal: bool = True,
    device_for_fit: str = "cpu",
):
    assert key_states.dim() == 4 and query_states.dim() == 4
    B, n_kv, S, d = key_states.shape
    _, n_q, S2, d2 = query_states.shape
    assert B == 1 and S == S2 and d == d2
    assert n_q % n_kv == 0
    kv_group = n_q // n_kv

    all_dK, all_dQ, all_dQQ = [], [], []
    K_all = key_states[0].to(device_for_fit).to(torch.float32).contiguous()
    Q_all = query_states[0].to(device_for_fit).to(torch.float32).contiguous()

    for h in range(n_kv):
        K_h = K_all[h]
        q_start = h * kv_group
        q_end = (h + 1) * kv_group
        Q_grp = Q_all[q_start:q_end]
        if q_reduce == "mean":
            Q_h = Q_grp.mean(dim=0)
        elif q_reduce == "stack":
            Q_h = Q_grp.reshape(-1, d)
        else:
            raise ValueError("q_reduce must be 'mean' or 'stack'")
        mu_K, inv_std_K, chol_K = fit_gaussian_from_K(K_h, diagonal=diagonal)
        mu_Q, inv_std_Q, chol_Q = fit_gaussian_from_K(Q_h, diagonal=diagonal)
        dK = mahalanobis(K_h, mu_K, inv_std_K, chol_K).cpu().numpy()
        dQ = mahalanobis(Q_h, mu_K, inv_std_K, chol_K).cpu().numpy()
        dQQ = mahalanobis(Q_h, mu_Q, inv_std_Q, chol_Q).cpu().numpy()
        all_dK.append(dK)
        all_dQ.append(dQ)
        all_dQQ.append(dQQ)

    all_dK = np.concatenate(all_dK, axis=0)
    all_dQ = np.concatenate(all_dQ, axis=0)
    all_dQQ = np.concatenate(all_dQQ, axis=0)
    return all_dQ, all_dK, all_dQQ


def plot_QK_hist(dQ: np.ndarray, dK: np.ndarray, dQQ: np.ndarray | None = None, title: str = "Model-X (all heads)", bins: int = 40):
    if dQQ is None:
        lo = min(dQ.min(), dK.min())
        hi = max(dQ.max(), dK.max())
    else:
        lo = min(dQ.min(), dK.min(), dQQ.min())
        hi = max(dQ.max(), dK.max(), dQQ.max())
    fig, ax = plt.subplots()
    ax.hist(dK, bins=bins, range=(lo, hi), alpha=0.7, label="K to K", color="#e74c3c")
    ax.hist(dQ, bins=bins, range=(lo, hi), alpha=0.7, label="Q to K", color="#3498db")
    if dQQ is not None:
        ax.hist(dQQ, bins=bins, range=(lo, hi), alpha=0.7, label="Q to Q", color="#2ecc71")
    ax.set_xlabel("Mahalanobis Distance")
    ax.set_ylabel("Frequency")
    ax.set_title(title)
    ax.legend()
    fig.tight_layout()
    return fig


def plot_QK_pca2d(
    Q,
    K,
    title: str = "Model-X (all heads)",
    max_points_per_class: int = 5000,
    pca_sample_max: int = 20000,
):
    import numpy as _np
    import torch as _torch

    def _to_numpy(arr):
        if isinstance(arr, _torch.Tensor):
            return arr.detach().to(_torch.float32).cpu().numpy()
        if isinstance(arr, _np.ndarray):
            return arr.astype(_np.float32, copy=False)
        raise TypeError("Q/K must be torch.Tensor or np.ndarray")

    Q_np = _to_numpy(Q)
    K_np = _to_numpy(K)
    assert Q_np.ndim == 2 and K_np.ndim == 2 and Q_np.shape[1] == K_np.shape[1]

    rng = _np.random.default_rng(42)

    def _sample_rows(X: _np.ndarray, max_n: int) -> _np.ndarray:
        n = X.shape[0]
        if n <= max_n:
            return X
        idx = rng.choice(n, size=max_n, replace=False)
        return X[idx]

    half = max(1, pca_sample_max // 2)
    Q_pca = _sample_rows(Q_np, half)
    K_pca = _sample_rows(K_np, pca_sample_max - Q_pca.shape[0])
    X_pca = _np.vstack([Q_pca, K_pca])

    mean = X_pca.mean(axis=0, keepdims=True)
    Xc = X_pca - mean
    _U, _S, Vt = _np.linalg.svd(Xc, full_matrices=False)
    components = Vt[:2].T

    Q2d = (Q_np - mean) @ components
    K2d = (K_np - mean) @ components
    Q_plot = _sample_rows(Q2d, max_points_per_class)
    K_plot = _sample_rows(K2d, max_points_per_class)

    color_q = "#3498db"
    color_k = "#e74c3c"

    fig, ax = plt.subplots()
    sc_q = ax.scatter(Q_plot[:, 0], Q_plot[:, 1], s=6, alpha=0.3, color=color_q, label="Q")
    sc_k = ax.scatter(K_plot[:, 0], K_plot[:, 1], s=6, alpha=0.3, color=color_k, label="K")
    ax.set_xlabel("PC1")
    ax.set_ylabel("PC2")
    # No title per request
    leg = ax.legend(loc='lower left', frameon=False)
    for h in getattr(leg, 'legendHandles', []):
        try:
            h.set_alpha(0.8)
        except Exception:
            pass
        try:
            h.set_sizes([70])
        except Exception:
            pass
    ax.grid(True, linestyle="--", alpha=0.3)
    ax.set_aspect("auto")
    fig.tight_layout()
    return fig



