{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b360b04",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21f696b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "! which python"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "115a4228",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "import torch\n",
    "torch.__version__, torch.version.cuda"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f54a334",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src import models, data, lens, functional\n",
    "from src.utils import experiment_utils\n",
    "\n",
    "import logging\n",
    "from src.utils import logging_utils\n",
    "logger = logging.getLogger(__name__)\n",
    "logging.basicConfig(\n",
    "    level=logging.DEBUG,\n",
    "    format=logging_utils.DEFAULT_FORMAT,\n",
    "    datefmt=logging_utils.DEFAULT_DATEFMT,\n",
    "    stream=sys.stdout,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d46d0c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda:0\"\n",
    "mt = models.load_model(\"mamba-3b\", device=device, fp16=False)\n",
    "# mt = models.load_model(\"gptj\", device=device, fp16=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0620a27c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import baukit\n",
    "\n",
    "# mt.model.config.fused_add_norm = False\n",
    "\n",
    "# for block_path in models.determine_layer_paths(mt):\n",
    "#     block = baukit.get_module(mt.model, block_path)\n",
    "#     setattr(block, \"fused_add_norm\", False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7c253f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "mt.model.backbone.layers[0].fused_add_norm, mt.model.config.fused_add_norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9fcd03e",
   "metadata": {},
   "outputs": [],
   "source": [
    "mt.model.config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1fe59d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils.typing import Mamba\n",
    "\n",
    "isinstance(mt.model, Mamba)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e27323d",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"mamba\" in str(type(mt.model))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a921c7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import predict_next_token\n",
    "\n",
    "predict_next_token(\n",
    "    mt = mt, \n",
    "    prompt = mt.tokenizer.eos_token + \" The capital of {} is\".format(\"France\"),\n",
    "    # prompt = mt.tokenizer.eos_token + \" The superlative of {} is\".format(\"good\"),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21f9f678",
   "metadata": {},
   "outputs": [],
   "source": [
    "# mt.model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f51154e",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = data.load_dataset()\n",
    "\n",
    "# relation_names = [r.name for r in dataset.relations]\n",
    "# relation_options = Menu(choices = relation_names, value = relation_names)\n",
    "# show(relation_options) # !caution: tested in a juputer-notebook. baukit visualizations are not supported in vscode."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a17444f",
   "metadata": {},
   "outputs": [],
   "source": [
    "relation_name = \"country capital city\"\n",
    "relation = dataset.filter(relation_names=[relation_name])[0]\n",
    "print(f\"{relation.name} -- {len(relation.samples)} samples\")\n",
    "print(\"------------------------------------------------------\")\n",
    "\n",
    "experiment_utils.set_seed(12345) # set seed to a constant value for sampling consistency\n",
    "train, test = relation.split(5)\n",
    "print(\"\\n\".join([sample.__str__() for sample in train.samples]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "145bf9dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "################### hparams ###################\n",
    "layer = 20\n",
    "beta = 8\n",
    "###############################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83b33032",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from src.operators import JacobianIclMeanEstimator\n",
    "# experiment_utils.set_seed(12345)\n",
    "\n",
    "# estimator = JacobianIclMeanEstimator(\n",
    "#     mt = mt, \n",
    "#     h_layer = layer,\n",
    "#     beta = beta\n",
    "# )\n",
    "# operator = estimator(\n",
    "#     relation.set(\n",
    "#         samples=train.samples, \n",
    "#     )\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6155e9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Optional\n",
    "\n",
    "def add_npz_extension(path: str):\n",
    "    return path if path.endswith(\".npz\") else path + \".npz\"\n",
    "\n",
    "def load_o1_approxes(path: str, sample_subjects: Optional[list[str]] = None):\n",
    "    approxes = []\n",
    "    to_load = sample_subjects if sample_subjects is not None else os.listdir(path)\n",
    "    for cached_file in to_load:\n",
    "        file_path = add_npz_extension(os.path.join(path, cached_file))\n",
    "        approx = functional.load_cached_linear_operator(file_path = file_path)\n",
    "        approxes.append(approx)\n",
    "    return approxes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "508e6d4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "##################################################\n",
    "n_train_samples = 5\n",
    "##################################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1bbc72d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import random\n",
    "\n",
    "root_path = \"../results/cached_o1_approxes\"\n",
    "path = os.path.join(\n",
    "    root_path,\n",
    "    mt.name,\n",
    "    relation_name.lower().replace(\" \", \"_\"),\n",
    "    str(layer),\n",
    ")\n",
    "\n",
    "all_cached_files = list(os.listdir(path))\n",
    "\n",
    "train_subj_files = random.sample(all_cached_files, n_train_samples)\n",
    "\n",
    "train_approxes = load_o1_approxes(\n",
    "    path = path, \n",
    "    sample_subjects = train_subj_files\n",
    ")\n",
    "\n",
    "train_samples = [\n",
    "    data.RelationSample.from_dict(approx.metadata[\"sample\"]) \n",
    "    for approx in train_approxes\n",
    "]\n",
    "\n",
    "train_relation = relation.set(samples=train_samples)\n",
    "\n",
    "test_relation = relation.set(\n",
    "    samples = list(set(relation.samples) - set(train_relation.samples))\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4af14f20",
   "metadata": {},
   "outputs": [],
   "source": [
    "weight = torch.stack([approx.weight for approx in train_approxes]).mean(dim=0)\n",
    "bias = torch.stack([approx.bias for approx in train_approxes]).mean(dim=0)\n",
    "\n",
    "prompt_template = relation.prompt_templates[0]\n",
    "\n",
    "prompt_template_icl = functional.make_prompt(\n",
    "    mt = mt,\n",
    "    prompt_template=prompt_template,\n",
    "    subject=\"{}\",\n",
    "    examples = train_samples\n",
    ")\n",
    "\n",
    "print(prompt_template_icl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32a2b071",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.operators import LinearRelationOperator\n",
    "operator = LinearRelationOperator(\n",
    "    mt = mt,\n",
    "    weight = weight,\n",
    "    bias = bias,\n",
    "    h_layer = train_approxes[0].h_layer,\n",
    "    z_layer = train_approxes[0].z_layer,\n",
    "    beta = 5,\n",
    "    prompt_template = prompt_template_icl,\n",
    "    metadata = {\n",
    "        \"Jh\": [\n",
    "            (approx.weight @ approx.h).detach().cpu().numpy()\n",
    "            for approx in train_approxes\n",
    "        ],\n",
    "        \"|w|\": [approx.weight.norm().item() for approx in train_approxes],\n",
    "        \"|b|\": [approx.bias.norm().item() for approx in train_approxes],\n",
    "    },\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1dfcbb0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.lens import logit_lens\n",
    "from src import models\n",
    "\n",
    "# logit_lens(mt = mt, h = operator.metadata[\"Jh\"][0].to(models.determine_device(mt)) + operator.bias)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4ad24c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import predict_next_token\n",
    "\n",
    "predict_next_token(\n",
    "    mt = mt, \n",
    "    prompt = mt.tokenizer.eos_token + \" The capital of {} is\".format(\"France\"),\n",
    "    # prompt = mt.tokenizer.eos_token + \" The superlative of {} is\".format(\"good\"),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c1b6eda",
   "metadata": {},
   "source": [
    "# Checking $faithfulness$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d79b613",
   "metadata": {},
   "outputs": [],
   "source": [
    "test = functional.filter_relation_samples_based_on_provided_fewshots(\n",
    "    mt=mt, test_relation=test, prompt_template=operator.prompt_template, batch_size=4\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ad18a70",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample = [s for s in test.samples if s.subject == \"France\"][0]\n",
    "print(sample)\n",
    "operator(subject = sample.subject).predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e717d47f",
   "metadata": {},
   "outputs": [],
   "source": [
    "hs_and_zs = functional.compute_hs_and_zs(\n",
    "    mt = mt,\n",
    "    prompt_template = operator.prompt_template,\n",
    "    subjects = [sample.subject],\n",
    "    h_layer= operator.h_layer,\n",
    ")\n",
    "h = hs_and_zs.h_by_subj[sample.subject]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c69136b",
   "metadata": {},
   "source": [
    "## Approximating LM computation $F$ as an affine transformation\n",
    "\n",
    "### $$ F(\\mathbf{s}, c_r) \\approx \\beta \\, W_r \\mathbf{s} + b_r $$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8bf8b1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "z = 5 * (operator.weight @ h) + operator.bias\n",
    "\n",
    "lens.logit_lens(\n",
    "    mt = mt,\n",
    "    h = z,\n",
    "    get_proba = True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0725d34",
   "metadata": {},
   "outputs": [],
   "source": [
    "correct = 0\n",
    "wrong = 0\n",
    "for sample in test.samples:\n",
    "    predictions = operator(subject = sample.subject).predictions\n",
    "    known_flag = functional.is_nontrivial_prefix(\n",
    "        prediction=predictions[0].token, target=sample.object\n",
    "    )\n",
    "    print(f\"{sample.subject=}, {sample.object=}, \", end=\"\")\n",
    "    print(f'predicted=\"{functional.format_whitespace(predictions[0].token)}\", (p={predictions[0].prob}), known=({functional.get_tick_marker(known_flag)})')\n",
    "    \n",
    "    correct += known_flag\n",
    "    wrong += not known_flag\n",
    "    \n",
    "faithfulness = correct/(correct + wrong)\n",
    "\n",
    "print(\"------------------------------------------------------------\")\n",
    "print(f\"Faithfulness (@1) = {faithfulness}\")\n",
    "print(\"------------------------------------------------------------\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d9a13389",
   "metadata": {},
   "source": [
    "# $causality$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da2f8eb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "################### hparams ###################\n",
    "rank = 100\n",
    "###############################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25ac7213",
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_utils.set_seed(12345) # set seed to a constant value for sampling consistency\n",
    "test_targets = functional.random_edit_targets(test.samples)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70d83c9b",
   "metadata": {},
   "source": [
    "## setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a13c0ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "source = test.samples[0]\n",
    "target = test_targets[source]\n",
    "\n",
    "f\"Changing the mapping ({source}) to ({source.subject} -> {target.object})\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e98f67c8",
   "metadata": {},
   "source": [
    "### Calculate $\\Delta \\mathbf{s}$ such that $\\mathbf{s} + \\Delta \\mathbf{s} \\approx \\mathbf{s}'$\n",
    "\n",
    "<p align=\"center\">\n",
    "    <img align=\"center\" src=\"../causality_crop\" style=\"width:80%;\"/>\n",
    "</p>\n",
    "\n",
    "Under the relation $r =\\, $*plays the instrument*, and given the subject $s =\\, $*Miles Davis*, the model will predict $o =\\, $*trumpet* **(a)**; and given the subject $s' =\\, $*Cat Stevens*, the model will now predict $o' =\\, $*guiter* **(b)**. \n",
    "\n",
    "If the computation from $\\mathbf{s}$ to $\\mathbf{o}$ is well-approximated by $operator$ parameterized by $W_r$ and $b_r$ **(c)**, then $\\Delta{\\mathbf{s}}$ **(d)** should tell us the direction of change from $\\mathbf{s}$ to $\\mathbf{s}'$. Thus, $\\tilde{\\mathbf{s}}=\\mathbf{s}+\\Delta\\mathbf{s}$ would be an approximation of $\\mathbf{s}'$ and patching $\\tilde{\\mathbf{s}}$ in place of $\\mathbf{s}$ should change the prediction to $o'$ = *guitar* "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53c632ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_delta_s(\n",
    "    operator, \n",
    "    source_subject, \n",
    "    target_subject,\n",
    "    rank = 100,\n",
    "    fix_latent_norm = None, # if set, will fix the norms of z_source and z_target\n",
    "):\n",
    "    w_p_inv = functional.low_rank_pinv(\n",
    "        matrix = operator.weight,\n",
    "        rank=rank,\n",
    "    )\n",
    "    hs_and_zs = functional.compute_hs_and_zs(\n",
    "        mt = mt,\n",
    "        prompt_template = operator.prompt_template,\n",
    "        subjects = [source_subject, target_subject],\n",
    "        h_layer= operator.h_layer,\n",
    "        z_layer=-1,\n",
    "    )\n",
    "\n",
    "    z_source = hs_and_zs.z_by_subj[source_subject]\n",
    "    z_target = hs_and_zs.z_by_subj[target_subject]\n",
    "    \n",
    "    z_source *= fix_latent_norm / z_source.norm() if fix_latent_norm is not None else 1.0\n",
    "    z_target *= z_source.norm() / z_target.norm() if fix_latent_norm is not None else 1.0\n",
    "\n",
    "    delta_s = w_p_inv @  (z_target.squeeze() - z_source.squeeze())\n",
    "\n",
    "    return delta_s, hs_and_zs\n",
    "\n",
    "delta_s, hs_and_zs = get_delta_s(\n",
    "    operator = operator,\n",
    "    source_subject = source.subject,\n",
    "    target_subject = target.subject,\n",
    "    rank = rank\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab1c7e88",
   "metadata": {},
   "outputs": [],
   "source": [
    "import baukit\n",
    "\n",
    "def get_intervention(h, int_layer, subj_idx):\n",
    "    def edit_output(output, layer):\n",
    "        if(layer != int_layer):\n",
    "            return output\n",
    "        functional.untuple(output)[:, subj_idx] = h \n",
    "        return output\n",
    "    return edit_output\n",
    "\n",
    "prompt = operator.prompt_template.format(source.subject)\n",
    "\n",
    "h_index, inputs = functional.find_subject_token_index(\n",
    "    mt=mt,\n",
    "    prompt=prompt,\n",
    "    subject=source.subject,\n",
    ")\n",
    "\n",
    "h_layer, z_layer = models.determine_layer_paths(model = mt, layers = [layer, -1])\n",
    "\n",
    "with baukit.TraceDict(\n",
    "    mt.model, layers = [h_layer, z_layer],\n",
    "    edit_output=get_intervention(\n",
    "#         h = hs_and_zs.h_by_subj[source.subject],         # let the computation proceed as usual\n",
    "        h = hs_and_zs.h_by_subj[source.subject] + delta_s, # replace s with s + delta_s\n",
    "        int_layer = h_layer, \n",
    "        subj_idx = h_index\n",
    "    )\n",
    ") as traces:\n",
    "    outputs = mt(\n",
    "        input_ids = inputs.input_ids,\n",
    "        attention_mask = inputs.attention_mask,\n",
    "    )\n",
    "\n",
    "logits = outputs.logits[0][-1] if hasattr(outputs, \"logits\") else outputs[0][-1]\n",
    "lens.interpret_logits(\n",
    "    mt = mt, \n",
    "    logits = logits, \n",
    "    get_proba=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3c272c1",
   "metadata": {},
   "source": [
    "## Measuring causality"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51efa257",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.editors import LowRankPInvEditor\n",
    "\n",
    "svd = torch.svd(operator.weight.float())\n",
    "editor = LowRankPInvEditor(\n",
    "    lre=operator,\n",
    "    rank=rank,\n",
    "    svd=svd,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88be35dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# precomputing latents to speed things up\n",
    "hs_and_zs = functional.compute_hs_and_zs(\n",
    "    mt = mt,\n",
    "    prompt_template = operator.prompt_template,\n",
    "    subjects = [sample.subject for sample in test.samples],\n",
    "    h_layer= operator.h_layer,\n",
    "    z_layer=-1,\n",
    "    batch_size = 2\n",
    ")\n",
    "\n",
    "success = 0\n",
    "fails = 0\n",
    "\n",
    "for sample in test.samples:\n",
    "    target = test_targets.get(sample)\n",
    "    assert target is not None\n",
    "    edit_result = editor(\n",
    "        subject = sample.subject,\n",
    "        target = target.subject\n",
    "    )\n",
    "    \n",
    "    success_flag = functional.is_nontrivial_prefix(\n",
    "        prediction=edit_result.predicted_tokens[0].token, target=target.object\n",
    "    )\n",
    "    \n",
    "    print(f\"Mapping {sample.subject} -> {target.object} | edit result={edit_result.predicted_tokens[0]} | success=({functional.get_tick_marker(success_flag)})\")\n",
    "    \n",
    "    success += success_flag\n",
    "    fails += not success_flag\n",
    "    \n",
    "causality = success / (success + fails)\n",
    "\n",
    "print(\"------------------------------------------------------------\")\n",
    "print(f\"Causality (@1) = {causality}\")\n",
    "print(\"------------------------------------------------------------\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d14acef6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import save_linear_operator\n",
    "\n",
    "save_linear_operator(\n",
    "    approx = operator,\n",
    "    file_name = \"lre_capital\",\n",
    "    path = \"cached\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1cbbdde",
   "metadata": {},
   "outputs": [],
   "source": [
    "operator_loaded = functional.load_cached_linear_operator(mt = mt, file_path = \"cached/lre_capital.npz\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "69a2069b",
   "metadata": {},
   "source": [
    "## Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68c55042",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = mt.tokenizer.eos_token + \" Michael Jordan professionally played the sport of\"\n",
    "tokenized = mt.tokenizer(prompt, return_tensors=\"pt\").to(device)\n",
    "models.determine_layer_paths(mt)[layer]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5459d50",
   "metadata": {},
   "outputs": [],
   "source": [
    "import baukit\n",
    "\n",
    "layer_out = models.determine_layer_paths(mt)[layer]\n",
    "mixer_out = models.determine_layer_paths(mt)[layer] + \".mixer\"\n",
    "layer_in = models.determine_layer_paths(mt)[layer+1]\n",
    "final_layer = models.determine_layer_paths(mt)[-1]\n",
    "final_mixer = models.determine_layer_paths(mt)[-1] + \".mixer\"\n",
    "\n",
    "with baukit.TraceDict(\n",
    "    module=mt.model,\n",
    "    layers=[\n",
    "        layer_out,\n",
    "        mixer_out,\n",
    "        layer_in,\n",
    "        final_mixer,\n",
    "        final_layer,\n",
    "        \"backbone\"\n",
    "    ],\n",
    "    retain_input=True,\n",
    ") as traces:\n",
    "    output = mt(**tokenized)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87121bf9",
   "metadata": {},
   "outputs": [],
   "source": [
    "traces[\"backbone\"].output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7473e5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "traces[final_layer].output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5242c6d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "\n",
    "class RMSNorm(nn.Module):\n",
    "    def __init__(self,\n",
    "                 d_model: int,\n",
    "                 eps: float = 1e-5):\n",
    "        super().__init__()\n",
    "        self.eps = eps\n",
    "        self.weight = baukit.get_module(mt.model, \"backbone.norm_f\").weight\n",
    "\n",
    "\n",
    "    def forward(self, x):\n",
    "        output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight\n",
    "\n",
    "        return output\n",
    "\n",
    "custom_rms = RMSNorm(d_model = models.determine_hidden_size(mt))\n",
    "block_output, residual = traces[final_layer].output\n",
    "backbone_output = custom_rms(block_output + residual)\n",
    "\n",
    "backbone_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2240b1c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.allclose(traces[mixer_out].output, traces[layer_out].output[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45c1334d",
   "metadata": {},
   "outputs": [],
   "source": [
    "baukit.get_module(mt.model, \"backbone.norm_f\").bias"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96443785",
   "metadata": {},
   "outputs": [],
   "source": [
    "hasattr(mt.model, \"backbone\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ecaaf9c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_layers = 32\n",
    "n_approx = 25\n",
    "time_per_approx = 7 # in minutes\n",
    "size_per_approx = 26 # in MB\n",
    "n_threads = 2\n",
    "\n",
    "\n",
    "total_time = n_layers * n_approx * time_per_approx\n",
    "total_time /= n_threads\n",
    "total_time /= 60\n",
    "\n",
    "total_size = n_layers * n_approx * size_per_approx\n",
    "total_size /= 1024\n",
    "\n",
    "f\"Time: {total_time} hours, Size: {total_size} GB\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4285ebcc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "layers = np.arange(0, 64, 2).tolist()\n",
    "layers[::2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f7ae65e",
   "metadata": {},
   "outputs": [],
   "source": [
    "layers[1::2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f6ee9a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "\" \".join([str(l) for l in layers[::2]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8ef6748",
   "metadata": {},
   "outputs": [],
   "source": [
    "(16 * time_per_approx * n_approx)/60"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c101157c",
   "metadata": {},
   "outputs": [],
   "source": [
    "21 * 47"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c4b8862",
   "metadata": {},
   "outputs": [],
   "source": [
    "(47 * 32 * 25 * 26) // 1024 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a2ccdb0",
   "metadata": {},
   "outputs": [],
   "source": [
    "512/8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6272f3b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
