{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d00803dd-6fa3-4f32-9b88-254c50c54c1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import yaml\n",
    "with open('global_config.yaml') as global_stream:\n",
    "    global_cfg = yaml.safe_load(global_stream)\n",
    "CACHE_DIR = global_cfg['CACHE_DIR']\n",
    "import os\n",
    "os.environ['HF_HOME'] = CACHE_DIR"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f442e504-0c85-4cb5-bb5a-3711c0726562",
   "metadata": {},
   "source": [
    "# Sparse Probing\n",
    "\n",
    "Eval just one SAE first.\n",
    "\n",
    "To get activations, just use `sparse_probing.main.get_dataset_activations(dataset_name, config, model, batch size, stop_layer, hook_point, device)`. \n",
    "Then you have two sets of activations you should train on.\n",
    "1. baseline: `activation_collection.create_meaned_model_activations(above)`  (this just uses the model activations)\n",
    "- you can try with and without taking top k here\n",
    "2. actual sae: `activation_collection.get_sae_meaned_activations(above, sae, sae batch size)`, move everything to cpu\n",
    "And then that's it, you can just plug into the rest now.\n",
    "\n",
    "The probe call does a few things (we're just gonna use sklearn):\n",
    "```\n",
    "for \"profession\" in train_activations.keys():\n",
    "    prepare_probe_data([train,test]_activations, profession, False)\n",
    "```\n",
    "\n",
    "To get top k: \n",
    "- `get_top_k_mean_diff_mask(train_acts, train_labels, k)` returns a mask\n",
    "- `apply_topk_mask_reduce_dim(acts, mask)` gives you the juices\n",
    "- For us, we can apply the same mask in 160m that we did in 70m\n",
    "\n",
    "Probing: `train_sklearn_probe(train_acts, train_labels, test_acts, test_labels)` returns probe and test accuracy\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fbd6ab7d-8db6-4f05-9cab-20084f617546",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformer_lens import HookedTransformer\n",
    "from sae_lens import SAE\n",
    "from sae import Sae\n",
    "import torch\n",
    "from stitching.sae_utils import convert_eleuther_sae_to_BaseSAE\n",
    "# SAEBench stuff\n",
    "from sae_bench.evals.sparse_probing.eval_config import SparseProbingEvalConfig\n",
    "from sae_bench.evals.sparse_probing.main import get_dataset_activations\n",
    "import sae_bench.sae_bench_utils.activation_collection as activation_collection\n",
    "from sae_bench.evals.sparse_probing.probe_training import prepare_probe_data, get_top_k_mean_diff_mask, apply_topk_mask_reduce_dim, train_sklearn_probe\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c30826bd-b684-4430-ba68-f65f69b249d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = SparseProbingEvalConfig(model_name='pythia-70m-deduped')\n",
    "device = 'cuda'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58a3288e-4d88-4f9a-a52e-84ac0eb8a794",
   "metadata": {},
   "outputs": [],
   "source": [
    "config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af78f093-a886-4ad9-b618-3fa342f12b16",
   "metadata": {},
   "outputs": [],
   "source": [
    "config.llm_batch_size = 20\n",
    "config.sae_batch_size = 64\n",
    "config.lower_vram_usage = True\n",
    "config.k_values = [1,2,5,10,20,50]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a40c45cd-8ad0-4ad8-a91f-7f0be153dd12",
   "metadata": {},
   "outputs": [],
   "source": [
    "pythia_70m = HookedTransformer.from_pretrained('pythia-70m-deduped', cache_dir=CACHE_DIR, device=device)\n",
    "layer_70m = 3\n",
    "pythia_160m = HookedTransformer.from_pretrained('pythia-160m-deduped', cache_dir=CACHE_DIR, device=device)\n",
    "layer_160m = 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a027023-3916-49d6-a51b-41e66d90f339",
   "metadata": {},
   "outputs": [],
   "source": [
    "sae = Sae.load_from_hub(\"EleutherAI/sae-pythia-70m-deduped-32k\", hookpoint=f\"layers.{layer_70m-1}\").to(device)\n",
    "sae = convert_eleuther_sae_to_BaseSAE(sae)\n",
    "sae.dtype = sae.W_enc.dtype\n",
    "sae.device = sae.W_enc.device\n",
    "sae.normalize_decoder_vectors()\n",
    "sae.get_rid_of_decoder_sub()\n",
    "sae_160m = Sae.load_from_hub(\"EleutherAI/sae-pythia-160m-deduped-32k\", hookpoint=f\"layers.{layer_160m-1}\").to(device)\n",
    "sae_160m = convert_eleuther_sae_to_BaseSAE(sae_160m)\n",
    "sae_160m.dtype = sae_160m.W_enc.dtype\n",
    "sae_160m.device = sae_160m.W_enc.device\n",
    "sae_160m.normalize_decoder_vectors()\n",
    "sae_160m.get_rid_of_decoder_sub()\n",
    "#sae, cfg_dict, _ = SAE.from_pretrained(\n",
    "#    release = \"pythia-70m-deduped-res-sm\", # see other options in sae_lens/pretrained_saes.yaml\n",
    "#    sae_id = f\"blocks.{layer_70m-1}.hook_resid_post\", # won't always be a hook point\n",
    "#    device = device\n",
    "#)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9aeb2210-86ef-48ea-88f1-5ed4df897689",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import accuracy_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a87fa39-869b-4d38-9b7d-fa772fc0acad",
   "metadata": {},
   "outputs": [],
   "source": [
    "def apply_existing_probes_and_masks(\n",
    "    activations,\n",
    "    custom_masks,\n",
    "    custom_probes\n",
    "):\n",
    "    test_accuracies = {}\n",
    "    for profession in activations.keys():\n",
    "        acts, labels = prepare_probe_data(activations, profession, False)\n",
    "        acts = apply_topk_mask_reduce_dim(acts, custom_masks[profession])\n",
    "        acts_np = acts.float().cpu().numpy()\n",
    "        labels_np = labels.float().cpu().numpy()\n",
    "        test_accuracy = accuracy_score(labels_np, custom_probes[profession].predict(acts_np))\n",
    "        test_accuracies[profession] = test_accuracy\n",
    "    return test_accuracies\n",
    "\n",
    "def custom_train_probe(\n",
    "    train_activations,\n",
    "    test_activations,\n",
    "    custom_masks = None,\n",
    "    select_top_k=None,\n",
    "    verbose: bool = False,\n",
    "    l1_penalty = None,\n",
    "):\n",
    "    \"\"\"Train a probe on the given activations and return the probe and test accuracies for each profession.\n",
    "    use_sklearn is a flag to use sklearn's LogisticRegression model instead of a custom PyTorch model.\n",
    "    We use sklearn by default. probe training on GPU is only for training a probe on all SAE features.\n",
    "    \"\"\"\n",
    "    probes, test_accuracies, masks = {}, {}, {}\n",
    "\n",
    "    for profession in train_activations.keys():\n",
    "        train_acts, train_labels = prepare_probe_data(train_activations, profession, False)\n",
    "        test_acts, test_labels = prepare_probe_data(test_activations, profession, False)\n",
    "        print(train_acts.shape, test_acts.shape)\n",
    "\n",
    "        if custom_masks is not None:\n",
    "            if verbose:\n",
    "                print(\"Mask provided, using this instead of select_top_k value\")\n",
    "            train_acts = apply_topk_mask_reduce_dim(train_acts, custom_masks[profession])\n",
    "            test_acts = apply_topk_mask_reduce_dim(test_acts, custom_masks[profession])\n",
    "            masks = custom_masks\n",
    "        elif select_top_k is not None:\n",
    "            activation_mask_D = get_top_k_mean_diff_mask(train_acts, train_labels, select_top_k)\n",
    "            train_acts = apply_topk_mask_reduce_dim(train_acts, activation_mask_D)\n",
    "            test_acts = apply_topk_mask_reduce_dim(test_acts, activation_mask_D)\n",
    "            masks[profession] = activation_mask_D\n",
    "        else:\n",
    "            masks[profession] = None\n",
    "        activation_dim = train_acts.shape[1]\n",
    "        if verbose:\n",
    "            print(f\"Num non-zero elements: {activation_dim}\")\n",
    "\n",
    "        probe, test_accuracy = train_sklearn_probe(\n",
    "            train_acts,\n",
    "            train_labels,\n",
    "            test_acts,\n",
    "            test_labels,\n",
    "            verbose=False,\n",
    "        )\n",
    "        if verbose:\n",
    "            print(f\"Test accuracy for {profession}: {test_accuracy}\")\n",
    "\n",
    "        probes[profession] = probe\n",
    "        test_accuracies[profession] = test_accuracy\n",
    "\n",
    "    return probes, test_accuracies, masks\n",
    "\n",
    "def PCA_probe(\n",
    "    train_activations,\n",
    "    test_activations,\n",
    "    select_top_k=None,\n",
    "    verbose: bool = False,\n",
    "    l1_penalty = None,\n",
    "):\n",
    "    \"\"\"Train a probe on the given activations and return the probe and test accuracies for each profession.\n",
    "    use_sklearn is a flag to use sklearn's LogisticRegression model instead of a custom PyTorch model.\n",
    "    We use sklearn by default. probe training on GPU is only for training a probe on all SAE features.\n",
    "    \"\"\"\n",
    "    probes, test_accuracies, masks = {}, {}, {}\n",
    "\n",
    "    for profession in train_activations.keys():\n",
    "        train_acts, train_labels = prepare_probe_data(train_activations, profession, False)\n",
    "        test_acts, test_labels = prepare_probe_data(test_activations, profession, False)\n",
    "        if select_top_k is not None:\n",
    "            activation_mask_D = torch.ones(train_acts.shape[1], dtype=torch.bool, device=train_acts.device)\n",
    "            activation_mask_D[:select_top_k] = False\n",
    "            train_acts = apply_topk_mask_reduce_dim(train_acts, activation_mask_D)\n",
    "            test_acts = apply_topk_mask_reduce_dim(test_acts, activation_mask_D)\n",
    "            masks[profession] = activation_mask_D\n",
    "        else:\n",
    "            masks[profession] = None\n",
    "        activation_dim = train_acts.shape[1]\n",
    "        if verbose:\n",
    "            print(f\"Num non-zero elements: {activation_dim}\")\n",
    "\n",
    "        probe, test_accuracy = train_sklearn_probe(\n",
    "            train_acts,\n",
    "            train_labels,\n",
    "            test_acts,\n",
    "            test_labels,\n",
    "            verbose=False,\n",
    "        )\n",
    "        if verbose:\n",
    "            print(f\"Test accuracy for {profession}: {test_accuracy}\")\n",
    "\n",
    "        probes[profession] = probe\n",
    "        test_accuracies[profession] = test_accuracy\n",
    "\n",
    "    return probes, test_accuracies, masks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67addb2c-2c00-41ae-9838-a5b30a349949",
   "metadata": {},
   "outputs": [],
   "source": [
    "def average_test_accuracy(test_accuracies: dict[str, float]) -> float:\n",
    "    return sum(test_accuracies.values()) / len(test_accuracies)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b0571ed-5713-40bf-9b3c-a288243c0f6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d304fe4-c2ec-4f4c-be9c-79f8f42311c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cdf116af-6663-41b3-8f3f-f0ba574818ee",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "results_dict = {}\n",
    "masks_dict = {}\n",
    "probes_dict = {}\n",
    "for dataset_name in config.dataset_names:\n",
    "    print(\"Running dataset\", dataset_name)\n",
    "    all_train_acts_BLD, all_test_acts_BLD = get_dataset_activations(\n",
    "        dataset_name,\n",
    "        config,\n",
    "        pythia_70m,\n",
    "        config.llm_batch_size,\n",
    "        layer_70m,\n",
    "        f\"blocks.{layer_70m-1}.hook_resid_post\",\n",
    "        'cuda'\n",
    "    )\n",
    "    all_train_acts_BD = activation_collection.create_meaned_model_activations(\n",
    "        all_train_acts_BLD\n",
    "    )\n",
    "    dataset_results_dict = {}\n",
    "    masks_results_dict = {}\n",
    "    probes_results_dict = {}\n",
    "    all_test_acts_BD = activation_collection.create_meaned_model_activations(all_test_acts_BLD)\n",
    "    llm_probes, llm_test_accuracies, _ = custom_train_probe(\n",
    "        all_train_acts_BD,\n",
    "        all_test_acts_BD,\n",
    "        select_top_k=None,\n",
    "        verbose=False\n",
    "    )\n",
    "    dataset_results_dict['llm_accuracy'] = average_test_accuracy(\n",
    "        llm_test_accuracies\n",
    "    )\n",
    "    all_sae_train_acts_BF = activation_collection.get_sae_meaned_activations(\n",
    "        all_train_acts_BLD, sae, config.sae_batch_size\n",
    "    )\n",
    "    all_sae_test_acts_BF = activation_collection.get_sae_meaned_activations(\n",
    "        all_test_acts_BLD, sae, config.sae_batch_size\n",
    "    )\n",
    "    for key in all_sae_train_acts_BF.keys():\n",
    "        all_sae_train_acts_BF[key] = all_sae_train_acts_BF[key].cpu()\n",
    "        all_sae_test_acts_BF[key] = all_sae_test_acts_BF[key].cpu()\n",
    "    #print(all_train_acts_BLD.shape)\n",
    "    #print(all_test_acts_BLD.shape)\n",
    "    print(\"Training probes.\")\n",
    "    for k in tqdm(config.k_values):\n",
    "        sae_top_k_probes, sae_top_k_test_accuracies, sae_top_k_masks = custom_train_probe(\n",
    "            all_sae_train_acts_BF,\n",
    "            all_sae_test_acts_BF,\n",
    "            select_top_k=k,\n",
    "            verbose=False\n",
    "        )\n",
    "        random_masks = {\n",
    "            key: torch.ones_like(value) for (key,value) in sae_top_k_masks.items()\n",
    "        }\n",
    "        for (key,value) in random_masks.items():\n",
    "            indices = np.random.choice(value.shape[0], k, replace=False)\n",
    "            value[indices] = 0\n",
    "            random_masks[key] = value\n",
    "            \n",
    "        random_sae_top_k_probes, random_sae_top_k_test_accuracies, _ = custom_train_probe(\n",
    "            all_sae_train_acts_BF,\n",
    "            all_sae_test_acts_BF,\n",
    "            custom_masks = random_masks,\n",
    "            select_top_k=None,\n",
    "            verbose=False\n",
    "        )\n",
    "        \n",
    "        dataset_results_dict[f\"sae_top_{k}_test_accuracy\"] = average_test_accuracy(\n",
    "            sae_top_k_test_accuracies\n",
    "        )\n",
    "        dataset_results_dict[f\"random_top_{k}_test_accuracy\"] = average_test_accuracy(\n",
    "            random_sae_top_k_test_accuracies\n",
    "        )\n",
    "        masks_results_dict[f\"sae_top_{k}_masks\"] = sae_top_k_masks\n",
    "        probes_results_dict[f\"sae_top_{k}_probes\"] = sae_top_k_probes\n",
    "        \n",
    "    results_dict[dataset_name] = dataset_results_dict\n",
    "    masks_dict[dataset_name] = masks_results_dict\n",
    "    probes_dict[dataset_name] = probes_results_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e11e559-e7fe-42a9-90ff-952291e4518a",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "results_dict  #,# masks_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8bc71795-9c50-4df4-84c3-38c3fff37603",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d807858a-6d10-4d96-81b8-70efad31abeb",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('sparse_probing/eleuther/70m_results_dict.json', 'w') as file:\n",
    "    json.dump(results_dict, file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab364608-2c8b-4610-96ee-4af1eb96cb16",
   "metadata": {},
   "outputs": [],
   "source": [
    "from stitching.stitching_utils import open_experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b511e04-35cf-4498-a307-e84b55c9da68",
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_name = 'tough_forest_11'\n",
    "project_name = f\"stitch_training_{pythia_70m.cfg.model_name}_to_{pythia_160m.cfg.model_name}_bidirectional_mse\"\n",
    "checkpoints_dir = os.path.join('checkpoints/', f\"{project_name}/\")\n",
    "P, Pinv, beta, bias, biasinv = open_experiment(512, 768, checkpoints_dir, exp_name, biases=True, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4344e1e8-cd3e-4ee8-9d73-fd11285d7b78",
   "metadata": {},
   "outputs": [],
   "source": [
    "transfer_results_dict = {}\n",
    "# this assumes you already have probes and masks from before\n",
    "for dataset_name in config.dataset_names:\n",
    "    print(\"Running dataset\", dataset_name)\n",
    "    all_train_acts_BLD, all_test_acts_BLD = get_dataset_activations(\n",
    "        dataset_name,\n",
    "        config,\n",
    "        pythia_160m,\n",
    "        config.llm_batch_size,\n",
    "        layer_160m,\n",
    "        f\"blocks.{layer_160m-1}.hook_resid_post\",\n",
    "        'cuda'\n",
    "    )\n",
    "\n",
    "    # get activations here, then SAE them.\n",
    "    transferred_all_train_acts_BLD = {}\n",
    "    transferred_all_test_acts_BLD = {}\n",
    "    for key, tensor in all_train_acts_BLD.items():\n",
    "        transferred_all_train_acts_BLD[key] = tensor @ Pinv + biasinv # transfer tensor\n",
    "\n",
    "    for key, tensor in all_test_acts_BLD.items():\n",
    "        transferred_all_test_acts_BLD[key] = tensor @ Pinv + biasinv # transfer tensor\n",
    "\n",
    "    # use transferred\n",
    "    transferred_all_sae_train_acts_BF = activation_collection.get_sae_meaned_activations(\n",
    "        transferred_all_train_acts_BLD, sae, config.sae_batch_size\n",
    "    )\n",
    "    transferred_all_sae_test_acts_BF = activation_collection.get_sae_meaned_activations(\n",
    "        transferred_all_test_acts_BLD, sae, config.sae_batch_size\n",
    "    )\n",
    "    for key in transferred_all_sae_train_acts_BF.keys():\n",
    "        transferred_all_sae_train_acts_BF[key] = transferred_all_sae_train_acts_BF[key].cpu()\n",
    "        transferred_all_sae_test_acts_BF[key] = transferred_all_sae_test_acts_BF[key].cpu()\n",
    "    dataset_results_dict = {}\n",
    "    print(\"Applying existing probes.\")\n",
    "    for k in tqdm(config.k_values):\n",
    "        sae_custom_masks = masks_dict[dataset_name][f\"sae_top_{k}_masks\"]\n",
    "        sae_custom_probes = probes_dict[dataset_name][f\"sae_top_{k}_probes\"]\n",
    "        test_accuracies = apply_existing_probes_and_masks(\n",
    "            transferred_all_sae_test_acts_BF,\n",
    "            sae_custom_masks,\n",
    "            sae_custom_probes\n",
    "        )\n",
    "        dataset_results_dict[f\"sae_top_{k}_test_accuracy\"] = average_test_accuracy(\n",
    "            test_accuracies\n",
    "        )\n",
    "        # masks_results_dict[f\"sae_top_{k}_masks\"] = sae_top_k_masks\n",
    "    transfer_results_dict[dataset_name] = dataset_results_dict\n",
    "    # masks_dict[dataset_name] = masks_results_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3af4ffdf-548c-4511-9fa4-c4ec9f754280",
   "metadata": {},
   "outputs": [],
   "source": [
    "transfer_results_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b17c180f-1d76-4e2d-a5a1-46f635c24b12",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('sparse_probing/eleuther/160m_no_retrain_stitch_results_dict_no_inverse_penalty.json', 'w') as file:\n",
    "    json.dump(transfer_results_dict, file)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4f9b04d-d193-4dd5-8d1e-20e64d65816e",
   "metadata": {},
   "source": [
    "# Sparse Probing on Transfer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "712bc9f4-c8f5-41eb-b73c-6f64821f1c61",
   "metadata": {},
   "outputs": [],
   "source": [
    "from stitching.stitching_utils import open_experiment, load_activation_store\n",
    "from stitching.sae_utils import BaseSAE, forward_modified, generate_modified, max_csim_transfer_to_orig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb2bd1d1-b219-4591-b526-15d41e7e1cb6",
   "metadata": {},
   "outputs": [],
   "source": [
    "pythia_160m = HookedTransformer.from_pretrained('pythia-160m-deduped', cache_dir=CACHE_DIR, device=device)\n",
    "layer_160m = 4\n",
    "exp_name = 'tough-forest-11'\n",
    "project_name = f\"stitch_training_{model_cfg['model_a_name']}_to_{model_cfg['model_b_name']}_bidirectional_mse\"\n",
    "checkpoints_dir = os.path.join('checkpoints/', f\"{project_name}/\")\n",
    "#run_name_id = 'tough-forest-11'\n",
    "P, Pinv, beta, bias, biasinv = open_experiment(512, 768, checkpoints_dir, exp_name, device=device, biases=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e9f2699-bdde-4f77-93f7-29ae964f6d8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "orig_sae = BaseSAE(\n",
    "    sae.W_enc.detach().clone(),\n",
    "    sae.W_dec.detach().clone(),\n",
    "    sae.b_enc.detach().clone(),\n",
    "    sae.b_dec.detach().clone(),\n",
    "    sae.activation_fn,\n",
    "    apply_b_dec=sae.apply_b_dec\n",
    ")\n",
    "orig_sae.normalize_decoder_vectors()\n",
    "orig_sae.get_rid_of_decoder_sub()\n",
    "transferred_sae = BaseSAE(\n",
    "    Pinv @ orig_sae.W_enc.detach().clone(),\n",
    "    orig_sae.W_dec.detach().clone() @ P,\n",
    "    orig_sae.b_enc.detach().clone() + biasinv @ orig_sae.W_enc.detach().clone(),\n",
    "    orig_sae.b_dec.detach().clone() @ P + bias,\n",
    "    orig_sae.activation_fn,\n",
    "    apply_b_dec=False\n",
    ")\n",
    "transferred_sae.normalize_decoder_vectors()\n",
    "transferred_sae.dtype = transferred_sae.W_enc.dtype\n",
    "transferred_sae.device = transferred_sae.W_enc.device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17a052c6-0f51-4e74-b79f-ed3357e29781",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "retrain_transfer_results_dict = {}\n",
    "for dataset_name in config.dataset_names:\n",
    "    print(\"Running dataset for transfer\", dataset_name)\n",
    "    all_train_acts_BLD, all_test_acts_BLD = get_dataset_activations(\n",
    "        dataset_name,\n",
    "        config,\n",
    "        pythia_160m,\n",
    "        config.llm_batch_size,\n",
    "        layer_160m,\n",
    "        f\"blocks.{layer_160m-1}.hook_resid_post\",\n",
    "        'cuda'\n",
    "    )\n",
    "    all_train_acts_BD = activation_collection.create_meaned_model_activations(\n",
    "        all_train_acts_BLD\n",
    "    )\n",
    "    dataset_results_dict = {}\n",
    "    all_test_acts_BD = activation_collection.create_meaned_model_activations(all_test_acts_BLD)\n",
    "    llm_probes, llm_test_accuracies, _ = custom_train_probe(\n",
    "        all_train_acts_BD,\n",
    "        all_test_acts_BD,\n",
    "        select_top_k=None,\n",
    "        verbose=False\n",
    "    )\n",
    "    dataset_results_dict['llm_accuracy'] = average_test_accuracy(\n",
    "        llm_test_accuracies\n",
    "    )\n",
    "    all_sae_train_acts_BF = activation_collection.get_sae_meaned_activations(\n",
    "        all_train_acts_BLD, transferred_sae, config.sae_batch_size\n",
    "    )\n",
    "    all_sae_test_acts_BF = activation_collection.get_sae_meaned_activations(\n",
    "        all_test_acts_BLD, transferred_sae, config.sae_batch_size\n",
    "    )\n",
    "    for key in all_sae_train_acts_BF.keys():\n",
    "        all_sae_train_acts_BF[key] = all_sae_train_acts_BF[key].cpu()\n",
    "        all_sae_test_acts_BF[key] = all_sae_test_acts_BF[key].cpu()\n",
    "\n",
    "    for k in config.k_values:\n",
    "        sae_custom_masks = masks_dict[dataset_name][f\"sae_top_{k}_masks\"]\n",
    "        sae_top_k_probes, sae_top_k_test_accuracies, _ = custom_train_probe(\n",
    "            all_sae_train_acts_BF,\n",
    "            all_sae_test_acts_BF,\n",
    "            custom_masks = sae_custom_masks,\n",
    "            select_top_k=None,\n",
    "            verbose=False\n",
    "        )\n",
    "        dataset_results_dict[f\"sae_top_{k}_test_accuracy\"] = average_test_accuracy(\n",
    "            sae_top_k_test_accuracies\n",
    "        )\n",
    "\n",
    "    retrain_transfer_results_dict[dataset_name] = dataset_results_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ddd1b40-568e-4588-b596-94adda8c6e0a",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "retrain_transfer_results_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a460433-885f-4fcb-bf81-25a25fc3c81a",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('sparse_probing/eleuther/160m_retrain_stitch_results_dict_no_inverse_penalty.json', 'w') as file:\n",
    "    json.dump(retrain_transfer_results_dict, file)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6156575c-6b5f-414a-be9f-31b72ed963b7",
   "metadata": {},
   "source": [
    "# Finally, 160m scratch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5aced988-a0e1-4785-8dec-a2c79cc3a53b",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_dict_160m = {}\n",
    "masks_dict = {}\n",
    "probes_dict = {}\n",
    "for dataset_name in config.dataset_names:\n",
    "    print(\"Running dataset\", dataset_name)\n",
    "    all_train_acts_BLD, all_test_acts_BLD = get_dataset_activations(\n",
    "        dataset_name,\n",
    "        config,\n",
    "        pythia_160m,\n",
    "        config.llm_batch_size,\n",
    "        layer_160m,\n",
    "        f\"blocks.{layer_160m-1}.hook_resid_post\",\n",
    "        'cuda'\n",
    "    )\n",
    "    all_train_acts_BD = activation_collection.create_meaned_model_activations(\n",
    "        all_train_acts_BLD\n",
    "    )\n",
    "    dataset_results_dict = {}\n",
    "    masks_results_dict = {}\n",
    "    probes_results_dict = {}\n",
    "    all_test_acts_BD = activation_collection.create_meaned_model_activations(all_test_acts_BLD)\n",
    "    llm_probes, llm_test_accuracies, _ = custom_train_probe(\n",
    "        all_train_acts_BD,\n",
    "        all_test_acts_BD,\n",
    "        select_top_k=None,\n",
    "        verbose=False\n",
    "    )\n",
    "    dataset_results_dict['llm_accuracy'] = average_test_accuracy(\n",
    "        llm_test_accuracies\n",
    "    )\n",
    "    all_sae_train_acts_BF = activation_collection.get_sae_meaned_activations(\n",
    "        all_train_acts_BLD, sae_160m, config.sae_batch_size\n",
    "    )\n",
    "    all_sae_test_acts_BF = activation_collection.get_sae_meaned_activations(\n",
    "        all_test_acts_BLD, sae_160m, config.sae_batch_size\n",
    "    )\n",
    "    for key in all_sae_train_acts_BF.keys():\n",
    "        all_sae_train_acts_BF[key] = all_sae_train_acts_BF[key].cpu()\n",
    "        all_sae_test_acts_BF[key] = all_sae_test_acts_BF[key].cpu()\n",
    "\n",
    "    print(\"Training probes.\")\n",
    "    for k in tqdm(config.k_values):\n",
    "        sae_top_k_probes, sae_top_k_test_accuracies, sae_top_k_masks = custom_train_probe(\n",
    "            all_sae_train_acts_BF,\n",
    "            all_sae_test_acts_BF,\n",
    "            select_top_k=k,\n",
    "            verbose=False\n",
    "        )\n",
    "        random_masks = {\n",
    "            key: torch.ones_like(value) for (key,value) in sae_top_k_masks.items()\n",
    "        }\n",
    "        for (key,value) in random_masks.items():\n",
    "            indices = np.random.choice(value.shape[0], k, replace=False)\n",
    "            value[indices] = 0\n",
    "            random_masks[key] = value\n",
    "            \n",
    "        random_sae_top_k_probes, random_sae_top_k_test_accuracies, _ = custom_train_probe(\n",
    "            all_sae_train_acts_BF,\n",
    "            all_sae_test_acts_BF,\n",
    "            custom_masks = random_masks,\n",
    "            select_top_k=None,\n",
    "            verbose=False\n",
    "        )\n",
    "        \n",
    "        dataset_results_dict[f\"sae_top_{k}_test_accuracy\"] = average_test_accuracy(\n",
    "            sae_top_k_test_accuracies\n",
    "        )\n",
    "        dataset_results_dict[f\"random_top_{k}_test_accuracy\"] = average_test_accuracy(\n",
    "            random_sae_top_k_test_accuracies\n",
    "        )\n",
    "        masks_results_dict[f\"sae_top_{k}_masks\"] = sae_top_k_masks\n",
    "        probes_results_dict[f\"sae_top_{k}_probes\"] = sae_top_k_probes\n",
    "        \n",
    "    results_dict_160m[dataset_name] = dataset_results_dict\n",
    "    masks_dict[dataset_name] = masks_results_dict\n",
    "    probes_dict[dataset_name] = probes_results_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d4c962b-8b46-449c-999a-e6d5e196cb65",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_dict_160m"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8503213a-703f-4c99-8a76-253374a57abd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66f09b6f-80fb-4b35-b83e-7032be374238",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('sparse_probing/eleuther/160m_results_dict.json', 'w') as file:\n",
    "    json.dump(results_dict_160m, file)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
