{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import random\n",
    "import math\n",
    "import torch\n",
    "import einops\n",
    "\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from torch.utils.data import IterableDataset, DataLoader\n",
    "from typing import Optional\n",
    "import pyrallis\n",
    "from pprint import pprint\n",
    "\n",
    "from tqdm.notebook import tqdm\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from utils import *\n",
    "from config import *\n",
    "from toy_models import *\n",
    "from data import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ImporanceWeightedMSE(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(ImporanceWeightedMSE, self).__init__()\n",
    "\n",
    "    def forward(self, predictions, targets, importance):\n",
    "        sub_total = ((predictions - targets) ** 2).sum(0).flatten()\n",
    "        return sum(sub_total * importance)\n",
    "\n",
    "def explained_variance_score(y_target, y_pred):\n",
    "    per_token_l2_loss_A = (y_pred.float() - y_target.float()).pow(2).sum(-1).squeeze()\n",
    "    total_variance_A = (y_target.float() - y_target.float().mean(0)).pow(2).sum(-1).squeeze()\n",
    "    explained_variance = (1 - per_token_l2_loss_A / total_variance_A).mean()\n",
    "    return explained_variance\n",
    "\n",
    "\n",
    "def train_autoencoder(\n",
    "    ae: SymAE, cfg: SynthConfig, device: str = \"cuda\", importance_score: int = None, dataset = None\n",
    ") -> dict:\n",
    "    \"\"\"\n",
    "    Trains the autoencoder for cfg.ae_train_steps steps on the fly.\n",
    "    Prints training MSE, no validation.\n",
    "    \"\"\"\n",
    "    ae.to(device)\n",
    "    optimizer = torch.optim.Adam(ae.parameters(), lr=cfg.ae_lr)\n",
    "    if importance_score is not None:\n",
    "        mse_loss = ImporanceWeightedMSE()\n",
    "    else:\n",
    "        mse_loss = nn.MSELoss()\n",
    "\n",
    "    metrics_list = []\n",
    "\n",
    "    dataloader = DataLoader(dataset, batch_size=None)\n",
    "    dataloader_iter = iter(dataloader)\n",
    "\n",
    "    pbar = tqdm(range(cfg.ae_train_steps))\n",
    "\n",
    "    for step in pbar:\n",
    "        x = next(dataloader_iter).to(device)\n",
    "        optimizer.zero_grad()\n",
    "        x_recon = ae(x)\n",
    "        if importance_score is not None:\n",
    "            importance = (importance_score ** torch.arange(0, x.shape[1])).to(device)\n",
    "            loss = mse_loss(x_recon, x, 1.0)\n",
    "        else:\n",
    "            loss = mse_loss(x_recon, x)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            metrics = {\n",
    "                \"loss\" : loss.item(),\n",
    "                \"ev\": explained_variance_score(x, x_recon).item(),\n",
    "                \"step\": step\n",
    "            }   \n",
    "    \n",
    "            metrics_list.append(metrics)\n",
    "\n",
    "        pbar.set_description(f\"[AE] metrics: {metrics}\")\n",
    "\n",
    "    if cfg.save_checkpoint:\n",
    "        os.makedirs(cfg.ckpt_dir, exist_ok=True)\n",
    "        torch.save(ae.state_dict(), os.path.join(cfg.ckpt_dir, \"sym_ae.pt\"))\n",
    "\n",
    "    return metrics_list\n",
    "\n",
    "@torch.no_grad()\n",
    "def make_decoder_weights_and_grad_unit_norm_(W_dec, dim=-1):\n",
    "    W_dec_normed = W_dec / W_dec.norm(dim=dim, keepdim=True)\n",
    "    W_dec_grad_proj = (W_dec.grad * W_dec_normed).sum(\n",
    "        dim, keepdim=True\n",
    "    ) * W_dec_normed\n",
    "    W_dec.grad -= W_dec_grad_proj\n",
    "    W_dec.data = W_dec_normed\n",
    "\n",
    "\n",
    "def train_topk_sae(\n",
    "    sae: nn.Module,\n",
    "    ae: SymAE,\n",
    "    cfg: SynthConfig,\n",
    "    dataloader: DataLoader,\n",
    "    device: str = \"cuda\",\n",
    "):\n",
    "    \"\"\"\n",
    "    Trains the top-k SAE on the latent codes from the SymAE,\n",
    "    including an auxiliary penalty to revive dead latents.\n",
    "    \"\"\"\n",
    "    sae.to(device)\n",
    "    ae.to(device)\n",
    "    sae.eval()  # freeze the symmetrical AE\n",
    "\n",
    "    # initialize dead-feature tracking\n",
    "    num_batches_not_active = torch.zeros(sae.dict_size, device=device)\n",
    "    optimizer = torch.optim.Adam(sae.parameters(), lr=cfg.fe_lr)\n",
    "    mse_loss = nn.MSELoss()\n",
    "\n",
    "    dataloader_iter = iter(dataloader)\n",
    "    pbar = tqdm(range(cfg.fe_train_steps))\n",
    "    metrics_list = []\n",
    "\n",
    "    for step in pbar:\n",
    "        x = next(dataloader_iter).to(device)\n",
    "        with torch.no_grad():\n",
    "            h = ae.encode(x)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        acts = sae.encode(h)\n",
    "        h_pred = sae.decode(acts)\n",
    "\n",
    "        # primary reconstruction loss\n",
    "        recon_loss = mse_loss(h_pred, h)\n",
    "\n",
    "        # identify dead features\n",
    "        dead = num_batches_not_active >= cfg.n_batches_to_dead\n",
    "        aux_loss = torch.tensor(0.0, device=device)\n",
    "        if dead.any():\n",
    "            # compute residual between h and its reconstruction\n",
    "            residual = (h - h_pred).detach()\n",
    "            # select top-k2_aux activations among dead features\n",
    "            dead_acts = acts[:, dead]\n",
    "            k_aux = min(cfg.topk2_aux, dead.sum().item())\n",
    "            top_aux = torch.topk(dead_acts, k_aux, dim=-1)\n",
    "            mask = torch.zeros_like(dead_acts).scatter(-1, top_aux.indices, top_aux.values)\n",
    "            # reconstruct only from those dead features\n",
    "            Wd = sae.W_dec[dead]\n",
    "            x_aux_pred = mask @ Wd\n",
    "            # L2 penalty on residual match\n",
    "            aux_loss = cfg.aux_penalty * F.mse_loss(x_aux_pred, residual)\n",
    "\n",
    "        # total loss\n",
    "        loss = recon_loss + aux_loss\n",
    "        loss.backward()\n",
    "\n",
    "        # clip grads and normalize decoder weight directions\n",
    "        torch.nn.utils.clip_grad_norm_(sae.parameters(), max_norm=1.)\n",
    "        make_decoder_weights_and_grad_unit_norm_(sae.W_dec)\n",
    "        optimizer.step()\n",
    "\n",
    "        # update dead-feature counters\n",
    "        active_now = (acts.sum(0) > 0).float()\n",
    "        num_batches_not_active = torch.where(active_now.bool(),\n",
    "                                          torch.zeros_like(num_batches_not_active),\n",
    "                                          num_batches_not_active + 1)\n",
    "\n",
    "        # log metrics\n",
    "        with torch.no_grad():\n",
    "            ev = explained_variance_score(h, h_pred).item()\n",
    "            dead_count = int((num_batches_not_active >= cfg.n_batches_to_dead).sum())\n",
    "            metrics = {\n",
    "                \"loss\": loss.item(),\n",
    "                \"recon_loss\": recon_loss.item(),\n",
    "                \"aux_loss\": aux_loss.item(),\n",
    "                \"ev\": ev,\n",
    "                \"dead_features\": dead_count,\n",
    "                \"step\": step,\n",
    "            }\n",
    "            metrics_list.append(metrics)\n",
    "        pbar.set_description(f\"[SAE] metrics: {metrics}\")\n",
    "\n",
    "    if cfg.save_checkpoint:\n",
    "        os.makedirs(cfg.ckpt_dir, exist_ok=True)\n",
    "        torch.save(sae.state_dict(), os.path.join(cfg.ckpt_dir, \"topk_sae.pt\"))\n",
    "\n",
    "    return metrics_list\n",
    "\n",
    "\n",
    "def eval_sae(sae: nn.Module, ae: SymAE, eval_dataloader: DataLoader, eval_steps: int, device: str = \"cuda\"):\n",
    "    ae.to(device)\n",
    "    ae.eval() # freeze the SAE\n",
    "    sae.to(device)\n",
    "    sae.eval()  # freeze the symmetrical AE\n",
    "    mse_loss = nn.MSELoss()\n",
    "\n",
    "    dataloader_iter = iter(eval_dataloader)\n",
    "\n",
    "    pbar = tqdm(range(eval_steps))\n",
    "    total_loss = 0\n",
    "    total_ev = 0\n",
    "    statistics = {\n",
    "        \"acts_std\": 0., \n",
    "        \"non_zero_acts_mean\": 0.,\n",
    "        \"acts_distribution\" : torch.zeros(sae.dict_size),\n",
    "        \"acts_value_distribution\" : 0.\n",
    "    }\n",
    "\n",
    "    for step in pbar:\n",
    "        x = next(dataloader_iter).to(device)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            h = ae.encode(x)\n",
    "            acts_topk = sae.encode(h)\n",
    "            h_pred = sae.decode(acts_topk)\n",
    "            loss = mse_loss(h_pred, h)\n",
    "            total_ev += explained_variance_score(h, h_pred).item()\n",
    "\n",
    "            non_zero_acts = acts_topk[acts_topk > 1e-8]\n",
    "            statistics[\"acts_std\"] += non_zero_acts.std(-1).mean()\n",
    "            statistics[\"non_zero_acts_mean\"] += non_zero_acts.mean()\n",
    "            statistics[\"acts_distribution\"] += (acts_topk.cpu() > 1e-8).float().sum(0)\n",
    "    \n",
    "        total_loss += loss\n",
    "\n",
    "        pbar.set_description(f\"[SAE] loss={loss.item():.6f}\")\n",
    "    total_loss /= eval_steps\n",
    "    total_ev /= eval_steps\n",
    "\n",
    "    statistics[\"acts_std\"] = statistics[\"acts_std\"] / eval_steps\n",
    "    statistics[\"non_zero_acts_mean\"] = statistics[\"non_zero_acts_mean\"] / eval_steps\n",
    "\n",
    "    return total_loss, total_ev, statistics\n",
    "\n",
    "\n",
    "def eval_ae(ae: SymAE, eval_dataloader: DataLoader, eval_steps: int, device: str = \"cuda\"):\n",
    "    ae.to(device)\n",
    "    ae.eval()\n",
    "    mse_loss = nn.MSELoss()\n",
    "\n",
    "    dataloader_iter = iter(eval_dataloader)\n",
    "\n",
    "    pbar = tqdm(range(eval_steps))\n",
    "    total_loss = 0\n",
    "    total_ev = 0\n",
    "\n",
    "    for step in pbar:\n",
    "        x = next(dataloader_iter).to(device)\n",
    "        with torch.no_grad():\n",
    "            x_recon = ae(x)\n",
    "            loss = mse_loss(x_recon, x)\n",
    "            total_ev += explained_variance_score(x, x_recon).item()\n",
    "\n",
    "        total_loss += loss\n",
    "        pbar.set_description(f\"[AE] loss={loss.item():.6f}\")\n",
    "\n",
    "    total_loss /= eval_steps\n",
    "    total_ev /= eval_steps\n",
    "    return total_loss, total_ev"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = SynthConfig()\n",
    "cfg.ae_train_steps = 20000\n",
    "cfg.fe_train_steps = 1000\n",
    "cfg.S = 8\n",
    "cfg.corr_inter = 0.0\n",
    "#cfg.corr_intra = 0.8\n",
    "cfg.corr_intra = (0.85, 0.95)\n",
    "cfg.topk2_aux = 32\n",
    "cfg.aux_penalty = 1/16\n",
    "cfg.sae_train_batch_size = 256 #or 32\n",
    "cfg.heavy_tail_scale = 1.\n",
    "set_all_seeds(cfg.seed)  # reproducibility\n",
    "device = \"mps\"  # forced to \"cuda\"\n",
    "print(\"==== Starting Training ====\")\n",
    "print(\"** Full Config **\")\n",
    "pprint(asdict(cfg))\n",
    "\n",
    "#sizes = [32 for _ in range(8)]\n",
    "sizes = [32, 16, 48, 32, 32, 32, 48, 16]\n",
    "\n",
    "mask = np.array([\n",
    "    [True,  False,  False, False, False, False, True,  True],\n",
    "    [False, True,   False, False, False, False, True,  True],\n",
    "    [False, False,  True,  False, False, False, False, False],\n",
    "    [False, False,  False, True , False, False, False, False],\n",
    "    [False, False,  False, False, True,  False, False, False],\n",
    "    [False, False,  False, False, False, True,  False, False],\n",
    "    [True,  True,   False, False, False, False, True,  False],\n",
    "    [True,  True,   False, False, False, False, False, True],\n",
    "], dtype=bool)\n",
    "\n",
    "# 1) Create SymAE\n",
    "ae = SymAE(input_dim=cfg.N, latent_dim=cfg.latent_dim, is_relu=True)\n",
    "print(\"\\n** Symmetrical Autoencoder (Phase 1) **\")\n",
    "print(ae)\n",
    "total_params_ae = sum(p.numel() for p in ae.parameters())\n",
    "print(f\"SymAE has {total_params_ae:,} parameters.\")\n",
    "\n",
    "# 2) Train SymAE\n",
    "dataset_ae = OnTheFlySynthDataset(\n",
    "    cfg,\n",
    "    batch_size = cfg.batch_size,\n",
    "    mask=mask,\n",
    "    sizes=sizes, \n",
    "    eps=1e-1\n",
    "    )\n",
    "ae_metrics = train_autoencoder(ae, cfg, device=device, dataset=dataset_ae) ##### common data\n",
    "\n",
    "# 3) Create SAE\n",
    "topk_sae = TopKSAE(input_dim=cfg.latent_dim, dict_size=cfg.hidden_dim, topk=cfg.topk)\n",
    "print(\"\\n** Top-K SAE (Phase 2) **\")\n",
    "print(topk_sae)\n",
    "total_params_topk_sae = sum(p.numel() for p in topk_sae.parameters())\n",
    "print(f\"TopKSAE has {total_params_topk_sae:,} parameters.\")\n",
    " \n",
    "kronsae = kronsae(input_dim=cfg.latent_dim, dict_size=cfg.hidden_dim, topk2=cfg.topk, num_heads=4, m_keys = 2, n_keys = 32)\n",
    "print(\"\\n** kronsae (Phase 2) **\")\n",
    "print(kronsae)\n",
    "total_params_kronsae = sum(p.numel() for p in kronsae.parameters())\n",
    "print(f\"kronsae has {total_params_kronsae:,} parameters.\")\n",
    "\n",
    "######\n",
    "# 4) Train SAE\n",
    "dataset = OnTheFlySynthDataset(\n",
    "    cfg, \n",
    "    on_the_fly = False, \n",
    "    size=cfg.fe_train_steps,\n",
    "    batch_size = cfg.sae_train_batch_size,\n",
    "    mask=mask,\n",
    "    sizes=sizes, eps=1e-1\n",
    "    )\n",
    "dataloader = DataLoader(dataset, batch_size=None)\n",
    "\n",
    "topk_sae_metrics = train_topk_sae(topk_sae, ae, cfg, dataloader, device=device)\n",
    "kronsae_metrics = train_topk_sae(kronsae, ae, cfg, dataloader, device=device)\n",
    "######\n",
    "\n",
    "print(\"Done!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "setups = {\n",
    "    \"kronsae_h2_m4_n32\": {\"h\":2, \"m\":4, \"n\":32},\n",
    "    \"kronsae_h4_m2_n32\": {\"h\":4, \"m\":2, \"n\":32},\n",
    "    # \"kronsae_h4_m4_n16\": {\"h\":4, \"m\":4, \"n\":16},\n",
    "    # \"kronsae_h4_m8_n8\": {\"h\":4, \"m\":8, \"n\":8},\n",
    "    # \"kronsae_h8_m2_n16\": {\"h\":8, \"m\":2, \"n\":16},\n",
    "}\n",
    "\n",
    "\n",
    "mask_list = [\n",
    "    {\n",
    "        \"sizes\": [32, 48, 16, 32, 32, 48, 16, 32 ],\n",
    "        \"mask\": np.array([\n",
    "            [True,  False,  False, False, False, False, True,  True],\n",
    "            [False, True,   False, False, False, False, True,  True],\n",
    "            [False, False,  False,  False, False, False, False, False],\n",
    "            [False, False,  False, False , False, False, False, False],\n",
    "            [False, False,  False, False, False,  False, False, False],\n",
    "            [False, False,  False, False, False, False,  False, False],\n",
    "            [True,  True,   False, False, False, False, True,  False],\n",
    "            [True,  True,   False, False, False, False, False, True],\n",
    "        ], dtype=bool)\n",
    "    },\n",
    "    {\n",
    "        \"sizes\": [32, 16, 48, 32, 32, 32, 48, 16],\n",
    "        \"mask\": np.array([\n",
    "            [True,  True,  False, False, False, False, False,  False],\n",
    "            [True, True,   True, False, False, False, False,  False],\n",
    "            [False, True,  True,  True, False, False, False, False],\n",
    "            [False, False,  True, True , True, False, False, False],\n",
    "            [False, False,  False, True, True,  True, False, False],\n",
    "            [False, False,  False, False, True, True,  True, False],\n",
    "            [False,  False,   False, False, False, True, True,  True],\n",
    "            [False,  False,   False, False, False, False, True, True],\n",
    "        ], dtype=bool)\n",
    "    },\n",
    "    {\n",
    "        \"sizes\": [32, 32, 32, 32, 32, 32, 32, 32],\n",
    "        \"mask\": np.array([\n",
    "            [False,  False, False, False, False,False, False, False],\n",
    "            [False, False,  False, False, False,False, False, False],\n",
    "            [False, False,  True,  True,  True, True,  False, False],\n",
    "            [False, False,  True,  True , True, True,  False, False],\n",
    "            [False, False,  True,  True,  True, True,  False, False],\n",
    "            [False, False,  True,  True,  True, True,  False, False],\n",
    "            [False,  False, False, False, False,False, False, False],\n",
    "            [False,  False, False, False, False,False, False, False],\n",
    "        ], dtype=bool)\n",
    "    },\n",
    "    {\n",
    "        \"sizes\": [32, 32, 32, 32, 32, 32, 32, 32],\n",
    "        \"mask\": np.array([\n",
    "            [True,  False,  False, False, False, False, False, True],\n",
    "            [False, True,   False, False, False, False, True, False],\n",
    "            [False, False,  True,  False, False, True, False, False],\n",
    "            [False, False,  False, True , True, False, False, False],\n",
    "            [False, False,  False, True, True,  False, False, False],\n",
    "            [False, False,  True, False, False, True,  False, False],\n",
    "            [False, True,  False, False, False, False, True,  False],\n",
    "            [True, False,  False, False, False, False, False, True],\n",
    "        ], dtype=bool)\n",
    "    },\n",
    "    {\n",
    "        \"sizes\": [32, 32, 32, 32, 32, 32, 32, 32],\n",
    "        \"mask\": np.array([\n",
    "            [False,  False,  False, False, False, False, False, True],\n",
    "            [False, False,   False, False, False, False, True, False],\n",
    "            [False, False,  False,  False, False, True, False, False],\n",
    "            [False, False,  False, False , True, False, False, False],\n",
    "            [False, False,  False, True, False,  False, False, False],\n",
    "            [False, False,  True, False, False, False,  False, False],\n",
    "            [False, True,  False, False, False, False, False,  False],\n",
    "            [True, False,  False, False, False, False, False, False],\n",
    "        ], dtype=bool)\n",
    "    }\n",
    "]\n",
    "\n",
    "\n",
    "def get_trained_saes(setup, mask_setups, cfg, dataloader, ae, device):\n",
    "    trained_saes = []\n",
    "    for mask_setup in mask_setups:\n",
    "        sizes = mask_setup[\"sizes\"]\n",
    "        mask = mask_setup[\"mask\"]\n",
    "        cfg = SynthConfig()\n",
    "        cfg.ae_train_steps = 20000\n",
    "        cfg.fe_train_steps = 1000\n",
    "        cfg.S = 8\n",
    "        cfg.corr_inter = 0.0\n",
    "        #cfg.corr_intra = 0.8\n",
    "        cfg.corr_intra = (0.9, 0.9)\n",
    "        cfg.topk2_aux = 32\n",
    "        cfg.aux_penalty = 1/16\n",
    "        cfg.sae_train_batch_size = 256 #or 32\n",
    "        cfg.heavy_tail_scale = 1.\n",
    "        set_all_seeds(cfg.seed)  # reproducibility\n",
    "        device = \"mps\"  # forced to \"cuda\"\n",
    "\n",
    "        # 1) Create SymAE\n",
    "        ae = SymAE(input_dim=cfg.N, latent_dim=cfg.latent_dim, is_relu=True)\n",
    "\n",
    "        # 2) Train SymAE\n",
    "        dataset_ae = OnTheFlySynthDataset(\n",
    "            cfg,\n",
    "            batch_size = cfg.batch_size,\n",
    "            mask=mask,\n",
    "            sizes=sizes, \n",
    "            eps=1e-1\n",
    "            )\n",
    "        ae_metrics = train_autoencoder(ae, cfg, device=device, dataset=dataset_ae) ##### common data\n",
    "\n",
    "        # 3) Create SAE\n",
    "        topk_sae = TopKSAE(input_dim=cfg.latent_dim, dict_size=cfg.hidden_dim, topk=cfg.topk)\n",
    "        \n",
    "        kronsae1 = kronsae(input_dim=cfg.latent_dim, dict_size=cfg.hidden_dim, topk2=cfg.topk, num_heads=setups[\"kronsae_h2_m4_n32\"][\"h\"], m_keys=setups[\"kronsae_h2_m4_n32\"][\"m\"], n_keys=setups[\"kronsae_h2_m4_n32\"][\"n\"])\n",
    "        \n",
    "        kronsae2 = kronsae(input_dim=cfg.latent_dim, dict_size=cfg.hidden_dim, topk2=cfg.topk, num_heads=setups[\"kronsae_h4_m2_n32\"][\"h\"], m_keys=setups[\"kronsae_h4_m2_n32\"][\"m\"], n_keys=setups[\"kronsae_h4_m2_n32\"][\"n\"])\n",
    "\n",
    "        ######\n",
    "        # 4) Train SAE\n",
    "        dataset = OnTheFlySynthDataset(\n",
    "            cfg, \n",
    "            on_the_fly = False, \n",
    "            size=cfg.fe_train_steps,\n",
    "            batch_size = cfg.sae_train_batch_size,\n",
    "            mask=mask,\n",
    "            sizes=sizes, eps=1e-1\n",
    "            )\n",
    "        dataloader = DataLoader(dataset, batch_size=None)\n",
    "\n",
    "        topk_sae_metrics = train_topk_sae(topk_sae, ae, cfg, dataloader, device=device)\n",
    "        kronsae1_metrics = train_topk_sae(kronsae1, ae, cfg, dataloader, device=device)\n",
    "        kronsae2_metrics = train_topk_sae(kronsae2, ae, cfg, dataloader, device=device)\n",
    "\n",
    "        trained_saes.append(\n",
    "            {\n",
    "                \"dataset\": dataset,\n",
    "                \"ae\": ae,\n",
    "                \"topk\": topk_sae,\n",
    "                \"kronsae_h2_m4_n32\": kronsae1,\n",
    "                \"kronsae_h4_m2_n32\": kronsae2,\n",
    "            }\n",
    "        )\n",
    "    return trained_saes    \n",
    "\n",
    "trained_saes = get_trained_saes(setups, mask_list,cfg, dataloader, ae, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def direction_similarity(W_dec: torch.Tensor):\n",
    "    W_norm = F.normalize(W_dec, p=2, dim=1)\n",
    "    return W_norm @ W_norm.T\n",
    "\n",
    "def pearson_correlation(W_dec: torch.Tensor):\n",
    "    W_norm = F.layer_norm(W_dec, normalized_shape=(W_dec.size(1),), elementwise_affine=False)\n",
    "    return W_norm @ W_norm.T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from scipy.optimize import linear_sum_assignment\n",
    "\n",
    "def find_one_to_one_mapping(set_A: torch.Tensor, set_B: torch.Tensor) -> torch.Tensor:\n",
    "    cost_matrix = torch.cdist(set_A, set_B).detach().cpu().numpy()  # (n, m)\n",
    "    \n",
    "    row_ind, col_ind = linear_sum_assignment(cost_matrix)\n",
    "    \n",
    "    mapping = torch.full((set_A.shape[0],), -1, dtype=torch.long, device=set_A.device)\n",
    "    mapping[row_ind] = torch.tensor(col_ind, dtype=torch.long, device=set_A.device)\n",
    "    \n",
    "    return mapping"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rv_coefficient(C1, C2):\n",
    "    num = np.trace(C1 @ C2)\n",
    "    den = np.sqrt(np.trace(C1 @ C1) * np.trace(C2 @ C2))\n",
    "    return num / den\n",
    "\n",
    "def rv_permutation_pvalue(C1, C2, num_permutations=1000, seed=None):\n",
    "    \"\"\"\n",
    "    Permutation test for RV coefficient significance.\n",
    "    Returns observed RV and p-value.\n",
    "    \"\"\"\n",
    "    rng = np.random.default_rng(seed)\n",
    "    d = C1.shape[0]\n",
    "    RV_obs = rv_coefficient(C1, C2)\n",
    "    count = 0\n",
    "\n",
    "    for _ in range(num_permutations):\n",
    "        perm = rng.permutation(d)\n",
    "        C2_perm = C2[perm][:, perm]        # permute both rows and cols\n",
    "        RV_perm = rv_coefficient(C1, C2_perm)\n",
    "        if RV_perm >= RV_obs:\n",
    "            count += 1\n",
    "\n",
    "    p_value = (count + 1) / (num_permutations + 1)\n",
    "    return RV_obs, p_value\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "title_kwargs = dict(fontsize=10, fontweight=\"bold\", pad=6)\n",
    "cmap       = 'magma'\n",
    "panel_labels = ['(a)', '(b)', '(b)', '(c)', \"(d)\"]\n",
    "panel_labels = ['', '', '', '', \"\"]\n",
    "\n",
    "take_first_n = 4\n",
    "fig, axs = plt.subplots(take_first_n, 5, figsize=(15, 3 * take_first_n), dpi=250)\n",
    "for j, set_to_vis in enumerate(trained_saes[:take_first_n]):\n",
    "    ae_selected = set_to_vis[\"ae\"]\n",
    "    topk_selected = set_to_vis[\"topk\"]\n",
    "    dataset_selected = set_to_vis[\"dataset\"]\n",
    "    kronsae_h2_m4_n32 = set_to_vis[\"kronsae_h2_m4_n32\"]\n",
    "    kronsae_h4_m2_n32 = set_to_vis[\"kronsae_h4_m2_n32\"]\n",
    "    idxs = find_one_to_one_mapping(ae_selected.W_enc, topk_selected.W_dec)\n",
    "    data_list = [\n",
    "        (dataset_selected.L @ dataset_selected.L.T, 'Covariance Matrix', 'Covariance value'),\n",
    "        (direction_similarity(topk_selected.W_dec).cpu().detach().numpy(), 'TopkSAE', 'Similarity'),\n",
    "        (direction_similarity(topk_selected.W_dec[idxs]).cpu().detach().numpy(), 'Matсhed $W_{dec}$ TopkSAE', 'Similarity'),\n",
    "        (direction_similarity(kronsae_h2_m4_n32.W_dec).cpu().detach().numpy(), 'KronSAE $(h=2,\\\\ m=4,\\\\ n=32)$', 'Similarity'),\n",
    "        (direction_similarity(kronsae_h4_m2_n32.W_dec).cpu().detach().numpy(), 'KronSAE $(h=4,\\\\ m=2,\\\\ n=32)$', 'Similarity'),\n",
    "    ]\n",
    "\n",
    "    for ax, (data, title, cbar_lbl), lbl in zip(axs[j], data_list, panel_labels):\n",
    "        #im = ax.imshow(data, cmap=cmap)\n",
    "        im = ax.imshow(\n",
    "            data.cpu().detach().numpy() if hasattr(data, \"cpu\") else data,\n",
    "            cmap=cmap,\n",
    "            interpolation=\"nearest\",\n",
    "            aspect=\"equal\"\n",
    "        )\n",
    "        ax.set_aspect('equal')\n",
    "        ax.axis('off')\n",
    "        if j == 0:\n",
    "            ax.set_title(title, **title_kwargs)\n",
    "\n",
    "    RV_obs, p_val = rv_permutation_pvalue(direction_similarity(topk_selected.W_dec).cpu().detach().numpy(), (dataset_selected.L @ dataset_selected.L.T).cpu().numpy(), num_permutations=5000, seed=42)\n",
    "    print(f\"RV topksae = {RV_obs:.3f}, p‐value ≈ {p_val:.3e}\")\n",
    "    RV_obs, p_val = rv_permutation_pvalue(direction_similarity(topk_selected.W_dec[idxs]).cpu().detach().numpy(), (dataset_selected.L @ dataset_selected.L.T).cpu().numpy(), num_permutations=5000, seed=42)\n",
    "    print(f\"RV topksae matched = {RV_obs:.3f}, p‐value ≈ {p_val:.3e}\")\n",
    "    RV_obs, p_val = rv_permutation_pvalue(direction_similarity(kronsae_h2_m4_n32.W_dec).cpu().detach().numpy(), (dataset_selected.L @ dataset_selected.L.T).cpu().numpy(), num_permutations=5000, seed=42)\n",
    "    print(f\"RV kronsae_h2_m4_n32 = {RV_obs:.3f}, p‐value ≈ {p_val:.3e}\")\n",
    "    RV_obs, p_val = rv_permutation_pvalue(direction_similarity(kronsae_h4_m2_n32.W_dec).cpu().detach().numpy(), (dataset_selected.L @ dataset_selected.L.T).cpu().numpy(), num_permutations=5000, seed=42)\n",
    "    print(f\"RV kronsae_h4_m2_n32 = {RV_obs:.3f}, p‐value ≈ {p_val:.3e}\")\n",
    "    print(\"----\")\n",
    "\n",
    "cbar_ax = fig.add_axes([0.94, 0.15, 0.02, 0.7])\n",
    "cbar = fig.colorbar(im, cax=cbar_ax)\n",
    "cbar.set_label('Value', fontsize=12)\n",
    "\n",
    "fig.savefig('figure_synth_all_appendix.pdf', format='pdf', bbox_inches='tight')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import *\n",
    "\n",
    "import matplotlib.pyplot as plt \n",
    "from matplotlib.colors import LinearSegmentedColormap\n",
    "from matplotlib.colors import TwoSlopeNorm\n",
    "\n",
    "def graph_biases(bias, ax_obj):\n",
    "  b = bias.clone().detach().cpu()\n",
    "  colors = [(.4, 0, 1), (1, 1, 1), (1, .4, 0)]  # Purple -> White -> Orange\n",
    "  n_bins = 100 \n",
    "  cm = LinearSegmentedColormap.from_list(\"\", colors, N=n_bins)\n",
    "  \n",
    "  norm = TwoSlopeNorm(vmin=-1, vcenter=0, vmax=1)\n",
    "  \n",
    "  ax_obj.imshow(b, cmap=cm, norm=norm)\n",
    "\n",
    "  ax_obj.set_xticks([])\n",
    "  ax_obj.set_yticks([])\n",
    "\n",
    "def graph_weights(weights, bias):\n",
    "    fig, axs = plt.subplots(1, 2, figsize=(7, 3.5)) # 1 row, 2 columns\n",
    "    \n",
    "    w = weights.clone().cpu().detach()\n",
    "    to_graph = w.T @ w\n",
    "    colors = [(.4, 0, 1), (1, 1, 1), (1, .4, 0)]  # Purple -> White -> Orange\n",
    "    n_bins = 100 \n",
    "    cm = LinearSegmentedColormap.from_list(\"\", colors, N=n_bins)\n",
    "    \n",
    "    norm = TwoSlopeNorm(vmin=-1, vcenter=0, vmax=1)\n",
    "    \n",
    "    axs[0].imshow(to_graph, cmap=cm, norm=norm)\n",
    "    axs[0].set_xticks([])\n",
    "    axs[0].set_yticks([])\n",
    "\n",
    "    graph_biases(bias, axs[1])\n",
    "    plt.subplots_adjust(left=0.0, right=1.4)\n",
    "    plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def direction_similarity(W_dec):\n",
    "    # Normalize each row (feature vector) to unit norm\n",
    "    W_norm = F.normalize(W_dec, p=2, dim=1)\n",
    "    # Compute cosine similarity (i.e., direction similarity)\n",
    "    return (W_norm @ W_norm.T).cpu().detach().numpy()\n",
    "\n",
    "def pearson_correlation(W_dec):\n",
    "    # Center and scale each row to zero mean and unit variance using layer_norm\n",
    "    W_norm = F.layer_norm(W_dec, normalized_shape=(W_dec.size(1),))\n",
    "    # Compute Pearson correlation (equivalent to cosine similarity of normalized rows)\n",
    "    return (W_norm @ W_norm.T).cpu().detach().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
    "\n",
    "# --- 1. Prepare your data ---\n",
    "cov = dataset.L @ dataset.L.T\n",
    "sim1 = direction_similarity(topk_sae.W_dec).detach().cpu()\n",
    "sim2 = direction_similarity(topk_sae.W_dec[idxs]).detach().cpu()\n",
    "sim3 = direction_similarity(kronsae.W_dec).detach().cpu()\n",
    "\n",
    "\n",
    "# Compute global vmin and vmax\n",
    "all_data = np.concatenate([cov.ravel(), sim1.ravel(), sim2.ravel()])\n",
    "vmin, vmax = all_data.min(), all_data.max()\n",
    "\n",
    "# --- 2. Set up figure ---\n",
    "num_plots = 4\n",
    "fig, axs = plt.subplots(2, 2, figsize=(8, 8), dpi=250,\n",
    "                        gridspec_kw={'wspace': 0.05, 'right': 0.85})\n",
    "\n",
    "cmap = 'magma'\n",
    "title_kwargs = dict(fontsize=14, fontweight='bold')\n",
    "labels = ['(a)', '(b)', '(c)', '(d)']\n",
    "labels = ['', '', '', '']\n",
    "titles = [\n",
    "    r'$Covariance Matrix$',\n",
    "    r'$TopkSAE$',\n",
    "    r'$Matched$ $TopkSAE$',\n",
    "    r'$KronSAE\\ (h=4,\\ m=2,\\ n=32)$',\n",
    "]\n",
    "data_list = [cov, sim1, sim2, sim3]\n",
    "\n",
    "# Plot each with same vmin/vmax\n",
    "for ax_num, data, title, lbl in zip([i for i in range(num_plots)], data_list, titles, labels):\n",
    "    ax = axs[ax_num // 2, ax_num % 2]\n",
    "    #ax = axs[ax_num]\n",
    "    im = ax.imshow(data, cmap=cmap, vmin=vmin, vmax=vmax)\n",
    "    im = ax.imshow(\n",
    "            data.cpu().detach().numpy() if hasattr(data, \"cpu\") else data,\n",
    "            cmap=cmap,\n",
    "            interpolation=\"nearest\",\n",
    "            aspect=\"equal\"\n",
    "        )\n",
    "    ax.set_aspect('equal')\n",
    "    ax.axis('off')\n",
    "    ax.set_title(title, **title_kwargs)\n",
    "    ax.text(-0.05, 1.05, lbl, transform=ax.transAxes,\n",
    "            fontsize=12, fontweight='bold', va='top')\n",
    "\n",
    "# --- 3. Add a shared colorbar ---\n",
    "# Create an axes for the colorbar on the right\n",
    "cbar_ax = fig.add_axes([0.88, 0.15, 0.02, 0.7])  # [left, bottom, width, height]\n",
    "cbar = fig.colorbar(im, cax=cbar_ax)\n",
    "cbar.set_label('Pearson $r$', fontsize=12)\n",
    "\n",
    "# --- 4. Final touches and save ---\n",
    "fig.tight_layout(pad=2.0)\n",
    "fig.savefig('figure_synth_covariance_2_2.pdf', format='pdf', bbox_inches='tight')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Example usage:\n",
    "RV_obs, p_val = rv_permutation_pvalue(direction_similarity(topk_sae.W_dec), (dataset.L @ dataset.L.T).cpu().numpy(), num_permutations=5000, seed=42)\n",
    "print(f\"RV topksae = {RV_obs:.3f}, p‐value ≈ {p_val:.3e}\")\n",
    "RV_obs, p_val = rv_permutation_pvalue(direction_similarity(kronsae.W_dec), (dataset.L @ dataset.L.T).cpu().numpy(), num_permutations=5000, seed=42)\n",
    "print(f\"RV kronsae = {RV_obs:.3f}, p‐value ≈ {p_val:.3e}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "peer_sae",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
