{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "import torch\n",
    "torch.__version__, torch.version.cuda"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src import models, data, lens, functional\n",
    "from src.utils import experiment_utils\n",
    "\n",
    "import logging\n",
    "from src.utils import logging_utils\n",
    "logger = logging.getLogger(__name__)\n",
    "logging.basicConfig(\n",
    "    level=logging.DEBUG,\n",
    "    format=logging_utils.DEFAULT_FORMAT,\n",
    "    datefmt=logging_utils.DEFAULT_DATEFMT,\n",
    "    stream=sys.stdout,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda:0\"\n",
    "mt = models.load_model(\"mamba-3b\", device=device, fp16=False)\n",
    "# mt = models.load_model(\"gptj\", device=device, fp16=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = data.load_dataset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "relation_name = \"country capital city\"\n",
    "relation = dataset.filter(relation_names=[relation_name])[0]\n",
    "print(f\"{relation.name} -- {len(relation.samples)} samples\")\n",
    "print(\"------------------------------------------------------\")\n",
    "prompt_template = relation.prompt_templates[0]\n",
    "prompt_template"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for sample in relation.samples:\n",
    "#     icl_examples = (\n",
    "#         relation.set(samples=list(set(relation.samples) - set([sample])))\n",
    "#         .split(train_size=min(3, len(relation.samples) - 1))[0]\n",
    "#         .samples\n",
    "#     )\n",
    "#     print(f\"{sample} | {[str(s) for s in icl_examples]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from src.sweeps import load_o1_approxes\n",
    "\n",
    "# # training_subjects = [\n",
    "# #     'Michael', 'Benjamin', 'Scarlett', 'Oliver', 'Tom'\n",
    "# # ]\n",
    "# # training_subjects = [s.lower().replace(\" \", \"_\") for s in training_subjects]\n",
    "\n",
    "# training_subjects = [\"John Adams\", \"Bill Clinton\", \"Andrew Jackson\", \"Barack Obama\", \"John Quincy Adams\"]\n",
    "# train_samples = [sample for sample in relation.samples if sample.subject in training_subjects]\n",
    "\n",
    "# test_relation = relation.set(\n",
    "#     samples = list(set(relation.samples) - set(train_samples))\n",
    "# )\n",
    "\n",
    "# icl_prompt = functional.make_prompt(\n",
    "#     mt=mt,\n",
    "#     prompt_template=prompt_template,\n",
    "#     subject=\"{}\",\n",
    "#     examples=train_samples,\n",
    "# )\n",
    "\n",
    "# test_relation = (\n",
    "#     functional.filter_relation_samples_based_on_provided_fewshots(\n",
    "#         mt=mt,\n",
    "#         test_relation=test_relation,\n",
    "#         prompt_template=icl_prompt,\n",
    "#     )\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Precompute all the hs to speed things up.\n",
    "# hs_by_subj, zs_by_subj = functional.compute_hs_and_zs(\n",
    "#     mt=mt,\n",
    "#     prompt_template=prompt_template,\n",
    "#     subjects=[x.subject for x in test_relation.samples],\n",
    "#     h_layer=[2, 4, 6, 8],\n",
    "#     z_layer=-1,\n",
    "#     batch_size=1,\n",
    "#     examples=train_samples,\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_subjects = [x.subject for x in relation.samples][:3]\n",
    "training_subjects = [functional.subject_to_filename(s) for s in training_subjects]\n",
    "\n",
    "training_subjects"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.sweeps import get_samples_corresponding_to_cached_file_names\n",
    "\n",
    "get_samples_corresponding_to_cached_file_names(relation=relation, cached_file_names=training_subjects)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.sweeps import load_o1_approxes\n",
    "layer = 18\n",
    "\n",
    "path = f\"../results/cache_o1_approxes/{mt.name}/{relation_name.lower().replace(' ', '_')}/{layer}\"\n",
    "print(path)\n",
    "\n",
    "train_approxes = load_o1_approxes(\n",
    "     path=path, sample_subjects=training_subjects\n",
    ")\n",
    "\n",
    "train_samples = [\n",
    "    data.RelationSample.from_dict(approx.metadata[\"sample\"])\n",
    "    for approx in train_approxes\n",
    "]\n",
    "train_relation = relation.set(samples=train_samples)\n",
    "test_relation = relation.set(\n",
    "    samples=list(set(relation.samples) - set(train_relation.samples))\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(train_approxes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "weight = torch.stack(\n",
    "    [approx.weight for approx in train_approxes]\n",
    ").mean(dim=0)\n",
    "bias = torch.stack([approx.bias for approx in train_approxes]).mean(\n",
    "    dim=0\n",
    ")\n",
    "prompt_template_icl = functional.make_prompt(\n",
    "    mt=mt,\n",
    "    prompt_template=prompt_template,\n",
    "    subject=\"{}\",\n",
    "    examples=train_samples,\n",
    ")\n",
    "\n",
    "from src.operators import LinearRelationOperator\n",
    "\n",
    "operator = LinearRelationOperator(\n",
    "    mt=mt,\n",
    "    weight=weight,\n",
    "    bias=bias,\n",
    "    h_layer=train_approxes[0].h_layer,\n",
    "    z_layer=train_approxes[0].z_layer,\n",
    "    prompt_template=prompt_template_icl,\n",
    "    metadata={\n",
    "        \"Jh\": [\n",
    "            (approx.weight @ approx.h).detach().cpu()\n",
    "            for approx in train_approxes\n",
    "        ],\n",
    "        \"|w|\": [\n",
    "            approx.weight.norm().item() for approx in train_approxes\n",
    "        ],\n",
    "        \"|b|\": [\n",
    "            approx.bias.norm().item() for approx in train_approxes\n",
    "        ],\n",
    "    },\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "weight = operator.weight.clone()\n",
    "svd = torch.svd(weight.float())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ranks = torch.arange(250, 260, 2)\n",
    "\n",
    "from src import editors, metrics\n",
    "from src.utils import experiment_utils\n",
    "\n",
    "experiment_utils.set_seed(71745)\n",
    "\n",
    "test_samples = test_relation.samples\n",
    "test_targets = functional.random_edit_targets(test_samples)\n",
    "\n",
    "\n",
    "hs_by_subj, zs_by_subj = functional.compute_hs_and_zs(\n",
    "    mt=mt,\n",
    "    prompt_template=prompt_template,\n",
    "    subjects=[x.subject for x in test_relation.samples],\n",
    "    h_layer=[layer],\n",
    "    z_layer=-1,\n",
    "    batch_size=4,\n",
    "    examples=train_samples,\n",
    ")\n",
    "\n",
    "\n",
    "for rank in ranks:\n",
    "    editor = editors.LowRankPInvEditor(\n",
    "        lre=operator,\n",
    "        rank=rank,\n",
    "        n_samples=1,\n",
    "        n_new_tokens=1,\n",
    "        svd=svd,\n",
    "    )\n",
    "\n",
    "    pred_objects = []\n",
    "    targ_objects = []\n",
    "    efficacy_successes = []\n",
    "    for sample in test_samples:\n",
    "        target = test_targets.get(sample)\n",
    "        assert target is not None\n",
    "        if target is None:\n",
    "            logger.debug(f\"cannot edit {target}, skipping\")\n",
    "            continue\n",
    "\n",
    "        z_original = zs_by_subj[sample.subject]\n",
    "        z_target = zs_by_subj[target.subject]\n",
    "        result = editor(\n",
    "            sample.subject,\n",
    "            target.subject,\n",
    "            z_original=z_original,\n",
    "            z_target=z_target,\n",
    "        )\n",
    "\n",
    "        pred = str(result.predicted_tokens[0])\n",
    "\n",
    "        tick = \"✗\"\n",
    "        if functional.is_nontrivial_prefix(\n",
    "            prediction=result.predicted_tokens[0].token,\n",
    "            target=target.object,\n",
    "        ):\n",
    "            tick = \"✓\"\n",
    "\n",
    "        logger.debug(\n",
    "            f\"editing: {layer=} {rank=} {sample.subject=} | {target.subject=} -> {target.object=} |>> {pred=} ({tick})\"\n",
    "        )\n",
    "\n",
    "        pred_objects.append([p.token for p in result.predicted_tokens])\n",
    "        targ_objects.append(target.object)\n",
    "\n",
    "\n",
    "    efficacy = metrics.recall(pred_objects, targ_objects)\n",
    "    logger.info(\"-\" * 80)\n",
    "    logger.info(f\"editing finished: {layer=} {rank=} {efficacy=}\")\n",
    "    logger.info(\"-\" * 80)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "relation_names = [\n",
    "    # \"name gender\",\n",
    "    # \"occupation gender\",\n",
    "    \"president birth year\", # even the caching failed\n",
    "    \"president election year\", # even the caching failed\n",
    "    # \"person lead singer of band\",\n",
    "    # \"country capital city\",\n",
    "    # \"country language\",\n",
    "    # \"country largest city\",\n",
    "    # \"city in country\",\n",
    "    \"characteristic gender\",\n",
    "    # ---------------------------------\n",
    "    # \"fruit outside color\",\n",
    "    # \"country currency\",\n",
    "    # \"food from country\",\n",
    "    # \"name birthplace\",\n",
    "    # \"name religion\",\n",
    "    # \"task person type\",\n",
    "    # \"fruit inside color\",\n",
    "    \"univ degree gender\",\n",
    "    # \"work location\",\n",
    "    # \"pokemon evolution\",\n",
    "    # \"occupation age\",\n",
    "    # ---------------------------------\n",
    "    # \"substance phase of matter\",\n",
    "    # \"task done by tool\",\n",
    "    # \"word sentiment\",\n",
    "    # \"adjective comparative\",\n",
    "    # \"object superclass\",\n",
    "    # \"verb past tense\",\n",
    "    # \"adjective superlative\",\n",
    "    # \"person university\",\n",
    "    # \"superhero archnemesis\",\n",
    "    # \"superhero person\",\n",
    "    # \"adjective antonym\",\n",
    "    # ---------------------------------\n",
    "    # big relations\n",
    "    # ---------------------------------\n",
    "    # \"word first letter\",\n",
    "    # \"word last letter\",\n",
    "    # \"company CEO\",\n",
    "    # \"plays pro sport\",\n",
    "    # \"star constellation name\",\n",
    "    # \"person plays instrument\",\n",
    "    # \"product by company\",\n",
    "    # \"company hq\",\n",
    "    # \"person occupation\",\n",
    "    # \"landmark in country\",\n",
    "    # \"person native language\",\n",
    "    # \"landmark on continent\",\n",
    "    # \"person sport position\",\n",
    "    # \"person father\",\n",
    "    # \"person mother\",\n",
    "]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "fact",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
