{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "fbbf09cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# jupyter prepare code\n",
    "from IPython import get_ipython\n",
    "ipython = get_ipython()\n",
    "# Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n",
    "ipython.run_line_magic(\"load_ext\", \"autoreload\")\n",
    "ipython.run_line_magic(\"autoreload\", \"2\")\n",
    "\n",
    "# plotly setting\n",
    "import plotly.io as pio\n",
    "pio.renderers.default = \"notebook_connected\"\n",
    "\n",
    "# add parent directory to path\n",
    "import sys\n",
    "sys.path.append('../')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "initial_id",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-23T02:26:33.633638600Z",
     "start_time": "2023-09-23T02:26:33.628639800Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from datasets import load_from_disk, load_dataset\n",
    "from wav2vec import Wav2Vec2ForCTC, Wav2Vec2Processor\n",
    "from torch.utils.data import DataLoader\n",
    "import plotly.express as px\n",
    "import plotly.graph_objects as go\n",
    "from plotly.subplots import make_subplots\n",
    "import numpy as np\n",
    "from HIB import random_sample_data, HIB\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad4e560e",
   "metadata": {},
   "source": [
    "# Load datasets, model, processor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcc78be6",
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = load_from_disk(\"path/to/your/dataset\") # or others load function\n",
    "data_loader = DataLoader(datasets, batch_size=1)\n",
    "processor = Wav2Vec2Processor.from_pretrained(\"path/to/your/processor\")\n",
    "model = Wav2Vec2ForCTC.from_pretrained(\"path/to/your/model\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d87192ce",
   "metadata": {},
   "source": [
    "### Randomly sample to generate examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "841a3db769f3bacf",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-23T02:17:58.440362800Z",
     "start_time": "2023-09-23T02:17:56.949162200Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "sample = random_sample_data(datasets, 1)[0]\n",
    "\n",
    "# prepare data\n",
    "input_data = sample[\"audio\"][\"array\"]\n",
    "labels = processor(text=sample[\"text\"], return_tensors=\"pt\").input_ids\n",
    "wav_data = processor(input_data, sampling_rate=16000, return_tensors=\"pt\", padding=\"longest\").input_values"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "96200cc3",
   "metadata": {},
   "source": [
    "### Set your device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "83d4962b",
   "metadata": {},
   "outputs": [],
   "source": [
    "gpu_index = 0\n",
    "device = torch.device(\"cuda:%d\" % gpu_index if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5268371a",
   "metadata": {},
   "source": [
    "# Regular Attention"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "884c3d26",
   "metadata": {},
   "outputs": [],
   "source": [
    "# set layer index\n",
    "analysis_layer_index = -1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "5b6031f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Proceed with the regular model's forward process to obtain the normal attention.\n",
    "model = model.to(device)\n",
    "with torch.no_grad():\n",
    "    outputs = model(wav_data.to(device), output_attentions=True)\n",
    "    att = torch.cat(outputs.attentions,dim=0).cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91e0fb0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot\n",
    "px.imshow(att[analysis_layer_index], facet_col=0, facet_col_wrap=6)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "645a7426",
   "metadata": {},
   "source": [
    "# HIB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "6ad25175",
   "metadata": {},
   "outputs": [],
   "source": [
    "# HIB init\n",
    "analysis_layer = model.wav2vec2.encoder.layers[analysis_layer_index].attention.att_v_bmm\n",
    "hib = HIB(analysis_layer, lamb=0.9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac41e727",
   "metadata": {},
   "outputs": [],
   "source": [
    "# estimate mean, std\n",
    "hib.estimate(model, data_loader, processor, n_samples=10000*12)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae2a76f0e7a369bf",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# build alpha\n",
    "with torch.no_grad(), hib.enable_build_alpha(), hib.interrupt_execution():\n",
    "    model(wav_data.cuda())\n",
    "\n",
    "# analyze\n",
    "ib_attr = hib.analyze(wav_data, lambda x: model(x.cuda(), labels=labels.cuda()).loss, optimization_steps=200, batch_size=1, beta=10, lr=0.1)\n",
    "# iba.detach()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5e1a0ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "# \n",
    "with hib.constraint_flow():\n",
    "    print(model(wav_data.cuda(), labels=labels.cuda()).loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98da8198",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot information loss\n",
    "px.line(hib._information_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87ff04b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot HIB attribution\n",
    "px.imshow(ib_attr, facet_col=0, facet_col_wrap=6)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "80a740eb",
   "metadata": {},
   "source": [
    "# Attention Map Similarity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f66daf6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = ib_attr.reshape(12,-1)\n",
    "y = att[analysis_layer_index].reshape(12,-1)\n",
    "sim = []\n",
    "for i in range(12):\n",
    "    sim.append(np.corrcoef(x[i], y[i])[0,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "300ee946",
   "metadata": {},
   "outputs": [],
   "source": [
    "px.line(sim, title=\"Similarity between HIB and Attention\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
