{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.benchmark import evaluate_sufficiency_and_necessity, cluster_cross_sections, evaluate_sparse_controllability, compute_sweep_cluster\n",
    "from utils.setup import init_model\n",
    "from utils.data_generator import init_ind,init_ioi\n",
    "from utils.vizualization_handler import create_chart, detect_and_visualize_outliers, viz_sweep_results\n",
    "from utils.data_generator import evaluate_example_generator\n",
    "\n",
    "from sae_lens import SAE\n",
    "from collections import defaultdict\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = init_model(\"gemma-2-2b\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SAEs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "saes = []\n",
    "\n",
    "# edit thje pre of the downstream layer\n",
    "# use activations from post from upstream to downstream - 1 ==> pre from upstream+1 to downstream \n",
    "\n",
    "sizes = (\"16k\",\"65k\")\n",
    "\n",
    "for size in sizes:\n",
    "    sae_layers = []\n",
    "    for layer_id in range(model.cfg.n_layers):\n",
    "        sae, cfg_dict, sparsity = SAE.from_pretrained(\n",
    "            release = f\"gemma-scope-2b-pt-res-canonical\",\n",
    "            sae_id = f\"layer_{layer_id}/width_{size}/canonical\",\n",
    "        )\n",
    "\n",
    "        sae_layers.append(sae)\n",
    "    saes.append(sae_layers)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Supervised dictionaries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_test_batches=5\n",
    "batch_size=50\n",
    "n_edits = [0,4,8,16]\n",
    "m = 1\n",
    "\n",
    "correct = \"ind2\"\n",
    "incorrect = \"ind1\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [],
   "source": [
    "ind_cross_sections_features = []\n",
    "path = \"./feature_dicts/gemma\"\n",
    "for filename in os.listdir(path):\n",
    "    if(\"ind\" in filename):\n",
    "        ind_cross_sections_features.extend(torch.load(path+\"/\"+filename))\n",
    "\n",
    "ind_cross_section_clusters = cluster_cross_sections(ind_cross_sections_features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ind_name_candidates = []\n",
    "\n",
    "for i in range(len(ind_cross_sections_features)):\n",
    "    for attr in ind_cross_sections_features[i]['feature-dict'].keys():\n",
    "        if(attr == \"order\"):\n",
    "            continue\n",
    "        else:\n",
    "            ind_name_candidates.append(ind_cross_sections_features[i]['feature-dict'][attr].keys())\n",
    "ind_name_candidates = list(set.intersection(*map(set,ind_name_candidates)))\n",
    "print(ind_name_candidates)\n",
    "\n",
    "get_ind_test_examples, n_ind_test_examples = init_ind(model, ind_name_candidates, batch_size, train=False, cross_entropy_threshhold=5)\n",
    "print(f\"Test Examples Prediction Accuracy: {evaluate_example_generator(model, get_ind_test_examples, 'ind2', 50)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sweep"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ind_cross_section_clusters.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sweep_results_list = []\n",
    "labels = []\n",
    "for key in tqdm(sorted(list(ind_cross_section_clusters.keys()))):\n",
    "\n",
    "    sweep_results_list.append(compute_sweep_cluster(model, get_ind_test_examples, correct, incorrect, n_test_batches, ind_cross_section_clusters[key], step=5))\n",
    "\n",
    "    labels.append(\"\\n\".join(key))\n",
    "\n",
    "viz_sweep_results(sweep_results_list,labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "relative_outliers, max_ablation_scores = detect_and_visualize_outliers(sweep_results_list,labels,percentile=0.4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "metadata": {},
   "outputs": [],
   "source": [
    "ind_cross_section_clusters_ = defaultdict()\n",
    "for i, sweep_results in enumerate(sweep_results_list):\n",
    "    # Find the optimal subset (largest normalization score) for the current cluster\n",
    "    optimal_subset = max(sweep_results, key=sweep_results.get)\n",
    "    optimal_length = int(optimal_subset.split('_')[1])  # Extract the length from the subset name (e.g., 'subset_26' -> 26)\n",
    "\n",
    "    # Select the corresponding subset of the cross-section cluster\n",
    "    key = sorted(list(ind_cross_section_clusters.keys()))[i]\n",
    "    ind_cross_section_clusters_[key] = ind_cross_section_clusters[key][:optimal_length]\n",
    "\n",
    "del_offset = 0\n",
    "for outlier_idx in relative_outliers:\n",
    "    key = sorted(list(ind_cross_section_clusters_.keys()))[outlier_idx-del_offset]\n",
    "    del ind_cross_section_clusters_[key]\n",
    "    del_offset += 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sufficiency and Necessity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.empty_cache()\n",
    "import gc\n",
    "gc.collect()\n",
    "\n",
    "# ind-specific code\n",
    "ind_sufficiencies = []\n",
    "ind_necessities = []\n",
    "                           \n",
    "for key in tqdm(sorted(list(ind_cross_section_clusters_.keys()))):\n",
    "\n",
    "    res = evaluate_sufficiency_and_necessity(model, get_ind_test_examples, correct, incorrect, n_test_batches, ind_cross_section_clusters_[key], saes, sizes, m)\n",
    "    \n",
    "    ind_sufficiencies.append({key:res[\"sufficiency\"]})\n",
    "    ind_necessities.append({key:res[\"necessity\"]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "create_chart([ind_sufficiencies,ind_necessities])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Controlability"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.empty_cache()\n",
    "import gc\n",
    "gc.collect()\n",
    "\n",
    "ind_controllability = evaluate_sparse_controllability(model, get_ind_test_examples, n_test_batches, ind_cross_section_clusters_, saes, n_edits, sizes, m=1)\n",
    "\n",
    "flattened = [{key: value} for key, value in ind_controllability.items()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "create_chart([flattened], controlability=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dissenv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
