import copy
import torch
import pytest

from prune.hybrid_obs_pruner_certified import HybridOBSLinearPruner

torch.set_printoptions(precision=6, sci_mode=True)

def make_calibrated_linear(seed=0, out_features=5, in_features=12, block_size=3, n_samples=256):
    """
    Build a tiny Linear layer + pruner, select all full blocks, run calibration, finalize.
    Returns (layer, pruner).
    """
    g = torch.Generator().manual_seed(seed)
    layer = torch.nn.Linear(in_features, out_features, bias=False)
    with torch.no_grad():
        layer.weight.copy_(torch.randn(layer.weight.shape, generator=g, dtype=layer.weight.dtype) * 0.1)

    pruner = HybridOBSLinearPruner(layer, block_size=block_size)
    # Select only full blocks
    n_blocks_full = in_features // block_size
    pruner.set_selected_blocks(list(range(n_blocks_full)))

    # Calibration data (X: (N, d))
    X = torch.randn(n_samples, in_features, generator=g) * 0.5
    pruner.add_batch(X)
    pruner.finalize_calibration(min_damping=1e-3, max_damping=1e2)

    return layer, pruner


def alt_importance_via_solve(pruner: HybridOBSLinearPruner) -> torch.Tensor:
    """
    Alternate implementation of OBS scores that uses solve(G_pp, w) instead of cholesky_solve.
    Should match pruner.importance_all() very closely.
    """
    assert pruner.G_CC is not None and pruner.H_panel is not None
    device = pruner.G_CC.device
    B = len(pruner.selected_blocks)
    if B == 0:
        return torch.empty(0, device=device)

    k = pruner.block_size * getattr(pruner, "kernel_elems", 1)
    out_rows = pruner.layer.weight.shape[0]

    # Gather W for all blocks in the selected order -> (out, B*k) -> (B, out, k)
    W2d = pruner.layer.weight.flatten(1).to(torch.float32)
    cols_order = []
    for blk in pruner.selected_blocks:
        cols_order.extend(pruner.blocks_to_global_indices[blk])
    W_all = W2d[:, cols_order]
    W_blocks = W_all.view(out_rows, B, k).permute(1, 0, 2).contiguous()  # (B, out, k)

    # Diagonal block G_PP for each block: (B, k, k)
    G4 = pruner.G_CC.view(B, k, B, k)
    G_pp = G4[torch.arange(B, device=device), :, torch.arange(B, device=device), :].contiguous()

    # Solve (G_pp) v = w   (batched over B*out right-hand sides)
    G_batched = G_pp.repeat_interleave(out_rows, dim=0)          # (B*out, k, k)
    W_batched = W_blocks.reshape(B * out_rows, k, 1)             # (B*out, k, 1)
    try:
        V = torch.linalg.solve(G_batched, W_batched)             # (B*out, k, 1)
    except RuntimeError:
        V = torch.linalg.pinv(G_batched) @ W_batched

    V = torch.where(torch.isfinite(V), V, torch.zeros_like(V))
    Wb = W_batched.squeeze(-1)                                    # (B*out, k)
    Vb = V.squeeze(-1)                                            # (B*out, k)

    scores = 0.5 * (Wb * Vb).sum(dim=1)                           # (B*out,)
    scores = scores.view(B, out_rows).sum(dim=1).contiguous()     # (B,)
    return scores


def clone_pruner(pr: HybridOBSLinearPruner) -> HybridOBSLinearPruner:
    """
    Deep-clone a pruner + layer state so we can apply different prune paths
    starting from *identical* weights and (G_CC, H_panel, bookkeeping).
    """
    # Clone the layer separately
    layer2 = torch.nn.Linear(pr.layer.weight.shape[1], pr.layer.weight.shape[0], bias=False)
    with torch.no_grad():
        layer2.weight.copy_(pr.layer.weight.detach())

    pr2 = HybridOBSLinearPruner(layer2, pr.block_size, use_chunking=pr.use_chunking, chunk_size=pr.chunk_size)

    # Copy bookkeeping
    pr2.selected_blocks = list(pr.selected_blocks)
    pr2.blocks_to_global_indices = copy.deepcopy(pr.blocks_to_global_indices)
    pr2.blocks_to_local_indices  = copy.deepcopy(pr.blocks_to_local_indices)

    # Copy panels
    pr2.G_CC    = pr.G_CC.detach().clone()
    pr2.H_panel = pr.H_panel.detach().clone()
    pr2.cols    = pr.cols.detach().clone() if pr.cols is not None else None

    return pr2


