from __future__ import annotations

from dataclasses import dataclass
import torch


def spectral_norm_bound(model: torch.nn.Module, iters: int = 1) -> float:
    bound = 1.0
    for m in model.modules():
        if isinstance(m, torch.nn.Linear):
            with torch.no_grad():
                w = m.weight
                v = torch.randn(w.size(1), device=w.device)
                v = v / (v.norm() + 1e-12)
                for _ in range(iters):
                    u = (w @ v); v = (w.t() @ u)
                    v = v / (v.norm() + 1e-12)
                sigma = (u.norm() / (v.norm() + 1e-12)).item()
                bound *= max(1.0, sigma)
    return float(bound)


def product_spectral_bounds(models: list[torch.nn.Module], iters: int = 1) -> float:
    b = 1.0
    for m in models:
        b *= spectral_norm_bound(m, iters=iters)
    return float(b)


def _demo():
    lin = torch.nn.Linear(4, 4)
    print(spectral_norm_bound(lin, iters=2))


if __name__ == "__main__":
    _demo()
                
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
