{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27ab6e34-640d-48e3-be30-0f85d9c9683b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!pip install scikit-learn\n",
    "!pip install seaborn\n",
    "!pip install -r turbo_sae/requirements.txt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa8d1500-fa07-44ce-8496-999a5309ae02",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.notebook import tqdm\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "    \n",
    "from transformer_lens import HookedTransformer\n",
    "import transformer_lens.utils as utils\n",
    "from transformers import AutoConfig\n",
    "from datasets import load_dataset\n",
    "from huggingface_hub import hf_hub_download\n",
    "from safetensors.torch import load_model\n",
    "##########\n",
    "import torch\n",
    "import json\n",
    "import math\n",
    "import einops \n",
    "import numpy as np\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from torch import nn\n",
    "from transformer_lens import HookedTransformer\n",
    "from transformer_lens import utils\n",
    "\n",
    "from turbo_sae.sae import KronSAE, TopKSAE\n",
    "from turbo_sae.sae.config import TrainingConfig, SAEConfig\n",
    "import wandb\n",
    "\n",
    "from sklearn.manifold import TSNE\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "\n",
    "import transformer_lens\n",
    "import transformer_lens.utils as utils\n",
    "from transformer_lens.hook_points import (\n",
    "    HookPoint,\n",
    ")  # Hooking utilities\n",
    "from transformer_lens import HookedTransformer, FactoredMatrix\n",
    "\n",
    "import einops\n",
    "from fancy_einsum import einsum\n",
    "import tqdm.auto as tqdm\n",
    "import plotly.express as px\n",
    "\n",
    "from functools import partial\n",
    "\n",
    "import wandb\n",
    "from huggingface_hub.hf_api import HfFolder\n",
    "\n",
    "device = \"cuda\"\n",
    "hook_point = \"resid_post\"\n",
    "layer_num = 5\n",
    "torch.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f853394e-40d6-463c-bba2-9f9ab23f8509",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = HookedTransformer.from_pretrained(\"Qwen/Qwen2.5-1.5B\", dtype=torch.bfloat16).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8cc9572-0f7f-4b46-95e6-607f8120b9f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_sae(sae_type: str, path):\n",
    "    # Load transformer model\n",
    "    sae_class = {\n",
    "        \"topk\": TopKSAE,\n",
    "        \"kronsae\": KronSAE\n",
    "    }[sae_type]\n",
    "    \n",
    "    with open(f'{path}/config.json') as json_file:\n",
    "        data = json.load(json_file)\n",
    "        data['dtype'] = \"float32\"\n",
    "        data['sae_dtype'] = \"float32\"\n",
    "    \n",
    "    train_config = TrainingConfig(**data)\n",
    "    sae_config = SAEConfig.from_training_config(train_config)\n",
    "\n",
    "    sae = sae_class(sae_config)\n",
    "    weights = torch.load(f\"{path}/sae.pt\", weights_only=True)\n",
    "    sae.load_state_dict(weights, strict=False)\n",
    "    sae.to(device)\n",
    "    return {'sae': sae, 'config': sae_config}\n",
    "    \n",
    "# sae_paths = {\n",
    "#     \"topk\" : \"saes/16k_topk\", \n",
    "#     \"kronsae\" : \"saes/16k_kronsae_4_16\",\n",
    "# }\n",
    "d = get_sae(\"topk\", \"saes/1_5b_qwen/32k_topk\")\n",
    "topk_sae = d[\"sae\"]\n",
    "topk_sae_config = d[\"config\"]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2538d746-5831-4cb3-9d4b-cc826cdc8afb",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = load_dataset(\"HuggingFaceFW/fineweb-edu\", split=\"train\", name=\"CC-MAIN-2024-51\", streaming=True, trust_remote_code=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "169d5dc0-32f8-4d50-b9a4-498e4671d1ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from datasets import load_dataset\n",
    "\n",
    "# --- CONFIG ----\n",
    "device = torch.device(device) #\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "hook_point = utils.get_act_name(\"resid_post\", 14)\n",
    "# topk_sae = ...        # your SAE encoder, already on `device`\n",
    "# model = ...           # your model, already on `device`\n",
    "\n",
    "# --- RUNNING STATISTICS STATE ---\n",
    "mu = None           # Tensor[D], on `device`\n",
    "M2 = None           # Tensor[D, D], on `device`\n",
    "n  = 0              # total samples seen\n",
    "\n",
    "@torch.no_grad()  # disable gradients inside hook\n",
    "def update_online_cov(x: torch.Tensor):\n",
    "    \"\"\"\n",
    "    x: Tensor of shape (N_batch, D) already on `device` and float32.\n",
    "    Updates global mu, M2, and n in-place.\n",
    "    \"\"\"\n",
    "    global mu, M2, n\n",
    "\n",
    "    N, D = x.shape\n",
    "    n_new = n + N\n",
    "\n",
    "    if mu is None:\n",
    "        # First batch\n",
    "        mu = x.mean(dim=0)                               # [D]\n",
    "        M2 = (x - mu).t().mm(x - mu)                     # [D, D]\n",
    "        n = N\n",
    "        return\n",
    "\n",
    "    # batch mean\n",
    "    mb = x.mean(dim=0)                                   # [D]\n",
    "    delta = mb - mu                                      # [D]\n",
    "    mu = mu + delta * (N / n_new)                        # [D]\n",
    "\n",
    "    # within-batch scatter\n",
    "    dev = x - mb                                         # [N, D]\n",
    "    M2_batch = dev.t().mm(dev)                           # [D, D]\n",
    "\n",
    "    # cross term\n",
    "    cross = torch.outer(delta, delta) * (n * N / n_new)  # [D, D]\n",
    "\n",
    "    # in-place update\n",
    "    M2 += M2_batch\n",
    "    M2 += cross\n",
    "\n",
    "    n = n_new\n",
    "\n",
    "def resid_post_hook(resid_post, hook):\n",
    "    # resid_post: [B, T, D], float32 on CPU by default?\n",
    "    B, T, D = resid_post.shape\n",
    "    # encode directly on `device`\n",
    "    flat = resid_post.view(B * T, D).float().to(device)\n",
    "    hidden = topk_sae.encode(flat)                       # already on device\n",
    "    update_online_cov(hidden)                            # fast, no transfers\n",
    "    return resid_post\n",
    "\n",
    "# --- MAIN LOOP ---\n",
    "model.to(device)\n",
    "topk_sae.to(device)\n",
    "model.eval()                                            # switch to eval mode\n",
    "\n",
    "# replace streaming=True if you need it\n",
    "# dataset = load_dataset(\"HuggingFaceFW/fineweb-edu\",\n",
    "#                        split=\"train\",\n",
    "#                        streaming=True,\n",
    "#                        trust_remote_code=True)\n",
    "iter_ds = iter(dataset)\n",
    "\n",
    "num_iters = 5000\n",
    "with torch.no_grad():  # full-loop no-grad\n",
    "    for _ in tqdm.tqdm(range(num_iters)):\n",
    "        sample = next(iter_ds)\n",
    "        tokens = model.to_tokens(sample[\"text\"]).to(device)\n",
    "        _ = model.run_with_hooks(\n",
    "            tokens,\n",
    "            fwd_hooks=[(hook_point, resid_post_hook)]\n",
    "        )\n",
    "\n",
    "# --- FINALIZE ---\n",
    "cov = M2 / (n - 1)                                      # unbiased covariance\n",
    "std = torch.sqrt(torch.diag(cov))                      # [D]\n",
    "corrA = cov / (std[:, None] * std[None, :] + 1e-6)      # [D, D]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25885fe2-5f45-4b6e-8d2b-1a6995415bfb",
   "metadata": {},
   "outputs": [],
   "source": [
    "d = get_sae(\"topk\", \"saes/1_5b_qwen/32k_topk\")\n",
    "topk_sae_random = d[\"sae\"]\n",
    "torch.nn.init.uniform_(\n",
    "                topk_sae_random.W_dec\n",
    "            )\n",
    "torch.nn.init.uniform_(\n",
    "                topk_sae_random.W_enc\n",
    "            )\n",
    "torch.nn.init.uniform_(\n",
    "                topk_sae_random.b_dec\n",
    "            )\n",
    "torch.nn.init.uniform_(\n",
    "                topk_sae_random.b_enc\n",
    "            )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac2a5fe6-f69e-432d-b3be-188f3b40f777",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from datasets import load_dataset\n",
    "\n",
    "# --- CONFIG ----\n",
    "device = torch.device(device) #\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "hook_point = utils.get_act_name(\"resid_post\", 14)\n",
    "# topk_sae = ...        # your SAE encoder, already on `device`\n",
    "# model = ...           # your model, already on `device`\n",
    "\n",
    "# --- RUNNING STATISTICS STATE ---\n",
    "mu = None           # Tensor[D], on `device`\n",
    "M2 = None           # Tensor[D, D], on `device`\n",
    "n  = 0              # total samples seen\n",
    "\n",
    "@torch.no_grad()  # disable gradients inside hook\n",
    "def update_online_cov(x: torch.Tensor):\n",
    "    \"\"\"\n",
    "    x: Tensor of shape (N_batch, D) already on `device` and float32.\n",
    "    Updates global mu, M2, and n in-place.\n",
    "    \"\"\"\n",
    "    global mu, M2, n\n",
    "\n",
    "    N, D = x.shape\n",
    "    n_new = n + N\n",
    "\n",
    "    if mu is None:\n",
    "        # First batch\n",
    "        mu = x.mean(dim=0)                               # [D]\n",
    "        M2 = (x - mu).t().mm(x - mu)                     # [D, D]\n",
    "        n = N\n",
    "        return\n",
    "\n",
    "    # batch mean\n",
    "    mb = x.mean(dim=0)                                   # [D]\n",
    "    delta = mb - mu                                      # [D]\n",
    "    mu = mu + delta * (N / n_new)                        # [D]\n",
    "\n",
    "    # within-batch scatter\n",
    "    dev = x - mb                                         # [N, D]\n",
    "    M2_batch = dev.t().mm(dev)                           # [D, D]\n",
    "\n",
    "    # cross term\n",
    "    cross = torch.outer(delta, delta) * (n * N / n_new)  # [D, D]\n",
    "\n",
    "    # in-place update\n",
    "    M2 += M2_batch\n",
    "    M2 += cross\n",
    "\n",
    "    n = n_new\n",
    "\n",
    "def resid_post_hook(resid_post, hook):\n",
    "    # resid_post: [B, T, D], float32 on CPU by default?\n",
    "    B, T, D = resid_post.shape\n",
    "    # encode directly on `device`\n",
    "    flat = resid_post.view(B * T, D).float().to(device)\n",
    "    hidden = topk_sae_random.encode(flat)                       # already on device\n",
    "    update_online_cov(hidden)                            # fast, no transfers\n",
    "    return resid_post\n",
    "\n",
    "# --- MAIN LOOP ---\n",
    "model.to(device)\n",
    "topk_sae.to(device)\n",
    "model.eval()                                            # switch to eval mode\n",
    "\n",
    "# replace streaming=True if you need it\n",
    "# dataset = load_dataset(\"HuggingFaceFW/fineweb-edu\",\n",
    "#                        split=\"train\",\n",
    "#                        streaming=True,\n",
    "#                        trust_remote_code=True)\n",
    "iter_ds = iter(dataset)\n",
    "\n",
    "num_iters = 5000\n",
    "with torch.no_grad():  # full-loop no-grad\n",
    "    for _ in tqdm.tqdm(range(num_iters)):\n",
    "        sample = next(iter_ds)\n",
    "        tokens = model.to_tokens(sample[\"text\"]).to(device)\n",
    "        _ = model.run_with_hooks(\n",
    "            tokens,\n",
    "            fwd_hooks=[(hook_point, resid_post_hook)]\n",
    "        )\n",
    "\n",
    "# --- FINALIZE ---\n",
    "cov = M2 / (n - 1)                                      # unbiased covariance\n",
    "std = torch.sqrt(torch.diag(cov))                      # [D]\n",
    "corrA_rand = cov / (std[:, None] * std[None, :] + 1e-6)      # [D, D]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e8b04ad-2ac8-449a-a051-c57743c3e55a",
   "metadata": {},
   "outputs": [],
   "source": [
    "#torch.save(corr, \"topk_corr.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b949342-3f38-4c03-85ca-c9b1744a8f57",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a70ead0a-9893-418a-b52b-6cdbf12155f0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d98d31c-bbf8-479e-8d3d-5caa3ceec967",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74bf02a0-df6f-44a6-848e-71f126d2cafc",
   "metadata": {},
   "outputs": [],
   "source": [
    "sae_paths = {\n",
    "    \"topk\" : \"saes/16k_topk\", \n",
    "    \"kronsae\" : \"saes/16k_kronsae_4_16\",\n",
    "}\n",
    "d = get_sae(\"kronsae\", \"saes/1_5b_qwen/32k_kronsae\")\n",
    "kronsae = d[\"sae\"]\n",
    "kronsae_config = d[\"config\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2502e1a-ca66-4ec4-89a8-e0a380471ff5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from datasets import load_dataset\n",
    "\n",
    "# --- CONFIG ----\n",
    "device = torch.device(device) #\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "hook_point = utils.get_act_name(\"resid_post\", 14)\n",
    "# topk_sae = ...        # your SAE encoder, already on `device`\n",
    "# model = ...           # your model, already on `device`\n",
    "\n",
    "# --- RUNNING STATISTICS STATE ---\n",
    "mu = None           # Tensor[D], on `device`\n",
    "M2 = None           # Tensor[D, D], on `device`\n",
    "n  = 0              # total samples seen\n",
    "\n",
    "@torch.no_grad()  # disable gradients inside hook\n",
    "def update_online_cov(x: torch.Tensor):\n",
    "    \"\"\"\n",
    "    x: Tensor of shape (N_batch, D) already on `device` and float32.\n",
    "    Updates global mu, M2, and n in-place.\n",
    "    \"\"\"\n",
    "    global mu, M2, n\n",
    "\n",
    "    N, D = x.shape\n",
    "    n_new = n + N\n",
    "\n",
    "    if mu is None:\n",
    "        # First batch\n",
    "        mu = x.mean(dim=0)                               # [D]\n",
    "        M2 = (x - mu).t().mm(x - mu)                     # [D, D]\n",
    "        n = N\n",
    "        return\n",
    "\n",
    "    # batch mean\n",
    "    mb = x.mean(dim=0)                                   # [D]\n",
    "    delta = mb - mu                                      # [D]\n",
    "    mu = mu + delta * (N / n_new)                        # [D]\n",
    "\n",
    "    # within-batch scatter\n",
    "    dev = x - mb                                         # [N, D]\n",
    "    M2_batch = dev.t().mm(dev)                           # [D, D]\n",
    "\n",
    "    # cross term\n",
    "    cross = torch.outer(delta, delta) * (n * N / n_new)  # [D, D]\n",
    "\n",
    "    # in-place update\n",
    "    M2 += M2_batch\n",
    "    M2 += cross\n",
    "\n",
    "    n = n_new\n",
    "\n",
    "def resid_post_hook(resid_post, hook):\n",
    "    # resid_post: [B, T, D], float32 on CPU by default?\n",
    "    B, T, D = resid_post.shape\n",
    "    # encode directly on `device`\n",
    "    flat = resid_post.view(B * T, D).float().to(device)\n",
    "    hidden = kronsae.encode(flat)                       # already on device\n",
    "    update_online_cov(hidden)                            # fast, no transfers\n",
    "    return resid_post\n",
    "\n",
    "# --- MAIN LOOP ---\n",
    "model.to(device)\n",
    "topk_sae.to(device)\n",
    "model.eval()                                            # switch to eval mode\n",
    "\n",
    "# replace streaming=True if you need it\n",
    "#dataset = load_dataset(\"HuggingFaceFW/fineweb-edu\", split=\"train\", name=\"CC-MAIN-2024-51\", streaming=True, trust_remote_code=True)\n",
    "iter_ds = iter(dataset)\n",
    "\n",
    "num_iters = 5000\n",
    "with torch.no_grad():  # full-loop no-grad\n",
    "    for _ in tqdm.tqdm(range(num_iters)):\n",
    "        sample = next(iter_ds)\n",
    "        tokens = model.to_tokens(sample[\"text\"]).to(device)\n",
    "        _ = model.run_with_hooks(\n",
    "            tokens,\n",
    "            fwd_hooks=[(hook_point, resid_post_hook)]\n",
    "        )\n",
    "\n",
    "# --- FINALIZE ---\n",
    "cov = M2 / (n - 1)                                      # unbiased covariance\n",
    "std = torch.sqrt(torch.diag(cov))                      # [D]\n",
    "corrB = cov / (std[:, None] * std[None, :] + 1e-6)      # [D, D]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "daa374ed-84a7-4252-90de-68574f8e926a",
   "metadata": {},
   "outputs": [],
   "source": [
    "d = get_sae(\"kronsae\", \"saes/1_5b_qwen/32k_kronsae\")\n",
    "kronsae_random = d[\"sae\"]\n",
    "torch.nn.init.uniform_(\n",
    "                kronsae_random.W_dec\n",
    "            )\n",
    "torch.nn.init.uniform_(\n",
    "                kronsae_random.W_enc\n",
    "            )\n",
    "torch.nn.init.uniform_(\n",
    "                kronsae_random.b_dec\n",
    "            )\n",
    "torch.nn.init.uniform_(\n",
    "                kronsae_random.b_enc\n",
    "            )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d2cdd8d-8132-4087-821e-62b1c5b2a8c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from datasets import load_dataset\n",
    "\n",
    "# --- CONFIG ----\n",
    "device = torch.device(device) #\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "hook_point = utils.get_act_name(\"resid_post\", 14)\n",
    "# topk_sae = ...        # your SAE encoder, already on `device`\n",
    "# model = ...           # your model, already on `device`\n",
    "\n",
    "# --- RUNNING STATISTICS STATE ---\n",
    "mu = None           # Tensor[D], on `device`\n",
    "M2 = None           # Tensor[D, D], on `device`\n",
    "n  = 0              # total samples seen\n",
    "\n",
    "@torch.no_grad()  # disable gradients inside hook\n",
    "def update_online_cov(x: torch.Tensor):\n",
    "    \"\"\"\n",
    "    x: Tensor of shape (N_batch, D) already on `device` and float32.\n",
    "    Updates global mu, M2, and n in-place.\n",
    "    \"\"\"\n",
    "    global mu, M2, n\n",
    "\n",
    "    N, D = x.shape\n",
    "    n_new = n + N\n",
    "\n",
    "    if mu is None:\n",
    "        # First batch\n",
    "        mu = x.mean(dim=0)                               # [D]\n",
    "        M2 = (x - mu).t().mm(x - mu)                     # [D, D]\n",
    "        n = N\n",
    "        return\n",
    "\n",
    "    # batch mean\n",
    "    mb = x.mean(dim=0)                                   # [D]\n",
    "    delta = mb - mu                                      # [D]\n",
    "    mu = mu + delta * (N / n_new)                        # [D]\n",
    "\n",
    "    # within-batch scatter\n",
    "    dev = x - mb                                         # [N, D]\n",
    "    M2_batch = dev.t().mm(dev)                           # [D, D]\n",
    "\n",
    "    # cross term\n",
    "    cross = torch.outer(delta, delta) * (n * N / n_new)  # [D, D]\n",
    "\n",
    "    # in-place update\n",
    "    M2 += M2_batch\n",
    "    M2 += cross\n",
    "\n",
    "    n = n_new\n",
    "\n",
    "def resid_post_hook(resid_post, hook):\n",
    "    # resid_post: [B, T, D], float32 on CPU by default?\n",
    "    B, T, D = resid_post.shape\n",
    "    # encode directly on `device`\n",
    "    flat = resid_post.view(B * T, D).float().to(device)\n",
    "    hidden = kronsae_random.encode(flat)                       # already on device\n",
    "    update_online_cov(hidden)                            # fast, no transfers\n",
    "    return resid_post\n",
    "\n",
    "# --- MAIN LOOP ---\n",
    "# model.to(device)\n",
    "# topk_sae.to(device)\n",
    "model.eval()                                            # switch to eval mode\n",
    "\n",
    "# replace streaming=True if you need it\n",
    "#dataset = load_dataset(\"HuggingFaceFW/fineweb-edu\", split=\"train\", name=\"CC-MAIN-2024-51\", streaming=True, trust_remote_code=True)\n",
    "iter_ds = iter(dataset)\n",
    "\n",
    "num_iters = 5000\n",
    "with torch.no_grad():  # full-loop no-grad\n",
    "    for _ in tqdm.tqdm(range(num_iters)):\n",
    "        sample = next(iter_ds)\n",
    "        tokens = model.to_tokens(sample[\"text\"]).to(device)\n",
    "        _ = model.run_with_hooks(\n",
    "            tokens,\n",
    "            fwd_hooks=[(hook_point, resid_post_hook)]\n",
    "        )\n",
    "\n",
    "# --- FINALIZE ---\n",
    "cov = M2 / (n - 1)                                      # unbiased covariance\n",
    "std = torch.sqrt(torch.diag(cov))                      # [D]\n",
    "corrB_rand = cov / (std[:, None] * std[None, :] + 1e-6)      # [D, D]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a8c3c22-661e-4de3-a0e5-cacde6c41e01",
   "metadata": {},
   "outputs": [],
   "source": [
    "#torch.save(corr.to(torch.bfloat16), \"mul_corr.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a711a25-ea8d-4a13-a320-245ad2271ddc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "A = corrA.float().cpu().numpy() #torch.load('mul_corr.pt', map_location='cpu').float().numpy()\n",
    "B = corrB.float().cpu().numpy() #torch.load('topk_corr.pt', map_location='cpu').float().numpy()\n",
    "A_rand = corrA_rand.float().cpu().numpy() #torch.load('mul_corr.pt', map_location='cpu').float().numpy()\n",
    "B_rand = corrB_rand.float().cpu().numpy() #torch.load('topk_corr.pt', map_location='cpu').float().numpy()\n",
    "\n",
    "np.fill_diagonal(A, 0)\n",
    "np.fill_diagonal(B, 0)\n",
    "np.fill_diagonal(A_rand, 0)\n",
    "np.fill_diagonal(B_rand, 0)\n",
    "\n",
    "hist_topk = np.sort(A, axis=1).mean(axis=0)\n",
    "hist_mul = np.sort(B, axis=1).mean(axis=0)\n",
    "hist_topk_rand = np.sort(A_rand, axis=1).mean(axis=0)\n",
    "hist_mul_rand = np.sort(B_rand, axis=1).mean(axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ddbda8fc-9182-4ff8-9539-15e66b1a19e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import ks_2samp, kruskal\n",
    "\n",
    "ks_val, ks_pvalue = ks_2samp(hist_mul, hist_topk)\n",
    "kw_val, kw_pvalue = kruskal(hist_mul, hist_topk)\n",
    "\n",
    "ks_val, ks_pvalue, kw_val, kw_pvalue"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43879c95-bccd-40d1-bdb4-279380b1f005",
   "metadata": {},
   "outputs": [],
   "source": [
    "ks_val, ks_pvalue = ks_2samp(hist_mul_rand, hist_topk_rand)\n",
    "kw_val, kw_pvalue = kruskal(hist_mul_rand, hist_topk_rand)\n",
    "\n",
    "ks_val, ks_pvalue, kw_val, kw_pvalue"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b5797f6-5af8-41f5-93b6-16f2b890f419",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "# Assuming hist_mul and hist_topk are your data arrays\n",
    "fig, ax = plt.subplots(1, 1, figsize=(8, 5), dpi=250)\n",
    "\n",
    "palette = sns.color_palette('tab10')\n",
    "\n",
    "#Plot both positive and negative values together using symlog scale\n",
    "sns.histplot(hist_mul,   log_scale=False, stat=\"density\", bins=256 + 128, \n",
    "             label=\"KronSAE\", ax=ax, element=\"step\", fill=False, color=palette[0])\n",
    "sns.histplot(hist_topk, log_scale=False, stat=\"density\", bins=256 + 128, \n",
    "             label=\"TopK\", ax=ax, element=\"step\", fill=False, color=palette[1])\n",
    "\n",
    "\n",
    "# Set symmetric logarithmic x-axis\n",
    "ax.set_xscale('symlog', linthresh=2e-3)  # Adjust linthresh based on your data\n",
    "\n",
    "# scales = [1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9, 1e-10, 1e-11]\n",
    "# scales = [-s for s in scales] + [0] + list(reversed(scales))\n",
    "\n",
    "# Customize ticks to show negative/positive explicitly\n",
    "# ax.set_xticks(scales)\n",
    "# ax.set_xticklabels([f'{s:.0e}' for s in scales])\n",
    "\n",
    "# Log scale for y-axis to see density differences\n",
    "ax.set_yscale('log')\n",
    "ax.grid(alpha=0.35)\n",
    "# ax.set_title('Positive & Negative Values (Symmetric Log Scale)')\n",
    "ax.legend(loc='upper right')\n",
    "\n",
    "# plt.axvline(np.median(hist_mul), linestyle='--',  color=palette[0])\n",
    "# plt.axvline(np.median(hist_topk), linestyle='--',  color=palette[1])\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(f'sae_acts_corr_distrib_qwen1.5b_final.pdf')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f13b40af-7e1b-49ee-baf4-8f2362e1540b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "# Assuming hist_mul and hist_topk are your data arrays\n",
    "fig, ax = plt.subplots(1, 1, figsize=(8, 6)) #, dpi=250)\n",
    "\n",
    "palette = sns.color_palette('tab10') #[2:]\n",
    "\n",
    "def proc_func(x):\n",
    "    x_abs = np.sort(np.abs(x)) #[::-1]\n",
    "    #x_abs = np.cumsum(x_abs) \n",
    "    x_abs /= x_abs.max()\n",
    "    return x_abs\n",
    "\n",
    "sns.histplot(proc_func(hist_mul_rand),   log_scale=False, stat=\"density\", bins=256 + 128, \n",
    "             label=\"KronSAE (random)\", ax=ax, \n",
    "             element=\"step\", \n",
    "             fill=True, \n",
    "             color=palette[2])\n",
    "sns.histplot(proc_func(hist_topk_rand), log_scale=False, stat=\"density\", bins=256 + 128, \n",
    "             label=\"TopK (random)\", ax=ax, \n",
    "             element=\"step\", \n",
    "             fill=True, \n",
    "             color=palette[3])\n",
    "\n",
    "sns.histplot(proc_func(hist_mul),   log_scale=False, stat=\"density\", bins=256 + 128, \n",
    "             label=\"KronSAE\", ax=ax, \n",
    "             element=\"step\", \n",
    "             fill=True, \n",
    "             color=palette[0])\n",
    "sns.histplot(proc_func(hist_topk), log_scale=False, stat=\"density\", bins=256 + 128, \n",
    "             label=\"TopK\", ax=ax, \n",
    "             element=\"step\", \n",
    "             fill=True,\n",
    "             color=palette[1])\n",
    "#Plot both positive and negative values together using symlog scale\n",
    "\n",
    "\n",
    "\n",
    "# Set symmetric logarithmic x-axis\n",
    "ax.set_xscale('symlog', linthresh=4e-1)  # Adjust linthresh based on your data\n",
    "\n",
    "# scales = [1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9, 1e-10, 1e-11]\n",
    "# scales = [-s for s in scales] + [0] + list(reversed(scales))\n",
    "\n",
    "# Customize ticks to show negative/positive explicitly\n",
    "# ax.set_xticks(scales)\n",
    "# ax.set_xticklabels([f'{s:.0e}' for s in scales])\n",
    "\n",
    "# Log scale for y-axis to see density differences\n",
    "ax.set_yscale('log')\n",
    "ax.grid(alpha=0.35)\n",
    "# ax.set_title('Positive & Negative Values (Symmetric Log Scale)')\n",
    "ax.legend(loc='upper right')\n",
    "\n",
    "plt.axvline(np.mean(proc_func(hist_mul)), linestyle='--',  color=palette[0])\n",
    "plt.axvline(np.mean(proc_func(hist_topk)), linestyle='--',  color=palette[1])\n",
    "plt.axvline(np.mean(proc_func(hist_mul_rand)), linestyle='--',  color=palette[2])\n",
    "plt.axvline(np.mean(proc_func(hist_topk_rand)), linestyle='--',  color=palette[3])\n",
    "\n",
    "plt.tight_layout()\n",
    "#plt.savefig(f'rand_sae_acts_corr_distrib_qwen1.5b.pdf')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6da26a9-5830-4fba-86d7-064fef5b2503",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.stats import ttest_ind\n",
    "\n",
    "# Example usage:\n",
    "def get_corr_distrib(corr_marix):\n",
    "    corr_matrix = torch.tensor(corr_marix)\n",
    "    group_size = 16\n",
    "    max_samples: int = 100000\n",
    "    random_seed: int = 42\n",
    "    \"\"\"\n",
    "    Compare within-group and between-group correlation distributions.\n",
    "    \n",
    "    Parameters:\n",
    "    - corr_matrix: torch.Tensor of shape (N, N), correlation matrix.\n",
    "    - group_size: int, size k of each group (assumed contiguous blocks).\n",
    "    - max_samples: int, maximum number of samples per distribution to plot/compare.\n",
    "    - random_seed: int, for reproducibility of sampling.\n",
    "    \n",
    "    Returns:\n",
    "    - t_stat: float, t-test statistic comparing inner vs. between.\n",
    "    - p_value: float, two-sided p-value of the test.\n",
    "    \"\"\"\n",
    "    torch.manual_seed(random_seed)\n",
    "    N = corr_matrix.size(0)\n",
    "    if N % group_size != 0:\n",
    "        raise ValueError(\"Matrix size must be divisible by group_size\")\n",
    "    num_groups = N // group_size\n",
    "    \n",
    "    inner_vals = []\n",
    "    between_vals = []\n",
    "    \n",
    "    # Collect samples\n",
    "    for g in range(num_groups):\n",
    "        s1 = g * group_size\n",
    "        e1 = s1 + group_size\n",
    "        block = corr_matrix[s1:e1, s1:e1]\n",
    "        triu = torch.triu_indices(group_size, group_size, offset=1)\n",
    "        inner_vals.append(block[triu[0], triu[1]])\n",
    "    \n",
    "        # sample sqrt(num_groups) other groups\n",
    "        others = list(range(num_groups))\n",
    "        others.remove(g)\n",
    "        k_sample = min(len(others), int(np.sqrt(num_groups)))\n",
    "        chosen = np.random.choice(others, size=k_sample, replace=False)\n",
    "        for h in chosen:\n",
    "            s2 = h * group_size\n",
    "            e2 = s2 + group_size\n",
    "            between_vals.append(corr_matrix[s1:e1, s2:e2].flatten())\n",
    "    \n",
    "    inner = torch.cat(inner_vals)\n",
    "    between = torch.cat(between_vals)\n",
    "    \n",
    "    # subsample if needed\n",
    "    def subsample(tensor):\n",
    "        total = tensor.numel()\n",
    "        if total > max_samples:\n",
    "            idx = torch.randperm(total)[:max_samples]\n",
    "            return tensor[idx]\n",
    "        return tensor\n",
    "    \n",
    "    inner = subsample(inner)\n",
    "    between = subsample(between)\n",
    "    \n",
    "    inner_np = inner.cpu().numpy()\n",
    "    between_np = between.cpu().numpy()\n",
    "    \n",
    "    combined = np.concatenate((inner_np, between_np))\n",
    "    min_val, max_val = combined.min(), combined.max()\n",
    "    num_bins = 36\n",
    "    bins = np.linspace(min_val, max_val, num_bins + 1)\n",
    "    return inner_np, between_np, bins\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a3f71b8-1fea-4c06-b62c-8a3dca348dcb",
   "metadata": {},
   "outputs": [],
   "source": [
    "inner_np, between_np, bins = get_corr_distrib(B)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa9c9fa1-c87c-4a5a-945b-81acbfc35bf6",
   "metadata": {},
   "outputs": [],
   "source": [
    "inner_np_rand, between_np_rand, bins_rand = get_corr_distrib(B_rand)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ecadde3-b6fc-4c08-a435-ff94690a603c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot both histograms on one axes\n",
    "plt.figure(figsize=(8, 6), dpi=250)\n",
    "\n",
    "palette = sns.color_palette('tab10')[2:]\n",
    "\n",
    "#Plot both positive and negative values together using symlog scale\n",
    "\n",
    "sns.histplot(inner_np, bins=bins,  kde=False, stat=\"density\", label='Within-Group', alpha=0.6, color=palette[0])\n",
    "sns.histplot(between_np, bins=bins,  kde=False, stat=\"density\", label='Between-Group', alpha=0.6, color=palette[1])\n",
    "\n",
    "# Avoid zero-density dominating by setting bins away from zero and using logarithmic scales\n",
    "#plt.xscale('symlog')\n",
    "plt.yscale('log')\n",
    "plt.xlabel('Correlation Value')\n",
    "plt.ylabel('Density (log scale)')\n",
    "plt.title('Within vs. Between Group Correlation Distributions')\n",
    "plt.legend()\n",
    "plt.tight_layout()\n",
    "plt.savefig(f'sae_acts_head_corr_distrib_qwen1.5b_final.pdf')\n",
    "plt.show()\n",
    "\n",
    "# t-test\n",
    "t_stat, p_value = ttest_ind(inner_np, between_np, equal_var=False)\n",
    "print(f\"T-statistic: {t_stat:.4f}, p-value: {p_value:.4e}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ebde408-1499-4554-a6c5-0544cc4ff188",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot both histograms on one axes\n",
    "plt.figure(figsize=(8, 5), dpi=250)\n",
    "\n",
    "palette = sns.color_palette('tab10')\n",
    "\n",
    "#Plot both positive and negative values together using symlog scale\n",
    "\n",
    "sns.histplot(inner_np, bins=bins,  kde=False, stat=\"density\", label='Within-Group', alpha=0.3, color=palette[0], element=\"step\", linewidth=1.5  )\n",
    "sns.histplot(between_np, bins=bins,  kde=False, stat=\"density\", label='Between-Group', alpha=0.3, color=palette[1], element=\"step\", linewidth=1.5 )\n",
    "sns.histplot(inner_np_rand, bins=bins,  kde=False, stat=\"density\", label='Within-Group (rand)', alpha=0.99, color=palette[0], element=\"step\", fill=False, linestyle='--', linewidth=1.5 )\n",
    "sns.histplot(between_np_rand, bins=bins,  kde=False, stat=\"density\", label='Between-Group (rand)', alpha=0.99, color=palette[1], element=\"step\", fill=False, linestyle='--', linewidth=1.5 )\n",
    "m1 = np.mean(inner_np)\n",
    "m2 = np.mean(between_np)\n",
    "m3 = np.mean(inner_np_rand)\n",
    "m4 = np.mean(between_np_rand)\n",
    "\n",
    "# plt.axvline(m1, linestyle='-',  alpha=0.6,color=palette[0])\n",
    "# plt.axvline(m2, linestyle='-',  alpha=0.6,color=palette[1])\n",
    "# plt.axvline(m3, linestyle='--',  alpha=0.99, color=palette[0])\n",
    "# plt.axvline(m4, linestyle='--',   alpha=0.99, color=palette[1])\n",
    "\n",
    "# Avoid zero-density dominating by setting bins away from zero and using logarithmic scales\n",
    "#plt.xscale('symlog')\n",
    "plt.yscale('log')\n",
    "plt.xlabel('Pearson correlation', fontsize=14)\n",
    "plt.ylabel('Density (log scale)', fontsize=14)\n",
    "#plt.title('Within vs. Between Group Correlation Distributions')\n",
    "plt.legend()\n",
    "plt.tight_layout()\n",
    "plt.savefig(f'rand_sae_acts_head_corr_distrib_qwen1.5b.pdf')\n",
    "plt.show()\n",
    "print(m1, \n",
    "m2 ,\n",
    "m3 ,\n",
    "m4 ,)\n",
    "# t-test\n",
    "t_stat, p_value = ttest_ind(inner_np, between_np, equal_var=False)\n",
    "print(f\"T-statistic: {t_stat:.4f}, p-value: {p_value:.4e}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d76aef7-5103-4aec-8740-dc6870f9b7a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "m1 = np.median(inner_np)\n",
    "m2 = np.median(between_np)\n",
    "m3 = np.median(inner_np_rand)\n",
    "m4 = np.median(between_np_rand)\n",
    "m1,m2,m3,m4"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
