#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
#
from .get_model import *

# Registry with the default modules analyzed for each model in the ICML'24 AURA paper.
ICML24_MODEL_MODULES_REGISTRY = {
    "google/gemma-2-2b": [
        "model.layers.*.mlp.up_proj",
        "model.layers.*.mlp.down_proj",
        "model.layers.*.mlp.gate_proj",
    ],
    "mistralai/Mistral-7B-v0.1": [
        "model.layers.*.mlp.up_proj",
        "model.layers.*.mlp.down_proj",
        "model.layers.*.mlp.gate_proj",
    ],
    "Llama-2-7b": [
        "model.layers.*.mlp.up_proj",
        "model.layers.*.mlp.down_proj",
        "model.layers.*.mlp.gate_proj",
    ],
    "meta-llama/Meta-Llama-3-8B": [
        "model.layers.*.mlp.up_proj",
        "model.layers.*.mlp.down_proj",
        "model.layers.*.mlp.gate_proj",
    ],
    "openai-community/gpt2": ["transformer.h.*.mlp.c_fc", "transformer.h.*.mlp.c_mlp"],
    "openai-community/gpt2-xl": [
        "transformer.h.*.mlp.c_fc",
        "transformer.h.*.mlp.c_mlp",
    ],
    "mosaicml/mpt-7b": [
        "transformer.blocks.*.ffn.up_proj",
        "transformer.blocks.*.ffn.down_proj",
    ],
    "mosaicml/mpt-30b": [
        "transformer.blocks.*.ffn.up_proj",
        "transformer.blocks.*.ffn.down_proj",
    ],
    "tiiuae/falcon-7b": [
        "transformer.h.*.mlp.dense_4h_to_h",
        "transformer.h.*.mlp.dense_h_to_4h",
    ],
    "tiiuae/falcon-7b-instruct": [
        "transformer.h.*.mlp.dense_4h_to_h",
        "transformer.h.*.mlp.dense_h_to_4h",
    ],
    "tiiuae/falcon-40b": [
        "transformer.h.*.mlp.dense_4h_to_h",
        "transformer.h.*.mlp.dense_h_to_4h",
    ],
}


ALL_MODEL_MODULES_REGISTRY = {
    "mistralai/Mistral-7B-v0.1": [
        "model.layers.*.mlp.up_proj",
        "model.layers.*.mlp.down_proj",
        "model.layers.*.mlp.gate_proj",
        "model.layers.*.*_layernorm",
    ],
    "Llama-2-7b": [
        "model.layers.*.mlp.up_proj",
        "model.layers.*.mlp.down_proj",
        "model.layers.*.mlp.gate_proj",
        "model.layers.*.*_layernorm",
    ],
    "openai-community/gpt2": [
        "transformer.h.*.mlp.c_fc",
        "transformer.h.*.mlp.c_mlp",
        "transformer.h.*.ln_1",
        "transformer.h.*.ln_2",
    ],
    "openai-community/gpt2-xl": [
        "transformer.h.*.mlp.c_fc",
        "transformer.h.*.mlp.c_mlp",
        "transformer.h.*.ln_1",
        "transformer.h.*.ln_2",
    ],
    "mosaicml/mpt-7b": [
        "transformer.blocks.*.ffn.up_proj",
        "transformer.blocks.*.ffn.down_proj",
        "transformer.blocks.*.norm_1",
        "transformer.blocks.*.norm_2",
    ],
    "mosaicml/mpt-30b": [
        "transformer.blocks.*.ffn.up_proj",
        "transformer.blocks.*.ffn.down_proj",
        "transformer.blocks.*.norm_1",
        "transformer.blocks.*.norm_2",
    ],
    "tiiuae/falcon-7b": [
        "transformer.h.*.mlp.dense_4h_to_h",
        "transformer.h.*.mlp.dense_h_to_4h",
        "transformer.h.*.input_layernorm",
    ],
    "tiiuae/falcon-7b-instruct": [
        "transformer.h.*.mlp.dense_4h_to_h",
        "transformer.h.*.mlp.dense_h_to_4h",
        "transformer.h.*.input_layernorm",
    ],
    "tiiuae/falcon-40b": [
        "transformer.h.*.mlp.dense_4h_to_h",
        "transformer.h.*.mlp.dense_h_to_4h",
        "transformer.h.*.ln_attn",
        "transformer.h.*.ln_mlp",
    ],
}

