{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# AudioGen\n",
    "Welcome to AudioGen's demo jupyter notebook. Here you will find a series of self-contained examples of how to use AudioGen in different settings.\n",
    "\n",
    "First, we start by initializing AudioGen. For now, we provide only a medium sized model for AudioGen: `facebook/audiogen-medium` - 1.5B transformer decoder. \n",
    "\n",
    "**Important note:** This variant is different from the original AudioGen model presented at [\"AudioGen: Textually-guided audio generation\"](https://arxiv.org/abs/2209.15352) as the model architecture is similar to MusicGen with a smaller frame rate and multiple streams of tokens, allowing to reduce generation time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from audiocraft.models import AudioGen\n",
    "\n",
    "model = AudioGen.get_pretrained('facebook/audiogen-medium')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, let us configure the generation parameters. Specifically, you can control the following:\n",
    "* `use_sampling` (bool, optional): use sampling if True, else do argmax decoding. Defaults to True.\n",
    "* `top_k` (int, optional): top_k used for sampling. Defaults to 250.\n",
    "* `top_p` (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.\n",
    "* `temperature` (float, optional): softmax temperature parameter. Defaults to 1.0.\n",
    "* `duration` (float, optional): duration of the generated waveform. Defaults to 10.0.\n",
    "* `cfg_coef` (float, optional): coefficient used for classifier free guidance. Defaults to 3.0.\n",
    "\n",
    "When left unchanged, AudioGen will revert to its default parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.set_generation_params(\n",
    "    use_sampling=True,\n",
    "    top_k=250,\n",
    "    duration=5\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we can go ahead and start generating sound using one of the following modes:\n",
    "* Audio continuation using `model.generate_continuation`\n",
    "* Text-conditional samples using `model.generate`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Audio Continuation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import torchaudio\n",
    "import torch\n",
    "from audiocraft.utils.notebook import display_audio\n",
    "\n",
    "def get_bip_bip(bip_duration=0.125, frequency=440,\n",
    "                duration=0.5, sample_rate=16000, device=\"cuda\"):\n",
    "    \"\"\"Generates a series of bip bip at the given frequency.\"\"\"\n",
    "    t = torch.arange(\n",
    "        int(duration * sample_rate), device=\"cuda\", dtype=torch.float) / sample_rate\n",
    "    wav = torch.cos(2 * math.pi * 440 * t)[None]\n",
    "    tp = (t % (2 * bip_duration)) / (2 * bip_duration)\n",
    "    envelope = (tp >= 0.5).float()\n",
    "    return wav * envelope"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Here we use a synthetic signal to prompt the generated audio.\n",
    "res = model.generate_continuation(\n",
    "    get_bip_bip(0.125).expand(2, -1, -1), \n",
    "    16000, ['Whistling with wind blowing', \n",
    "            'Typing on a typewriter'], \n",
    "    progress=True)\n",
    "display_audio(res, 16000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# You can also use any audio from a file. Make sure to trim the file if it is too long!\n",
    "prompt_waveform, prompt_sr = torchaudio.load(\"../assets/sirens_and_a_humming_engine_approach_and_pass.mp3\")\n",
    "prompt_duration = 2\n",
    "prompt_waveform = prompt_waveform[..., :int(prompt_duration * prompt_sr)]\n",
    "output = model.generate_continuation(prompt_waveform, prompt_sample_rate=prompt_sr, progress=True)\n",
    "display_audio(output, sample_rate=16000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Text-conditional Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from audiocraft.utils.notebook import display_audio\n",
    "\n",
    "output = model.generate(\n",
    "    descriptions=[\n",
    "        'Subway train blowing its horn',\n",
    "        'A cat meowing',\n",
    "    ],\n",
    "    progress=True\n",
    ")\n",
    "display_audio(output, sample_rate=16000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
