{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "sys.path.append(\"../\")\n",
    "\n",
    "import torch\n",
    "import transformers\n",
    "import baukit\n",
    "from tqdm.auto import tqdm\n",
    "import json\n",
    "import os\n",
    "from src import functional\n",
    "import src.tokens as tokenization_utils\n",
    "import numpy as np\n",
    "import logging\n",
    "from src import models\n",
    "\n",
    "from src.utils import logging_utils\n",
    "logger = logging.getLogger(__name__)\n",
    "logging.basicConfig(\n",
    "    level=logging.DEBUG,\n",
    "    format=logging_utils.DEFAULT_FORMAT,\n",
    "    datefmt=logging_utils.DEFAULT_DATEFMT,\n",
    "    stream=sys.stdout,\n",
    ")\n",
    "\n",
    "torch.__version__, transformers.__version__, torch.version.cuda"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.models import ModelandTokenizer\n",
    "\n",
    "MODEL_PATH = \"state-spaces/mamba-2.8b\" # state-spaces/mamba-2.8b\n",
    "\n",
    "mt = ModelandTokenizer(\n",
    "    model_path=MODEL_PATH, \n",
    "    torch_dtype=torch.float32\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#####################################################\n",
    "subject = \"The Space Needle\"\n",
    "# subject = \"The Statue of Liberty\"\n",
    "prompt_template = \"{} is located in the city of\"\n",
    "# prompt_template = tokenization_utils.maybe_prefix_eos(\n",
    "#     mt.tokenizer, prompt_template\n",
    "# )\n",
    "#####################################################\n",
    "\n",
    "prompt = prompt_template.format(subject)\n",
    "prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import predict_next_token\n",
    "\n",
    "predict_next_token(\n",
    "    mt,\n",
    "    # prompt=prompt,\n",
    "    prompt = prompt_template.format(\"Colosseum\"),\n",
    "    k=5,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from src.data.dataclasses import MultiCounterFactDataset\n",
    "\n",
    "# dataset = MultiCounterFactDataset(\"../data\")\n",
    "\n",
    "request = {\n",
    "    \"prompt\": prompt_template,\n",
    "    \"subject\": subject,\n",
    "    \"target_new\": {\"str\": \"ROME\"},\n",
    "}\n",
    "\n",
    "generation_prompts = [\n",
    "    f\"{subject} is located in the city of\",\n",
    "    f\"{subject}, which is in the city of\",\n",
    "    f\"Which city is the {subject} in? It is in\",\n",
    "    f\"{subject} is made of\",\n",
    "    f\"{subject} is in\"\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.rome.compute_v import compute_v, get_module_input_output_at_word\n",
    "\n",
    "context_templates=[\n",
    "    '{}', \n",
    "    'The first step to a new life is to. {}', \n",
    "    'Therefore, the best way to prevent this from. {}', \n",
    "    'Because the first time I saw the trailer. {}', \n",
    "    \"I'm not sure if this is the. {}\", \n",
    "    'You are here: Home / Archives for . {}', \n",
    "]\n",
    "words= [subject] * len(context_templates)\n",
    "\n",
    "l_input, l_output = get_module_input_output_at_word(\n",
    "    mt, \n",
    "    layer = 15,\n",
    "    context_template = request[\"prompt\"],\n",
    "    word = request[\"subject\"],\n",
    "    module_template=mt.layer_name_format + \".mixer.out_proj\",\n",
    "    fact_token_strategy=\"subject_last\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.rome_utils import nethook\n",
    "\n",
    "tokenized = mt.tokenizer(prompt, return_tensors=\"pt\", padding=True, return_offsets_mapping=True).to(mt.device)\n",
    "offsets = tokenized.pop(\"offset_mapping\")\n",
    "\n",
    "[(idx, mt.tokenizer.decode(t)) for idx, t in enumerate(tokenized.input_ids[0])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# with nethook.Trace(\n",
    "#     module = mt.model,\n",
    "#     layer = mt.layer_name_format.format(15) + \".mixer\",\n",
    "#     retain_output = True,\n",
    "#     retain_input = True,\n",
    "# ) as tr:\n",
    "#     output = mt(**tokenized)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.rome.rome_hparams import ROMEHyperParams\n",
    "\n",
    "hparams = ROMEHyperParams(\n",
    "    layers = [15],\n",
    "    fact_token=\"subject_last\",\n",
    "    v_num_grad_steps=25,\n",
    "    v_lr=5e-1,\n",
    "    v_loss_layer=models.determine_layers(mt)[-1],\n",
    "    v_weight_decay=0.5,\n",
    "    clamp_norm_factor=3,\n",
    "    kl_factor=0.0625,\n",
    "    mom2_adjustment=True,\n",
    "    context_template_length_params=[[5, 10], [10, 10]],\n",
    "\n",
    "    rewrite_module_tmp=mt.layer_name_format + \".mixer.in_proj\",\n",
    "    layer_module_tmp=mt.layer_name_format,\n",
    "    mlp_module_tmp=\"\",\n",
    "    attn_module_tmp=\"\",\n",
    "    ln_f_module=models.determine_final_layer_norm_path(mt),\n",
    "    lm_head_module=models.determine_lm_head_path(mt),\n",
    "    \n",
    "    mom2_dataset=\"wikipedia\",\n",
    "    mom2_n_samples=1000,\n",
    "    mom2_dtype=\"float32\",\n",
    "\n",
    "    mamba_block_non_ssm=True, # will effect the non-ssm flow only, default is false\n",
    "    # mamba_block_ssm=True, # will effect the ssm flow only, default is false\n",
    ")\n",
    "\n",
    "import json\n",
    "print(json.dumps(hparams.__dict__, indent=2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.rome.rome_main import get_context_templates\n",
    "\n",
    "get_context_templates(\n",
    "    mt = mt,\n",
    "    length_params=[[5, 10], [10, 10]]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# mt.model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.rome.compute_v import compute_v\n",
    "\n",
    "v = compute_v(\n",
    "    mt = mt,\n",
    "    request = request,\n",
    "    hparams = hparams,\n",
    "    layer = 15,\n",
    "    context_templates=context_templates,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "functional.free_gpu_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.rome.rome_main import (\n",
    "    apply_rome_to_model,\n",
    "    restore_weights,\n",
    "    save_weights,\n",
    ")\n",
    "\n",
    "model, orig_weights = apply_rome_to_model(\n",
    "    mt = mt, \n",
    "    requests=request,\n",
    "    hparams=hparams,\n",
    "    # cache_template=\n",
    ")\n",
    "\n",
    "rome_weights = save_weights(model, list(orig_weights.keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generation_prompts = [\n",
    "    f\"{subject} is located in the city of\",\n",
    "    f\"{subject}, which is in the city of\",\n",
    "    f\"Which city is the {subject} in? It is in\",\n",
    "    f\"{subject} is made of\",\n",
    "    f\"{subject} is in\",\n",
    "    f\"The Statue of Liberty is located in the city of\",\n",
    "    f\"Colosseum is located in the city of\",\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils.generation import generate_fast\n",
    "\n",
    "restore_weights(model, rome_weights)\n",
    "generate_fast(\n",
    "    mt = mt, \n",
    "    prompts = generation_prompts,\n",
    "    max_out_len = 50,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "restore_weights(model, orig_weights)\n",
    "generate_fast(\n",
    "    mt = mt, \n",
    "    prompts = generation_prompts,\n",
    "    max_out_len = 50,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "fact",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
