{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Transformer Explainability Methods"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import required modules\n",
    "import torch\n",
    "from transformers import AutoFeatureExtractor\n",
    "from transformers import AutoModelForAudioClassification\n",
    "import librosa\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "import pandas as pd\n",
    "from azureml.fsspec import AzureMachineLearningFileSystem\n",
    "\n",
    "# Set global font settings\n",
    "plt.rcParams.update({'font.size': 25})\n",
    "plt.rcParams.update({'font.family': 'serif', 'font.serif': 'Times New Roman'})\n",
    "plt.tight_layout()\n",
    "\n",
    "# Set global slice length\n",
    "slice_length = 6.0\n",
    "\n",
    "# Set visualization for bonafide or spoof\n",
    "bonafide = True\n",
    "spoof = not bonafide"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load Data from Azure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "root_dir = \"azureml:/\"\n",
    "fs = AzureMachineLearningFileSystem(root_dir)\n",
    "\n",
    "metadata = pd.read_csv(root_dir+'/metadata/train_metadata.csv')\n",
    "\n",
    "if bonafide:\n",
    "    fname = metadata[metadata['KEY']=='bonafide'].iloc[0]['FLAC_FILE_NAME']\n",
    "else:\n",
    "    fname = metadata[metadata['KEY']=='spoof'].iloc[0]['FLAC_FILE_NAME']\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load test audio file\n",
    "audio_file = fs.open(root_dir+'flac_T/'+fname+'.flac', 'r')\n",
    "audio, sr = librosa.load(audio_file, sr=16000)\n",
    "audio = audio[:int(sr*slice_length)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Play audio if desired\n",
    "import IPython\n",
    "IPython.display.Audio(audio, rate=sr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Occlusion with AST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load model from Azure filesystem\n",
    "model_dir = f\"azureml://checkpoint/\"\n",
    "model_fs = AzureMachineLearningFileSystem(model_dir)\n",
    "\n",
    "feature_extractor = AutoFeatureExtractor.from_pretrained(\"MIT/ast-finetuned-audioset-10-10-0.4593\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Preprocess the audio into spectrogram\n",
    "inputs = feature_extractor(audio, sampling_rate=sr, return_tensors=\"pt\", padding=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copy model to local directory\n",
    "local_dir = '../temp_model'\n",
    "\n",
    "os.makedirs(local_dir, exist_ok=True)\n",
    "\n",
    "# Download the model files\n",
    "for file in model_fs.ls(local_dir, detail=False, recursive=True):\n",
    "    if model_fs.isfile(file):  # Only process files, not directories\n",
    "        file_name = os.path.basename(file)\n",
    "        model_fs.get(file, local_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Instantiate the model\n",
    "model = AutoModelForAudioClassification.from_pretrained(local_dir, output_attentions=True, output_hidden_states=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize the features\n",
    "features = inputs['input_values'].squeeze().numpy()\n",
    "plt.figure(figsize=(10, 6))\n",
    "plt.imshow(features.T, aspect='auto', origin='lower', cmap='jet')\n",
    "plt.colorbar(label='Amplitude')\n",
    "plt.title(\"AST Feature Extractor Output\")\n",
    "plt.xlabel(\"Time (frames)\")\n",
    "plt.ylabel(\"Frequency bins\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "\n",
    "def occlusion_with_probability_change(inputs, model, occlusion_size=(200, 50), stride=(100, 25)):\n",
    "    # Extract the original logits and get the softmax probabilities\n",
    "    original_logits = model(**inputs).logits\n",
    "    original_probs = F.softmax(original_logits, dim=-1)\n",
    "    original_prediction = torch.argmax(original_probs, dim=-1).item()\n",
    "    original_prob_predicted_class = original_probs[0, original_prediction].item()\n",
    "\n",
    "    # Copy the spectrogram to perform occlusions\n",
    "    spectrogram = inputs['input_values'][0].unsqueeze(0).clone()\n",
    "\n",
    "    print(spectrogram.size())\n",
    "\n",
    "    # Initialize heatmap for storing changes in probability\n",
    "    heatmap = np.zeros((spectrogram.size(1), spectrogram.size(2)))\n",
    "\n",
    "    # Occlude different parts of the spectrogram and measure probability changes\n",
    "    for i in range(0, spectrogram.size(1) - occlusion_size[0] + 1, stride[0]):\n",
    "        for j in range(0, spectrogram.size(2) - occlusion_size[1] + 1, stride[1]):\n",
    "            # Clone the spectrogram and occlude the patch\n",
    "            print(i,j)\n",
    "            occluded_spectrogram = spectrogram.clone()\n",
    "            occluded_spectrogram[:, i:i+occlusion_size[0], j:j+occlusion_size[1]] = 0\n",
    "\n",
    "            # Get the logits and probabilities after occlusion\n",
    "            occluded_inputs = {'input_values': occluded_spectrogram}\n",
    "            occluded_logits = model(**occluded_inputs).logits\n",
    "            occluded_probs = F.softmax(occluded_logits, dim=-1)\n",
    "            occluded_prob_predicted_class = occluded_probs[0, original_prediction].item()\n",
    "\n",
    "            # Calculate the difference in probability for the predicted class\n",
    "            prob_diff = original_prob_predicted_class - occluded_prob_predicted_class\n",
    "\n",
    "            # Store the difference in the heatmap (higher difference = more important region)\n",
    "            print(prob_diff)\n",
    "            \n",
    "            heatmap[i:i+occlusion_size[0], j:j+occlusion_size[1]] += abs(prob_diff)\n",
    "\n",
    "    return heatmap\n",
    "\n",
    "heatmap = occlusion_with_probability_change(inputs, model)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize resultant heatmap\n",
    "def visualize_occlusion(heatmap, inputs):\n",
    "    spectrogram = inputs['input_values'][0].squeeze().numpy()\n",
    "    heatmap = heatmap.squeeze()\n",
    "\n",
    "    plt.figure(figsize=(10, 6))\n",
    "    plt.imshow(spectrogram.T, aspect='auto', origin='lower', cmap='Greys')\n",
    "    plt.imshow(heatmap.T, aspect='auto', origin='lower', cmap='jet', alpha=0.5)\n",
    "    plt.colorbar(label='Importance')\n",
    "    # plt.title(\"Spectrogram with Occlusion Heatmap\")\n",
    "    plt.xlabel(\"Time Frames\")\n",
    "    plt.ylabel(\"Frequency Bins\")\n",
    "    plt.show()\n",
    "\n",
    "# Visualize\n",
    "visualize_occlusion(heatmap, inputs)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Attention Visualization with Wav2Vec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load model from local\n",
    "model_name = f\"{int(slice_length)}/checkpoint/\"\n",
    "feature_extractor = AutoFeatureExtractor.from_pretrained(\"facebook/wav2vec2-base\")\n",
    "\n",
    "# Preprocess the audio into spectrogram\n",
    "inputs = feature_extractor(audio, sampling_rate=sr, return_tensors=\"pt\", padding=False)\n",
    "\n",
    "# Instantiate model\n",
    "model = AutoModelForAudioClassification.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "outputs = model(**inputs)\n",
    "attentions = outputs.attentions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize attention per layer per head (here, we have only one head)\n",
    "layer_num = -1\n",
    "head_num = 0\n",
    "attention_weights = torch.mean(attentions[layer_num][head_num], dim=0).detach().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(attention_weights.flatten())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "normalized_attention_weights = (attention_weights - np.min(attention_weights)) / (np.max(attention_weights) - np.min(attention_weights))\n",
    "plt.hist(normalized_attention_weights.flatten())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This is to see where the most attention is concentrated\n",
    "# The final sum gives us a sense of how distributed vs. concentrated the attention is\n",
    "from itertools import compress\n",
    "\n",
    "over = [a > 0.8 for a in normalized_attention_weights.ravel()]\n",
    "over_idx = np.where(np.array(over))[0]\n",
    "for idx in over_idx:\n",
    "    print(np.unravel_index(idx, attention_weights.shape))\n",
    "\n",
    "sum(over)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize the averaged attention weights\n",
    "attention_weights = torch.mean(attentions[layer_num][head_num], dim=0).detach().numpy()\n",
    "attention_weights = (attention_weights - np.min(attention_weights)) / (np.max(attention_weights) - np.min(attention_weights))\n",
    "plt.figure(figsize=(10, 8))\n",
    "plt.imshow(attention_weights, cmap='viridis', aspect='auto') # vmin=0.0, vmax=0.15\n",
    "plt.colorbar(label='Attention Weight')\n",
    "plt.title(f\"Average Attention Weights (Layer {layer_num})\")\n",
    "plt.xlabel(\"Input Sequence Position\")\n",
    "plt.ylabel(\"Input Sequence Position\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(4, 3, figsize=(20, 27))\n",
    "\n",
    "for i in range(12):\n",
    "    attention_weights = torch.mean(attentions[i][0], dim=0).detach().numpy()\n",
    "    attention_weights = (attention_weights - np.min(attention_weights)) / (np.max(attention_weights) - np.min(attention_weights))  # Normalize the values\n",
    "\n",
    "    # Get the corresponding axis in the grid\n",
    "    ax = axes[i // 3, i % 3]  # Map head i to the correct subplot\n",
    "\n",
    "    # Plot the attention weights in the current axis\n",
    "    im = ax.imshow(attention_weights, cmap='viridis', aspect='equal')\n",
    "\n",
    "    # Add title and labels\n",
    "    ax.set_title(f\"Layer {i+1}\")\n",
    "    ax.set_xlabel(\"Input Sequence Position\")\n",
    "    ax.set_ylabel(\"Input Sequence Position\")\n",
    "\n",
    "\n",
    "cbar = fig.colorbar(im, ax=axes, orientation='vertical', fraction=0.05, pad=0.05)\n",
    "cbar.set_label('Attention Weight')\n",
    "\n",
    "# Adjust layout to avoid overlapping plots\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_attention_rollout(attentions):\n",
    "    # Start with identity matrix for the last layer\n",
    "    rollout = torch.eye(attentions[0].size(-1)).to(attentions[0].device)\n",
    "    \n",
    "    # For each layer, multiply the current attention matrix with the cumulative rollout\n",
    "    for attention in attentions:\n",
    "        # Average attention across all heads\n",
    "        attention_heads_fused = attention.mean(dim=1)\n",
    "        \n",
    "        # Add a residual connection and normalize\n",
    "        attention_heads_fused = attention_heads_fused + torch.eye(attention_heads_fused.size(-1)).to(attention_heads_fused.device)\n",
    "        attention_heads_fused = attention_heads_fused / attention_heads_fused.sum(dim=-1, keepdim=True)\n",
    "        \n",
    "        # Multiply with the previous rollout matrix\n",
    "        rollout = torch.matmul(attention_heads_fused, rollout)\n",
    "    \n",
    "    return rollout\n",
    "\n",
    "# Compute attention rollout\n",
    "attention_map = compute_attention_rollout(outputs.attentions)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize attention roll-out\n",
    "from torch.nn.functional import normalize\n",
    "\n",
    "m = attention_map[0, 0, 1:].detach().numpy()\n",
    "n = (m - np.min(m)) / (np.max(m) - np.min(m))\n",
    "\n",
    "plt.figure(figsize=(8, 6))\n",
    "plt.bar(range(len(n)), n)\n",
    "\n",
    "plt.xlabel('Token ID')\n",
    "plt.ylabel('Attention Weight')\n",
    "\n",
    "# Display the plot\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "wav2vec",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
