{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0884259",
   "metadata": {},
   "outputs": [],
   "source": [
    "# make ipython reload modules whenever they have changed on disk\n",
    "%load_ext autoreload\n",
    "%autoreload 2 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24e88d8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import librosa\n",
    "import librosa.display\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "plt.style.use('Solarize_Light2')\n",
    "prop_cycle = plt.rcParams['axes.prop_cycle']\n",
    "colors = prop_cycle.by_key()['color']\n",
    "blue = colors[1]\n",
    "red = colors[5]\n",
    "import tensorflow as tf\n",
    "import os, sys\n",
    "import sound_tools\n",
    "import utils\n",
    "from sklearn.model_selection import train_test_split\n",
    "import pickle\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ddb65557",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Paths definitions\n",
    "\n",
    "DATA           = 'path_to_slider_audio_files' #Locate 0_dB_slider folder and change the directory accordingly  \\\n",
    "# eg: /.../0_dB_slider/slider (Download 0_dB_slider from MIMII dataset)\n",
    "\n",
    "PROCESSED_DATA = os.path.join(DATA, 'processed')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2577513",
   "metadata": {},
   "outputs": [],
   "source": [
    "# feature params\n",
    "\n",
    "n_mels = 64\n",
    "frames = 5\n",
    "n_fft = 1024\n",
    "hop_length = 512\n",
    "power = 2.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc1fd486",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build the list of normal and abnormal files:\n",
    "normal_files, abnormal_files = utils.build_files_list(root_dir=DATA)\n",
    "\n",
    "# filter and retain only files for machine id 00. For machines 02,04,06, replace 00 with 02,04,06 respectively\n",
    "normal_files = [f for f in normal_files if 'id_00' in f]\n",
    "abnormal_files = [f for f in abnormal_files if 'id_00' in f]\n",
    "\n",
    "\n",
    "# Concatenate them to obtain a features and label datasets that we can split:\n",
    "X = np.concatenate((normal_files, abnormal_files), axis=0)\n",
    "y = np.concatenate((np.zeros(len(normal_files)), np.ones(len(abnormal_files))), axis=0)\n",
    "\n",
    "train_files, test_files, train_labels, test_labels = train_test_split(X, y,\n",
    "                                                                      train_size=0.8,\n",
    "                                                                      random_state=42,\n",
    "                                                                      shuffle=True,\n",
    "                                                                      stratify=y\n",
    "                                                                     )\n",
    "# We will want to reuse this same train/test split for our next experiment in the next notebook:\n",
    "dataset = dict({\n",
    "    'train_files': train_files,\n",
    "    'test_files': test_files,\n",
    "    'train_labels': train_labels,\n",
    "    'test_labels': test_labels\n",
    "})\n",
    "\n",
    "for key, values in dataset.items():\n",
    "    fname = os.path.join(PROCESSED_DATA, key + '.txt')\n",
    "    with open(fname, 'w') as f:\n",
    "        for item in values:\n",
    "            f.write(str(item))\n",
    "            f.write('\\n')\n",
    "\n",
    "# We now keep only the normal signals from the train files to train the autoencoder:\n",
    "train_files = [f for f in train_files if f not in abnormal_files]\n",
    "train_labels = np.zeros(len(train_files))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6cf12a64",
   "metadata": {},
   "outputs": [],
   "source": [
    "#generate train data and save for each machine. Replace 00 by 02,04,06 for machines 02,04,06 respectively\n",
    "train_data_location = os.path.join(DATA, 'train_data_id_00_autoencoder.pkl')\n",
    "\n",
    "if not os.path.exists(train_data_location):\n",
    "    train_data = sound_tools.generate_dataset(train_files, n_mels=n_mels, frames=frames, n_fft=n_fft, hop_length=hop_length)\n",
    "    print('Saving training data to disk...')\n",
    "    with open(train_data_location, 'wb') as f:\n",
    "        pickle.dump(train_data, f)\n",
    "    print('Done.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7b36451",
   "metadata": {},
   "outputs": [],
   "source": [
    "# training the autoencoder \n",
    "\n",
    "import model\n",
    "\n",
    "# training params\n",
    "\n",
    "training_dir = DATA\n",
    "model_dir = training_dir\n",
    "lr = 0.001\n",
    "batch_size = 128\n",
    "epochs = 30\n",
    "gpu_count = 1\n",
    "\n",
    "#replace 00 by 02,04,06 for machines 02,04,06 respectively\n",
    "trainingfilename = 'train_data_id_00_autoencoder.pkl'\n",
    "\n",
    "#replace 0 by 2,4,6 for machines 02,04,06 respectively\n",
    "model_data_store_id = '0' # user-defined subdir label to store the trained model\n",
    "\n",
    "model.train(training_dir=training_dir, \n",
    "           model_dir=model_dir,\n",
    "           n_mels=n_mels,\n",
    "           frame=frames,\n",
    "           lr=lr,\n",
    "           batch_size=batch_size,\n",
    "           epochs=epochs,\n",
    "           gpu_count=gpu_count,\n",
    "           trainingfilename=trainingfilename,\n",
    "           model_data_store_id=model_data_store_id)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb137f9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# evaluate trained autoencoder model on test set\n",
    "\n",
    "#replace 0 with 2,4,6 for machines 02,04,06 respectively. Locate 0_dB_slider folder and change the directory accordingly\n",
    "trained_model = '/path_to_audio_files/0_dB_slider/slider/model/0' # trained only on machine id_00\n",
    "\n",
    "model_trained = tf.keras.models.load_model(trained_model)\n",
    "\n",
    "y_true = test_labels\n",
    "reconstruction_errors = []\n",
    "\n",
    "for index, eval_filename in tqdm(enumerate(test_files), total=len(test_files)):\n",
    "    # Load signal\n",
    "    signal, sr = sound_tools.load_sound_file(eval_filename)\n",
    "\n",
    "    # Extract features from this signal:\n",
    "    eval_features = sound_tools.extract_signal_features(\n",
    "        signal, \n",
    "        sr, \n",
    "        n_mels=n_mels, \n",
    "        frames=frames, \n",
    "        n_fft=n_fft, \n",
    "        hop_length=hop_length\n",
    "    )\n",
    "    \n",
    "    # Get predictions from our autoencoder:\n",
    "    prediction = model_trained.predict(eval_features)\n",
    "    \n",
    "    # Estimate the reconstruction error:\n",
    "    mse = np.mean(np.mean(np.square(eval_features - prediction), axis=1))\n",
    "    reconstruction_errors.append(mse)\n",
    "    \n",
    "# test set reconstruction error analysis\n",
    "\n",
    "data = np.column_stack((range(len(reconstruction_errors)), reconstruction_errors))\n",
    "bin_width = 0.25\n",
    "bins = np.arange(min(reconstruction_errors), max(reconstruction_errors) + bin_width, bin_width)\n",
    "\n",
    "fig = plt.figure(figsize=(12,4))\n",
    "plt.hist(data[y_true==0][:,1], bins=bins, color=blue, alpha=0.6, label='Normal signals', edgecolor='#FFFFFF')\n",
    "plt.hist(data[y_true==1][:,1], bins=bins, color=red, alpha=0.6, label='Abnormal signals', edgecolor='#FFFFFF')\n",
    "plt.xlabel(\"Testing reconstruction error\")\n",
    "plt.ylabel(\"# Samples\")\n",
    "plt.title('Reconstruction error distribution on the testing set', fontsize=16)\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25aaa84a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
