{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6a2e888",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb00e2a4-8e37-4a38-87dd-596ceb9c82f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.notebook import tqdm\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "    \n",
    "from transformer_lens import HookedTransformer\n",
    "import transformer_lens.utils as utils\n",
    "from transformers import AutoConfig\n",
    "from datasets import load_dataset\n",
    "from huggingface_hub import hf_hub_download\n",
    "from safetensors.torch import load_model\n",
    "##########\n",
    "import torch\n",
    "import json\n",
    "import math\n",
    "import einops \n",
    "import numpy as np\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from torch import nn\n",
    "from transformer_lens import HookedTransformer\n",
    "from transformer_lens import utils\n",
    "\n",
    "from turbo_sae.sae import KronSAE, TopKSAE\n",
    "from turbo_sae.sae.config import TrainingConfig, SAEConfig\n",
    "import wandb\n",
    "\n",
    "from sklearn.manifold import TSNE\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "\n",
    "import transformer_lens\n",
    "import transformer_lens.utils as utils\n",
    "from transformer_lens.hook_points import (\n",
    "    HookPoint,\n",
    ")  # Hooking utilities\n",
    "from transformer_lens import HookedTransformer, FactoredMatrix\n",
    "\n",
    "import einops\n",
    "from fancy_einsum import einsum\n",
    "import tqdm.auto as tqdm\n",
    "import plotly.express as px\n",
    "\n",
    "from functools import partial\n",
    "\n",
    "import wandb\n",
    "from huggingface_hub.hf_api import HfFolder\n",
    "\n",
    "device = \"cuda\"\n",
    "\n",
    "def get_sae(sae_type: str, path):\n",
    "    # Load transformer model\n",
    "    sae_class = {\n",
    "        \"topk\": TopKSAE,\n",
    "        \"kronsae\": KronSAE\n",
    "    }[sae_type]\n",
    "    \n",
    "    with open(f'{path}/config.json') as json_file:\n",
    "        data = json.load(json_file)\n",
    "        data['dtype'] = \"float32\"\n",
    "        data['sae_dtype'] = \"float32\"\n",
    "    \n",
    "    train_config = TrainingConfig(**data)\n",
    "    sae_config = SAEConfig.from_training_config(train_config)\n",
    "\n",
    "    sae = sae_class(sae_config)\n",
    "    weights = torch.load(f\"{path}/sae.pt\", weights_only=True)\n",
    "    sae.load_state_dict(weights, strict=False)\n",
    "    sae.to(device)\n",
    "    return {'sae': sae, 'config': sae_config}\n",
    "    \n",
    "sae_paths = {\n",
    "    \"topk\" : \"saes/16k_topk\", \n",
    "    \"kronsae\" : \"saes/16k_kronsae_4_16\",\n",
    "}\n",
    "d = get_sae(\"topk\", \"saes/1_5b_qwen/32k_topk\")\n",
    "topk_sae = d[\"sae\"]\n",
    "topk_sae_config = d[\"config\"]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09234dcd-458a-47e7-ac22-03f40797adfa",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36cc8051",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "A = torch.load('mul_corr.pt', map_location='cpu').float().numpy()\n",
    "B = torch.load('topk_corr.pt', map_location='cpu').float().numpy()\n",
    "\n",
    "np.fill_diagonal(A, 0)\n",
    "np.fill_diagonal(B, 0)\n",
    "\n",
    "hist_mul = np.sort(A, axis=1).mean(axis=0)\n",
    "hist_topk = np.sort(B, axis=1).mean(axis=0)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c284eeae",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "# Assuming hist_mul and hist_topk are your data arrays\n",
    "fig, ax = plt.subplots(1, 1, figsize=(8, 3), dpi=250)\n",
    "\n",
    "palette = sns.color_palette('Set2', 2)\n",
    "\n",
    "# Plot both positive and negative values together using symlog scale\n",
    "sns.histplot(hist_mul,   log_scale=False, stat=\"density\", bins=256 + 128, \n",
    "             label=\"KronSAE\", ax=ax, element=\"step\", fill=False, color=palette[0])\n",
    "sns.histplot(hist_topk, log_scale=False, stat=\"density\", bins=256 + 128, \n",
    "             label=\"TopK\", ax=ax, element=\"step\", fill=False, color=palette[1])\n",
    "\n",
    "# Set symmetric logarithmic x-axis\n",
    "ax.set_xscale('symlog', linthresh=2e-3)  # Adjust linthresh based on your data\n",
    "\n",
    "# scales = [1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9, 1e-10, 1e-11]\n",
    "# scales = [-s for s in scales] + [0] + list(reversed(scales))\n",
    "\n",
    "# Customize ticks to show negative/positive explicitly\n",
    "# ax.set_xticks(scales)\n",
    "# ax.set_xticklabels([f'{s:.0e}' for s in scales])\n",
    "\n",
    "# Log scale for y-axis to see density differences\n",
    "ax.set_yscale('log')\n",
    "ax.grid(alpha=0.35)\n",
    "# ax.set_title('Positive & Negative Values (Symmetric Log Scale)')\n",
    "ax.legend(loc='upper right')\n",
    "\n",
    "# plt.axvline(np.median(hist_mul), linestyle='--',  color=palette[0])\n",
    "# plt.axvline(np.median(hist_topk), linestyle='--',  color=palette[1])\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13d8d697",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"KronSAE: mean {hist_mul.mean():.4e}, median {np.median(hist_mul):.5f}\")\n",
    "print(f\"TopK: mean {hist_topk.mean():.4e}, median {np.median(hist_topk):.5f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71499f5f-58c7-4e0e-b696-1e8bf79e8000",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.stats import ttest_ind\n",
    "\n",
    "# Example usage:\n",
    "corr_matrix = torch.tensor(A)\n",
    "group_size = 16\n",
    "max_samples: int = 100000\n",
    "random_seed: int = 42\n",
    "\"\"\"\n",
    "Compare within-group and between-group correlation distributions.\n",
    "\n",
    "Parameters:\n",
    "- corr_matrix: torch.Tensor of shape (N, N), correlation matrix.\n",
    "- group_size: int, size k of each group (assumed contiguous blocks).\n",
    "- max_samples: int, maximum number of samples per distribution to plot/compare.\n",
    "- random_seed: int, for reproducibility of sampling.\n",
    "\n",
    "Returns:\n",
    "- t_stat: float, t-test statistic comparing inner vs. between.\n",
    "- p_value: float, two-sided p-value of the test.\n",
    "\"\"\"\n",
    "torch.manual_seed(random_seed)\n",
    "N = corr_matrix.size(0)\n",
    "if N % group_size != 0:\n",
    "    raise ValueError(\"Matrix size must be divisible by group_size\")\n",
    "num_groups = N // group_size\n",
    "\n",
    "inner_vals = []\n",
    "between_vals = []\n",
    "\n",
    "# Collect samples\n",
    "for g in range(num_groups):\n",
    "    s1 = g * group_size\n",
    "    e1 = s1 + group_size\n",
    "    block = corr_matrix[s1:e1, s1:e1]\n",
    "    triu = torch.triu_indices(group_size, group_size, offset=1)\n",
    "    inner_vals.append(block[triu[0], triu[1]])\n",
    "\n",
    "    # sample sqrt(num_groups) other groups\n",
    "    others = list(range(num_groups))\n",
    "    others.remove(g)\n",
    "    k_sample = min(len(others), int(np.sqrt(num_groups)))\n",
    "    chosen = np.random.choice(others, size=k_sample, replace=False)\n",
    "    for h in chosen:\n",
    "        s2 = h * group_size\n",
    "        e2 = s2 + group_size\n",
    "        between_vals.append(corr_matrix[s1:e1, s2:e2].flatten())\n",
    "\n",
    "inner = torch.cat(inner_vals)\n",
    "between = torch.cat(between_vals)\n",
    "\n",
    "# subsample if needed\n",
    "def subsample(tensor):\n",
    "    total = tensor.numel()\n",
    "    if total > max_samples:\n",
    "        idx = torch.randperm(total)[:max_samples]\n",
    "        return tensor[idx]\n",
    "    return tensor\n",
    "\n",
    "inner = subsample(inner)\n",
    "between = subsample(between)\n",
    "\n",
    "inner_np = inner.cpu().numpy()\n",
    "between_np = between.cpu().numpy()\n",
    "\n",
    "combined = np.concatenate((inner_np, between_np))\n",
    "min_val, max_val = combined.min(), combined.max()\n",
    "num_bins = 36\n",
    "bins = np.linspace(min_val, max_val, num_bins + 1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7f45f06-6e80-41fb-8388-8c9ee4598ba9",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "# # Plot both histograms on one axis with same bins\n",
    "# plt.figure(figsize=(8, 6))\n",
    "# sns.histplot(inner_np, bins=bins, kde=False, stat=\"density\",\n",
    "#              label='Within-Group', alpha=0.6)\n",
    "# sns.histplot(between_np, bins=bins, kde=False, stat=\"density\",\n",
    "#              label='Between-Group', alpha=0.6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfc79958-2132-473b-9597-1cb3c4d1dddb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot both histograms on one axes\n",
    "plt.figure(figsize=(8, 6))\n",
    "sns.histplot(inner_np, bins=bins,  kde=False, stat=\"density\", label='Within-Group', alpha=0.6)\n",
    "sns.histplot(between_np, bins=bins,  kde=False, stat=\"density\", label='Between-Group', alpha=0.6)\n",
    "\n",
    "# Avoid zero-density dominating by setting bins away from zero and using logarithmic scales\n",
    "#plt.xscale('symlog')\n",
    "plt.yscale('log')\n",
    "plt.xlabel('Correlation Value')\n",
    "plt.ylabel('Density (log scale)')\n",
    "plt.title('Within vs. Between Group Correlation Distributions')\n",
    "plt.legend()\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "# t-test\n",
    "t_stat, p_value = ttest_ind(inner_np, between_np, equal_var=False)\n",
    "print(f\"T-statistic: {t_stat:.4f}, p-value: {p_value:.4e}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0be3689f-2746-416d-93ed-b77e4658a3bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"KronSAE: inner corr mean {inner_np.mean():.4e}, median {np.median(inner_np):.5f}\")\n",
    "print(f\"KronSAE: inner between mean {between_np.mean():.4e}, median {np.median(between_np):.5f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "103a7a33-c44a-4080-8c02-9c930d8f4063",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "# Assuming hist_mul and hist_topk are your data arrays\n",
    "fig, ax = plt.subplots(1, 2, figsize=(10, 3), dpi=250)\n",
    "\n",
    "palette = sns.color_palette('Set2', 2)\n",
    "\n",
    "# Plot both positive and negative values together using symlog scale\n",
    "sns.histplot(hist_mul,   log_scale=False, stat=\"density\", bins=256 + 128, \n",
    "             label=\"KronSAE\", ax=ax[0], element=\"step\", fill=False, color=palette[0])\n",
    "sns.histplot(hist_topk, log_scale=False, stat=\"density\", bins=256 + 128, \n",
    "             label=\"TopK\", ax=ax[0], element=\"step\", fill=False, color=palette[1])\n",
    "\n",
    "# Set symmetric logarithmic x-axis\n",
    "ax[0].set_xscale('symlog', linthresh=2e-3)  # Adjust linthresh based on your data\n",
    "ax[0].set_ylabel('Density (log scale)')\n",
    "\n",
    "# scales = [1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9, 1e-10, 1e-11]\n",
    "# scales = [-s for s in scales] + [0] + list(reversed(scales))\n",
    "\n",
    "# Customize ticks to show negative/positive explicitly\n",
    "# ax.set_xticks(scales)\n",
    "# ax.set_xticklabels([f'{s:.0e}' for s in scales])\n",
    "\n",
    "# Log scale for y-axis to see density differences\n",
    "ax[0].set_yscale('log')\n",
    "ax[0].grid(alpha=0.35)\n",
    "# ax.set_title('Positive & Negative Values (Symmetric Log Scale)')\n",
    "ax[0].legend(loc='upper right')\n",
    "\n",
    "# plt.axvline(np.median(hist_mul), linestyle='--',  color=palette[0])\n",
    "# plt.axvline(np.median(hist_topk), linestyle='--',  color=palette[1])\n",
    "\n",
    "sns.histplot(inner_np, bins=bins, ax=ax[1], kde=False, stat=\"density\", label='Within-Group', alpha=0.6)\n",
    "sns.histplot(between_np, bins=bins, ax=ax[1], kde=False, stat=\"density\", label='Between-Group', alpha=0.6)\n",
    "\n",
    "# Avoid zero-density dominating by setting bins away from zero and using logarithmic scales\n",
    "#plt.xscale('symlog')\n",
    "ax[1].set_yscale('log')\n",
    "#ax[1].set_xlabel('Correlation Value')\n",
    "ax[1].set_ylabel('')\n",
    "#ax[1].set_title('Within vs. Between Group Correlation Distributions')\n",
    "ax[1].grid(alpha=0.35)\n",
    "fig.text(0.5, -0.00, 'Correlation Value', ha='center')\n",
    "plt.legend()\n",
    "plt.tight_layout()\n",
    "plt.savefig(f'corrs.pdf')\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
