{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "sys.path.append(\"../\")\n",
    "import copy\n",
    "\n",
    "import logging\n",
    "from src.utils import logging_utils\n",
    "from src import functional\n",
    "from src.models import ModelandTokenizer\n",
    "# from src.data import load_relation\n",
    "\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "logging.basicConfig(\n",
    "    level=logging.DEBUG,\n",
    "    format=logging_utils.DEFAULT_FORMAT,\n",
    "    datefmt=logging_utils.DEFAULT_DATEFMT,\n",
    "    stream=sys.stdout,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import transformers\n",
    "import baukit\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "from src.models import ModelandTokenizer\n",
    "\n",
    "MODEL_PATH = \"state-spaces/mamba-2.8b-slimpj\" # state-spaces/mamba-2.8b\n",
    "\n",
    "mt = ModelandTokenizer(\n",
    "    model_path=MODEL_PATH, \n",
    "    torch_dtype=torch.float32\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from src.data.dataclasses import CounterFactDataset\n",
    "\n",
    "from src.dataset.dataclasses import load_relation\n",
    "relation = load_relation(file = \"../data/relation/factual/place_in_city.json\")\n",
    "relation.select_icl_examples(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "question, answer = relation[10]\n",
    "print(question, end=\"\\n\\n\")\n",
    "print(f\"{answer=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import filter_samples_by_model_knowledge\n",
    "relation = filter_samples_by_model_knowledge(\n",
    "    mt = mt,\n",
    "    relation = relation,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.tracing import calculate_average_indirect_effects\n",
    "from src.plotting import plot_trace_heatmap\n",
    "\n",
    "# -------------------------------------------------\n",
    "n_trials = 30\n",
    "# -------------------------------------------------"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_template = mt.tokenizer.eos_token + \" \" + relation.prompt_templates[0]\n",
    "\n",
    "aie = calculate_average_indirect_effects(\n",
    "    mt = mt,\n",
    "    prompt = prompt_template,\n",
    "    samples = relation.samples,\n",
    "    corruption_strategy=\"alt_patch\",\n",
    "    n_trials=n_trials,\n",
    "    save_path=f\"../results/causal_tracing/aie/block_state.json\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_trace_heatmap(aie, modelname = mt.name.split(\"/\")[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_template = mt.tokenizer.eos_token + \" \" + relation.prompt_templates[0]\n",
    "\n",
    "hook = \"mlp_after_silu\"\n",
    "aie = calculate_average_indirect_effects(\n",
    "    mt = mt,\n",
    "    prompt = prompt_template,\n",
    "    samples = relation.samples,\n",
    "    corruption_strategy=\"alt_patch\",\n",
    "    n_trials=n_trials,\n",
    "    save_path=f\"../results/causal_tracing/aie/{hook}.json\",\n",
    "    mamba_block_hook = hook\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_trace_heatmap(aie, modelname = mt.name.split(\"/\")[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_template = mt.tokenizer.eos_token + \" \" + relation.prompt_templates[0]\n",
    "\n",
    "hook = \"ssm_after_ssm\"\n",
    "aie = calculate_average_indirect_effects(\n",
    "    mt = mt,\n",
    "    prompt = prompt_template,\n",
    "    samples = relation.samples,\n",
    "    corruption_strategy=\"alt_patch\",\n",
    "    n_trials=n_trials,\n",
    "    save_path=f\"../results/causal_tracing/aie/{hook}.json\",\n",
    "    mamba_block_hook = hook\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_trace_heatmap(aie, modelname = mt.name.split(\"/\")[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "relations",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
