{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference and Permutation Importance for GBDT"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, we want to calculate and compare feature importances for the gradient boosting decision tree (GBDT) classifier."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import required libraries\n",
    "from features import Features\n",
    "import os\n",
    "import pandas as pd\n",
    "import librosa\n",
    "import numpy as np\n",
    "from azureml.fsspec import AzureMachineLearningFileSystem\n",
    "\n",
    "feature_generator = Features()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "count = 0\n",
    "max = 100000\n",
    "\n",
    "folder = \"azureml:/\"\n",
    "fs = AzureMachineLearningFileSystem(folder)\n",
    "test_dir = os.path.join(folder, 'flac_T/')\n",
    "test_features = []\n",
    "test_labels = []\n",
    "slice_size = 3\n",
    "slice_audio = True\n",
    "df = pd.read_csv(folder+ 'metadata/train_metadata.csv')\n",
    "raw_labels = df['KEY'].to_list()[:max]\n",
    "file_names = [file + '.flac' for file in df['FLAC_FILE_NAME'].to_list()][:max]\n",
    "\n",
    "for idx, file in enumerate(file_names):\n",
    "    if file.endswith('.flac'):\n",
    "        label = raw_labels[idx]\n",
    "        file_path = os.path.join(test_dir, file)\n",
    "\n",
    "        # Load files from Azure filesystem (fs)\n",
    "        with fs.open(file_path) as f:\n",
    "            if not slice_audio:\n",
    "                y, sr = librosa.load(f)\n",
    "                test_features.append(feature_generator.make_features(y, sr))\n",
    "                test_labels.append(label)\n",
    "            else:\n",
    "                y, sr = librosa.load(f)\n",
    "\n",
    "                segment_length_samples = int(slice_size * sr)\n",
    "\n",
    "                # Determine the number of segments\n",
    "                num_segments = int(np.ceil(len(y) / segment_length_samples))\n",
    "\n",
    "                for i in range(num_segments-1):\n",
    "                    start_sample = i * segment_length_samples\n",
    "                    end_sample = min((i + 1) * segment_length_samples, len(y))\n",
    "                    \n",
    "                    # Extract the segment\n",
    "                    segment = y[start_sample:end_sample]\n",
    "                    test_features.append(feature_generator.make_features(segment, sr))\n",
    "                    test_labels.append(label)\n",
    "                \n",
    "        # Crude limiter\n",
    "        if count >= max:\n",
    "            break\n",
    "\n",
    "print(\"loaded test audio\")\n",
    "X_test = np.array(test_features)\n",
    "y_test = np.array(test_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the specified model from Azure file storage\n",
    "\n",
    "from pickle import load\n",
    "with open(\"/home/azureuser/model.pkl\", \"rb\") as f:\n",
    "    booster = load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make sure the labels are correctly mapped to the logits\n",
    "d = {'spoof':1, 'bonafide':0}\n",
    "y_test = list(map(lambda x: d[x], list(y_test)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calculate and print performance\n",
    "from sklearn.metrics import classification_report\n",
    "\n",
    "y_pred = booster.predict(X_test)\n",
    "print(classification_report(y_test, y_pred))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Feature Importances"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Graphs for slice lengths of 6, 3, and 1 seconds."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the requisite data from the XGBooster class\n",
    "\n",
    "from xgboost import XGBooster\n",
    "\n",
    "slice_length = 3\n",
    "compressed=False\n",
    "xg_data = XGBooster(data_only=True, splitvoice=False, compressed=compressed, rerecorded=False, max=1000, slice_size_seconds=slice_length)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the model from Azure\n",
    "from pickle import load\n",
    "\n",
    "filename = \"/home/azureuser/{}.pkl\".format(slice_length)\n",
    "with open(filename, \"rb\") as f:\n",
    "    model = load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calculate the permutation importance of each feature\n",
    "from sklearn.inspection import permutation_importance\n",
    "r = permutation_importance(model, xg_data.X_test, xg_data.y_test,\n",
    "                           n_repeats=30,\n",
    "                           random_state=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Print the importance means and stds\n",
    "for i in r.importances_mean.argsort()[::-1]:\n",
    "    if r.importances_mean[i] - 2 * r.importances_std[i] > 0:\n",
    "        print(f\"{i} \"\n",
    "              f\"{r.importances_mean[i]:.3f}\"\n",
    "              f\" +/- {r.importances_std[i]:.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize the importances \n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "\n",
    "mfccs = ['mfcc'+str(i+1) for i in range(0, 20)]\n",
    "chroma = ['chroma'+str(i+1) for i in range(0, 12)]\n",
    "feature_names = np.array(mfccs + ['cr'] + chroma + ['sc', 'sb', 'rolloff', 'rms'])\n",
    "\n",
    "forest_importances = pd.Series(r.importances_mean, index=feature_names)\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "forest_importances.plot.bar(yerr=r.importances_std, ax=ax)\n",
    "ax.set_title(\"Feature importances using permutation on full model\")\n",
    "ax.set_ylabel(\"Mean accuracy decrease\")\n",
    "\n",
    "plt.rcParams.update({'font.size': 10})\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# As many of our features are correlated, \n",
    "# we want to vizualize the correlations so that we can choose features with large span\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "plt.figure(figsize=(16, 10))\n",
    "\n",
    "mfccs = ['mfcc'+str(i+1) for i in range(0, 20)]\n",
    "chroma = ['chroma'+str(i+1) for i in range(0, 12)]\n",
    "feature_names = np.array(mfccs + ['cr'] + chroma + ['sc', 'sb', 'rolloff', 'rms'])\n",
    "\n",
    "forest_importances = pd.Series(r.importances_mean, index=feature_names)\n",
    "\n",
    "forest_importances.plot(kind='bar', yerr=r.importances_std)\n",
    "plt.ylim(-0.01, 0.065)\n",
    "\n",
    "plt.title(\"Feature importances using permutation on full model\")\n",
    "plt.ylabel(\"Mean accuracy decrease\")\n",
    "\n",
    "# Scale font size\n",
    "plt.rcParams.update({'font.size': 25})\n",
    "plt.rcParams.update({'font.family': 'serif', 'font.serif': 'Times New Roman'})\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('feature_importances_full_model_.png')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Here, we plot the correlation matrix\n",
    "from scipy.cluster import hierarchy\n",
    "from scipy.spatial.distance import squareform\n",
    "from scipy.stats import spearmanr\n",
    "\n",
    "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8))\n",
    "corr = spearmanr(xg_data.X_test).correlation\n",
    "\n",
    "# Ensure the correlation matrix is symmetric\n",
    "corr = (corr + corr.T) / 2\n",
    "np.fill_diagonal(corr, 1)\n",
    "\n",
    "# We convert the correlation matrix to a distance matrix before performing\n",
    "# hierarchical clustering using Ward's linkage.\n",
    "distance_matrix = 1 - np.abs(corr)\n",
    "dist_linkage = hierarchy.ward(squareform(distance_matrix))\n",
    "dendro = hierarchy.dendrogram(\n",
    "    dist_linkage, labels=feature_names, ax=ax1, leaf_rotation=90\n",
    ")\n",
    "dendro_idx = np.arange(0, len(dendro[\"ivl\"]))\n",
    "\n",
    "ax2.imshow(corr[dendro[\"leaves\"], :][:, dendro[\"leaves\"]])\n",
    "ax2.set_xticks(dendro_idx)\n",
    "ax2.set_yticks(dendro_idx)\n",
    "ax2.set_xticklabels(dendro[\"ivl\"], rotation=\"vertical\")\n",
    "ax2.set_yticklabels(dendro[\"ivl\"])\n",
    "_ = fig.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Retrain our model with a subset of the most important features\n",
    "xg_data_subset = XGBooster(subset=['mfcc_3', 'mfcc_10'], to_save=False, data_only=False, splitvoice=False, compressed=compressed, rerecorded=False, max=1000, slice_size_seconds=slice_length)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py38",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