@pytest.mark.parametrize("seed", [0, 1, 2])
def test_importance_matches_alt_solve(seed):
    _, pr = make_calibrated_linear(seed=seed, out_features=6, in_features=12, block_size=3, n_samples=256)

    scores_chol = pr.importance_all(return_tensor=True)
    scores_alt  = alt_importance_via_solve(pr)

    # They should be very close (allow tiny numeric drift)
    torch.testing.assert_close(scores_chol, scores_alt, rtol=1e-4, atol=1e-6)


@pytest.mark.parametrize("seed", [0, 3])
def test_joint_prune_equals_rankk_stream(seed):
    # Build and calibrate
    _, pr_a = make_calibrated_linear(seed=seed, out_features=5, in_features=12, block_size=3, n_samples=192)
    pr_b = clone_pruner(pr_a)

    # Choose two blocks to prune jointly (both are full-sized)
    assert len(pr_a.selected_blocks) >= 3
    # pick blocks 0 and 2 (non-adjacent just to stress indexing)
    blk_ids = [pr_a.selected_blocks[0], pr_a.selected_blocks[2]]

    # Map to *local* indices for pr_a:
    B_loc = torch.tensor([0, 2], device=pr_a.G_CC.device, dtype=torch.long)

    # --- Path A: joint multi-block
    pr_a._apply_joint_update_and_downdate(B_loc)

    # --- Path B: rank-k streaming in the same order of block IDs
    # --- Path B: rank-k streaming over the *union* (matches joint)
    pr_b.prune_blocks_rank1_stream(B_loc)

    # --- Compare results
    torch.testing.assert_close(pr_a.layer.weight, pr_b.layer.weight, rtol=1e-5, atol=1e-6)
    torch.testing.assert_close(pr_a.G_CC, pr_b.G_CC, rtol=1e-5, atol=1e-6)
    torch.testing.assert_close(pr_a.H_panel, pr_b.H_panel, rtol=1e-5, atol=1e-6)

    assert pr_a.selected_blocks == pr_b.selected_blocks
    assert pr_a.blocks_to_global_indices.keys() == pr_b.blocks_to_global_indices.keys()
    assert pr_a.blocks_to_local_indices == pr_b.blocks_to_local_indices


@pytest.mark.parametrize("seed", [4])
def test_single_block_joint_vs_single(seed):
    # Sanity: joint prune with |B|=1 equals single-block prune.
    _, pr_x = make_calibrated_linear(seed=seed, out_features=4, in_features=12, block_size=3, n_samples=128)
    pr_y = clone_pruner(pr_x)

    # Pick first block
    bid = pr_x.selected_blocks[0]
    Bloc = torch.tensor([0], device=pr_x.G_CC.device, dtype=torch.long)

    # Joint path with 1 block
    pr_x._apply_joint_update_and_downdate(Bloc)

    # Single-block path
    pr_y._prune_block(local_idx=0, block_id=bid)

    torch.testing.assert_close(pr_x.layer.weight, pr_y.layer.weight, rtol=1e-5, atol=1e-6)
    torch.testing.assert_close(pr_x.G_CC, pr_y.G_CC, rtol=1e-5, atol=1e-6)
    torch.testing.assert_close(pr_x.H_panel, pr_y.H_panel, rtol=1e-5, atol=1e-6)

    assert pr_x.selected_blocks == pr_y.selected_blocks
    assert pr_x.blocks_to_local_indices == pr_y.blocks_to_local_indices


# ----- Utilities (reuse from your file if already defined) ----------------
def make_calibrated_linear(seed=0, out_features=6, in_features=12, block_size=3, n_samples=256):
    g = torch.Generator().manual_seed(seed)
    layer = torch.nn.Linear(in_features, out_features, bias=False)
    with torch.no_grad():
        torch.manual_seed(seed)
        layer.weight.normal_(mean=0.0, std=0.1)

    pruner = HybridOBSLinearPruner(layer, block_size=block_size)
    n_blocks_full = in_features // block_size
    pruner.set_selected_blocks(list(range(n_blocks_full)))

    torch.manual_seed(seed + 1)
    X = torch.randn(n_samples, in_features) * 0.5
    pruner.add_batch(X)
    pruner.finalize_calibration(min_damping=1e-3, max_damping=1e2)

    return layer, pruner


