[
    {
        "description": "This metric evaluates the sum of max activations in BatchNorm layers weighted by their layer depth.",
        "code": "import torch\nimport torch.nn as nn\n\ndef heuristic_3(model, inputs, targets):\n    weighted_max = []\n    hooks = []\n    depth = 0\n    \n    def hook_fn(module, input, output):\n        weighted_max.append((output.abs().max() * depth).detach())\n    \n    for layer in model.modules():\n        if isinstance(layer, nn.BatchNorm2d):\n            hooks.append(layer.register_forward_hook(hook_fn))\n            depth += 1\n    \n    with torch.no_grad():\n        model(inputs)\n    \n    for hook in hooks:\n        hook.remove()\n    \n    if not weighted_max:\n        return 0.0\n    \n    return torch.sum(torch.stack(weighted_max)).item()",
        "score": 0.7978655290087604
    },
    {
        "description": "This metric computes the sum of max activations across all BatchNorm layers.",
        "code": "import torch\nimport torch.nn as nn\n\ndef heuristic_2(model, inputs, targets):\n    max_vals = []\n    hooks = []\n    \n    def hook_fn(module, input, output):\n        max_vals.append(output.abs().max().detach())\n    \n    for layer in model.modules():\n        if isinstance(layer, nn.BatchNorm2d):\n            hooks.append(layer.register_forward_hook(hook_fn))\n    \n    with torch.no_grad():\n        model(inputs)\n    \n    for hook in hooks:\n        hook.remove()\n    \n    if not max_vals:\n        return 0.0\n    \n    return torch.sum(torch.stack(max_vals)).item()",
        "score": 0.7686068602207567
    },
    {
        "description": "This metric measures the sum of mean activations divided by their max values in residual block outputs.",
        "code": "import torch\nimport torch.nn as nn\n\ndef heuristic_5(model, inputs, targets):\n    ratios = []\n    hooks = []\n    \n    def hook_fn(module, input, output):\n        mean_val = output.abs().mean()\n        max_val = output.abs().max()\n        ratios.append((mean_val / (max_val + 1e-6)).detach())\n    \n    for layer in model.modules():\n        if isinstance(layer, nn.Sequential) and any(isinstance(m, nn.Conv2d) for m in layer):\n            hooks.append(layer.register_forward_hook(hook_fn))\n    \n    with torch.no_grad():\n        model(inputs)\n    \n    for hook in hooks:\n        hook.remove()\n    \n    if not ratios:\n        return 0.0\n    \n    return torch.sum(torch.stack(ratios)).item()",
        "score": 0.7523746520153696
    },
    {
        "description": "This metric measures the sum of variance-to-mean ratios in BatchNorm layer activations.",
        "code": "import torch\nimport torch.nn as nn\n\ndef heuristic_2(model, inputs, targets):\n    ratios = []\n    hooks = []\n    \n    def hook_fn(module, input, output):\n        var_val = output.var()\n        mean_val = output.abs().mean()\n        ratios.append((var_val / (mean_val + 1e-6)).detach())\n    \n    for layer in model.modules():\n        if isinstance(layer, nn.BatchNorm2d):\n            hooks.append(layer.register_forward_hook(hook_fn))\n    \n    with torch.no_grad():\n        model(inputs)\n    \n    for hook in hooks:\n        hook.remove()\n    \n    if not ratios:\n        return 0.0\n    \n    return torch.sum(torch.stack(ratios)).item()",
        "score": 0.7397967923141099
    },
    {
        "description": "This metric evaluates the sum of mean activations multiplied by their standard deviations in BatchNorm layers.",
        "code": "import torch\nimport torch.nn as nn\n\ndef heuristic_5(model, inputs, targets):\n    products = []\n    hooks = []\n    \n    def hook_fn(module, input, output):\n        mean_activation = output.abs().mean()\n        std_activation = output.std()\n        products.append((mean_activation * std_activation).detach())\n    \n    for layer in model.modules():\n        if isinstance(layer, nn.BatchNorm2d):\n            hooks.append(layer.register_forward_hook(hook_fn))\n    \n    with torch.no_grad():\n        model(inputs)\n    \n    for hook in hooks:\n        hook.remove()\n    \n    if not products:\n        return 0.0\n    \n    return torch.sum(torch.stack(products)).item()",
        "score": 0.7324196401049915
    }
]