MODULE_NAMES_REGISTRY = {
    "icml24": ICML24_MODEL_MODULES_REGISTRY,
    "post_layernorm": [
        ".*post_attention_layernorm",
        ".*post_feedforward_layernorm",
    ],  # Fails for Llama, no such layers
    "attention": [".*o_proj"],
    "all_layernorm": [
        ".*_layernorm",
    ],
    "mlp": [".*down_proj"],
    "pre_layernorm": [".*o_proj", ".*down_proj"],
    "unet_layernorm": ["unet.*norm.*"],
    "text_encoder_layernorm": ["text_encoder.*norm.*"],
    "text_encoders_norm": ["text_encoder.*norm.*", "text_encoder_2.*norm.*"],
    "transformer_norm": ["transformer.*attn:0"],
    "transformer_attn": [
        "transformer.transformer_blocks.[0-9]+.attn:0",
        "transformer.transformer_blocks.[0-9]+.attn:1",
        "transformer.single_transformer_blocks.[0-9]+.attn:0",
    ],
    "transformer_attn_2": [
        "transformer.transformer_blocks.[0-9]+.attn:0",
        "transformer.transformer_blocks.[0-9]+.attn:1",
        "transformer.single_transformer_blocks.[0-9]+.attn:0",
        "transformer.x_embedder",
        "transformer.context_embedder",
    ],
    "transformer_blocks": ["transformer.transformer_blocks.[0-9]+:0"],
    "transformer_blocks_0": [
        "transformer.transformer_blocks.[0-9]+:0",
        "transformer.transformer_blocks.[0-9]+:1",
    ],
    "transformer_blocks_1": [
        "transformer.transformer_blocks.[0-9]+:0",
        "transformer.transformer_blocks.[0-9]+:1",
        "transformer.single_transformer_blocks.0:0",
    ],
    "transformer_blocks_2": [
        "transformer.transformer_blocks.[0-9]+:0",
        "transformer.transformer_blocks.[0-9]+:1",
        "transformer.single_transformer_blocks.0:0",
        "transformer.single_transformer_blocks.1:0",
    ],
    "transformer_blocks_3": [
        "transformer.transformer_blocks.[0-9]+:0",
        "transformer.transformer_blocks.[0-9]+:1",
        "transformer.single_transformer_blocks.0:0",
        "transformer.single_transformer_blocks.1:0",
        "transformer.single_transformer_blocks.2:0",
    ],
    "transformer_blocks_4": [
        "transformer.transformer_blocks.[0-9]+:0",
        "transformer.transformer_blocks.[0-9]+:1",
        "transformer.single_transformer_blocks.0:0",
        "transformer.single_transformer_blocks.1:0",
        "transformer.single_transformer_blocks.2:0",
        "transformer.single_transformer_blocks.3:0",
    ],
    "transformer_blocks_5": [
        "transformer.transformer_blocks.[0-9]+:0",
        "transformer.transformer_blocks.[0-9]+:1",
        "transformer.single_transformer_blocks.0:0",
        "transformer.single_transformer_blocks.1:0",
        "transformer.single_transformer_blocks.2:0",
        "transformer.single_transformer_blocks.3:0",
        "transformer.single_transformer_blocks.4:0",
    ],
    "transformer_blocks_6": [
        "transformer.transformer_blocks.[0-9]+:0",
        "transformer.transformer_blocks.[0-9]+:1",
        "transformer.single_transformer_blocks.0:0",
        "transformer.single_transformer_blocks.1:0",
        "transformer.single_transformer_blocks.2:0",
        "transformer.single_transformer_blocks.3:0",
        "transformer.single_transformer_blocks.4:0",
        "transformer.single_transformer_blocks.5:0",
    ],
    "transformer_blocks_7": [
        "transformer.transformer_blocks.[0-9]+:0",
        "transformer.transformer_blocks.[0-9]+:1",
        "transformer.single_transformer_blocks.0:0",
        "transformer.single_transformer_blocks.1:0",
        "transformer.single_transformer_blocks.2:0",
        "transformer.single_transformer_blocks.3:0",
        "transformer.single_transformer_blocks.4:0",
        "transformer.single_transformer_blocks.5:0",
        "transformer.single_transformer_blocks.6:0",
    ],
    "transformer_blocks_8": [
        "transformer.transformer_blocks.[0-9]+:0",
        "transformer.transformer_blocks.[0-9]+:1",
        "transformer.single_transformer_blocks.0:0",
        "transformer.single_transformer_blocks.1:0",
        "transformer.single_transformer_blocks.2:0",
        "transformer.single_transformer_blocks.3:0",
        "transformer.single_transformer_blocks.4:0",
        "transformer.single_transformer_blocks.5:0",
        "transformer.single_transformer_blocks.6:0",
        "transformer.single_transformer_blocks.7:0",
    ],
    "transformer_blocks_9": [
        "transformer.transformer_blocks.[0-9]+:0",
        "transformer.transformer_blocks.[0-9]+:1",
        "transformer.single_transformer_blocks.0:0",
        "transformer.single_transformer_blocks.1:0",
        "transformer.single_transformer_blocks.2:0",
        "transformer.single_transformer_blocks.3:0",
        "transformer.single_transformer_blocks.4:0",
        "transformer.single_transformer_blocks.5:0",
        "transformer.single_transformer_blocks.6:0",
        "transformer.single_transformer_blocks.7:0",
        "transformer.single_transformer_blocks.8:0",
    ],
    "transformer_blocks_10": [
        "transformer.transformer_blocks.[0-9]+:0",
        "transformer.transformer_blocks.[0-9]+:1",
        "transformer.single_transformer_blocks.0:0",
        "transformer.single_transformer_blocks.1:0",
        "transformer.single_transformer_blocks.2:0",
        "transformer.single_transformer_blocks.3:0",
        "transformer.single_transformer_blocks.4:0",
        "transformer.single_transformer_blocks.5:0",
        "transformer.single_transformer_blocks.6:0",
        "transformer.single_transformer_blocks.7:0",
        "transformer.single_transformer_blocks.8:0",
        "transformer.single_transformer_blocks.9:0",
    ],
    "transformer_blocks_11": [
        "transformer.transformer_blocks.[0-9]+:0",
        "transformer.transformer_blocks.[0-9]+:1",
        "transformer.single_transformer_blocks.0:0",
        "transformer.single_transformer_blocks.1:0",
        "transformer.single_transformer_blocks.2:0",
        "transformer.single_transformer_blocks.3:0",
        "transformer.single_transformer_blocks.4:0",
        "transformer.single_transformer_blocks.5:0",
        "transformer.single_transformer_blocks.6:0",
        "transformer.single_transformer_blocks.7:0",
        "transformer.single_transformer_blocks.8:0",
        "transformer.single_transformer_blocks.9:0",
        "transformer.single_transformer_blocks.10:0",
    ],
    "transformer_blocks_12": [
        "transformer.transformer_blocks.[0-9]+:0",
        "transformer.transformer_blocks.[0-9]+:1",
        "transformer.single_transformer_blocks.0:0",
        "transformer.single_transformer_blocks.1:0",
        "transformer.single_transformer_blocks.2:0",
        "transformer.single_transformer_blocks.3:0",
        "transformer.single_transformer_blocks.4:0",
        "transformer.single_transformer_blocks.5:0",
        "transformer.single_transformer_blocks.6:0",
        "transformer.single_transformer_blocks.7:0",
        "transformer.single_transformer_blocks.8:0",
        "transformer.single_transformer_blocks.9:0",
        "transformer.single_transformer_blocks.10:0",
        "transformer.single_transformer_blocks.11:0",
    ],
    "transformer_blocks_13": [
        "transformer.transformer_blocks.[0-9]+:0",
        "transformer.transformer_blocks.[0-9]+:1",
        "transformer.single_transformer_blocks.0:0",
        "transformer.single_transformer_blocks.1:0",
        "transformer.single_transformer_blocks.2:0",
        "transformer.single_transformer_blocks.3:0",
        "transformer.single_transformer_blocks.4:0",
        "transformer.single_transformer_blocks.5:0",
        "transformer.single_transformer_blocks.6:0",
        "transformer.single_transformer_blocks.7:0",
        "transformer.single_transformer_blocks.8:0",
        "transformer.single_transformer_blocks.9:0",
        "transformer.single_transformer_blocks.10:0",
        "transformer.single_transformer_blocks.11:0",
        "transformer.single_transformer_blocks.12:0",
    ],
    "transformer_blocks_14": [
        "transformer.transformer_blocks.[0-9]+:0",
        "transformer.transformer_blocks.[0-9]+:1",
        "transformer.single_transformer_blocks.0:0",
        "transformer.single_transformer_blocks.1:0",
        "transformer.single_transformer_blocks.2:0",
        "transformer.single_transformer_blocks.3:0",
        "transformer.single_transformer_blocks.4:0",
        "transformer.single_transformer_blocks.5:0",
        "transformer.single_transformer_blocks.6:0",
        "transformer.single_transformer_blocks.7:0",
        "transformer.single_transformer_blocks.8:0",
        "transformer.single_transformer_blocks.9:0",
        "transformer.single_transformer_blocks.10:0",
        "transformer.single_transformer_blocks.11:0",
        "transformer.single_transformer_blocks.12:0",
        "transformer.single_transformer_blocks.13:0",
    ],
    "transformer_blocks_15": [
        "transformer.transformer_blocks.[0-9]+:0",
        "transformer.transformer_blocks.[0-9]+:1",
        "transformer.single_transformer_blocks.0:0",
        "transformer.single_transformer_blocks.1:0",
        "transformer.single_transformer_blocks.2:0",
        "transformer.single_transformer_blocks.3:0",
        "transformer.single_transformer_blocks.4:0",
        "transformer.single_transformer_blocks.5:0",
        "transformer.single_transformer_blocks.6:0",
        "transformer.single_transformer_blocks.7:0",
        "transformer.single_transformer_blocks.8:0",
        "transformer.single_transformer_blocks.9:0",
        "transformer.single_transformer_blocks.10:0",
        "transformer.single_transformer_blocks.11:0",
        "transformer.single_transformer_blocks.12:0",
        "transformer.single_transformer_blocks.13:0",
        "transformer.single_transformer_blocks.14:0",
    ],
    "transformer_blocks_16": [
        "transformer.transformer_blocks.[0-9]+:0",
        "transformer.transformer_blocks.[0-9]+:1",
        "transformer.single_transformer_blocks.0:0",
        "transformer.single_transformer_blocks.1:0",
        "transformer.single_transformer_blocks.2:0",
        "transformer.single_transformer_blocks.3:0",
        "transformer.single_transformer_blocks.4:0",
        "transformer.single_transformer_blocks.5:0",
        "transformer.single_transformer_blocks.6:0",
        "transformer.single_transformer_blocks.7:0",
        "transformer.single_transformer_blocks.8:0",
        "transformer.single_transformer_blocks.9:0",
        "transformer.single_transformer_blocks.10:0",
        "transformer.single_transformer_blocks.11:0",
        "transformer.single_transformer_blocks.12:0",
        "transformer.single_transformer_blocks.13:0",
        "transformer.single_transformer_blocks.14:0",
        "transformer.single_transformer_blocks.15:0",
    ],
    "transformer_blocks_17": [
        "transformer.transformer_blocks.[0-9]+:0",
        "transformer.transformer_blocks.[0-9]+:1",
        "transformer.single_transformer_blocks.0:0",
        "transformer.single_transformer_blocks.1:0",
        "transformer.single_transformer_blocks.2:0",
        "transformer.single_transformer_blocks.3:0",
        "transformer.single_transformer_blocks.4:0",
        "transformer.single_transformer_blocks.5:0",
        "transformer.single_transformer_blocks.6:0",
        "transformer.single_transformer_blocks.7:0",
        "transformer.single_transformer_blocks.8:0",
        "transformer.single_transformer_blocks.9:0",
        "transformer.single_transformer_blocks.10:0",
        "transformer.single_transformer_blocks.11:0",
        "transformer.single_transformer_blocks.12:0",
        "transformer.single_transformer_blocks.13:0",
        "transformer.single_transformer_blocks.14:0",
        "transformer.single_transformer_blocks.15:0",
        "transformer.single_transformer_blocks.16:0",
    ],
    "transformer_blocks_18": [
        "transformer.transformer_blocks.[0-9]+:0",
        "transformer.transformer_blocks.[0-9]+:1",
        "transformer.single_transformer_blocks.0:0",
        "transformer.single_transformer_blocks.1:0",
        "transformer.single_transformer_blocks.2:0",
        "transformer.single_transformer_blocks.3:0",
        "transformer.single_transformer_blocks.4:0",
        "transformer.single_transformer_blocks.5:0",
        "transformer.single_transformer_blocks.6:0",
        "transformer.single_transformer_blocks.7:0",
        "transformer.single_transformer_blocks.8:0",
        "transformer.single_transformer_blocks.9:0",
        "transformer.single_transformer_blocks.10:0",
        "transformer.single_transformer_blocks.11:0",
        "transformer.single_transformer_blocks.12:0",
        "transformer.single_transformer_blocks.13:0",
        "transformer.single_transformer_blocks.14:0",
        "transformer.single_transformer_blocks.15:0",
        "transformer.single_transformer_blocks.16:0",
        "transformer.single_transformer_blocks.17:0",
    ],
    "transformer_blocks_19": [
        "transformer.transformer_blocks.[0-9]+:0",
        "transformer.transformer_blocks.[0-9]+:1",
        "transformer.single_transformer_blocks.0:0",
        "transformer.single_transformer_blocks.1:0",
        "transformer.single_transformer_blocks.2:0",
        "transformer.single_transformer_blocks.3:0",
        "transformer.single_transformer_blocks.4:0",
        "transformer.single_transformer_blocks.5:0",
        "transformer.single_transformer_blocks.6:0",
        "transformer.single_transformer_blocks.7:0",
        "transformer.single_transformer_blocks.8:0",
        "transformer.single_transformer_blocks.9:0",
        "transformer.single_transformer_blocks.10:0",
        "transformer.single_transformer_blocks.11:0",
        "transformer.single_transformer_blocks.12:0",
        "transformer.single_transformer_blocks.13:0",
        "transformer.single_transformer_blocks.14:0",
        "transformer.single_transformer_blocks.15:0",
        "transformer.single_transformer_blocks.16:0",
        "transformer.single_transformer_blocks.17:0",
        "transformer.single_transformer_blocks.18:0",
    ],
    "all": ALL_MODEL_MODULES_REGISTRY,
}
