{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%env CUDA_VISIBLE_DEVICES=1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Inference process of WaveGrad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.insert(0, '..')\n",
    "\n",
    "import json\n",
    "import IPython.display as ipd\n",
    "\n",
    "import torch\n",
    "\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "import utils\n",
    "import benchmark\n",
    "from model import WaveGrad\n",
    "from data import AudioDataset, MelSpectrogramFixed\n",
    "import numpy as np\n",
    "\n",
    "from torch.utils.data import DataLoader\n",
    "import soundfile as sf\n",
    "import os\n",
    "from pypesq import pesq\n",
    "from pystoi import stoi"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Load configuration**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "CONFIG_PATH='../configs/default.json'\n",
    "\n",
    "with open(CONFIG_PATH) as f:\n",
    "    config = utils.ConfigWrapper(**json.load(f))\n",
    "config.training_config.logdir = f'../{config.training_config.logdir}'\n",
    "config.training_config.train_filelist_path = f'../{config.training_config.train_filelist_path}'\n",
    "config.training_config.test_filelist_path = f'../{config.training_config.test_filelist_path}'\n",
    "config.training_config.theta_0=0.001\n",
    "config"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Initialize the model**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = WaveGrad(config).cuda()\n",
    "ckpt_path = '.pt'\n",
    "print(f'Number of parameters: {model.nparams}')\n",
    "model.load_state_dict(torch.load(ckpt_path)['model'], strict=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Initialize the dataset**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = AudioDataset(config, training=False)\n",
    "mel_fn = MelSpectrogramFixed(\n",
    "    sample_rate=config.data_config.sample_rate,\n",
    "    n_fft=config.data_config.n_fft,\n",
    "    win_length=config.data_config.win_length,\n",
    "    hop_length=config.data_config.hop_length,\n",
    "    f_min=config.data_config.f_min,\n",
    "    f_max=config.data_config.f_max,\n",
    "    n_mels=config.data_config.n_mels,\n",
    "    window_fn=torch.hann_window\n",
    ").cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TEST_BATCH_SIZE=2\n",
    "\n",
    "# Sample test batch from test set \n",
    "test_batch = dataset.sample_test_batch(TEST_BATCH_SIZE)\n",
    "\n",
    "for test_sample in test_batch:\n",
    "    ipd.display(ipd.Audio(test_sample.squeeze(), rate=22050))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Grid search of best schedule (optional, otherwise set betas in the next section by hand)**\n",
    "\n",
    "Note: the lower `step` argument, the more accurate the search is.\n",
    "\n",
    "Grid search needed for 6 it sampling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "PERFORM_GRID_SEARCH=True\n",
    "\n",
    "if PERFORM_GRID_SEARCH:\n",
    "    n_iter = 6\n",
    "    path_to_store_schedule = f'../schedules/theta_0_0001/{n_iter}iters.pt'\n",
    "\n",
    "    iters_best_schedule, stats = benchmark.iters_schedule_grid_search(\n",
    "        model, config,\n",
    "        n_iter=n_iter,\n",
    "        betas_range=(1e-06, 0.01),\n",
    "        test_batch_size=4, step=2,\n",
    "        path_to_store_schedule=path_to_store_schedule,\n",
    "        save_stats_for_grid=True,\n",
    "        verbose=True, n_jobs=4\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Set noise schedule**\n",
    "\n",
    "Note: `init_kwargs` should always contain the key `steps`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SCHEDULE_PATHS={\n",
    "    6: '../schedules/theta_0_0001/6iters.pt',\n",
    "    25: '../schedules/pretrained/25iters.pt',\n",
    "    50: '../schedules/pretrained/50iters.pt',\n",
    "    100: '../schedules/pretrained/100iters.pt',\n",
    "    1000: '../schedules/pretrained/1000iters.pt',\n",
    "}\n",
    "\n",
    "SCHEDULES = {\n",
    "    schedule_type: {\n",
    "        'init': lambda **kwargs: torch.FloatTensor(torch.load(kwargs['path'])),\n",
    "        'init_kwargs': {'steps': schedule_type, 'path': path}\n",
    "    } for schedule_type, path in SCHEDULE_PATHS.items()\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SCHEDULE_TYPE_TO_SET=25\n",
    "\n",
    "model.set_new_noise_schedule(\n",
    "    init=SCHEDULES[SCHEDULE_TYPE_TO_SET]['init'],\n",
    "    init_kwargs=SCHEDULES[SCHEDULE_TYPE_TO_SET]['init_kwargs']\n",
    ")\n",
    "# torch.load(SCHEDULE_PATHS[SCHEDULE_TYPE_TO_SET])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Listen Samples :"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_preds = []\n",
    "for test_sample in tqdm(test_batch):\n",
    "    mel = mel_fn(test_sample[None].cuda())\n",
    "    outputs = model.forward(\n",
    "        mel, store_intermediate_states=False\n",
    "    )\n",
    "    test_preds.append(outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for signal in test_preds:\n",
    "    ipd.display(ipd.Audio((signal).squeeze().cpu(), rate=config.data_config.sample_rate))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Sampling and metrics\n",
    "\n",
    "The samples are stored in .wav to compute the MCD using [https://github.com/MattShannon/mcd](https://github.com/MattShannon/mcd)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_loader = DataLoader(dataset, batch_size=1)\n",
    "\n",
    "path_to_save='SOMEPATH'\n",
    "path_to_save_GT = os.path.join(path_to_save, 'GroundTruth')\n",
    "os.mkdir(path_to_save_GT)\n",
    "path_to_save_samples = os.path.join(path_to_save, 'Sampled')\n",
    "os.mkdir(path_to_save_samples)\n",
    "GT = []\n",
    "Samples = []\n",
    "\n",
    "\n",
    "for i, test_sample in tqdm(enumerate(test_loader)):\n",
    "    path_GT = os.path.join(path_to_save_GT, 'sample_'+str(i)+'.wav')\n",
    "    path_sample = os.path.join(path_to_save_samples, 'sample_'+str(i)+'.wav')\n",
    "    sf.write(path_GT, test_sample[0].cpu().numpy(), 22050)\n",
    "    GT.append(test_sample[0].cpu().numpy())\n",
    "    mel = mel_fn(test_sample.cuda())\n",
    "    outputs = model.forward(\n",
    "        mel, store_intermediate_states=False\n",
    "    )\n",
    "    Samples.append(outputs[0].cpu().numpy())\n",
    "    save = np.zeros(test_sample[0].cpu().numpy().shape)\n",
    "    N =  len(outputs[0].cpu().numpy())\n",
    "    save[:N] = outputs[0].cpu().numpy()\n",
    "    sf.write(path_sample, save, 22050)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "P = 0\n",
    "S = 0\n",
    "for i in tqdm(range(len(GT))):\n",
    "    gt = GT[i]\n",
    "    sample = Samples[i]\n",
    "    P += pesq(gt, sample)\n",
    "    S += stoi(gt, sample, 22050)\n",
    "    \n",
    "print(\"PESQ = {}, STOI = {}\".format(P/100, S/100))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
