{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "import numpy as np\n",
    "import collections\n",
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import cm as cm\n",
    "from IPython.display import Audio, display, clear_output, Markdown, Image\n",
    "#import librosa\n",
    "#import librosa.display\n",
    "import ipywidgets as widgets\n",
    "# \n",
    "from tacotron2.text import text_to_sequence as text_to_sequence_internal\n",
    "from tacotron2.text.symbols import symbols\n",
    "# \n",
    "import tritonhttpclient as thc\n",
    "\n",
    "defaults = {\n",
    "    # settings\n",
    "    'sigma_infer': 0.6,        # don't touch this\n",
    "    'sampling_rate': 22050,    # don't touch this\n",
    "    'stft_hop_length': 256,    # don't touch this\n",
    "    'url': 'localhost:8000',   # don't touch this\n",
    "    'autoplay': True,          # autoplay\n",
    "    'character_limit_min': 4,  # don't touch this\n",
    "    'character_limit_max': 340 # don't touch this\n",
    "}\n",
    "\n",
    "\n",
    "# create args object\n",
    "class Struct:\n",
    "    def __init__(self, **entries):\n",
    "        self.__dict__.update(entries)\n",
    "\n",
    "args = Struct(**defaults)\n",
    "\n",
    "triton_client = thc.InferenceServerClient(args.url)\n",
    "\n",
    "def display_sound(signal, title, color):\n",
    "    ''' displays signal '''\n",
    "    clear_output(wait=True)\n",
    "    plt.figure(figsize=(10, 2.5))\n",
    "    plt.title(title)\n",
    "    plt.tick_params(\n",
    "        axis='both',\n",
    "        which='both',\n",
    "        bottom=True,\n",
    "        top=False,\n",
    "        left=False,\n",
    "        right=False,\n",
    "        labelbottom=True,\n",
    "        labelleft=False)\n",
    "    # librosa.display.waveplot(signal, color=color)\n",
    "    sig = signal[0]\n",
    "    hop = args.stft_hop_length\n",
    "    smoothed = []\n",
    "    for i in range(0, len(sig), hop):\n",
    "        smoothed.append(np.average(sig[i:i+hop]))\n",
    "    plt.plot(smoothed, color=color)\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "def display_spectrogram(mel, title):\n",
    "    ''' displays mel spectrogram '''\n",
    "    clear_output(wait=True)\n",
    "    fig = plt.figure(figsize=(10, 2.5))\n",
    "    ax = fig.add_subplot(111)\n",
    "    plt.title(title)\n",
    "    plt.tick_params(\n",
    "        axis='both',\n",
    "        which='both',\n",
    "        bottom=True,\n",
    "        top=False,\n",
    "        left=False,\n",
    "        right=False,\n",
    "        labelbottom=True,\n",
    "        labelleft=False)\n",
    "    plt.xlabel('Time')\n",
    "    cmap = cm.get_cmap('jet', 30)\n",
    "    cax = ax.imshow(mel[0].astype(np.float32), interpolation=\"nearest\", cmap=cmap)\n",
    "    ax.grid(True)\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "def text_to_sequence(text):\n",
    "    ''' preprocessor of tacotron2\n",
    "        ::text:: the input str\n",
    "        ::returns:: sequence, the preprocessed text\n",
    "    '''\n",
    "    sequence = text_to_sequence_internal(text, ['english_cleaners'])\n",
    "    sequence = np.array(sequence, dtype=np.int64)\n",
    "    return sequence\n",
    "\n",
    "\n",
    "def sequence_to_mel(sequence):\n",
    "    ''' calls tacotron2\n",
    "        ::sequence:: int64 numpy array, contains the preprocessed text\n",
    "        ::returns:: (mel, mel_lengths, alignments) tuple\n",
    "                     mel is the mel-spectrogram, np.array\n",
    "                     mel_lengths contains the length of the unpadded mel, np.array\n",
    "                     alignments contains attention weigths, np.array\n",
    "    '''\n",
    "    sequence = np.reshape(sequence, (1, -1))\n",
    "    input_lengths = np.array([[len(sequence[0])]], dtype=np.int64)\n",
    "    # prepare input/output\n",
    "    inputs = []\n",
    "    inputs.append(thc.InferInput('input__0', sequence.shape, 'INT64'))\n",
    "    inputs.append(thc.InferInput('input__1', input_lengths.shape, 'INT64'))\n",
    "    inputs[0].set_data_from_numpy(sequence, binary_data=True)\n",
    "    inputs[1].set_data_from_numpy(input_lengths, binary_data=True)\n",
    "    outputs = []\n",
    "    outputs.append(thc.InferRequestedOutput('output__0', binary_data=True))\n",
    "    outputs.append(thc.InferRequestedOutput('output__1', binary_data=True))\n",
    "    outputs.append(thc.InferRequestedOutput('output__2', binary_data=True))\n",
    "    # call tacotron2\n",
    "    result = triton_client.infer(model_name=\"tacotron2-ts-script\", inputs=inputs, outputs=outputs)\n",
    "    # get results\n",
    "    mel = result.as_numpy('output__0')\n",
    "    mel_lengths = result.as_numpy('output__1')\n",
    "    alignments = result.as_numpy('output__2')\n",
    "    return mel, mel_lengths, alignments\n",
    "\n",
    "\n",
    "def mel_to_signal(mel, mel_lengths):\n",
    "    ''' calls waveglow\n",
    "        ::mel:: mel spectrogram\n",
    "        ::mel_lengths:: original length of mel spectrogram\n",
    "        ::returns:: waveform\n",
    "    '''\n",
    "    # prepare input/output\n",
    "    mel = mel[:,:,:,None]\n",
    "    stride = 256\n",
    "    n_group = 8\n",
    "    z_size =  mel.shape[2]*stride//n_group\n",
    "    shape = (1, n_group, z_size, 1)\n",
    "    z = np.random.normal(0.0, 1.0, shape).astype(mel.dtype)\n",
    "    \n",
    "    inputs = []\n",
    "    inputs.append(thc.InferInput('mel', mel.shape, 'FP16'))\n",
    "    inputs.append(thc.InferInput('z', z.shape, 'FP16'))\n",
    "    inputs[0].set_data_from_numpy(mel, binary_data=True)\n",
    "    inputs[1].set_data_from_numpy(z, binary_data=True)\n",
    "    outputs = []\n",
    "    outputs.append(thc.InferRequestedOutput('audio', binary_data=True))\n",
    "    # call waveglow\n",
    "    result = triton_client.infer(model_name=\"waveglow-tensorrt\", inputs=inputs, outputs=outputs)\n",
    "    # get the results\n",
    "    signal = result.as_numpy('audio')\n",
    "    # postprocessing of waveglow: trimming signal to its actual size\n",
    "    trimmed_length = mel.shape[2]*args.stft_hop_length\n",
    "    signal = signal[:trimmed_length] # trim\n",
    "    signal = signal.astype(np.float32)\n",
    "    return signal\n",
    "\n",
    "\n",
    "# widgets\n",
    "def get_output_widget(width, height):\n",
    "    ''' creates an output widget with default values and returns it '''\n",
    "    layout = widgets.Layout(width=width,\n",
    "                            height=height,\n",
    "                            object_fit='fill',\n",
    "                            object_position = '{center} {center}')\n",
    "    ret = widgets.Output(layout=layout)\n",
    "    return ret\n",
    "\n",
    "\n",
    "text_area = widgets.Textarea(\n",
    "    value='type here',\n",
    "    placeholder='',\n",
    "    description='',\n",
    "    disabled=False,\n",
    "    continuous_update=True,\n",
    "    layout=widgets.Layout(width='550px', height='80px')\n",
    ")\n",
    "\n",
    "\n",
    "plot_spectrogram = get_output_widget(width='10in',height='2.1in')\n",
    "plot_signal = get_output_widget(width='10in',height='2.1in')\n",
    "plot_play = get_output_widget(width='10in',height='1in')\n",
    "\n",
    "\n",
    "def text_area_change(change):\n",
    "    ''' this gets called each time text_area.value changes '''\n",
    "    text = change['new']\n",
    "    text = text.strip(' ')\n",
    "    length = len(text)\n",
    "    if length < args.character_limit_min: # too short text\n",
    "        return\n",
    "    if length > args.character_limit_max: # too long text\n",
    "        text_area.value = text[:args.character_limit_max]\n",
    "        return\n",
    "    # preprocess tacotron2\n",
    "    sequence = text_to_sequence(text)\n",
    "    # run tacotron2\n",
    "    mel, mel_lengths, alignments = sequence_to_mel(sequence)\n",
    "    with plot_spectrogram:\n",
    "        display_spectrogram(mel, change['new'])\n",
    "    # run waveglow\n",
    "    signal = mel_to_signal(mel, mel_lengths)\n",
    "    with plot_signal:\n",
    "        display_sound(signal, change['new'], 'green')\n",
    "    with plot_play:\n",
    "        clear_output(wait=True)\n",
    "        display(Audio(signal, rate=args.sampling_rate, autoplay=args.autoplay))\n",
    "        # related issue: https://github.com/ipython/ipython/issues/11316\n",
    "\n",
    "\n",
    "# setup callback\n",
    "text_area.observe(text_area_change, names='value')\n",
    "\n",
    "# decorative widgets\n",
    "empty = widgets.VBox([], layout=widgets.Layout(height='1in'))\n",
    "markdown_4 = Markdown('**tacotron2 input**')\n",
    "markdown_6 = Markdown('**tacotron2 output / waveglow input**')\n",
    "markdown_7 = Markdown('**waveglow output**')\n",
    "markdown_8 = Markdown('**play**')\n",
    "\n",
    "# display widgets\n",
    "display(\n",
    "    empty, \n",
    "    markdown_4, text_area, \n",
    "    markdown_6, plot_spectrogram, \n",
    "    markdown_7, plot_signal, \n",
    "    markdown_8, plot_play, \n",
    "    empty\n",
    ")\n",
    "\n",
    "# default text\n",
    "text_area.value = \"The forms of printed letters should be beautiful, and that their arrangement on the page should be reasonable and a help to the shapeliness of the letters themselves.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
