
import sys
from typing import List, Tuple

import torch
import torch.nn as nn


# =============================================================================
# Vanilla OBS (singleton columns) - dependency-free reference
# =============================================================================

def _stable_cholesky_inv(A: torch.Tensor) -> torch.Tensor:
    """Return inv(A) using Cholesky with small jitter. Fallback to pinv if needed."""
    A = 0.5 * (A + A.T)
    mean_diag = float(torch.diagonal(A).abs().mean().item()) if A.numel() > 0 else 1.0
    jitter = max(1e-12 * max(1.0, mean_diag), 1e-12)
    eye = torch.eye(A.shape[0], device=A.device, dtype=A.dtype)
    for _ in range(6):
        try:
            L = torch.linalg.cholesky(A + jitter * eye)
            return torch.cholesky_inverse(L, upper=False)
        except RuntimeError:
            jitter *= 10.0
    return torch.linalg.pinv(A)


def build_H_and_G(X: torch.Tensor, lam: float = 1e-2, *, lam_is_relative: bool = True, x_dtype=torch.bfloat16) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Compute H = (X^T X)/N and G = (H + λI)^{-1}.
    - By default, λ is *relative*: λ = lam * mean(diag(H)).
    - Mimic pruner's Linear preprocessor by casting X -> bfloat16 before forming H.
    """
    X = X.to(x_dtype)
    N = max(1, int(X.shape[0]))
    H = (X.T @ X) / float(N)
    H = 0.5 * (H + H.T)
    H = H.to(torch.float64)
    if lam_is_relative:
        lam_eff = float(torch.diagonal(H).mean().item()) * float(lam)
    else:
        lam_eff = float(lam)
    G = _stable_cholesky_inv(H + lam_eff * torch.eye(H.shape[0], dtype=H.dtype, device=H.device))
    return H, G


def obs_scores_singleton(W: torch.Tensor, G_local: torch.Tensor, alive: List[int]) -> torch.Tensor:
    """
    OBS score for block_size=1 (columns in `alive`) using the *inverse Hessian*:
        s(j) = 0.5 * ||W[:, j]||_2^2 / G_local[j_local, j_local]
    where j_local is the index of global column j within the current alive set.
    """
    Wf = W.to(torch.float64)
    scores = []
    for j_local, j_global in enumerate(alive):
        denom = float(G_local[j_local, j_local].item())
        denom = denom if abs(denom) > 1e-30 else 1e-30
        s = 0.5 * (Wf[:, j_global] @ Wf[:, j_global]) / denom
        scores.append(s)
    return torch.tensor(scores, dtype=torch.float64, device=W.device)


def obs_update_singleton_mapped(W: torch.Tensor, G_local: torch.Tensor, alive: List[int], j_local: int) -> None:
    """
    Exact OBS update when current inverse is over the active coordinates indexed by `alive`.
        Δw = -(w_j / G[j,j]) * G[:, j]
    """
    G = G_local.to(torch.float64)
    W64 = W.to(torch.float64)

    j_global = alive[j_local]
    g_j_local = G[:, j_local]                           # (|alive|,)
    denom = float(G[j_local, j_local].item())
    denom = denom if abs(denom) > 1e-30 else 1e-30

    # Scatter G[:, j_local] into full-d vector to align with W columns
    d_full = W.shape[1]
    g_full = torch.zeros(d_full, dtype=torch.float64, device=W.device)
    g_full[alive] = g_j_local

    alpha = -W64[:, j_global] / denom                   # (out,)
    W64 += alpha.unsqueeze(1) * g_full.unsqueeze(0)     # (out,d)
    W.copy_(W64.to(W.dtype))


def schur_downdate_inverse_singleton_local(G_local: torch.Tensor, j_local: int) -> torch.Tensor:
    """
    Rank-1 Schur downdate of inverse when removing the local column j_local:
        G'_RR = G_RR - G_RP (1/G_PP) G_PR, with P={j_local}.
    Returns the new inverse over survivors R (same coordinate order as old `alive` minus j_local).
    """
    G = G_local.to(torch.float64)
    d = G.shape[0]
    idx_all = torch.arange(d, device=G.device)
    R = idx_all[idx_all != j_local]
    Gj = float(G[j_local, j_local].item())
    Gj = Gj if abs(Gj) > 1e-30 else 1e-30
    G_RP = G[R, j_local:j_local+1]           # (d-1,1)
    G_PR = G[j_local:j_local+1, R]           # (1,d-1)
    G_RR = G[R][:, R].clone()                # (d-1,d-1)
    G_new = G_RR - (G_RP @ G_PR) / Gj
    return G_new.to(G_local.dtype)


def greedy_obs_prune_linear(layer: nn.Linear, X: torch.Tensor, steps: int, lam: float = 1e-2) -> Tuple[List[int], torch.Tensor]:
    """
    Greedy K-step vanilla OBS for a Linear(out,d) with bias=False and block_size=1.
    Scores use G (inverse Hessian): s_j = 0.5 * ||w_j||^2 / G_jj.
    Returns (pruned_global_order, final_weight_copy).
    """
    assert layer.bias is None, "This helper assumes bias=False"
    W = layer.weight.data
    d = W.shape[1]
    _, G = build_H_and_G(X, lam=lam, lam_is_relative=True, x_dtype=torch.bfloat16)  # mimic pruner calibration
    alive = list(range(d))
    pruned = []

    for _ in range(steps):
        scores = obs_scores_singleton(W, G, alive)      # aligned with `alive`
        i_min = int(torch.argmin(scores).item())        # local index in alive
        j_global = alive[i_min]
        pruned.append(j_global)

        obs_update_singleton_mapped(W, G, alive, i_min) # true OBS update
        W[:, j_global] = 0                              # zero pruned column
        G = schur_downdate_inverse_singleton_local(G, i_min)  # update inverse
        alive.pop(i_min)

    return pruned, layer.weight.detach().clone()


def joint_obs_prune_linear(layer: nn.Linear, X: torch.Tensor, P: List[int], lam: float = 1e-2) -> torch.Tensor:
    """
    Exact multi-block OBS for any set P of global column indices:
        ΔW_R = - (W_P @ (G_PP^{-1})^T) @ G_RP^T,  then zero W[:,P].
    Returns a copy of the final weights.
    """
    assert layer.bias is None, "This helper assumes bias=False"
    W = layer.weight.data
    d = W.shape[1]
    _, G = build_H_and_G(X, lam=lam, lam_is_relative=True, x_dtype=torch.bfloat16)

    device = W.device
    dtypeW = torch.float64
    all_idx = torch.arange(d, device=device)
    mask = torch.ones(d, dtype=torch.bool, device=device)
    mask[P] = False
    R = all_idx[mask]

    W_P = W[:, P].to(dtypeW)                 # (out, |P|)
    G_PP = G[P][:, P].to(dtypeW)             # (|P|, |P|)
    try:
        L = torch.linalg.cholesky(0.5*(G_PP+G_PP.T) + 1e-12*torch.eye(G_PP.shape[0], device=device, dtype=dtypeW))
        Vt = torch.cholesky_solve(W_P.T.contiguous(), L, upper=False)   # (|P|, out)
        v_rows = Vt.T.contiguous()                                      # (out, |P|)
    except RuntimeError:
        v_rows = (torch.linalg.pinv(G_PP) @ W_P.T).T

    G_RP = G[R][:, P].to(dtypeW)           # (|R|, |P|)
    delta = - v_rows @ G_RP.T              # (out, |R|)
    W[:, R] += delta.to(W.dtype)
    W[:, P] = 0

    return layer.weight.detach().clone()


# =============================================================================
# Helper to prune K blocks via repeated certify_batch calls
# =============================================================================

def prune_k_via_certifier(pr, K: int) -> None:
    """
    Repeatedly calls certify_batch and applies the joint OBS update
    until exactly K blocks are pruned (or candidates run out).
    Mutates `pr.layer.weight` via pr.apply_joint_update_and_downdate.
    """
    while len(pr.get_pruned_blocks()) < K and len(pr.get_selected_blocks()) > 0:
        # Recompute scores on the current panel
        scores = pr.importance_all(return_tensor=True)
        cand = torch.arange(len(pr.get_selected_blocks()))

        # Ask only for the remaining budget this round
        need = min(K - len(pr.get_pruned_blocks()), len(pr.get_selected_blocks()))
        if need <= 0:
            break

        # Disable early-stops to *force* progress this round
        J_loc = pr.certify_batch(
            scores,
            candidates=cand,
            max_try=need,
            budget=need,
            rel_worsen_tol=1e9,    # no worsen early-stop
            min_chol_diag=0.0,     # no conditioning early-stop
            time_budget_ms=None,   # no time early-stop
            device="cpu",
        )

        if not J_loc:
            # nothing certified; bail
            break

        pr.apply_joint_update_and_downdate(torch.tensor(J_loc, dtype=torch.long))


# =============================================================================
# Tests (pytest-compatible)
# =============================================================================

def _setup_linear_and_data(d=20, out=20, n=400, seed=0):
    g = torch.Generator().manual_seed(seed)
    layer = nn.Linear(d, out, bias=False)
    X = torch.randn(n, d, generator=g)
    return layer, X


def test_vanilla_obs_greedy_equals_joint():
    torch.manual_seed(0)
    d, out, n, k = 20, 20, 400, 10
    lam = 1e-2

    Lg, X = _setup_linear_and_data(d, out, n, seed=123)
    Lj = nn.Linear(d, out, bias=False)
    with torch.no_grad():
        Lj.weight.copy_(Lg.weight)

    order_g, Wg = greedy_obs_prune_linear(Lg, X, steps=k, lam=lam)
    Wj = joint_obs_prune_linear(Lj, X, P=order_g, lam=lam)

    diff = (Wg - Wj).abs()
    assert diff.max().item() < 1e-5
    assert diff.norm().item() < 1e-4


def test_hybrid_certify_joint_matches_vanilla_greedy():
    # Try multiple import paths; skip if pruner not found
    try:
        from prune.hybrid_obs_pruner_certified import HybridOBSLinearPruner  # type: ignore
    except Exception:
        try:
            from hybrid_obs_pruner_certified import HybridOBSLinearPruner  # type: ignore
        except Exception:
            try:
                from prune.obs_hybrid_pruner import HybridOBSLinearPruner  # type: ignore
            except Exception:
                try:
                    from obs_hybrid_pruner import HybridOBSLinearPruner  # type: ignore
                except Exception:
                    import pytest  # type: ignore
                    pytest.skip("HybridOBSLinearPruner not importable; skipping integration test.")
                    return
    print('nfsdjnfsdgndfgsmdfdfnskfdjksfgnsjdkgnksgsdkjgnskgnkssnksgnks')
    torch.manual_seed(1)
    d, out, n, k = 20, 20, 400, 10
    lam = 1e-2

    # layer for vanilla greedy
    L_van, X = _setup_linear_and_data(d, out, n, seed=7)
    # identical layer for pruner path
    L_prn = nn.Linear(d, out, bias=False)
    with torch.no_grad():
        L_prn.weight.copy_(L_van.weight)

    # Vanilla greedy K-step
    order_g, Wg = greedy_obs_prune_linear(L_van, X, steps=k, lam=lam)

    # Setup user's pruner
    pr = HybridOBSLinearPruner(L_prn, block_size=1, device="cpu")
    pr.set_selected_blocks(list(range(d)))
    pr.add_batch(X)               # accum H ≈ X^T X (computed in bf16 inside the class)
    pr.finalize_calibration()     # builds cols, G_CC, H_panel

    # Loop certify+apply until exactly k are pruned (or panel exhausted)
    prune_k_via_certifier(pr, K=k)

    # Compare sets (order may differ)
    set_van = set(order_g)
    set_prn = set(pr.get_pruned_blocks()[:k])
    assert set_van == set_prn, f"Different pruned sets: vanilla={sorted(set_van)}, pruner={sorted(set_prn)}"

    # Compare final weights
    Wp = L_prn.weight.detach().clone()
    diff = (Wg - Wp).abs()
    assert diff.max().item() < 1e-3, f"max diff {diff.max().item():.3e}"
    assert diff.norm().item() < 1e-2, f"l2 diff {diff.norm().item():.3e}"

def test_debug_compare_obs_vs_certifier_report(capfd):
    """
    Diagnostic test:
    - prints the pruned orders (vanilla vs certifier)
    - asserts the pruned sets match
    - computes survivor-only weight diffs and asserts they're small
    """
    # ---- config (matches other tests) ----
    torch.manual_seed(1)
    d, out, n, k = 20, 20, 400, 10
    lam = 1e-2

    # ---- build layer + data ----
    L_van, X = _setup_linear_and_data(d, out, n, seed=7)
    L_prn = nn.Linear(d, out, bias=False)
    with torch.no_grad():
        L_prn.weight.copy_(L_van.weight)

    # ---- vanilla greedy (reference) ----
    order_g, Wg = greedy_obs_prune_linear(L_van, X, steps=k, lam=lam)

    # ---- pruner path (fallback imports) ----
    try:
        from prune.hybrid_obs_pruner_certified import HybridOBSLinearPruner  # type: ignore
    except Exception:
        try:
            from hybrid_obs_pruner_certified import HybridOBSLinearPruner  # type: ignore
        except Exception:
            try:
                from prune.obs_hybrid_pruner import HybridOBSLinearPruner  # type: ignore
            except Exception:
                from obs_hybrid_pruner import HybridOBSLinearPruner  # type: ignore

    pr = HybridOBSLinearPruner(L_prn, block_size=1, device="cpu")
    pr.set_selected_blocks(list(range(d)))
    pr.add_batch(X)
    pr.finalize_calibration()

    # prune exactly k via repeated certify_batch
    prune_k_via_certifier(pr, K=k)
    order_c = pr.get_pruned_blocks()[:k]

    # ---- compare orders / sets ----
    set_g, set_c = set(order_g), set(order_c)
    survivors = sorted(set(range(d)) - set_g)

    print("\n==== PRUNED BLOCKS (ORDERS) ====")
    print("Vanilla greedy order:     ", order_g)
    print("Certifier+joint order:    ", order_c)
    print("Same set?                 ", set_g == set_c)
    if set_g != set_c:
        print("Only in vanilla:          ", sorted(set_g - set_c))
        print("Only in certifier:        ", sorted(set_c - set_g))

    # Strong assertion: sets should match
    assert set_g == set_c, "Pruned sets differ between vanilla and certifier."

    # ---- survivor-only weight diffs ----
    Wp = L_prn.weight.detach().clone()
    if survivors:
        Gs = Wg[:, survivors]
        Ps = Wp[:, survivors]
        D = (Gs - Ps).abs()
        max_abs = float(D.max().item())
        mean_abs = float(D.mean().item())
        l2 = float(torch.linalg.norm((Gs - Ps).reshape(-1)).item())

        col_l2 = torch.linalg.norm(Gs - Ps, dim=0)
        topk = min(10, col_l2.numel())
        vals, idx = torch.topk(col_l2, k=topk, largest=True)
        top_cols = [(int(survivors[int(i.item())]), float(v.item())) for v, i in zip(vals, idx)]

        print("\n==== WEIGHT DIFFS (SURVIVORS ONLY) ====")
        print(f"Survivor count:           {len(survivors)}")
        print(f"Max |ΔW|:                 {max_abs:.6e}")
        print(f"Mean |ΔW|:                {mean_abs:.6e}")
        print(f"L2 ||ΔW||:                {l2:.6e}")
        print("\nTop survivor columns by L2 diff (global_col, l2):")
        for gc, v in top_cols:
            print(f"  col {gc:>2}: {v:.6e}")

        # Tolerances: allow tiny drift from bf16 numerics
        assert max_abs < 1e-3
        assert l2 < 1e-2
    else:
        print("\nNo survivors (k == d), skipping survivor diff checks.")

    # flush prints for pytest -q (they’ll show with `-s`)
    capfd.readouterr()

# =============================================================================
# Main runner
# =============================================================================

def _run_smoke():
    print("Running smoke checks...")
    torch.manual_seed(7)
    d, out, n, k = 20, 20, 400, 10
    lam = 1e-2

    Lg, X = _setup_linear_and_data(d, out, n, seed=7)
    Lj = nn.Linear(d, out, bias=False)
    with torch.no_grad():
        Lj.weight.copy_(Lg.weight)

    order_g, Wg = greedy_obs_prune_linear(Lg, X, steps=k, lam=lam)
    Wj = joint_obs_prune_linear(Lj, X, P=order_g, lam=lam)

    max_abs = float((Wg - Wj).abs().max().item())
    l2 = float(torch.norm((Wg - Wj).flatten()).item())

    print("Greedy order:", order_g)
    print(f"Max |ΔW| (greedy vs joint with same P): {max_abs:.3e}")
    print(f"L2(ΔW): {l2:.3e}")

    # Integration (best-effort)
    try:
        test_hybrid_certify_joint_matches_vanilla_greedy()
        print("Integration vs. certify_batch: PASS")
    except Exception as e:
        print("Integration vs. certify_batch: SKIPPED or FAILED:", repr(e))

    debug_compare_obs_vs_certifier(d=d, out=out, n=n, k=k, lam=lam, seed_data=7, seed_torch=1)

def debug_compare_obs_vs_certifier(*, d=20, out=20, n=400, k=10, lam=1e-2, seed_data=7, seed_torch=1):
    """
    Run vanilla greedy OBS vs. your certifier loop and print:
    - orders picked
    - pruned sets
    - weight diffs restricted to surviving columns
    - per-column L2 diff on survivors (top 10)
    """
    import torch
    torch.manual_seed(seed_torch)

    # ----- build layer + data -----
    L_van, X = _setup_linear_and_data(d, out, n, seed=seed_data)
    L_prn = nn.Linear(d, out, bias=False)
    with torch.no_grad():
        L_prn.weight.copy_(L_van.weight)

    # ----- vanilla greedy (reference) -----
    order_g, Wg = greedy_obs_prune_linear(L_van, X, steps=k, lam=lam)

    # ----- your pruner path -----
    # import with fallbacks
    try:
        from prune.hybrid_obs_pruner_certified import HybridOBSLinearPruner  # type: ignore
    except Exception:
        try:
            from hybrid_obs_pruner_certified import HybridOBSLinearPruner  # type: ignore
        except Exception:
            try:
                from prune.obs_hybrid_pruner import HybridOBSLinearPruner  # type: ignore
            except Exception:
                from obs_hybrid_pruner import HybridOBSLinearPruner  # type: ignore

    pr = HybridOBSLinearPruner(L_prn, block_size=1, device="cpu")
    pr.set_selected_blocks(list(range(d)))
    pr.add_batch(X)
    pr.finalize_calibration()

    prune_k_via_certifier(pr, K=k)
    order_c = pr.get_pruned_blocks()[:k]

    # ----- report orders and sets -----
    set_g, set_c = set(order_g), set(order_c)
    survivors = sorted(set(range(d)) - set_g)  # same set should equal with set_c

    print("\n==== PRUNED BLOCKS (ORDERS) ====")
    print("Vanilla greedy order:     ", order_g)
    print("Certifier+joint order:    ", order_c)
    print("Same set?                 ", set_g == set_c)
    if set_g != set_c:
        print("Only in vanilla:          ", sorted(set_g - set_c))
        print("Only in certifier:        ", sorted(set_c - set_g))

    # ----- weight differences on survivors only -----
    Wg_final = Wg.clone()
    Wp_final = L_prn.weight.detach().clone()

    if survivors:
        # restrict to surviving columns
        Gs = Wg_final[:, survivors]
        Ps = Wp_final[:, survivors]
        D  = (Gs - Ps).abs()
        max_abs = float(D.max().item())
        mean_abs = float(D.mean().item())
        l2 = float(torch.linalg.norm((Gs - Ps).reshape(-1)).item())

        # per-survivor-column L2 diffs
        col_l2 = torch.linalg.norm(Gs - Ps, dim=0)  # (|survivors|,)
        topk = min(10, col_l2.numel())
        vals, idx = torch.topk(col_l2, k=topk, largest=True)
        top_cols = [(int(survivors[int(i.item())]), float(v.item())) for v, i in zip(vals, idx)]

        print("\n==== WEIGHT DIFFS (SURVIVORS ONLY) ====")
        print(f"Survivor count:           {len(survivors)}")
        print(f"Max |ΔW|:                 {max_abs:.6e}")
        print(f"Mean |ΔW|:                {mean_abs:.6e}")
        print(f"L2 ||ΔW||:                {l2:.6e}")

        print("\nTop survivor columns by L2 diff (global_col, l2):")
        for gc, v in top_cols:
            print(f"  col {gc:>2}: {v:.6e}")
    else:
        print("\nNo survivors (k == d), skipping survivor diff report.")

if __name__ == "__main__":
    # If invoked directly:
    # - If pytest is installed, run the tests via pytest
    # - Otherwise, run a simple smoke test
    try:
        import pytest  # type: ignore
        sys.exit(pytest.main([__file__]))
    except Exception:
        _run_smoke()
    
    debug_compare_obs_vs_certifier(d=20, out=20, n=400, k=10, lam=1e-2, seed_data=7, seed_torch=1)
