{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "parent_dir = os.path.abspath('..')\n",
    "sys.path.append(parent_dir)\n",
    "\n",
    "from nnsight import LanguageModel\n",
    "from activation_utils import SparseAct\n",
    "import torch as t\n",
    "import plotly.graph_objects as go\n",
    "from loading_utils import load_examples\n",
    "from dictionary_learning import AutoEncoder\n",
    "from dictionary_learning.dictionary import IdentityDict\n",
    "from ablation import run_with_ablations\n",
    "from scipy import interpolate\n",
    "import math\n",
    "from statistics import stdev\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cuda:0'\n",
    "model = LanguageModel('EleutherAI/pythia-70m-deduped', device_map=device, dispatch=True)\n",
    "\n",
    "start_layer = 2 # explain the model starting here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load submodules\n",
    "submodules = []\n",
    "if start_layer < 0: submodules.append(model.gpt_neox.embed_in)\n",
    "for i in range(start_layer, len(model.gpt_neox.layers)):\n",
    "    submodules.extend([\n",
    "        model.gpt_neox.layers[i].attention,\n",
    "        model.gpt_neox.layers[i].mlp,\n",
    "        model.gpt_neox.layers[i]\n",
    "    ])\n",
    "\n",
    "submod_names = {\n",
    "    model.gpt_neox.embed_in : 'embed'\n",
    "}\n",
    "for i in range(len(model.gpt_neox.layers)):\n",
    "    submod_names[model.gpt_neox.layers[i].attention] = f'attn_{i}'\n",
    "    submod_names[model.gpt_neox.layers[i].mlp] = f'mlp_{i}'\n",
    "    submod_names[model.gpt_neox.layers[i]] = f'resid_{i}'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load dictionaries\n",
    "dict_id = 10\n",
    "\n",
    "activation_dim = 512\n",
    "expansion_factor = 64\n",
    "dict_size = expansion_factor * activation_dim\n",
    "\n",
    "feat_dicts = {}\n",
    "feat_dicts[model.gpt_neox.embed_in] = AutoEncoder.from_pretrained(\n",
    "    f'../dictionaries/pythia-70m-deduped/embed/{dict_id}_{dict_size}/ae.pt', device=device\n",
    ")\n",
    "for i in range(len(model.gpt_neox.layers)):\n",
    "    feat_dicts[model.gpt_neox.layers[i].attention] = AutoEncoder.from_pretrained(\n",
    "        f'../dictionaries/pythia-70m-deduped/attn_out_layer{i}/{dict_id}_{dict_size}/ae.pt', device=device\n",
    "    )\n",
    "    feat_dicts[model.gpt_neox.layers[i].mlp] = AutoEncoder.from_pretrained(\n",
    "        f'../dictionaries/pythia-70m-deduped/mlp_out_layer{i}/{dict_id}_{dict_size}/ae.pt', device=device\n",
    "    )\n",
    "    feat_dicts[model.gpt_neox.layers[i]] = AutoEncoder.from_pretrained(\n",
    "        f'../dictionaries/pythia-70m-deduped/resid_out_layer{i}/{dict_id}_{dict_size}/ae.pt', device=device\n",
    "    )\n",
    "\n",
    "neuron_dicts = {\n",
    "    submod : IdentityDict(activation_dim).to(device) for submod in submodules\n",
    "}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# use mean ablation\n",
    "ablation_fn = lambda x: x.mean(dim=0).expand_as(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get m(C) for the circuit obtained by thresholding nodes with the given threshold\n",
    "def get_fcs(\n",
    "        dataset,\n",
    "        model,\n",
    "        submodules,\n",
    "        dictionaries,\n",
    "        ablation_fn,\n",
    "        thresholds,\n",
    "        length,\n",
    "        handle_errors = 'default', # also 'remove' or 'resid_only'\n",
    "        use_neurons = False,\n",
    "        random = False\n",
    "):\n",
    "    # load data \n",
    "    if not use_neurons:\n",
    "        circuit = t.load(f'../circuits/{dataset}_train_dict10_node0.1_edge0.01_n100_aggnone.pt')['nodes']\n",
    "    else:\n",
    "        circuit = t.load(f'../circuits/{dataset}_train_dictid_node0.1_edge0.01_n100_aggnone.pt')['nodes']\n",
    "    examples = load_examples(f'/share/projects/dictionary_circuits/data/phenomena/{dataset}_test.json', 40, model, length=length)\n",
    "    clean_inputs = t.cat([e['clean_prefix'] for e in examples], dim=0).to('cuda:0')\n",
    "    clean_answer_idxs = t.tensor([e['clean_answer'] for e in examples], dtype=t.long, device='cuda:0')\n",
    "    patch_inputs = t.cat([e['patch_prefix'] for e in examples], dim=0).to('cuda:0')\n",
    "    patch_answer_idxs = t.tensor([e['patch_answer'] for e in examples], dtype=t.long, device='cuda:0')\n",
    "    def metric_fn(model):\n",
    "        return (\n",
    "            - t.gather(model.embed_out.output[:,-1,:], dim=-1, index=patch_answer_idxs.view(-1, 1)).squeeze(-1) + \\\n",
    "            t.gather(model.embed_out.output[:,-1,:], dim=-1, index=clean_answer_idxs.view(-1, 1)).squeeze(-1)\n",
    "        )\n",
    "    \n",
    "    with t.no_grad():\n",
    "        out = {}\n",
    "\n",
    "        # get F(M)\n",
    "        with model.trace(clean_inputs):\n",
    "            metric = metric_fn(model).save()\n",
    "        fm = metric.value.mean().item()\n",
    "\n",
    "        out['fm'] = fm\n",
    "\n",
    "        # get m(∅)\n",
    "        fempty = run_with_ablations(\n",
    "            clean_inputs,\n",
    "            patch_inputs,\n",
    "            model,\n",
    "            submodules,\n",
    "            dictionaries,\n",
    "            nodes = {\n",
    "                submod : SparseAct(\n",
    "                    act=t.zeros(dict_size if not use_neurons else activation_dim, dtype=t.bool), \n",
    "                    resc=t.zeros(1, dtype=t.bool)).to(device)\n",
    "                for submod in submodules\n",
    "            },\n",
    "            metric_fn=metric_fn,\n",
    "            ablation_fn=ablation_fn,\n",
    "        ).mean().item()\n",
    "        out['fempty'] = fempty\n",
    "\n",
    "        for threshold in thresholds:\n",
    "            out[threshold] = {}\n",
    "            nodes = {\n",
    "                submod : circuit[submod_names[submod]].abs() > threshold for submod in submodules\n",
    "            }\n",
    "\n",
    "            if handle_errors == 'remove':\n",
    "                for k in nodes: nodes[k].resc = t.zeros_like(nodes[k].resc, dtype=t.bool)\n",
    "            elif handle_errors == 'resid_only':\n",
    "                for k in nodes:\n",
    "                    if k not in model.gpt_neox.layers: nodes[k].resc = t.zeros_like(nodes[k].resc, dtype=t.bool)\n",
    "\n",
    "            n_nodes = sum([n.act.sum() + n.resc.sum() for n in nodes.values()]).item()\n",
    "            if random:\n",
    "                total_nodes = sum([n.act.numel() + n.resc.numel() for n in nodes.values()])\n",
    "                p = n_nodes / total_nodes\n",
    "                for k in nodes:\n",
    "                    nodes[k].act = t.bernoulli(t.ones_like(nodes[k].act, dtype=t.float) * p).to(device).to(dtype=t.bool)\n",
    "                    nodes[k].resc = t.ones_like(nodes[k].resc, dtype=t.bool).to(device)\n",
    "                out[threshold]['n_nodes'] = sum([n.act.sum() + n.resc.sum() for n in nodes.values()]).item()\n",
    "            else:\n",
    "                out[threshold]['n_nodes'] = n_nodes\n",
    "            \n",
    "            out[threshold]['fc'] = run_with_ablations(\n",
    "                clean_inputs,\n",
    "                patch_inputs,\n",
    "                model,\n",
    "                submodules,\n",
    "                dictionaries,\n",
    "                nodes=nodes,\n",
    "                metric_fn=metric_fn,\n",
    "                ablation_fn=ablation_fn,\n",
    "            ).mean().item()\n",
    "            out[threshold]['fccomp'] = run_with_ablations(\n",
    "                clean_inputs,\n",
    "                patch_inputs,\n",
    "                model,\n",
    "                submodules,\n",
    "                dictionaries,\n",
    "                nodes=nodes,\n",
    "                metric_fn=metric_fn,\n",
    "                ablation_fn=ablation_fn,\n",
    "                complement=True\n",
    "            ).mean().item()\n",
    "            out[threshold]['faithfulness'] = (out[threshold]['fc'] - fempty) / (fm - fempty)\n",
    "            out[threshold]['completeness'] = (out[threshold]['fccomp'] - fempty) / (fm - fempty)\n",
    "    return out\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataset : number of tokens in inputs from dataset\n",
    "datasets = {\n",
    "    'rc' : 6,\n",
    "    'nounpp' : 5,\n",
    "    'simple' : 2,\n",
    "    'within_rc' : 5\n",
    "}\n",
    "thresholds = t.logspace(-4, 0, 15, 10).tolist()\n",
    "outs = {\n",
    "    'features' : {\n",
    "        dataset : get_fcs(\n",
    "            dataset,\n",
    "            model,\n",
    "            submodules,\n",
    "            feat_dicts,\n",
    "            ablation_fn=ablation_fn,\n",
    "            thresholds = thresholds,\n",
    "            length=length,\n",
    "        ) for dataset, length in datasets.items()\n",
    "    },\n",
    "    'features_wo_errs' : {\n",
    "        dataset : get_fcs(\n",
    "            dataset,\n",
    "            model,\n",
    "            submodules,\n",
    "            feat_dicts,\n",
    "            ablation_fn=ablation_fn,\n",
    "            thresholds = thresholds,\n",
    "            length=length,\n",
    "            handle_errors='remove'\n",
    "        ) for dataset, length in datasets.items()\n",
    "    },\n",
    "    'features_wo_some_errs' : {\n",
    "        dataset : get_fcs(\n",
    "            dataset,\n",
    "            model,\n",
    "            submodules,\n",
    "            feat_dicts,\n",
    "            ablation_fn=ablation_fn,\n",
    "            thresholds = thresholds,\n",
    "            length=length,\n",
    "            handle_errors='resid_only'\n",
    "        ) for dataset, length in datasets.items()\n",
    "    },\n",
    "    'neurons' : {\n",
    "        dataset : get_fcs(\n",
    "            dataset,\n",
    "            model,\n",
    "            submodules,\n",
    "            neuron_dicts,\n",
    "            ablation_fn=ablation_fn,\n",
    "            thresholds = thresholds,\n",
    "            length=length,\n",
    "            use_neurons=True\n",
    "        ) for dataset, length in datasets.items()\n",
    "    },\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot faithfulness results\n",
    "fig = go.Figure()\n",
    "\n",
    "colors = {\n",
    "    'features' : 'blue',\n",
    "    'features_wo_errs' : 'red',\n",
    "    'features_wo_some_errs' : 'green',\n",
    "    'neurons' : 'purple',\n",
    "    # 'random_features' : 'black'\n",
    "}\n",
    "\n",
    "for setting, subouts in outs.items():\n",
    "\n",
    "    x_min = max([min(subouts[dataset][t]['n_nodes'] for t in thresholds) for dataset in datasets]) + 1\n",
    "    x_max = min([max(subouts[dataset][t]['n_nodes'] for t in thresholds) for dataset in datasets]) - 1\n",
    "    fs = {\n",
    "        dataset : interpolate.interp1d([subouts[dataset][t]['n_nodes'] for t in thresholds], [subouts[dataset][t]['faithfulness'] for t in thresholds])\n",
    "        for dataset in datasets\n",
    "    }\n",
    "    xs = t.logspace(math.log10(x_min), math.log10(x_max), 100, 10).tolist()\n",
    "\n",
    "    for dataset in datasets:\n",
    "\n",
    "        \n",
    "\n",
    "        fig.add_trace(go.Scatter(\n",
    "            x = [subouts[dataset][t]['n_nodes'] for t in thresholds],\n",
    "            y = [subouts[dataset][t]['faithfulness'] for t in thresholds],\n",
    "            mode='lines', line=dict(color=colors[setting]), opacity=0.17, showlegend=False\n",
    "        ))\n",
    "\n",
    "    fig.add_trace(go.Scatter(\n",
    "        x=xs,\n",
    "        y=[ sum([f(x) for f in fs.values()]) / len(fs) for x in xs ],\n",
    "        mode='lines', line=dict(color=colors[setting]), name=setting\n",
    "    ))\n",
    "\n",
    "fig.update_xaxes(range=(0, 1700))\n",
    "fig.update_yaxes(range=(0, 1.1))\n",
    "\n",
    "fig.update_layout(\n",
    "    xaxis_title='Nodes',\n",
    "    yaxis_title='Faithfulness',\n",
    "    width=800,\n",
    "    height=375,\n",
    "    # set white background color\n",
    "    plot_bgcolor='rgba(0,0,0,0)',\n",
    "    # add grey gridlines\n",
    "    yaxis=dict(gridcolor='rgb(200,200,200)',mirror=True,ticks='outside',showline=True),\n",
    "    xaxis=dict(gridcolor='rgb(200,200,200)', mirror=True, ticks='outside', showline=True),\n",
    "\n",
    ")\n",
    "\n",
    "# fig.show()\n",
    "fig.write_image('faithfulness.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot completeness results\n",
    "fig = go.Figure()\n",
    "\n",
    "colors = {\n",
    "    'features' : 'blue',\n",
    "    'features_wo_errs' : 'red',\n",
    "    'features_wo_some_errs' : 'green',\n",
    "    'neurons' : 'purple'\n",
    "}\n",
    "\n",
    "for setting, subouts in outs.items():\n",
    "\n",
    "    x_min = max([min(subouts[dataset][t]['n_nodes'] for t in thresholds) for dataset in datasets]) + 1\n",
    "    x_max = min([max(subouts[dataset][t]['n_nodes'] for t in thresholds) for dataset in datasets]) - 1\n",
    "    fs = {\n",
    "        dataset : interpolate.interp1d([subouts[dataset][t]['n_nodes'] for t in thresholds], [subouts[dataset][t]['completeness'] for t in thresholds])\n",
    "        for dataset in datasets\n",
    "    }\n",
    "    xs = t.logspace(math.log10(x_min), math.log10(x_max), 100, 10).tolist()\n",
    "\n",
    "    for dataset in datasets:\n",
    "        fig.add_trace(go.Scatter(\n",
    "            x = [subouts[dataset][t]['n_nodes'] for t in thresholds],\n",
    "            y = [subouts[dataset][t]['completeness'] for t in thresholds],\n",
    "            mode='lines', line=dict(color=colors[setting]), opacity=0.17, showlegend=False\n",
    "        ))\n",
    "    fig.add_trace(go.Scatter(\n",
    "        x=xs,\n",
    "        y=[ sum([f(x) for f in fs.values()]) / len(fs) for x in xs ],\n",
    "        mode='lines', line=dict(color=colors[setting]), name=setting\n",
    "    ))\n",
    "\n",
    "fig.update_xaxes(range=(0,300))\n",
    "fig.update_yaxes(range=(-.15, 1))\n",
    "\n",
    "fig.update_layout(\n",
    "    xaxis_title='Nodes',\n",
    "    yaxis_title='Faithfulness',\n",
    "    width=800,\n",
    "    height=375,\n",
    "    # set white background color\n",
    "    plot_bgcolor='rgba(0,0,0,0)',\n",
    "    # add grey gridlines\n",
    "    yaxis=dict(gridcolor='rgb(200,200,200)',mirror=True,ticks='outside',showline=True),\n",
    "    xaxis=dict(gridcolor='rgb(200,200,200)', mirror=True, ticks='outside', showline=True),\n",
    ")\n",
    "# fig.show()\n",
    "fig.write_image('completeness.pdf')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
