"""Diagnose Qronos statistics computation.

This script analyzes the Qronos statistics (Sig_X, Sig_hX, Sig_X_hX) to check for:
1. Correlation between X and X_hat (how different are quantized activations?)
2. Condition number of Sig_hX (numerical stability)
3. How much W* differs from W
4. Symmetry properties (Sig_X, Sig_hX should be symmetric)
5. Whether the Qronos loss formula is correct

Usage:
    python scripts/diagnose_qronos_stats.py --qronos_run <path_to_qronos_run>
"""

import argparse
import os
import pickle
import sys
from pathlib import Path

import torch

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))


def load_qronos_stats(run_dir: str, layer_id: int, weight: str):
    """Load Qronos stats from a run directory."""
    # Try different path patterns
    patterns = [
        f"{run_dir}/qronos_stats/layers.{layer_id}.attention.{weight}.pkl",
        f"{run_dir}/qronos_stats/layers.{layer_id}.feed_forward.{weight}.pkl",
    ]

    for path in patterns:
        if os.path.exists(path):
            with open(path, "rb") as f:
                data = pickle.load(f)
            return data

    raise FileNotFoundError(f"Could not find stats for layer {layer_id}.{weight}")


def analyze_stats(Sig_X, Sig_hX, Sig_X_hX, layer_id, weight):
    """Analyze a single layer's statistics."""
    n = Sig_X.shape[0]

    print(f"\n{'='*60}")
    print(f"Layer {layer_id}.{weight} (n={n})")
    print(f"{'='*60}")

    # 1. Check symmetry
    sym_X = torch.allclose(Sig_X, Sig_X.T, rtol=1e-5, atol=1e-8)
    sym_hX = torch.allclose(Sig_hX, Sig_hX.T, rtol=1e-5, atol=1e-8)
    asym_X_hX = torch.norm(Sig_X_hX - Sig_X_hX.T) / torch.norm(Sig_X_hX)

    print(f"\n1. Symmetry:")
    print(f"   Sig_X symmetric:   {sym_X}")
    print(f"   Sig_hX symmetric:  {sym_hX}")
    print(f"   Sig_X_hX asymmetry: {asym_X_hX:.6f} (0 = symmetric)")

    # 2. Trace and scale
    tr_X = torch.trace(Sig_X).item()
    tr_hX = torch.trace(Sig_hX).item()
    tr_X_hX = torch.trace(Sig_X_hX).item()

    print(f"\n2. Traces (scale):")
    print(f"   trace(Sig_X):    {tr_X:.6e}")
    print(f"   trace(Sig_hX):   {tr_hX:.6e}")
    print(f"   trace(Sig_X_hX): {tr_X_hX:.6e}")
    print(f"   ratio tr(Sig_hX)/tr(Sig_X): {tr_hX/tr_X:.4f}")

    # 3. Correlation between X and X_hat
    # correlation = tr(Sig_X_hX) / sqrt(tr(Sig_X) * tr(Sig_hX))
    correlation = tr_X_hX / (tr_X * tr_hX) ** 0.5

    # Alternative: Frobenius correlation
    # This is more robust to different eigenvalue distributions
    frob_corr = torch.norm(Sig_X_hX, 'fro') / (torch.norm(Sig_X, 'fro') * torch.norm(Sig_hX, 'fro')) ** 0.5

    print(f"\n3. Correlation X vs X_hat:")
    print(f"   Trace correlation:    {correlation:.6f}")
    print(f"   Frobenius correlation: {frob_corr.item():.6f}")
    print(f"   (1.0 = X and X_hat identical, lower = more different)")

    # 4. Condition number of Sig_hX
    try:
        eigvals = torch.linalg.eigvalsh(Sig_hX)
        eigvals_pos = eigvals[eigvals > 0]
        if len(eigvals_pos) > 0:
            cond_number = eigvals_pos.max() / eigvals_pos.min()
            n_neg_eig = (eigvals <= 0).sum().item()
        else:
            cond_number = float('inf')
            n_neg_eig = len(eigvals)
    except:
        cond_number = float('inf')
        n_neg_eig = -1

    print(f"\n4. Condition of Sig_hX:")
    print(f"   Condition number: {cond_number:.2e}")
    print(f"   Negative eigenvalues: {n_neg_eig}")

    # 5. Eigenvalue spectrum
    try:
        eig_X = torch.linalg.eigvalsh(Sig_X)
        eig_hX = torch.linalg.eigvalsh(Sig_hX)

        print(f"\n5. Eigenvalue spectrum:")
        print(f"   Sig_X:  min={eig_X.min():.2e}, max={eig_X.max():.2e}, median={eig_X.median():.2e}")
        print(f"   Sig_hX: min={eig_hX.min():.2e}, max={eig_hX.max():.2e}, median={eig_hX.median():.2e}")
    except Exception as e:
        print(f"\n5. Eigenvalue spectrum: failed ({e})")

    # 6. Check diagonal dominance
    diag_X = torch.diag(Sig_X)
    diag_hX = torch.diag(Sig_hX)
    diag_X_hX = torch.diag(Sig_X_hX)

    print(f"\n6. Diagonal statistics:")
    print(f"   Sig_X diag:    mean={diag_X.mean():.2e}, std={diag_X.std():.2e}")
    print(f"   Sig_hX diag:   mean={diag_hX.mean():.2e}, std={diag_hX.std():.2e}")
    print(f"   Sig_X_hX diag: mean={diag_X_hX.mean():.2e}, std={diag_X_hX.std():.2e}")
    print(f"   diag(Sig_X_hX) / diag(Sig_hX): mean={((diag_X_hX / (diag_hX + 1e-10)).mean()):.4f}")

    # 7. Check if Sig_X_hX is close to Sig_hX (would mean X ≈ X_hat)
    diff_normalized = torch.norm(Sig_X_hX - Sig_hX, 'fro') / torch.norm(Sig_hX, 'fro')
    diff_to_X = torch.norm(Sig_X_hX - Sig_X, 'fro') / torch.norm(Sig_X, 'fro')

    print(f"\n7. Cross-covariance analysis:")
    print(f"   ||Sig_X_hX - Sig_hX|| / ||Sig_hX||: {diff_normalized:.4f}")
    print(f"   ||Sig_X_hX - Sig_X|| / ||Sig_X||:   {diff_to_X:.4f}")
    print(f"   (0 = identical, high = very different)")

    return {
        'layer_id': layer_id,
        'weight': weight,
        'n': n,
        'sym_X': sym_X,
        'sym_hX': sym_hX,
        'asym_X_hX': asym_X_hX.item(),
        'tr_X': tr_X,
        'tr_hX': tr_hX,
        'tr_X_hX': tr_X_hX,
        'correlation': correlation,
        'frob_corr': frob_corr.item(),
        'cond_number': cond_number.item() if hasattr(cond_number, 'item') else cond_number,
        'n_neg_eig': n_neg_eig,
        'diff_X_hX_to_hX': diff_normalized.item(),
        'diff_X_hX_to_X': diff_to_X.item(),
    }


