{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2fa1c1a-cc55-4114-a777-bbd132045d12",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sae_lens import SAE\n",
    "from sae import Sae\n",
    "import torch\n",
    "import os\n",
    "import json\n",
    "from transformer_lens import HookedTransformer, FactoredMatrix\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import scipy\n",
    "from scipy.stats import spearmanr, pearsonr\n",
    "\n",
    "from stitching.losses import next_token_cross_entropy_loss, kl_div_loss\n",
    "from stitching.stitching_utils import open_experiment, load_activation_store\n",
    "from stitching.sae_utils import BaseSAE, forward_modified, generate_modified, max_csim_transfer_to_orig, convert_eleuther_sae_to_BaseSAE,argmax_csim_for_subset\n",
    "from stitching.generic_experiments import *\n",
    "from stitching.sae_utils import feature_activations, precision_recall_f1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "712fc154-d7a8-443d-9984-f5ec8f289959",
   "metadata": {},
   "outputs": [],
   "source": [
    "TRUNCATION_LENGTH = 512\n",
    "import yaml\n",
    "with open('global_config.yaml') as global_stream:\n",
    "    global_cfg = yaml.safe_load(global_stream)\n",
    "CACHE_DIR = global_cfg['CACHE_DIR']\n",
    "modelA_name = 'pythia-70m-deduped'\n",
    "modelB_name = 'pythia-160m-deduped'\n",
    "tokenized_dataset = {}\n",
    "for dataset_key in ['train', 'test']:\n",
    "    tokenized_dataset[dataset_key] = torch.load(f'data/{modelA_name}_tokenized_dataset_200000_{dataset_key}_{TRUNCATION_LENGTH}.pt', weights_only=True)\n",
    "device = 'cuda'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "123ce8c3-597a-4ad5-a220-6a090fe0b6fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "modelA = HookedTransformer.from_pretrained(modelA_name, cache_dir=CACHE_DIR, device=device)\n",
    "modelB = HookedTransformer.from_pretrained(modelB_name, cache_dir=CACHE_DIR, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e4cadda-edc8-48c0-b6fc-00a0c0c46703",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set hyperparameters for experiment\n",
    "layer_A = 3\n",
    "layer_B = 4\n",
    "checkpoints_dir = f\"checkpoints/stitch_training_{modelA_name}_to_{modelB_name}_bidirectional_mse/\"\n",
    "run_name_id = 'tough-forest-11'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02d4d92a-1cef-4513-aa7a-0050ad7d3bc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "P, Pinv, beta, bias, biasinv = open_experiment(modelA.cfg.d_model, modelB.cfg.d_model, checkpoints_dir, run_name_id, biases=True,device=device)\n",
    "sae_A = Sae.load_from_hub(\"EleutherAI/sae-pythia-70m-deduped-32k\", hookpoint=f\"layers.{layer_A-1}\").to(device)\n",
    "sae_A = convert_eleuther_sae_to_BaseSAE(sae_A)\n",
    "sae_B = Sae.load_from_hub(\"EleutherAI/sae-pythia-160m-deduped-32k\", hookpoint=f\"layers.{layer_B-1}\").to(device)\n",
    "sae_B = convert_eleuther_sae_to_BaseSAE(sae_B)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e836023-a378-4f49-a5ed-26b4163b795c",
   "metadata": {},
   "outputs": [],
   "source": [
    "if 'BaseSAE' in str(type(sae_A)):\n",
    "    apply_b_dec = sae_A.apply_b_dec\n",
    "else:\n",
    "    apply_b_dec = sae_A.cfg.apply_b_dec_to_input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "491a8b7e-ba42-4c39-bb77-76cc0d3eedb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "orig_sae_A = BaseSAE(\n",
    "    sae_A.W_enc.detach().clone(),\n",
    "    sae_A.W_dec.detach().clone(),\n",
    "    sae_A.b_enc.detach().clone(),\n",
    "    sae_A.b_dec.detach().clone(),\n",
    "    sae_A.activation_fn,\n",
    "    apply_b_dec=apply_b_dec\n",
    ")\n",
    "orig_sae_A.normalize_decoder_vectors()\n",
    "orig_sae_A.get_rid_of_decoder_sub()\n",
    "\n",
    "transferred_sae_B = BaseSAE(\n",
    "    Pinv @ orig_sae_A.W_enc.detach().clone(),\n",
    "    orig_sae_A.W_dec.detach().clone() @ P,\n",
    "    orig_sae_A.b_enc.detach().clone() + biasinv @ orig_sae_A.W_enc.detach().clone(),\n",
    "    orig_sae_A.b_dec.detach().clone() @ P + bias,\n",
    "    orig_sae_A.activation_fn,\n",
    "    apply_b_dec=False\n",
    ")\n",
    "transferred_sae_B.normalize_decoder_vectors()\n",
    "\n",
    "\n",
    "orig_sae_B = BaseSAE(\n",
    "    sae_B.W_enc.detach().clone(),\n",
    "    sae_B.W_dec.detach().clone(),\n",
    "    sae_B.b_enc.detach().clone(),\n",
    "    sae_B.b_dec.detach().clone(),\n",
    "    sae_B.activation_fn,\n",
    "    apply_b_dec=apply_b_dec\n",
    ")\n",
    "orig_sae_B.normalize_decoder_vectors()\n",
    "orig_sae_B.get_rid_of_decoder_sub()\n",
    "\n",
    "transferred_sae_A = BaseSAE(\n",
    "    P @ orig_sae_B.W_enc.detach().clone(),\n",
    "    orig_sae_B.W_dec.detach().clone() @ Pinv,\n",
    "    orig_sae_B.b_enc.detach().clone() + bias @ orig_sae_B.W_enc.detach().clone(),\n",
    "    orig_sae_B.b_dec.detach().clone() @ Pinv + biasinv,\n",
    "    orig_sae_B.activation_fn,\n",
    "    apply_b_dec=False\n",
    ")\n",
    "transferred_sae_A.normalize_decoder_vectors()\n",
    "#crosscoder = BaseSAE(\n",
    "#    orig_sae.W_enc.detach().clone(),\n",
    "#    orig_sae.W_dec.detach().clone() @ P,\n",
    "#    orig_sae.b_enc.detach().clone(),\n",
    "#    orig_sae.b_dec.detach().clone() @ P,\n",
    "#    orig_sae.activation_fn,\n",
    "#    apply_b_dec=False\n",
    "#)\n",
    "#crosscoder.normalize_decoder_vectors()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fbcb7d0e-5c1c-4aa1-a7ab-b718c20d409b",
   "metadata": {},
   "source": [
    "Create baseline SAEs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1866dfa-9d6b-4618-b64c-e2933b7153de",
   "metadata": {},
   "outputs": [],
   "source": [
    "random_sae = BaseSAE(\n",
    "    torch.nn.init.kaiming_uniform_(torch.zeros((modelB.cfg.d_model, orig_sae_A.d_sae), requires_grad=False)),\n",
    "    torch.nn.init.kaiming_uniform_(torch.zeros((orig_sae_A.d_sae, modelB.cfg.d_model), requires_grad=False)),\n",
    "    torch.nn.init.kaiming_uniform_(torch.zeros((1,orig_sae_A.d_sae), requires_grad=False)).flatten(),\n",
    "    torch.nn.init.kaiming_uniform_(torch.zeros((1,modelB.cfg.d_model), requires_grad=False)).flatten(),\n",
    "    sae_A.activation_fn\n",
    ")\n",
    "random_sae = random_sae.to(device)\n",
    "random_sae.normalize_decoder_vectors()\n",
    "random_sae.get_rid_of_decoder_sub()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5080bee3-ddc6-4267-a275-601ad118ac15",
   "metadata": {},
   "source": [
    "# Bottleneck Experiments / Stitching Evals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f30a0ce-ce60-42d7-a045-5f27e7bd1aaf",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader = torch.utils.data.DataLoader(\n",
    "    tokenized_dataset['test'][:1000],\n",
    "    batch_size=10,\n",
    "    shuffle=False\n",
    ")\n",
    "models = {\n",
    "    'A': (layer_A, modelA),\n",
    "    'B': (layer_B, modelB)\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40e02e02-f1c3-41bf-be09-d84d57fbeb79",
   "metadata": {},
   "outputs": [],
   "source": [
    "baseline_results = run_multi_model_interventions(\n",
    "    dataloader,\n",
    "    models,\n",
    "    [\n",
    "        get_identity_intervention('orig_A', 'A'),\n",
    "        get_identity_intervention('orig_B', 'B'),\n",
    "        get_sae_intervention('sae_A', orig_sae_A, 'A'),\n",
    "        get_sae_intervention('sae_B', orig_sae_B, 'B'),\n",
    "        get_zero_intervention('zero_A', 'A'),\n",
    "        get_zero_intervention('zero_B', 'B'),\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94603e13-e19a-48ac-a534-37bfe03bdae6",
   "metadata": {},
   "outputs": [],
   "source": [
    "baseline_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f82ceba8-991b-42b5-8f48-781370e46e42",
   "metadata": {},
   "outputs": [],
   "source": [
    "baseline_results = {k: v.item() for (k,v) in baseline_results.items()}\n",
    "with open(f\"results/{modelA_name}_{modelB_name}_baseline_results.json\", \"w\") as file:\n",
    "    json.dump(baseline_results, file)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54a1d1bf-405b-4924-a4ff-8e40e85d0738",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## KL Divergence\n",
    "\n",
    "Measure KL divergence between (A->B) and B and A, compare to baseline of B<->A"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "850f7306-d656-4388-8e8f-16ae6cea9ec1",
   "metadata": {},
   "outputs": [],
   "source": [
    "kl_div_results = {\n",
    "    'AB_to_A': [],\n",
    "    'AB_to_B': [],\n",
    "    'BA_to_A': [],\n",
    "    'BA_to_B': [],\n",
    "    'B_to_A': [],\n",
    "    'A_to_B': []\n",
    "}\n",
    "with torch.inference_mode():\n",
    "    for i, sample in tqdm(enumerate(dataloader)):\n",
    "        sample = sample.to(device)\n",
    "        intermediate = pythia_70m.forward(sample, stop_at_layer=layer_70m)\n",
    "        logits = pythia_160m.forward(intermediate @ P, start_at_layer=layer_160m)\n",
    "        other_logits = pythia_70m.forward(pythia_160m.forward(sample,stop_at_layer=layer_160m) @ Pinv, start_at_layer=layer_70m)\n",
    "        true_logits_B = pythia_160m.forward(sample)\n",
    "        true_logits_A = pythia_70m.forward(sample)\n",
    "        mask = (sample != 0)\n",
    "        kl_div_results['AB_to_A'].append(kl_div_loss(logits[mask], true_logits_A[mask], reduction='none').sum(dim=-1).cpu())\n",
    "        kl_div_results['AB_to_B'].append(kl_div_loss(logits[mask], true_logits_B[mask], reduction='none').sum(dim=-1).cpu())\n",
    "        kl_div_results['BA_to_A'].append(kl_div_loss(other_logits[mask], true_logits_A[mask], reduction='none').sum(dim=-1).cpu())\n",
    "        kl_div_results['BA_to_B'].append(kl_div_loss(other_logits[mask], true_logits_B[mask], reduction='none').sum(dim=-1).cpu())\n",
    "        kl_div_results['B_to_A'].append(kl_div_loss(true_logits_B[mask], true_logits_A[mask], reduction='none').sum(dim=-1).cpu())\n",
    "        kl_div_results['A_to_B'].append(kl_div_loss(true_logits_A[mask], true_logits_B[mask], reduction='none').sum(dim=-1).cpu())\n",
    "\n",
    "for result_key, result_tens in kl_div_results.items():\n",
    "    kl_div_results[result_key] = torch.cat(result_tens).mean().item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99e57c74-0193-4caf-a1be-1d54cc66ab4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "kl_div_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b95074f4-a714-4cce-9ab8-adb95f6d9de0",
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.Series(kl_div_results).plot()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2acbcacd-7010-4766-a98f-cc6dff5e056f",
   "metadata": {},
   "source": [
    "## Run single"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7db2805d-d13e-4920-8080-66c1f6845309",
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_results = run_multi_model_interventions(\n",
    "    dataloader,\n",
    "    models,\n",
    "    [\n",
    "        Intervention('A_to_B', lambda x: x @ P + bias, 'A', 'B'),\n",
    "        Intervention('B_to_A', lambda x: x @ Pinv + biasinv, 'B', 'A'),\n",
    "        Intervention('inverse_B_to_B', lambda x: (x @ Pinv + biasinv) @ P + bias, 'B', 'B'),\n",
    "        Intervention('inverse_A_to_A', lambda x: (x @ P + bias) @ Pinv + biasinv, 'A', 'A'),\n",
    "        get_sae_intervention('stitch_sae_A_to_B', transferred_sae_B, 'B'),\n",
    "        get_sae_intervention('stitch_sae_B_to_A', transferred_sae_A, 'A'),\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58e08ac1-08d1-4728-bf6e-89271d94f08a",
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_results = {k: v.item() for (k,v) in experiment_results.items()}\n",
    "experiment_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f6db664-0a02-4c71-bb99-4069437f4b99",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(f\"results/{run_name_id}_bottleneck_results.json\", \"w\") as file:\n",
    "    json.dump(experiment_results, file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af721715-9bc2-4bff-bd9a-767ad11d9a95",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 5))\n",
    "plt.grid(False)\n",
    "ax = sns.barplot(pd.Series({k : (v if v is not None else None) for (k,v) in experiment_results.items()}), label=f'{modelA_name}.{layer_A}.pre -> {modelB_name}.{layer_B}.pre')\n",
    "ax.axhline(baseline_results['orig_A'], label='orig_A', color='orange')\n",
    "ax.axhline(baseline_results['sae_A'], label='sae_A', linestyle='dashed', color='orange')\n",
    "ax.axhline(baseline_results['orig_B'], label='orig_B', color='red')\n",
    "ax.axhline(baseline_results['sae_B'], label='sae_B', linestyle='dashed', color='red')\n",
    "\n",
    "ax.set(xlabel='method', ylabel='(Reconstruction) CE loss') #title=f'{modelA_name}.{layer_A}.pre -> {modelB_name}.{layer_B}.pre')\n",
    "ax.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d14562c0-3bb7-413d-b9b4-7a243cec3a36",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Run multiple"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62d02a20-1856-4a49-a791-a8a0146cd515",
   "metadata": {},
   "outputs": [],
   "source": [
    "experiments = [\n",
    "    Experiment('gentle-sky-18', '$\\\\alpha = 0$', {}),\n",
    "    Experiment('likely-grass-20', \"$\\\\alpha = 0.1$\", {}),\n",
    "    Experiment('tough-forest-11', \"$\\\\alpha = 1.0$\", {}),\n",
    "    Experiment('breezy-surf-19', \"$\\\\alpha = 10.0$\", {})\n",
    "]\n",
    "\n",
    "bottlenecks = [\n",
    "    Intervention('A->B', lambda x: x @ P + bias, 'A', 'B'),\n",
    "    Intervention('B->A', lambda x: x @ Pinv + biasinv, 'B', 'A'),\n",
    "    Intervention('B->A->B', lambda x: (x @ Pinv + biasinv) @ P + bias, 'B', 'B'),\n",
    "    Intervention('A->B->A', lambda x: (x @ P + bias) @ Pinv + biasinv, 'A', 'A'),\n",
    "    # get_sae_intervention('stitch_sae_A_to_B', transferred_sae, 'B'),\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8847abd-4853-48a0-83e3-2c942642bb24",
   "metadata": {},
   "outputs": [],
   "source": [
    "for experiment in experiments:\n",
    "    P, Pinv, _, bias, biasinv = open_experiment(modelA.cfg.d_model, modelB.cfg.d_model, checkpoints_dir, experiment.run_id, biases=True, device='cuda')\n",
    "    current_sae = BaseSAE(\n",
    "        Pinv @ orig_sae_A.W_enc.detach().clone(),\n",
    "        orig_sae_A.W_dec.detach().clone() @ P,\n",
    "        orig_sae_A.b_enc.detach().clone() + biasinv @ orig_sae_A.W_enc.detach().clone(),\n",
    "        orig_sae_A.b_dec.detach().clone() @ P + bias,\n",
    "        orig_sae_A.activation_fn,\n",
    "        apply_b_dec=False\n",
    "    )\n",
    "    experiment.results = run_multi_model_interventions(\n",
    "        dataloader,\n",
    "        models,\n",
    "        bottlenecks + [get_sae_intervention('stitch_sae_A_to_B', current_sae, 'B')],\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1bdbe403-195d-4098-b028-cb62278a5f8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "collected_dataframe = pd.DataFrame({})\n",
    "for experiment in experiments:\n",
    "    collected_dataframe[experiment.method] = pd.Series({k : (v.item() if v is not None else None) for (k,v) in experiment.results.items()})\n",
    "collected_dataframe = collected_dataframe.dropna()\n",
    "collected_dataframe = collected_dataframe.reset_index(names='eval')\n",
    "melted_df = pd.melt(collected_dataframe, id_vars=\"eval\", var_name=\"method\", value_name=\"CE loss\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33548fdb-7f1a-4502-8bc0-8f1ca9c781c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "collected_dataframe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "feda7aba-afa9-45cf-8542-b4de0a3df491",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 5))\n",
    "plt.grid(True)\n",
    "ax = sns.barplot(x='eval', y='CE loss', hue='method', palette='viridis', data=melted_df)\n",
    "ax.axhline(baseline_results['orig_A'], label='orig_A', color='orange')\n",
    "ax.axhline(baseline_results['orig_B'], label='orig_B', color='red')\n",
    "ax.axhline(baseline_results['sae_A'], label='sae_A', linestyle='dashed', color='orange')\n",
    "ax.axhline(baseline_results['sae_B'], label='sae_B', linestyle='dashed', color='red')\n",
    "ax.set(xlabel='method', ylabel='(Reconstruction) CE loss', title=f'{modelA_name}.{layer_A}.pre -> {modelB_name}.{layer_B}.pre')\n",
    "#ax.bar_label(ax.containers[0], fontsize=10)\n",
    "ax.legend()\n",
    "plt.savefig('results/figures/inversion_ablation.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb7cea81-68b1-4239-845e-02d37acc62a2",
   "metadata": {},
   "source": [
    "# RSA-style Analysis (REMEMBER TO NORMALIZE VECTORS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0bfd88dd-9b8b-4a07-a116-a51ae8e49f05",
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_batch_size = 2000"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef555b77-5468-4643-8d4a-7c2dfb0eb49c",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Attribution Correlation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1bf01ea-9a66-4924-9013-85455ce83f41",
   "metadata": {},
   "outputs": [],
   "source": [
    "activations_dir = f'activations_store/fr_fixed_attribution_correlation/eleuther/{run_name_id}/'\n",
    "os.makedirs(activations_dir,exist_ok=True)\n",
    "print(activations_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6042175-db57-4565-8cba-2717fe6d81a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.inference_mode():\n",
    "    for i in np.arange(0, orig_sae_A.d_sae, feature_batch_size):\n",
    "        print(\"batch\", i // feature_batch_size)\n",
    "        feature_subset = np.arange(i, min(orig_sae_A.d_sae, i+feature_batch_size))\n",
    "        logit_weights_A = modelA.unembed(modelA.ln_final(orig_sae_A.W_dec[feature_subset])).cpu().numpy()\n",
    "        logit_weights_B = modelB.unembed(modelB.ln_final(transferred_sae_B.W_dec[feature_subset])).cpu().numpy()\n",
    "        # (feature_batch_size, vocab_size)\n",
    "        orig_sae_activations, next_tokens, _ = feature_activations(\n",
    "            torch.utils.data.DataLoader(\n",
    "                tokenized_dataset['test'][:1000],\n",
    "                batch_size=10,\n",
    "                shuffle=False\n",
    "            ),\n",
    "            modelA,\n",
    "            layer_A,\n",
    "            orig_sae_A,\n",
    "            feature_subset,\n",
    "            preacts=False,\n",
    "            return_next_tokens=True\n",
    "        )\n",
    "        # (total tokens, feature_batch_size), (total_tokens)\n",
    "        stitched_sae_activations = feature_activations(\n",
    "            torch.utils.data.DataLoader(\n",
    "                tokenized_dataset['test'][:1000],\n",
    "                batch_size=10,\n",
    "                shuffle=False\n",
    "            ),\n",
    "            modelB,\n",
    "            layer_B,\n",
    "            transferred_sae_B,\n",
    "            feature_subset,\n",
    "            preacts=False,\n",
    "            return_next_tokens = False\n",
    "        )\n",
    "        assert(orig_sae_activations.shape[0] == next_tokens.shape[0])\n",
    "        assert(stitched_sae_activations.shape[0] == next_tokens.shape[0])\n",
    "        scores_A = orig_sae_activations * logit_weights_A[:, next_tokens].T\n",
    "        scores_B = stitched_sae_activations * logit_weights_B[:, next_tokens].T\n",
    "        #scores_A = orig_sae_activations * logit_weights_A[:, orig_pred_next_tokens].T\n",
    "        #scores_B = stitched_sae_activations * logit_weights_B[:, stitched_pred_next_tokens].T\n",
    "        corrs = {'activation': [], 'attribution': []}\n",
    "        for j in range(scores_A.shape[1]):\n",
    "            pearson_res_attribution = pearsonr(scores_A[:,j], scores_B[:,j], axis=0)\n",
    "            pearson_res_activation = pearsonr(orig_sae_activations[:, j], stitched_sae_activations[:,j], axis=0)\n",
    "            corrs['attribution'].append(pearson_res_attribution.statistic)\n",
    "            corrs['activation'].append(pearson_res_activation.statistic)\n",
    "        results_dict = {\n",
    "            'activation_correlation': np.array(corrs['activation']),\n",
    "            'attribution_correlation': np.array(corrs['attribution']),\n",
    "        }\n",
    "        np.savez(os.path.join(activations_dir, f'metrics_size_{feature_batch_size}_batch_{i}.npz'), **results_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f59d4783-7bd4-48cb-b28c-ad092f3abb1f",
   "metadata": {},
   "outputs": [],
   "source": [
    "activations_dir = f'activations_store/fr_fixed_attribution_correlation/random/'\n",
    "os.makedirs(activations_dir,exist_ok=True)\n",
    "print(activations_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4ccb191-22f1-475a-a7d3-41e40d9fddb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.inference_mode():\n",
    "    for i in np.arange(0, orig_sae_A.d_sae, feature_batch_size):\n",
    "        print(\"batch\", i // feature_batch_size)\n",
    "        feature_subset = np.arange(i, min(orig_sae_A.d_sae, i+feature_batch_size))\n",
    "        logit_weights_A = modelA.unembed(modelA.ln_final(orig_sae_A.W_dec[feature_subset])).cpu().numpy()\n",
    "        logit_weights_B = modelB.unembed(modelB.ln_final(random_sae.W_dec[feature_subset])).cpu().numpy()\n",
    "        # (feature_batch_size, vocab_size)\n",
    "        orig_sae_activations, next_tokens, orig_pred_next_tokens = feature_activations(\n",
    "            torch.utils.data.DataLoader(\n",
    "                tokenized_dataset['test'][:1000],\n",
    "                batch_size=10,\n",
    "                shuffle=False\n",
    "            ),\n",
    "            modelA,\n",
    "            layer_A,\n",
    "            orig_sae_A,\n",
    "            feature_subset,\n",
    "            preacts=False,\n",
    "            return_next_tokens=True\n",
    "        )\n",
    "        # (total tokens, feature_batch_size), (total_tokens)\n",
    "        stitched_sae_activations = feature_activations(\n",
    "            torch.utils.data.DataLoader(\n",
    "                tokenized_dataset['test'][:1000],\n",
    "                batch_size=10,\n",
    "                shuffle=False\n",
    "            ),\n",
    "            modelB,\n",
    "            layer_B,\n",
    "            random_sae,\n",
    "            feature_subset,\n",
    "            preacts=False,\n",
    "            return_next_tokens = False\n",
    "        )\n",
    "        assert(orig_sae_activations.shape[0] == next_tokens.shape[0])\n",
    "        assert(stitched_sae_activations.shape[0] == next_tokens.shape[0])\n",
    "        scores_A = orig_sae_activations * logit_weights_A[:, next_tokens].T\n",
    "        scores_B = stitched_sae_activations * logit_weights_B[:, next_tokens].T\n",
    "        #scores_A = orig_sae_activations * logit_weights_A[:, orig_pred_next_tokens].T\n",
    "        #scores_B = stitched_sae_activations * logit_weights_B[:, stitched_pred_next_tokens].T\n",
    "        corrs = {'activation': [], 'attribution': []}\n",
    "        for j in range(scores_A.shape[1]):\n",
    "            pearson_res_attribution = pearsonr(scores_A[:,j], scores_B[:,j], axis=0)\n",
    "            pearson_res_activation = pearsonr(orig_sae_activations[:, j], stitched_sae_activations[:,j], axis=0)\n",
    "            corrs['attribution'].append(pearson_res_attribution.statistic)\n",
    "            corrs['activation'].append(pearson_res_activation.statistic)\n",
    "        results_dict = {\n",
    "            'activation_correlation': np.array(corrs['activation']),\n",
    "            'attribution_correlation': np.array(corrs['attribution']),\n",
    "        }\n",
    "        np.savez(os.path.join(activations_dir, f'metrics_size_{feature_batch_size}_batch_{i}.npz'), **results_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e2da0fd1-2f22-4274-9901-e8759db63eb7",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Generate Spearman Correlations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "411739ab-911b-437d-8142-3aec697dd7e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "activations_dir = f'activations_store/eleuther/{run_name_id}/'\n",
    "os.makedirs(activations_dir,exist_ok=True)\n",
    "print(activations_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c555ffb-1ba4-4f4a-a36c-b88e590bee81",
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.inference_mode():\n",
    "    for i in np.arange(0, orig_sae_A.d_sae, feature_batch_size):\n",
    "        print(\"batch\", i // feature_batch_size)\n",
    "        feature_subset = np.arange(i, min(orig_sae_A.d_sae, i+feature_batch_size))\n",
    "        orig_sae_activations = feature_activations(\n",
    "            torch.utils.data.DataLoader(\n",
    "                tokenized_dataset['test'][:1000],\n",
    "                batch_size=10,\n",
    "                shuffle=False\n",
    "            ),\n",
    "            modelA,\n",
    "            layer_A,\n",
    "            orig_sae_A,\n",
    "            feature_subset,\n",
    "            preacts=True\n",
    "        )\n",
    "        stitched_sae_activations = feature_activations(\n",
    "            torch.utils.data.DataLoader(\n",
    "                tokenized_dataset['test'][:1000],\n",
    "                batch_size=10,\n",
    "                shuffle=False\n",
    "            ),\n",
    "            modelB,\n",
    "            layer_B,\n",
    "            transferred_sae_B,\n",
    "            feature_subset,\n",
    "            preacts=True\n",
    "        )\n",
    "    \n",
    "        spearmans = {'statistic': [], 'pvalue': []}\n",
    "        for j in range(orig_sae_activations.shape[1]):\n",
    "            spearman_res = spearmanr(orig_sae_activations[:,j], stitched_sae_activations[:,j], axis=0)\n",
    "            spearmans['statistic'].append(spearman_res.statistic)\n",
    "            spearmans['pvalue'].append(spearman_res.pvalue)\n",
    "        p, r, f1 = precision_recall_f1(orig_sae_activations, stitched_sae_activations, axis=0)\n",
    "        max_activation_A = np.maximum(0, orig_sae_activations.max(axis=0))\n",
    "        max_activation_B_to_A = np.maximum(0, stitched_sae_activations.max(axis=0))\n",
    "        results_dict = {\n",
    "            'activation_count_A': (orig_sae_activations > 0).sum(axis=0),\n",
    "            'activation_count_B->A': (stitched_sae_activations > 0).sum(axis=0),\n",
    "            'max_activation_A': max_activation_A,\n",
    "            'max_activation_B->A': max_activation_B_to_A,\n",
    "            'spearman_stat': np.array(spearmans['statistic']),\n",
    "            'spearman_pvalue': np.array(spearmans['pvalue']),\n",
    "            'precision': p,\n",
    "            'recall': r,\n",
    "            'f1': f1\n",
    "        }\n",
    "        np.savez(os.path.join(activations_dir, f'metrics_size_{feature_batch_size}_batch_{i}.npz'), **results_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35405874-3b05-45a5-939c-3bd2f5fd320c",
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.inference_mode():\n",
    "    for i in np.arange(0, orig_sae_A.d_sae, feature_batch_size):\n",
    "        print(\"batch\", i // feature_batch_size)\n",
    "        feature_subset = np.arange(i, min(orig_sae_A.d_sae, i+feature_batch_size))\n",
    "        orig_sae_activations = feature_activations(\n",
    "            torch.utils.data.DataLoader(\n",
    "                tokenized_dataset['test'][:1000],\n",
    "                batch_size=10,\n",
    "                shuffle=False\n",
    "            ),\n",
    "            modelA,\n",
    "            layer_A,\n",
    "            orig_sae_A,\n",
    "            feature_subset,\n",
    "            preacts=True\n",
    "        )\n",
    "        stitched_sae_activations = feature_activations(\n",
    "            torch.utils.data.DataLoader(\n",
    "                tokenized_dataset['test'][:1000],\n",
    "                batch_size=10,\n",
    "                shuffle=False\n",
    "            ),\n",
    "            modelB,\n",
    "            #pythia_70m_orig,\n",
    "            layer_B,\n",
    "            #layer_70m,\n",
    "            transferred_sae_B,\n",
    "            feature_subset,\n",
    "            preacts=True\n",
    "        )\n",
    "    \n",
    "        spearmans = {'statistic': [], 'pvalue': []}\n",
    "        for j in range(orig_sae_activations.shape[1]):\n",
    "            spearman_res = spearmanr(orig_sae_activations[:,j], stitched_sae_activations[:,j], axis=0)\n",
    "            spearmans['statistic'].append(spearman_res.statistic)\n",
    "            spearmans['pvalue'].append(spearman_res.pvalue)\n",
    "        p, r, f1 = precision_recall_f1(orig_sae_activations, stitched_sae_activations, axis=0)\n",
    "        max_activation_A = np.maximum(0, orig_sae_activations.max(axis=0))\n",
    "        max_activation_B_to_A = np.maximum(0, stitched_sae_activations.max(axis=0))\n",
    "        results_dict = {\n",
    "            'activation_count_A': (orig_sae_activations > 0).sum(axis=0),\n",
    "            'activation_count_B->A': (stitched_sae_activations > 0).sum(axis=0),\n",
    "            'max_activation_A': max_activation_A,\n",
    "            'max_activation_B->A': max_activation_B_to_A,\n",
    "            'spearman_stat': np.array(spearmans['statistic']),\n",
    "            'spearman_pvalue': np.array(spearmans['pvalue']),\n",
    "            'precision': p,\n",
    "            'recall': r,\n",
    "            'f1': f1\n",
    "        }\n",
    "        np.savez(os.path.join(activations_dir, f'metrics_size_{feature_batch_size}_batch_{i}.npz'), **results_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "339ec9b4-012f-41b1-89da-6355c9e499e3",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Generate Spearman Correlations for Baselines"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c36b9680-75bf-4a86-84ad-2eaa50519bdc",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### Random P"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7eb5e6c-f00e-4f5a-bd1d-d2d863430fb8",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "with torch.inference_mode():\n",
    "    for i in np.arange(0, orig_sae.d_sae, feature_batch_size):\n",
    "        print(\"batch\", i // feature_batch_size)\n",
    "        feature_subset = np.arange(i, min(sae.cfg.d_sae, i+feature_batch_size))\n",
    "        orig_sae_activations = feature_activations(\n",
    "            torch.utils.data.DataLoader(\n",
    "                tokenized_dataset['test'][:1000],\n",
    "                batch_size=10,\n",
    "                shuffle=False\n",
    "            ),\n",
    "            pythia_70m,\n",
    "            layer_70m,\n",
    "            orig_sae,\n",
    "            feature_subset,\n",
    "            preacts=True\n",
    "        )\n",
    "        stitched_sae_activations = feature_activations(\n",
    "            torch.utils.data.DataLoader(\n",
    "                tokenized_dataset['test'][:1000],\n",
    "                batch_size=10,\n",
    "                shuffle=False\n",
    "            ),\n",
    "            pythia_160m,\n",
    "            layer_160m,\n",
    "            random_P_sae,\n",
    "            feature_subset,\n",
    "            preacts=True\n",
    "        )\n",
    "    \n",
    "        spearmans = {'statistic': [], 'pvalue': []}\n",
    "        for j in range(orig_sae_activations.shape[1]):\n",
    "            spearman_res = spearmanr(orig_sae_activations[:,j], stitched_sae_activations[:,j], axis=0)\n",
    "            spearmans['statistic'].append(spearman_res.statistic)\n",
    "            spearmans['pvalue'].append(spearman_res.pvalue)\n",
    "        p, r, f1 = precision_recall_f1(orig_sae_activations, stitched_sae_activations, axis=0)\n",
    "        max_activation_A = np.maximum(0, orig_sae_activations.max(axis=0))\n",
    "        max_activation_B_to_A = np.maximum(0, stitched_sae_activations.max(axis=0))\n",
    "        results_dict = {\n",
    "            'activation_count_A': (orig_sae_activations > 0).sum(axis=0),\n",
    "            'activation_count_B->A': (stitched_sae_activations > 0).sum(axis=0),\n",
    "            'max_activation_A': max_activation_A,\n",
    "            'max_activation_B->A': max_activation_B_to_A,\n",
    "            'spearman_stat': np.array(spearmans['statistic']),\n",
    "            'spearman_pvalue': np.array(spearmans['pvalue']),\n",
    "            'precision': p,\n",
    "            'recall': r,\n",
    "            'f1': f1\n",
    "        }\n",
    "        np.savez(f'activations_store/{layer_70m}_to_{layer_160m}_random_P_metrics_size_{feature_batch_size}_batch_{i}.npz', **results_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b973af30-4057-4e46-8fc1-e38d9ff34f81",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### Random SAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2f8a589-62c6-4e49-a711-92c3136bcffa",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "with torch.inference_mode():\n",
    "    for i in np.arange(0, orig_sae.d_sae, feature_batch_size):\n",
    "        print(\"batch\", i // feature_batch_size)\n",
    "        feature_subset = np.arange(i, min(sae.cfg.d_sae, i+feature_batch_size))\n",
    "        orig_sae_activations = feature_activations(\n",
    "            torch.utils.data.DataLoader(\n",
    "                tokenized_dataset['test'][:1000],\n",
    "                batch_size=10,\n",
    "                shuffle=False\n",
    "            ),\n",
    "            pythia_70m,\n",
    "            layer_70m,\n",
    "            orig_sae,\n",
    "            feature_subset,\n",
    "            preacts=True\n",
    "        )\n",
    "        stitched_sae_activations = feature_activations(\n",
    "            torch.utils.data.DataLoader(\n",
    "                tokenized_dataset['test'][:1000],\n",
    "                batch_size=10,\n",
    "                shuffle=False\n",
    "            ),\n",
    "            pythia_160m,\n",
    "            layer_160m,\n",
    "            random_sae,\n",
    "            feature_subset,\n",
    "            preacts=True\n",
    "        )\n",
    "    \n",
    "        spearmans = {'statistic': [], 'pvalue': []}\n",
    "        for j in range(orig_sae_activations.shape[1]):\n",
    "            spearman_res = spearmanr(orig_sae_activations[:,j], stitched_sae_activations[:,j], axis=0)\n",
    "            spearmans['statistic'].append(spearman_res.statistic)\n",
    "            spearmans['pvalue'].append(spearman_res.pvalue)\n",
    "        p, r, f1 = precision_recall_f1(orig_sae_activations, stitched_sae_activations, axis=0)\n",
    "        max_activation_A = np.maximum(0, orig_sae_activations.max(axis=0))\n",
    "        max_activation_B_to_A = np.maximum(0, stitched_sae_activations.max(axis=0))\n",
    "        results_dict = {\n",
    "            'activation_count_A': (orig_sae_activations > 0).sum(axis=0),\n",
    "            'activation_count_B->A': (stitched_sae_activations > 0).sum(axis=0),\n",
    "            'max_activation_A': max_activation_A,\n",
    "            'max_activation_B->A': max_activation_B_to_A,\n",
    "            'spearman_stat': np.array(spearmans['statistic']),\n",
    "            'spearman_pvalue': np.array(spearmans['pvalue']),\n",
    "            'precision': p,\n",
    "            'recall': r,\n",
    "            'f1': f1\n",
    "        }\n",
    "        np.savez(f'activations_store/{layer_70m}_to_{layer_160m}_random_sae_metrics_size_{feature_batch_size}_batch_{i}.npz', **results_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "47880d5a-2194-4c4f-8021-bef6d9c2626b",
   "metadata": {},
   "source": [
    "## Plot Histograms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "926f9165-41be-436b-b6a2-31d58afd4f5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "stitched = load_activation_store('metrics', 2000, orig_sae_A.d_sae, subdir=f\"fr_fixed_attribution_correlation/eleuther/{run_name_id}/\")#f\"eleuther/{run_name_id}/\") #\n",
    "random_sae_baseline = load_activation_store('metrics', 2000, orig_sae_A.d_sae, subdir=f\"fr_fixed_attribution_correlation/random/\")#f\"eleuther/{run_name_id}/\") #\n",
    "\n",
    "#simple_stitched = load_activation_store('metrics', 2000, orig_sae_A.d_sae, subdir=f\"eleuther/tough-forest-11/\") #f\"winter-snowball-2/\"\n",
    "#trained_stitched = load_activation_store('trained_metrics', feature_batch_size, orig_sae.d_sae, subdir=f\"{run_name_id}/{corresponding_run_id}/\")\n",
    "#random_P_baseline = load_activation_store(f'{layer_70m}_to_{layer_160m}_random_P_metrics', 2000, sae.cfg.d_sae)\n",
    "#random_sae_baseline = load_activation_store(f'{layer_70m}_to_{layer_160m}_random_sae_metrics', 2000, sae.cfg.d_sae)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "643552c4-03e5-4554-b993-4a2e9385199a",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "sns.histplot(random_sae_baseline['attribution_correlation'], binrange=(-1, 1), binwidth=0.02, label='random SAE', log=True)\n",
    "sns.histplot(stitched['attribution_correlation'], binrange=(-1, 1), binwidth=0.02, label='stitched', log=True)\n",
    "\n",
    "#sns.histplot(random_P_baseline['spearman_stat'], binrange=(-1, 1), binwidth=0.01, label='random orthogonal P')\n",
    "#sns.histplot(simple_stitched['spearman_stat'], binrange=(-1, 1), binwidth=0.01, label='seed stitched')\n",
    "plt.xlabel('correlation')\n",
    "plt.title('pythia')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bebea28b-748d-45be-b4ba-d0a79d2b67b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "#structural_features = torch.load('structural_features.pt', weights_only=True)\n",
    "#low_density = torch.load('low_density_features.pt', weights_only=True)\n",
    "#mid_density = torch.load('mid_density_features.pt', weights_only=True)\n",
    "#high_density = torch.load('high_density_features.pt', weights_only=True)\n",
    "#not_dead = torch.load('not_dead_features.pt', weights_only=True)\n",
    "#pth = \"pythia-70m_marks_structural.pt\"\n",
    "#pth = \"pythia-70m-eleuther_structural.pt\"\n",
    "pth = \"pythia-70m-deduped.3_structural.pt\"\n",
    "res_dict = torch.load(pth, weights_only=True)\n",
    "structural_features, densities = res_dict['structural'], res_dict['densities']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe3d7962-7d8f-47d6-91ef-3c2c1dcb1e80",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.nanmean(stitched['attribution_correlation'][densities > 1e-6])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5361a44-e741-455f-9aef-5039705d318e",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.nanmean(stitched['attribution_correlation'][structural_features & (densities > 1e-6)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "376d0cd1-9716-4a04-b5c1-36e691df7074",
   "metadata": {},
   "outputs": [],
   "source": [
    "bins = 50\n",
    "key = 'attribution_correlation'\n",
    "binrange = (-1, 1)\n",
    "not_dead = densities > 1e-6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5c752ca-3317-4846-858a-2a8ade2ddc70",
   "metadata": {},
   "outputs": [],
   "source": [
    "hist_orig, bin_edges = np.histogram(stitched[key][not_dead], bins=bins, range=binrange)\n",
    "hist_good, bin_edges = np.histogram(stitched[key][structural_features], bins=bins, range=binrange, density=True)\n",
    "#hist_other, bin_edges = np.histogram(stitched['activation_correlation'][torch.logical_and(structural_features, torch.logical_or(high_density, mid_density))], bins=bins, range=(0,1))\n",
    "hist_semantic, bin_edges = np.histogram(stitched[key][~structural_features], bins=bins, range=binrange, density=True)\n",
    "#hist_trash, bin_edges = np.histogram(stitched['activation_correlation'][torch.logical_and(~structural_features, torch.logical_or(high_density, mid_density))], bins=bins, range=(0,1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f398759e-0c16-466d-bb99-f00892524fba",
   "metadata": {},
   "outputs": [],
   "source": [
    "#expected_ratio = len(stitched['spearman_stat'][structural_features]) / len(stitched['spearman_stat'][not_dead])\n",
    "#expected_ratio = np.where(bin_counts1 == 0, bin_counts1, expected_ratio)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80ca3a8d-fca3-4cb3-817e-0f3434163a8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax1 = plt.subplots()\n",
    "plt.grid(True)\n",
    "ax1_color = 'tab:gray'\n",
    "ax1_alpha = 0.5\n",
    "ax2_color = 'tab:blue'\n",
    "ax2_alpha = 1.0\n",
    "ax1.hist(stitched[key][not_dead], bins=bins, range=binrange, alpha=ax1_alpha, color=ax1_color, log=True)\n",
    "ax2 = ax1.twinx()\n",
    "ax2.plot((bin_edges[:-1] + bin_edges[1:]) / 2, np.nan_to_num(hist_good / hist_orig, 0), color='tab:blue', alpha=ax2_alpha,label='structural')\n",
    "#ax2.plot((bin_edges[:-1] + bin_edges[1:]) / 2, np.nan_to_num(hist_other / hist_orig, 0), color='tab:orange', alpha=ax2_alpha,label='structural_hi_density')\n",
    "ax2.plot((bin_edges[:-1] + bin_edges[1:]) / 2, np.nan_to_num(hist_semantic / hist_orig, 0), color='tab:green', alpha=ax2_alpha,label='semantic')\n",
    "#ax2.plot((bin_edges[:-1] + bin_edges[1:]) / 2, np.nan_to_num(hist_trash / hist_orig, 0), color='tab:red', alpha=ax2_alpha,label='semantic_hi_density')\n",
    "\n",
    "# plt.plot([patch.get_x() for patch in hist1.patches], expected_ratio, linestyle='dashed', label='expected_density_ratio')\n",
    "plt.legend()\n",
    "ax1.set_ylabel('alive log count', color=ax1_color)\n",
    "ax2.set_ylabel('frac of bucket',color=ax2_color, alpha=ax2_alpha)\n",
    "ax2.set_ylim(0)\n",
    "ax2.tick_params(axis='y', labelcolor=ax2_color)\n",
    "ax1.tick_params(axis='y', labelcolor=ax1_color)\n",
    "plt.title('Alive Features Histogram')\n",
    "ax1.set_xlabel('spearman correlation')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7bdc3ff6-c6d0-4d02-a658-6c87438a32d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "alives = stitched[key][not_dead]\n",
    "structural_spearmans = alives[structural_features[not_dead]]\n",
    "semantic_spearmans = alives[~structural_features[not_dead]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f1059ad-9f7d-4298-9b5a-dcff69e5a65d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the JointGrid\n",
    "g = sns.JointGrid()\n",
    "\n",
    "# Plot the joint and only the x marginal\n",
    "g.plot_joint(sns.histplot, data=semantic_spearmans, binrange=binrange, bins=bins, stat='density', label='semantic', alpha=0.5)\n",
    "g.plot_joint(sns.histplot, data=structural_spearmans, binrange=binrange, bins=bins, stat='density', label='structural', alpha=0.5)\n",
    "\n",
    "#g.ax_marg_x.plot((bin_edges[:-1] + bin_edges[1:]) / 2, np.nan_to_num(hist_good / hist_orig, 0), color='tab:orange', label='structural')\n",
    "\n",
    "g.ax_marg_x.fill_between((bin_edges[:-1] + bin_edges[1:]) / 2, 0, np.nan_to_num(hist_good / hist_semantic, 0), color='tab:orange',alpha=0.5)\n",
    "g.ax_marg_x.fill_between((bin_edges[:-1] + bin_edges[1:]) / 2, np.nan_to_num(hist_good / hist_semantic, 0), 1000, color='tab:blue',alpha=0.5)\n",
    "g.ax_marg_x.set(ylabel='fraction', ylim=[0,2.0])\n",
    "#g.ax_marg_x.axhline(0.5, color='black', alpha=0.5, linestyle='dashed')\n",
    "\n",
    "#ax2.plot((bin_edges[:-1] + bin_edges[1:]) / 2, np.nan_to_num(hist_other / hist_orig, 0), color='tab:orange', alpha=ax2_alpha,label='structural_hi_density')\n",
    "#ax2.plot((bin_edges[:-1] + bin_edges[1:]) / 2, np.nan_to_num(hist_semantic / hist_orig, 0), color='tab:green', alpha=ax2_alpha,label='semantic')\n",
    "# Hide the y marginal\n",
    "g.ax_marg_y.remove()\n",
    "# Shrink the height of the top marginal plot\n",
    "box = g.ax_marg_x.get_position()\n",
    "g.ax_marg_x.set_position([box.x0, box.y0, box.width, box.height * 1.0])\n",
    "g.ax_marg_x.yaxis.set_visible(True)\n",
    "g.ax_marg_x.grid(True, color='gray')\n",
    "g.ax_marg_x.spines['left'].set_visible(True)\n",
    "g.ax_marg_x.tick_params(axis='y', which='both', left=True, labelleft=True)  # Turn on ticks & labels\n",
    "g.ax_marg_x.set_ylabel(\"frac\")  # Optional\n",
    "g.ax_joint.set_xlabel('correlation')\n",
    "g.ax_joint.set_ylabel('density')\n",
    "g.ax_joint.set_xlim(binrange)\n",
    "g.ax_joint.legend()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79b0a346-6459-42cc-a842-a2500f835dca",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1,1, figsize=(5,4))\n",
    "chub = sns.histplot(semantic_spearmans, binrange=binrange, bins=bins, stat='density', label='semantic', ax=ax,alpha=0.5)\n",
    "sns.histplot(structural_spearmans, binrange=binrange, bins=bins, stat='density', label='structural',ax=ax, alpha=0.5)\n",
    "\n",
    "\n",
    "#axins = inset_axes(ax, width=\"30%\", height=\"30%\", loc=\"center right\")\n",
    "#sns.histplot(semantic_spearmans,binrange=(-0.5, 1.0), bins=100, stat='count',label='semantic', ax=axins,alpha=0.5)\n",
    "#sns.histplot(structural_spearmans, binrange=(-0.5, 1.0), bins=100, stat='count',label='structural',ax=axins, alpha=0.5)\n",
    "#axins.set(ylabel='')\n",
    "\n",
    "# Define zoomed-in region\n",
    "#xlim, ylim = (0.5, 1), (-0.05, 50)  # Adjust these values based on your data\n",
    "#axins.set_xlim(xlim)\n",
    "#axins.set_ylim(ylim)\n",
    "\n",
    "# Draw rectangle and connecting lines\n",
    "#mark_inset(ax, axins, loc1=2, loc2=4, fc=\"none\", ec=\"black\", lw=1)\n",
    "\n",
    "\n",
    "ax.legend()\n",
    "ax.set_xlabel('correlation')\n",
    "ax.set_ylabel('density')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12dcd499-6f59-4299-9e43-2c3bd93ad55c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# bucket by percentile rather than histogram buckets\n",
    "def bucket_indices_by_percentile(values, num_buckets=50):\n",
    "    # Step 1: Sort values and keep track of original indices\n",
    "    sorted_indices = np.argsort(values)\n",
    "    sorted_values = np.array(values)[sorted_indices]\n",
    "    \n",
    "    # Step 2: Compute the bucket boundaries\n",
    "    n = len(values)\n",
    "    bucket_boundaries = [int(n * i / num_buckets) for i in range(num_buckets + 1)]\n",
    "    \n",
    "    # Step 3: Assign indices to buckets\n",
    "    buckets = [[] for _ in range(num_buckets)]\n",
    "    for i, index in enumerate(sorted_indices):\n",
    "        for b in range(num_buckets):\n",
    "            if bucket_boundaries[b] <= i < bucket_boundaries[b + 1]:\n",
    "                buckets[b].append(index)\n",
    "                break\n",
    "    return buckets, bucket_boundaries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2375f9c-edf9-4208-804c-5d265b570555",
   "metadata": {},
   "outputs": [],
   "source": [
    "buckets, bucket_boundaries = bucket_indices_by_percentile(stitched['spearman_stat'][not_dead])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cd97d86-e92e-4861-a7c9-43140a8b4a77",
   "metadata": {},
   "outputs": [],
   "source": [
    "#category_masks = {\n",
    "#    'structural_low_density': structural_and_low[not_dead],\n",
    "#    'structural_hi_density': torch.logical_and(structural_features, torch.logical_or(high_density, mid_density))[not_dead],\n",
    "#    'semantic_low_density': torch.logical_and(~structural_features, low_density)[not_dead],\n",
    "#    'semantic_hi_density': torch.logical_and(~structural_features, torch.logical_or(high_density, mid_density))[not_dead],\n",
    "#}\n",
    "category_masks = {\n",
    "    #'structural_low_density': structural_and_low[not_dead],\n",
    "    #'structural_hi_density': torch.logical_and(structural_features, torch.logical_or(high_density, mid_density))[not_dead],\n",
    "    #'semantic_hi_density': \n",
    "    'structural': structural_features[not_dead],\n",
    "    'semantic': ~structural_features[not_dead]\n",
    "    #'semantic_low_density': torch.logical_and(~structural_features, low_density)[not_dead]\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8fdddf2f-b802-4025-953f-6aefc0873912",
   "metadata": {},
   "outputs": [],
   "source": [
    "category_results = {\n",
    "    category: [] for category in category_masks.keys()\n",
    "}\n",
    "for bucket in buckets:\n",
    "    for category, mask in category_masks.items():\n",
    "        count = mask[bucket].sum()\n",
    "        category_results[category].append(count / len(bucket))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57629c8c-b0db-492b-96fe-8443f54b7c27",
   "metadata": {},
   "outputs": [],
   "source": [
    "from mpl_toolkits.axes_grid1.inset_locator import inset_axes, mark_inset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a06a993-a1d7-45ad-b8fd-2e1e6dafacb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "alives = stitched['spearman_stat'][not_dead]\n",
    "structural_spearmans = alives[structural_features[not_dead]]\n",
    "semantic_spearmans = alives[~structural_features[not_dead]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "791286a1-ff6c-4179-808d-7801c115304f",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1,1, figsize=(5,4))\n",
    "chub = sns.histplot(semantic_spearmans, binrange=(-0.5, 1.0), bins=100, stat='count', label='semantic', ax=ax,alpha=0.5)\n",
    "sns.histplot(structural_spearmans, binrange=(-0.5, 1.0), bins=100, stat='count', label='structural',ax=ax, alpha=0.5)\n",
    "\n",
    "\n",
    "axins = inset_axes(ax, width=\"30%\", height=\"30%\", loc=\"center right\")\n",
    "sns.histplot(semantic_spearmans,binrange=(-0.5, 1.0), bins=100, stat='count',label='semantic', ax=axins,alpha=0.5)\n",
    "sns.histplot(structural_spearmans, binrange=(-0.5, 1.0), bins=100, stat='count',label='structural',ax=axins, alpha=0.5)\n",
    "axins.set(ylabel='')\n",
    "\n",
    "# Define zoomed-in region\n",
    "xlim, ylim = (0.5, 1), (-0.05, 50)  # Adjust these values based on your data\n",
    "axins.set_xlim(xlim)\n",
    "axins.set_ylim(ylim)\n",
    "\n",
    "# Draw rectangle and connecting lines\n",
    "mark_inset(ax, axins, loc1=2, loc2=4, fc=\"none\", ec=\"black\", lw=1)\n",
    "\n",
    "\n",
    "ax.legend()\n",
    "ax.set_xlabel('spearman correlation')\n",
    "ax.set_ylabel('count')\n",
    "ax.set_title('pythia-70m.3_pre->pythia-160m.4_pre')\n",
    "plt.savefig('results/figures/semantic_structural_histogram_70m_160m.svg', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0497959b-2d9a-4e34-ade5-680408667718",
   "metadata": {},
   "source": [
    "# SAE Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12c39dfd-c250-4699-96a2-eb8c59173503",
   "metadata": {},
   "outputs": [],
   "source": [
    "from stitching.sae_viz import get_activations_for_feature, display_acts, highlight_tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99e5a03c-2fc1-4d22-9061-581e36f874f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "feat_idx = 27154  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8cc1431-a9fc-4d5a-a73b-885c7061cca2",
   "metadata": {},
   "outputs": [],
   "source": [
    "acts, toks = get_activations_for_feature(val_dataloader, feat_idx, modelA, layer_A, orig_sae_A)\n",
    "stitched_acts, stitched_toks = get_activations_for_feature(val_dataloader, feat_idx, modelB, layer_B, transferred_sae)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "558c5e4d-4244-4f66-a8d7-bbdc54a30266",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "display_acts(toks,acts,modelA,k=100,ctx=25,upper_cap=50,display_density=False,verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3c63160-eeb3-4ecb-a593-9f53c9cd1a46",
   "metadata": {},
   "outputs": [],
   "source": [
    "display_acts(stitched_toks,stitched_acts,k=10, ctx=25, display_density=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f485ef2-0c10-4fb0-b427-74a681c449ff",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# get activations on text\n",
    "@torch.inference_mode()\n",
    "def get_top_features(sample, model, layer, sae, k, measurement='max'):\n",
    "    # mark out max activation\n",
    "    device = next(model.parameters()).device\n",
    "    sample = sample.to(device)\n",
    "    logits = model(sample, stop_at_layer=layer)\n",
    "    acts = sae.encode(logits)\n",
    "    ignore_mask = (sample == 0)\n",
    "    acts = acts[~ignore_mask]\n",
    "\n",
    "    # we want to get \n",
    "    if measurement == 'max':\n",
    "        over_sample_top_features = acts.max(dim=0).values\n",
    "    elif measurement == 'mean':\n",
    "        over_sample_top_features = acts.mean(dim=0)\n",
    "    \n",
    "    over_sample_top_features = over_sample_top_features.argsort(dim=-1, descending=True)[...,:k]\n",
    "    activations_over_text = acts[..., over_sample_top_features]\n",
    "    \n",
    "    return over_sample_top_features.cpu().numpy(), activations_over_text.cpu().numpy(), sample[~ignore_mask].cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb58af86-b1c5-42da-9335-c80dfa1f3a5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample = \"Era el mejor de los tiempos y era el peor de los tiempos; la edad de la sabiduría y también de la locura; la época de las creencias y de la incredulidad; la era de la luz y de las tinieblas; la primavera de la esperanza y el invierno de la desesperación. Todo lo poseíamos, pero nada teníamos; íbamos directamente al cielo y nos extraviábamos en el camino opuesto. En una palabra, aquella época era tan parecida a la actual, que nuestras más notables autoridades insisten en que, tanto en lo que se refiere al bien como al mal, sólo es aceptable la comparación en grado superlativo.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "031a0a7e-3e8e-4b18-acca-ba5a961d7d4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "top_features, top_feature_activations, masked_sample = get_top_features(modelA.to_tokens(sample), modelA, layer_A, orig_sae_A, 20, measurement='mean')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68e62b0b-5f2a-43d1-9140-354e268e00e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(top_feature_activations.shape[1]):\n",
    "    print(top_features[i], top_feature_activations[:, i].mean())\n",
    "    highlight_tokens(modelA.to_str_tokens(masked_sample),top_feature_activations[:,i])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
