import src.libs.baukit_fn as baukit_fn
import src.libs.nnsight_fn as nnsight_fn
import src.libs.pyvene_fn as pyvene_fn
import src.libs.transformerlens_fn as transformerlens_fn


MODELS = {
    "gpt2": "openai-community/gpt2-xl",
    # "gemma-7b": "google/gemma-7b",
    # "llama3-8b": "meta-llama/Meta-Llama-3.1-8B",
    # "llama3-70b": "meta-llama/Meta-Llama-3.1-70B",
}

OPT_MODELS = {
    "prep": "facebook/opt-125m",
    "opt-125m": "facebook/opt-125m",
    "opt-350m": "facebook/opt-350m",
    "opt-1.3b": "facebook/opt-1.3b",
    "opt-2.7b": "facebook/opt-2.7b",
    "opt-6.7b": "facebook/opt-6.7b",
    "opt-13b": "facebook/opt-13b",
    "opt-30b": "facebook/opt-30b",
    "opt-66b": "facebook/opt-66b",
}


LIBRARIES = {
    "baukit": baukit_fn.setup_environment,
    "pyvene": pyvene_fn.setup_environment,
    "transformerlens": transformerlens_fn.setup_environment,
    "nnsight": nnsight_fn.setup_environment,
}


EXPERIMENT_FUNCTIONS = {
    "activation_patching": {
        "nnsight": nnsight_fn.activation_patching,
        "baukit": baukit_fn.activation_patching,
        "transformerlens": transformerlens_fn.activation_patching,
        "pyvene": pyvene_fn.activation_patching,
    },
}
