{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "6uNrFWq5BRba"
      },
      "outputs": [],
      "source": [
        "#@title\n",
        "# Copyright 2018 Google LLC.\n",
        "\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "7tB9m_fw9Xkl"
      },
      "outputs": [],
      "source": [
        "!pip install -qq tensorflow\n",
        "!pip install -qq tensor2tensor\n",
        "!pip install -qq pydub\n",
        "!apt-get -qq update\n",
        "!apt-get -qq install -y ffmpeg\n",
        "!apt-get -qq install -y sox"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "hF_ZmvGjEyJd"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import os\n",
        "import collections\n",
        "import base64\n",
        "import cStringIO\n",
        "import pydub\n",
        "import shutil\n",
        "from scipy.io import wavfile\n",
        "\n",
        "import IPython\n",
        "import google.colab\n",
        "\n",
        "from tensor2tensor import models\n",
        "from tensor2tensor import problems\n",
        "from tensor2tensor.layers import common_layers\n",
        "from tensor2tensor.utils import trainer_lib\n",
        "from tensor2tensor.utils import t2t_model\n",
        "from tensor2tensor.utils import registry\n",
        "from tensor2tensor.utils import metrics\n",
        "\n",
        "# Enable TF Eager execution\n",
        "tfe = tf.contrib.eager\n",
        "tf.enable_eager_execution()\n",
        "\n",
        "# Other setup\n",
        "Modes = tf.estimator.ModeKeys\n",
        "\n",
        "# Setup some directories\n",
        "data_dir = os.path.expanduser(\"~/t2t/data\")\n",
        "tmp_dir = os.path.expanduser(\"~/t2t/tmp\")\n",
        "train_dir = os.path.expanduser(\"~/t2t/train\")\n",
        "checkpoint_dir = os.path.expanduser(\"~/t2t/checkpoints\")\n",
        "tf.gfile.MakeDirs(data_dir)\n",
        "tf.gfile.MakeDirs(tmp_dir)\n",
        "tf.gfile.MakeDirs(train_dir)\n",
        "tf.gfile.MakeDirs(checkpoint_dir)\n",
        "\n",
        "gs_ckpt_dir = \"gs://tensor2tensor-checkpoints/\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "LwPvdJJ4xN6y"
      },
      "source": [
        "\n",
        "### Define problem, hparams, model, encoder and decoder\n",
        "Definition of this model (as well as many more) can be found on tensor2tensor github [page](https://github.com/tensorflow/tensor2tensor)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "hH0FEHhDIGjM"
      },
      "outputs": [],
      "source": [
        "problem_name = \"librispeech_clean\"\n",
        "asr_problem = problems.problem(problem_name)\n",
        "encoders = asr_problem.feature_encoders(None)\n",
        "\n",
        "model_name = \"transformer\"\n",
        "hparams_set = \"transformer_librispeech_tpu\"\n",
        "\n",
        "hparams = trainer_lib.create_hparams(hparams_set,data_dir=data_dir, problem_name=problem_name)\n",
        "asr_model = registry.model(model_name)(hparams, Modes.PREDICT)\n",
        "\n",
        "def encode(x):\n",
        "  waveforms = encoders[\"waveforms\"].encode(x)\n",
        "  encoded_dict = asr_problem.preprocess_example({\"waveforms\":waveforms, \"targets\":[]}, Modes.PREDICT, hparams)\n",
        "  \n",
        "  return {\"inputs\" : tf.expand_dims(encoded_dict[\"inputs\"], 0), \"targets\" : tf.expand_dims(encoded_dict[\"targets\"], 0)}\n",
        "\n",
        "def decode(integers):\n",
        "  integers = list(np.squeeze(integers))\n",
        "  if 1 in integers:\n",
        "    integers = integers[:integers.index(1)]\n",
        "  return encoders[\"targets\"].decode(np.squeeze(integers))\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "pGhUGptixYBd"
      },
      "source": [
        "### Define path to checkpoint\n",
        "In this demo we are using a pretrained model.\n",
        "Instructions for training your own model can be found in the [tutorial](https://github.com/tensorflow/tensor2tensor/blob/master/docs/tutorials/asr_with_transformer.md) on tensor2tensor page."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "p9D8OJdFezsJ"
      },
      "outputs": [],
      "source": [
        "# Copy the pretrained checkpoint locally\n",
        "ckpt_name = \"transformer_asr_180214\"\n",
        "gs_ckpt = os.path.join(gs_ckpt_dir, ckpt_name)\n",
        "print(gs_ckpt)\n",
        "!gsutil cp -R {gs_ckpt} {checkpoint_dir} \n",
        "ckpt_path = tf.train.latest_checkpoint(os.path.join(checkpoint_dir, ckpt_name))\n",
        "ckpt_path"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "arS1sXFPxvde"
      },
      "source": [
        "### Define transcribe function"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "od7ZPT3wfkZs"
      },
      "outputs": [],
      "source": [
        "# Restore and transcribe!\n",
        "def transcribe(inputs):\n",
        "  encoded_inputs = encode(inputs)\n",
        "  with tfe.restore_variables_on_create(ckpt_path): \n",
        "    model_output = asr_model.infer(encoded_inputs, beam_size=2, alpha=0.6, decode_length=1)[\"outputs\"]\n",
        "  return decode(model_output)\n",
        "\n",
        "def play_and_transcribe(inputs):\n",
        "  waveforms = encoders[\"waveforms\"].encode(inputs)\n",
        "  IPython.display.display(IPython.display.Audio(data=waveforms, rate=16000))\n",
        "  return transcribe(inputs) "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Qz5u2O5LvShm"
      },
      "source": [
        "# Decoding prerecorded examples\n",
        "\n",
        "You can upload any .wav files. They will be transcribed if frame rate matches Librispeeche's frame rate (16kHz)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "xAstJTeyvXMf"
      },
      "outputs": [],
      "source": [
        "uploaded = google.colab.files.upload()\n",
        "prerecorded_messages = []\n",
        "\n",
        "for fn in uploaded.keys():\n",
        "  print('User uploaded file \"{name}\" with length {length} bytes'.format(\n",
        "      name=fn, length=len(uploaded[fn])))\n",
        "  mem_file = cStringIO.StringIO(uploaded[fn])\n",
        "    \n",
        "  save_filename = os.path.join(tmp_dir, fn)\n",
        "  with open(save_filename, 'w') as fd:\n",
        "    mem_file.seek(0)\n",
        "    shutil.copyfileobj(mem_file, fd)\n",
        "  prerecorded_messages.append(save_filename)\n",
        "         \n",
        "              \n",
        "for inputs in prerecorded_messages:\n",
        "  outputs = play_and_transcribe(inputs)\n",
        "\n",
        "  print(\"Inputs: %s\" % inputs)\n",
        "  print(\"Outputs: %s\" % outputs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "mJvRjlHUrr65"
      },
      "source": [
        "# Recording your own examples"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "oirqsdqVoElk"
      },
      "outputs": [],
      "source": [
        "# Records webm file and converts\n",
        "def RecordNewAudioSample(filename=None, webm_filename=None):\n",
        "  \"\"\"Args:\n",
        "          filename - string, path for storing wav file\n",
        "          webm_filename - string, path for storing webm file\n",
        "     Returns:\n",
        "          string - path where wav file was saved. (=filename if specified)\n",
        "    \n",
        "  \"\"\"\n",
        "  # Create default filenames in tmp_dir if not specified.\n",
        "  if not filename:\n",
        "    filename = os.path.join(tmp_dir, \"recording.wav\")\n",
        "  if not webm_filename:\n",
        "    webm_filename = os.path.join(tmp_dir, \"recording.webm\")\n",
        "    \n",
        "  # Record webm file form colab.\n",
        "  \n",
        "  audio = google.colab._message.blocking_request('user_media', {\"audio\":True, \"video\":False, \"duration\":-1}, timeout_sec=600)\n",
        "  #audio = frontend.RecordMedia(True, False)\n",
        "  \n",
        "  # Convert the recording into in_memory file.\n",
        "  music_mem_file = cStringIO.StringIO(\n",
        "      base64.decodestring(audio[audio.index(',')+1:]))\n",
        "  \n",
        "  # Store webm recording in webm_filename. Storing is necessary for conversion.\n",
        "  with open(webm_filename, 'w') as fd:\n",
        "    music_mem_file.seek(0)\n",
        "    shutil.copyfileobj(music_mem_file, fd)\n",
        "    \n",
        "  # Open stored file and save it as wav with sample_rate=16000.\n",
        "  pydub.AudioSegment.from_file(webm_filename, codec=\"opus\"\n",
        "                              ).set_frame_rate(16000).export(out_f=filename,\n",
        "                                                             format=\"wav\")\n",
        "  return filename"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "90BjliFTCQm9"
      },
      "outputs": [],
      "source": [
        "# Record the sample\n",
        "my_sample_filename = RecordNewAudioSample()\n",
        "print my_sample_filename"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "PdBfEik0-pMv"
      },
      "outputs": [],
      "source": [
        "print play_and_transcribe(my_sample_filename)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "default_view": {},
      "name": "ASR with Transformer example notebook",
      "provenance": [
        {
          "file_id": "notebooks/SR_with_Transformer_example_notebook.ipynb",
          "timestamp": 1525703542020
        },
        {
          "file_id": "1hEMwW8LgaQPLngfka0tbobYB-ZTVqy34",
          "timestamp": 1525702247248
        },
        {
          "file_id": "1Pp4aSAceJRNpxtSrTevUKpHKudMxHyBF",
          "timestamp": 1518630927690
        }
      ],
      "version": "0.3.2",
      "views": {}
    },
    "kernelspec": {
      "display_name": "Python 2",
      "name": "python2"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