def compute_w_star_analysis(W, Sig_X, Sig_hX, Sig_X_hX, percdamp=0.0001):
    """Analyze how W* differs from W."""
    n = Sig_hX.shape[0]

    # Add damping
    damp = percdamp * torch.mean(torch.diag(Sig_hX))
    Sig_hX_damped = Sig_hX + damp * torch.eye(n, device=Sig_hX.device, dtype=Sig_hX.dtype)

    # Compute W* = W @ Sig_X_hX @ Sig_hX^{-1}
    try:
        W_star = torch.linalg.solve(Sig_hX_damped.T, (W @ Sig_X_hX).T).T

        # Relative difference
        diff = torch.norm(W_star - W, 'fro') / torch.norm(W, 'fro')

        # Per-row difference
        row_diffs = torch.norm(W_star - W, dim=1) / torch.norm(W, dim=1)

        return {
            'W_star_diff_pct': diff.item() * 100,
            'W_star_row_diff_mean': row_diffs.mean().item() * 100,
            'W_star_row_diff_max': row_diffs.max().item() * 100,
            'W_star_row_diff_min': row_diffs.min().item() * 100,
        }
    except Exception as e:
        return {
            'W_star_diff_pct': float('nan'),
            'error': str(e),
        }


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--qronos_run', type=str, required=True,
                        help='Path to Qronos quantization run directory')
    parser.add_argument('--layers', type=str, default='0,1,2,5,10,15',
                        help='Comma-separated layer IDs to analyze')
    parser.add_argument('--weights', type=str, default='wq,wk,wv,wo',
                        help='Comma-separated weight types')
    args = parser.parse_args()

    layer_ids = [int(x) for x in args.layers.split(',')]
    weights = [w.strip() for w in args.weights.split(',')]

    print("=" * 70)
    print("Qronos Statistics Diagnostic")
    print("=" * 70)
    print(f"Run directory: {args.qronos_run}")
    print(f"Layers: {layer_ids}")
    print(f"Weights: {weights}")

    results = []

    for layer_id in layer_ids:
        for weight in weights:
            try:
                data = load_qronos_stats(args.qronos_run, layer_id, weight)

                Sig_X = data['Sig_X'].double()
                Sig_hX = data['Sig_hX'].double()
                Sig_X_hX = data['Sig_X_hX'].double()

                result = analyze_stats(Sig_X, Sig_hX, Sig_X_hX, layer_id, weight)
                results.append(result)

            except FileNotFoundError as e:
                print(f"\nSkipping layer {layer_id}.{weight}: {e}")
            except Exception as e:
                print(f"\nError analyzing layer {layer_id}.{weight}: {e}")
                import traceback
                traceback.print_exc()

    # Summary
    if results:
        print("\n" + "=" * 70)
        print("SUMMARY")
        print("=" * 70)

        print(f"\n{'Layer':<12} {'Corr':<8} {'Cond#':<12} {'||ΣXX̂-ΣX̂||':<12} {'Asym':<8}")
        print("-" * 60)

        for r in results:
            print(f"{r['layer_id']}.{r['weight']:<8} "
                  f"{r['correlation']:.4f}  "
                  f"{r['cond_number']:<12.2e} "
                  f"{r['diff_X_hX_to_hX']:<12.4f} "
                  f"{r['asym_X_hX']:.4f}")

        # Averages
        avg_corr = sum(r['correlation'] for r in results) / len(results)
        avg_cond = sum(r['cond_number'] for r in results if r['cond_number'] < 1e15) / len([r for r in results if r['cond_number'] < 1e15])

        print("-" * 60)
        print(f"Average correlation: {avg_corr:.4f}")
        print(f"Average condition #: {avg_cond:.2e}")

        # Red flags
        print("\n" + "=" * 70)
        print("RED FLAGS")
        print("=" * 70)

        low_corr = [r for r in results if r['correlation'] < 0.9]
        if low_corr:
            print(f"\nLow correlation (< 0.9) - X and X_hat are very different:")
            for r in low_corr:
                print(f"  Layer {r['layer_id']}.{r['weight']}: corr={r['correlation']:.4f}")

        high_cond = [r for r in results if r['cond_number'] > 1e10]
        if high_cond:
            print(f"\nHigh condition number (> 1e10) - numerical instability:")
            for r in high_cond:
                print(f"  Layer {r['layer_id']}.{r['weight']}: cond={r['cond_number']:.2e}")

        asymmetric = [r for r in results if r['asym_X_hX'] > 0.1]
        if asymmetric:
            print(f"\nHigh asymmetry in Sig_X_hX (> 0.1):")
            for r in asymmetric:
                print(f"  Layer {r['layer_id']}.{r['weight']}: asym={r['asym_X_hX']:.4f}")

        if not low_corr and not high_cond and not asymmetric:
            print("\nNo obvious red flags detected.")


if __name__ == '__main__':
    main()
