{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sounddevice as sd\n",
    "print(sd.query_devices())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sd.default.device = 11"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "import time\n",
    "import numpy as np\n",
    "import collections\n",
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import cm as cm\n",
    "from IPython.display import Audio, display, clear_output, Markdown, Image\n",
    "import librosa\n",
    "import librosa.display\n",
    "import ipywidgets as widgets\n",
    "# \n",
    "# import tacotron2 preprocessing utilities\n",
    "from utils.tacotron2.symbols import symbols\n",
    "from utils.tacotron2 import text_to_sequence as text_to_sequence_internal\n",
    "# import bert pre- and postprocessing utilities\n",
    "from utils.bert.preprocessing import convert_example_to_feature, read_squad_example, get_predictions\n",
    "from utils.bert.tokenization import BertTokenizer\n",
    "# import jasper pre- and postprocessing utilities\n",
    "from utils.jasper.speech_utils import AudioSegment, SpeechClient\n",
    "# import trtis api\n",
    "from tensorrtserver.api import *\n",
    "\n",
    "\n",
    "defaults = {\n",
    "    # settings\n",
    "    'sigma_infer': 0.6,                         # don't touch this\n",
    "    'sampling_rate': 22050,                     # don't touch this\n",
    "    'stft_hop_length': 256,                     # don't touch this\n",
    "    'url': 'localhost:8000',                    # don't touch this\n",
    "    'protocol': 0,                              # 0: http, 1: grpc \n",
    "    'autoplay': True,                           # autoplay\n",
    "    'character_limit_min': 4,                   # don't touch this\n",
    "    'character_limit_max': 124,                 # don't touch this\n",
    "    'vocab_file': \"./utils/bert/vocab.txt\",     # don't touch this\n",
    "    'do_lower_case': True,                      # don't touch this\n",
    "    'version_2_with_negative': False,           # if true, the model may give 'i don't know' as an answer. the model has to be trained for it. \n",
    "    'max_seq_length': 384,                      # the maximum total input sequence length after WordPiece tokenization. Sequences longer than this will be truncated, and sequences shorter than this will be padded. \n",
    "    'doc_stride': 128,                          # when splitting up a long document into chunks, how much stride to take between chunks\n",
    "    'max_query_length': 64,                     # the maximum number of tokens in the question. Questions longer than this will be truncated to this length\n",
    "    'n_best_size': 10,                          # don't touch this\n",
    "    'max_answer_length': 30,                    # don't touch this\n",
    "    'do_lower_case': True,                      # don't touch this\n",
    "    'null_score_diff_threshold': 0.0,           # don't touch this\n",
    "    'jasper_batch_size': 1,                     # don't touch this\n",
    "    'jasper_sampling_rate': 44100,              # don't touch this\n",
    "    'record_maximum_seconds': 4.0               # maximum number of seconds to record\n",
    "}\n",
    "\n",
    "\n",
    "# create args object\n",
    "class Struct:\n",
    "    def __init__(self, **entries):\n",
    "        self.__dict__.update(entries)\n",
    "\n",
    "\n",
    "args = Struct(**defaults)\n",
    "\n",
    "\n",
    "# create the inference context for the models\n",
    "infer_ctx_bert = InferContext(args.url, args.protocol, 'bertQA-ts-script', -1)\n",
    "infer_ctx_tacotron2 = InferContext(args.url, args.protocol, 'tacotron2', -1)\n",
    "infer_ctx_waveglow = InferContext(args.url, args.protocol, 'waveglow-trt', -1)\n",
    "infer_jasper = SpeechClient(args.url, args.protocol, 'jasper-trt-ensemble', -1, \n",
    "                            args.jasper_batch_size, 'pyt', verbose=False, \n",
    "                            mode='asynchronous', from_features=False)\n",
    "\n",
    "\n",
    "def display_sequences(sequences, labels, colors):\n",
    "    ''' displays sequences on a dotted plot '''\n",
    "    plt.figure(figsize=(10, 2.5))\n",
    "    plt.tick_params(\n",
    "        axis='both',\n",
    "        which='both',\n",
    "        bottom=False,\n",
    "        top=False,\n",
    "        left=False,\n",
    "        right=False,\n",
    "        labelbottom=False,\n",
    "        labelleft=False)\n",
    "    for sequence,color,label in zip(sequences,colors,labels):\n",
    "        plt.plot(sequence, color, label=label)\n",
    "    plt.legend(loc='upper right')\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "def display_heatmap(sequence, title='preprocessed text'):\n",
    "    ''' displays sequence as a heatmap '''\n",
    "    clear_output(wait=True)\n",
    "    sequence = sequence[None, :]\n",
    "    plt.figure(figsize=(10, 2.5))\n",
    "    plt.title(title)\n",
    "    plt.tick_params(\n",
    "        axis='both',\n",
    "        which='both',\n",
    "        bottom=False,\n",
    "        top=False,\n",
    "        left=False,\n",
    "        right=False,\n",
    "        labelbottom=False,\n",
    "        labelleft=False)\n",
    "    plt.imshow(sequence, cmap='BrBG_r', interpolation='nearest')\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "def display_sound(signal, title, color):\n",
    "    ''' displays signal '''\n",
    "    clear_output(wait=True)\n",
    "    plt.figure(figsize=(10, 2.5))\n",
    "    plt.title(title)\n",
    "    plt.tick_params(\n",
    "        axis='both',\n",
    "        which='both',\n",
    "        bottom=True,\n",
    "        top=False,\n",
    "        left=False,\n",
    "        right=False,\n",
    "        labelbottom=True,\n",
    "        labelleft=False)\n",
    "    librosa.display.waveplot(signal, color=color)\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "def display_spectrogram(mel, title):\n",
    "    ''' displays mel spectrogram '''\n",
    "    clear_output(wait=True)\n",
    "    fig = plt.figure(figsize=(10, 2.5))\n",
    "    ax = fig.add_subplot(111)\n",
    "#    plt.title(title)\n",
    "    plt.tick_params(\n",
    "        axis='both',\n",
    "        which='both',\n",
    "        bottom=True,\n",
    "        top=False,\n",
    "        left=False,\n",
    "        right=False,\n",
    "        labelbottom=True,\n",
    "        labelleft=False)\n",
    "    plt.xlabel('Time')\n",
    "    cmap = cm.get_cmap('jet', 30)\n",
    "    cax = ax.imshow(mel.astype(np.float32), interpolation=\"nearest\", cmap=cmap)\n",
    "    ax.grid(True)\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "def text_to_sequence(text):\n",
    "    ''' preprocessor of tacotron2\n",
    "        ::text:: the input str\n",
    "        ::returns:: sequence, the preprocessed text\n",
    "    '''\n",
    "    sequence = text_to_sequence_internal(text, ['english_cleaners'])\n",
    "    sequence = np.array(sequence, dtype=np.int64)\n",
    "    return sequence\n",
    "\n",
    "\n",
    "def sequence_to_mel(sequence):\n",
    "    ''' calls tacotron2\n",
    "        ::sequence:: int64 numpy array, contains the preprocessed text\n",
    "        ::returns:: (mel, mel_lengths) pair\n",
    "                     mel is the mel-spectrogram, np.array\n",
    "                     mel_lengths contains the length of the unpadded mel, np.array\n",
    "    '''\n",
    "    input_lengths = [len(sequence)]\n",
    "    input_lengths = np.array(input_lengths, dtype=np.int64)\n",
    "    # prepare input/output\n",
    "    input_dict = {}\n",
    "    input_dict['sequence__0'] = (sequence,)\n",
    "    input_dict['input_lengths__1'] = (input_lengths,)\n",
    "    output_dict = {}\n",
    "    output_dict['mel_outputs_postnet__0'] = InferContext.ResultFormat.RAW\n",
    "    output_dict['mel_lengths__1'] = InferContext.ResultFormat.RAW\n",
    "    batch_size = 1\n",
    "    # call tacotron2\n",
    "    result = infer_ctx_tacotron2.run(input_dict, output_dict, batch_size)\n",
    "    # get results\n",
    "    mel = result['mel_outputs_postnet__0'][0] # take only the first instance in the output batch\n",
    "    mel_lengths = result['mel_lengths__1'][0] # take only the first instance in the output batch\n",
    "    return mel, mel_lengths\n",
    "\n",
    "\n",
    "def mel_to_signal(mel, mel_lengths):\n",
    "    ''' calls waveglow\n",
    "        ::mel:: mel spectrogram\n",
    "        ::mel_lengths:: original length of mel spectrogram\n",
    "        ::returns:: waveform\n",
    "    '''\n",
    "    # prepare input/output\n",
    "    mel = np.expand_dims(mel, axis=0)\n",
    "    input_dict = {}\n",
    "    input_dict['mel'] = (mel,)\n",
    "    stride = 256\n",
    "    n_group = 8\n",
    "    z_size = mel.shape[2]*stride//n_group\n",
    "    shape = (1,n_group,z_size)\n",
    "    input_dict['z'] = np.random.normal(0.0, 1.0, shape).astype(mel.dtype)\n",
    "    input_dict['z'] = (input_dict['z'],)\n",
    "    output_dict = {}\n",
    "    output_dict['audio'] = InferContext.ResultFormat.RAW\n",
    "    # call waveglow\n",
    "    result = infer_ctx_waveglow.run(input_dict, output_dict)\n",
    "    # get the results\n",
    "    signal = result['audio'][0] # take only the first instance in the output batch\n",
    "    # postprocessing of waveglow: trimming signal to its actual size\n",
    "    trimmed_length = mel_lengths[0] * args.stft_hop_length\n",
    "    signal = signal[:trimmed_length] # trim\n",
    "    signal = signal.astype(np.float32)\n",
    "    return signal\n",
    "\n",
    "\n",
    "def question_and_context_to_feature(question_text, context):\n",
    "    tokenizer = BertTokenizer(args.vocab_file, do_lower_case=args.do_lower_case, max_len=512) # for bert large\n",
    "    example = read_squad_example(question_text, \n",
    "                                 context, \n",
    "                                 version_2_with_negative=args.version_2_with_negative)\n",
    "    feature = convert_example_to_feature(\n",
    "        example=example, \n",
    "        tokenizer=tokenizer, \n",
    "        max_seq_length=args.max_seq_length, \n",
    "        doc_stride=args.doc_stride, \n",
    "        max_query_length=args.max_query_length)\n",
    "    return example, feature\n",
    "\n",
    "\n",
    "def button_rec_clicked(change):\n",
    "    if record_seconds.value > 0.0:\n",
    "        with plot_jasper_audio:\n",
    "            clear_output(wait=True)\n",
    "            recording = sd.rec(int(record_seconds.value*args.jasper_sampling_rate), samplerate=args.jasper_sampling_rate, channels=1)\n",
    "            while record_seconds.value > 0:\n",
    "                time.sleep(0.01)\n",
    "                record_seconds.value -= 0.01\n",
    "            sd.wait()\n",
    "            recording = recording.squeeze()\n",
    "            display_sound(recording,'recorded audio','orange')\n",
    "            audio = AudioSegment(recording, args.jasper_sampling_rate).samples\n",
    "        hypotheses = infer_jasper.recognize([audio], ['audio recording'])\n",
    "        question_text.value = str(hypotheses[0]) + '? '\n",
    "\n",
    "\n",
    "button_rec = widgets.Button(description=\"RECORD\")\n",
    "button_rec.on_click(button_rec_clicked)\n",
    "record_seconds = widgets.FloatSlider(min=0.0, max=args.record_maximum_seconds, value=args.record_maximum_seconds, \n",
    "                                     step=0.1, continuous_update=True, description = \"seconds\")\n",
    "buttons = widgets.HBox([button_rec, record_seconds])\n",
    "\n",
    "\n",
    "question_text = widgets.Textarea(\n",
    "    value='jasper output / bert input question',\n",
    "    placeholder='',\n",
    "    description='',\n",
    "    disabled=False,\n",
    "    continuous_update=True,\n",
    "    layout=widgets.Layout(width='550px', height='40px')\n",
    ")\n",
    "\n",
    "\n",
    "context = widgets.Textarea(\n",
    "    value='bert input context',\n",
    "    placeholder='',\n",
    "    description='',\n",
    "    disabled=False,\n",
    "    continuous_update=True,\n",
    "    layout=widgets.Layout(width='550px', height='80px')\n",
    ")\n",
    "\n",
    "question_context = widgets.HBox([question_text, context])\n",
    "\n",
    "response_text = widgets.Textarea(\n",
    "    value='',\n",
    "    placeholder='',\n",
    "    description='',\n",
    "    disabled=False,\n",
    "    continuous_update=True,\n",
    "    layout=widgets.Layout(width='550px', height='40px')\n",
    ")\n",
    "\n",
    "\n",
    "def text_to_logits(input_ids_data, segment_ids_data, input_mask_data):\n",
    "    # call bert\n",
    "    input_dict = {}\n",
    "    input_dict['input__0']   = (input_ids_data.astype(np.int64),)\n",
    "    input_dict['input__1'] = (segment_ids_data.astype(np.int64),)\n",
    "    input_dict['input__2']  = (input_mask_data.astype(np.int64),)\n",
    "    batch_size = 1\n",
    "    output_dict = {}\n",
    "    output_dict['output__0'] = InferContext.ResultFormat.RAW\n",
    "    output_dict['output__1']   = InferContext.ResultFormat.RAW\n",
    "    # \n",
    "    result = infer_ctx_bert.run(input_dict, output_dict, batch_size)\n",
    "    # \n",
    "    start_logits = [float(x) for x in result[\"output__0\"][0].flat]\n",
    "    end_logits = [float(x) for x in result[\"output__1\"][0].flat]\n",
    "    return start_logits, end_logits\n",
    "\n",
    "\n",
    "def question_text_change(change):\n",
    "    text = change['new']\n",
    "    text = text.strip(' ')\n",
    "    length = len(text)\n",
    "    if length < args.character_limit_min: # too short text\n",
    "        return\n",
    "    if text[-1] != '?':\n",
    "        return\n",
    "    # preprocess bert\n",
    "    example, feature = question_and_context_to_feature(text, context.value)\n",
    "    input_ids_data = np.array(feature.input_ids, dtype=np.int64)\n",
    "    input_mask_data = np.array(feature.input_mask, dtype=np.int64)\n",
    "    segment_ids_data = np.array(feature.segment_ids, dtype=np.int64)\n",
    "    L = segment_ids_data.shape[0] - 1\n",
    "    while L > 20 and segment_ids_data[L-20] == 0:\n",
    "        L -= 20\n",
    "    with plot_tensor:\n",
    "        clear_output(wait=True)\n",
    "        C = input_ids_data.max()\n",
    "        sequences = (input_ids_data[:L],C//2*input_mask_data[:L],C*segment_ids_data[:L])\n",
    "        display_sequences(sequences, ('input','mask','segment'), ('r.','b.','g.'))\n",
    "        \n",
    "    # call bert\n",
    "    start_logits, end_logits = text_to_logits(input_ids_data, segment_ids_data, input_mask_data)\n",
    "    with plot_logits:\n",
    "        clear_output(wait=True)\n",
    "        start = np.array(start_logits, dtype=np.float32)\n",
    "        end = np.array(end_logits, dtype=np.float32)\n",
    "        sequences = (start[:L], end[:L])\n",
    "        display_sequences(sequences, ('start_logits', 'end_logits'), ('black', 'violet'))\n",
    "    # postprocess bert\n",
    "    prediction = get_predictions(example, feature, start_logits, end_logits, \n",
    "                                 args.n_best_size, args.max_answer_length, args.do_lower_case, \n",
    "                                 args.version_2_with_negative, args.null_score_diff_threshold)\n",
    "    response_text.value = prediction[0][\"text\"] + '. \\n'\n",
    "\n",
    "\n",
    "def context_change(change):\n",
    "    text = change['new']\n",
    "    length = len(text)\n",
    "    if length < args.character_limit_min: # too short text\n",
    "        return\n",
    "    # inference\n",
    "    question_text.value += ' '\n",
    "\n",
    "def response_text_change(change):\n",
    "    ''' this gets called each time text_area.value changes '''\n",
    "    text = change['new']\n",
    "    text = text.strip(' ')\n",
    "    length = len(text)\n",
    "    if length < args.character_limit_min: # too short text\n",
    "        return\n",
    "    if length > args.character_limit_max: # too long text\n",
    "        text_area.value = text[:args.character_limit_max]\n",
    "        return\n",
    "    # preprocess tacotron2\n",
    "    sequence = text_to_sequence(text)\n",
    "    with plot_response_text_preprocessed:\n",
    "        display_heatmap(sequence)\n",
    "    # run tacotron2\n",
    "    mel, mel_lengths = sequence_to_mel(sequence)\n",
    "    with plot_spectrogram:\n",
    "        display_spectrogram(mel, change['new'])\n",
    "    # run waveglow\n",
    "    signal = mel_to_signal(mel, mel_lengths)\n",
    "    with plot_signal:\n",
    "        display_sound(signal, change['new'], 'green')\n",
    "    with plot_play:\n",
    "        clear_output(wait=True)\n",
    "        display(Audio(signal, rate=args.sampling_rate, autoplay=args.autoplay))\n",
    "\n",
    "def get_output_widget(width, height, object_fit='fill'):\n",
    "    ''' creates an output widget with default values and returns it '''\n",
    "    layout = widgets.Layout(width=width,\n",
    "                            height=height,\n",
    "                            object_fit=object_fit,\n",
    "                            object_position = '{center} {center}')\n",
    "    ret = widgets.Output(layout=layout)\n",
    "    return ret\n",
    "\n",
    "\n",
    "plot_tensor = get_output_widget(width='5in',height='1.75in')\n",
    "plot_logits = get_output_widget(width='5in',height='1.75in')\n",
    "plot_response_text_preprocessed = get_output_widget(width='10in',height='1in')\n",
    "plot_spectrogram = get_output_widget(width='10in',height='2.0in', object_fit='scale-down')\n",
    "plot_jasper_audio = get_output_widget(width='10in',height='2.0in')\n",
    "plot_signal = get_output_widget(width='10in',height='2.0in')\n",
    "plot_play = get_output_widget(width='4in',height='1in')\n",
    "\n",
    "empty = widgets.VBox([], layout=widgets.Layout(height='1in'))\n",
    "markdown_z0 = Markdown('**Jasper input**')\n",
    "markdown_m0 = Markdown('**Jasper output / BERT input**')\n",
    "markdown_bert = Markdown('**BERT**')\n",
    "markdown_tacotron2 = Markdown('**Tacotron 2**')\n",
    "markdown_3 = Markdown('**WaveGlow**')\n",
    "\n",
    "bert_widgets = widgets.HBox([plot_tensor, plot_logits])\n",
    "tacotron2_widgets = widgets.HBox([response_text, plot_spectrogram])\n",
    "\n",
    "display(\n",
    "    empty, \n",
    "    markdown_z0, \n",
    "    buttons, \n",
    "    markdown_m0, question_context,\n",
    "    markdown_bert,\n",
    "    bert_widgets,\n",
    "    markdown_tacotron2,\n",
    "    tacotron2_widgets,\n",
    "    markdown_3, \n",
    "    plot_play, \n",
    "    empty\n",
    ")\n",
    "\n",
    "\n",
    "def fill_initial_values():\n",
    "    with plot_jasper_audio:\n",
    "        display_sound(np.zeros(100),\"input audio\",'orange')\n",
    "    # \n",
    "    context.value = \"The man holding the telescope went into a shop to purchase some flowers on the occasion of all saints day. \"\n",
    "    # context.value = \"William Shakespeare was an English poet, playwright and actor, widely regarded as the greatest writer in the English language and the world's greatest dramatist. He is often called England's national poet and the \\\"Bard of Avon\\\".\"\n",
    "    question_text.value = \"\"\n",
    "    \n",
    "fill_initial_values()\n",
    "\n",
    "response_text.observe(response_text_change, names='value')\n",
    "question_text.observe(question_text_change, names='value')\n",
    "context.observe(context_change, names='value')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