# ----- Tests for certify_batch --------------------------------------------

@pytest.mark.parametrize("seed", [0, 1, 2])
def test_certify_batch_basic_validity(seed):
    _, pr = make_calibrated_linear(seed=seed, out_features=6, in_features=12, block_size=3, n_samples=256)

    scores = pr.importance_all(return_tensor=True)
    B_loc = pr.certify_batch(scores, max_try=64, window=32, outsiders_k=3, slack=1e-3)

    # Nonempty, on-device, dtype long
    assert isinstance(B_loc, torch.Tensor)
    assert B_loc.dtype == torch.long
    assert B_loc.device == pr.G_CC.device
    assert B_loc.numel() >= 1

    # Indices unique and in range
    assert len(set(B_loc.tolist())) == B_loc.numel()
    assert B_loc.min().item() >= 0
    assert B_loc.max().item() < len(pr.selected_blocks)


@pytest.mark.parametrize("seed", [3, 4])
def test_certify_batch_margin_is_respected(seed):
    """
    Recreate the margin check from certify_batch for the returned set.
    LHS = max_i in top_r  [ score[i] + bi + slack ]
    RHS = min_j in outsiders [ score[j] - bj - slack ]
    Require LHS < RHS.
    """
    _, pr = make_calibrated_linear(seed=seed, out_features=6, in_features=12, block_size=3, n_samples=256)
    scores = pr.importance_all(return_tensor=True)

    outsiders_k = 3
    slack = 1e-3
    B_loc = pr.certify_batch(scores, max_try=64, window=32, outsiders_k=outsiders_k, slack=slack)
    r = len(B_loc)
    order = torch.argsort(scores)

    # top_r and outsiders as in the implementation
    top_r = order[:r]
    outsiders = order[r : min(len(scores), r + outsiders_k)]

    # Compute the bi/bj drift bounds with provided helpers
    lhs_terms = []
    for i in top_r.tolist():
        Bminus = [b for b in B_loc.tolist() if b != i]
        bi = 0.0 if len(Bminus) == 0 else float(pr._score_drift_upper(i, torch.tensor(Bminus, device=pr.G_CC.device)))
        lhs_terms.append(scores[i].item() + bi + slack)
    lhs = max(lhs_terms)

    if outsiders.numel() > 0:
        rhs_terms = []
        for j in outsiders.tolist():
            bj = float(pr._score_drift_upper(j, B_loc))
            rhs_terms.append(scores[j].item() - bj - slack)
        rhs = min(rhs_terms)
        assert lhs < rhs, f"cert margin violated: lhs={lhs:.3e} rhs={rhs:.3e} (r={r})"
    else:
        # No outsiders; margin condition vacuously ok
        assert r >= 1


@pytest.mark.parametrize("seed", [5, 6])
def test_certify_batch_slack_sensitivity(seed):
    """
    Increasing slack should not increase the certified batch size.
    """
    _, pr = make_calibrated_linear(seed=seed, out_features=6, in_features=12, block_size=3, n_samples=256)
    scores = pr.importance_all(return_tensor=True)

    B_lo_slack = pr.certify_batch(scores, max_try=128, window=32, outsiders_k=3, slack=1e-4)
    B_hi_slack = pr.certify_batch(scores, max_try=128, window=32, outsiders_k=3, slack=5e-3)

    assert len(B_lo_slack) >= 1
    assert len(B_hi_slack) >= 1
    assert len(B_lo_slack) >= len(B_hi_slack)


@pytest.mark.parametrize("seed", [7])
def test_certify_batch_repeatability(seed):
    """
    Same inputs -> same certified set (deterministic).
    """
    _, pr = make_calibrated_linear(seed=seed, out_features=6, in_features=12, block_size=3, n_samples=256)
    scores = pr.importance_all(return_tensor=True)

    B1 = pr.certify_batch(scores, max_try=64, window=16, outsiders_k=2, slack=1e-3)
    B2 = pr.certify_batch(scores, max_try=64, window=16, outsiders_k=2, slack=1e-3)
    assert torch.equal(B1, B2)
