AXES_TO_PERM = {
    "layer0.weight": ("P_0", None),
    "layer1.weight": ("P_1", "P_0"),
    "layer2.weight": ("P_2", "P_1"),
    "layer3.weight": ("P_3", "P_2"),
    "layer0.bias": ("P_0",),
    "layer1.bias": ("P_1",),
    "layer2.bias": ("P_2",),
    "layer3.bias": ("P_3",),
    "layer4.weight": (None, "P_3"),
    "layer4.bias": (None,),
}


PERM_TO_AXES = {
    "P_0": [("layer0.weight", 0), ("layer1.weight", 1), ("layer0.bias", 0)],
    "P_1": [("layer1.weight", 0), ("layer2.weight", 1), ("layer1.bias", 0)],
    "P_2": [("layer2.weight", 0), ("layer3.weight", 1), ("layer2.bias", 0)],
    "P_3": [("layer3.weight", 0), ("layer4.weight", 1), ("layer3.bias", 0)],
}
