import torch
import torch.nn as nn

def proxy4(model, inputs, targets):
    bn_ranks = []
    ratios = []
    hooks = []
    
    def bn_hook(module, inp, out):
        if isinstance(out, torch.Tensor):
            B, C, H, W = out.shape
            mat = out.view(B, C, -1).permute(0, 2, 1).reshape(-1, C)
            frob_norm = torch.linalg.matrix_norm(mat, ord='fro')**2
            spec_norm = torch.linalg.matrix_norm(mat, ord=2)**2
            stable_rank = frob_norm / (spec_norm + 1e-6)
            bn_ranks.append(stable_rank.mean())
    
    for layer in model.modules():
        if isinstance(layer, nn.BatchNorm2d):
            hooks.append(layer.register_forward_hook(bn_hook))
        elif isinstance(layer, nn.Conv2d):
            weights = layer.weight
            l1_norm = weights.abs().sum(dim=(1,2,3)).mean()
            l2_norm = weights.norm(p=2, dim=(1,2,3)).mean()
            ratios.append((l1_norm / l2_norm).item())
    
    with torch.no_grad():
        model(inputs)
    
    for hook in hooks:
        hook.remove()
    
    bn_sum = torch.stack(bn_ranks).sum().item() if bn_ranks else 0.0
    ratio_sum = sum(ratios) if ratios else 0.0
    
    return bn_sum * ratio_sum
