import torch
import json

def log_corrupted_minus_original_activations_norms(corrupted_minus_original_activations,
                                                   num_layers, num_heads, num_tokens,
                                                   log_dir) -> None:
    assert "attn" in corrupted_minus_original_activations
    assert "mlp" in corrupted_minus_original_activations

    corrupted_minus_original_activations_norms = {
        "attn": {layer: {head: torch.norm(corrupted_minus_original_activations["attn"][layer][head]).item()
                for head in range(num_heads)}
                for layer in range(num_layers)},
        "mlp": {
            layer: torch.norm(corrupted_minus_original_activations["mlp"][layer]).item()
                for layer in range(num_layers)
        },
    }
    with open(str(log_dir.joinpath("corrupted_minus_original_activations_norms.json")), "w") as file:
        json.dump(corrupted_minus_original_activations_norms, file, indent=4)