import torch
import torch.nn as nn

from spastra.stats import ModuleStatsCollector, StatConfig


class Tiny(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 4, 3, bias=False)
        self.lin = nn.Linear(4, 2, bias=False)

    def forward(self, x):
        x = self.conv(x)
        x = torch.mean(x, dim=(2,3))
        return self.lin(x)


def test_stats_sparsity_delta_bitflip():
    torch.manual_seed(0)
    m = Tiny()

    cfg = StatConfig(
        sparsity_types=(nn.Conv2d, nn.Linear),
        delta_types=(nn.Conv2d, nn.Linear),
        bitflip_types=(nn.Conv2d, nn.Linear),
        eps=1e-8,
    )
    coll = ModuleStatsCollector(cfg)

    # Prime the collector
    coll.prime(m)

    # Manually zero some weights to create sparsity
    with torch.no_grad():
        m.conv.weight.view(-1)[0:3] = 0
        m.lin.weight.view(-1)[0:2] = 0

    stats1 = coll.update(m)

    # Check sparsity keys exist
    assert any(k.endswith('conv') for k in stats1['sparsity'].keys())
    assert any(k.endswith('lin') for k in stats1['sparsity'].keys())

    # Delta should be > 0 because weights changed vs snapshot
    total_delta = 0.0
    for v in stats1['delta'].values():
        try:
            total_delta += float(v)  # type: ignore[arg-type]
        except Exception:
            pass
    assert total_delta >= 0

    # Bitflip counts reflect zero->nonzero or nonzero->zero transitions
    # We zeroed some entries relative to prime snapshot
    total_leaves = 0
    for v in stats1['bitflip'].values():
        if isinstance(v, dict):
            total_leaves += int(v.get('leaves', 0))
    assert total_leaves > 0

    # Next update with small non-zero tweak that doesn't cross zero threshold
    with torch.no_grad():
        m.conv.weight.add_(1e-6)
        m.lin.weight.add_(1e-6)

    stats2 = coll.update(m)
    # Bitflip can be zero if no sign/support crossings occurred
    total_entries = 0
    for v in stats2['bitflip'].values():
        if isinstance(v, dict):
            total_entries += int(v.get('entries', 0))
    assert total_entries >= 0
