{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "UViM depth task",
      "provenance": [],
      "collapsed_sections": [],
      "private_outputs": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU",
    "gpuClass": "standard"
  },
  "cells": [
    {
      "cell_type": "code",
      "source": [
        "# Fetch big_vision repository and move it into the current workdir (import path).\n",
        "!git clone --depth=1 https://github.com/google-research/big_vision big_vision_repo\n",
        "!cp -R big_vision_repo/big_vision big_vision\n",
        "!pip install -qr big_vision/requirements.txt"
      ],
      "metadata": {
        "id": "sKZK6_QpVI_O"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import numpy as np\n",
        "\n",
        "from big_vision.models.proj.uvim import vtt  # stage-II model\n",
        "from big_vision.models.proj.uvim import vit  # stage-I model\n",
        "\n",
        "from big_vision.models.proj.uvim import decode\n",
        "from big_vision.trainers.proj.uvim import depth_task as task\n",
        "from big_vision.configs.proj.uvim import train_nyu_depth_pretrained as config_module\n",
        "\n",
        "import big_vision.pp.ops_image\n",
        "import big_vision.pp.ops_general\n",
        "import big_vision.pp.proj.uvim.pp_ops\n",
        "from big_vision.pp import builder as pp_builder\n",
        "\n",
        "config = config_module.get_config()\n",
        "res = 512\n",
        "seq_len = config.model.seq_len\n",
        "\n",
        "lm_model = vtt.Model(**config.model)\n",
        "oracle_model = vit.Model(**config.oracle.model)\n",
        "\n",
        "preprocess_fn = pp_builder.get_preprocess_fn(\n",
        "    'resize(512)|value_range(-1,1)|'\n",
        "    'copy(inkey=\"image\",outkey=\"image_ctx\")')\n",
        "\n",
        "@jax.jit\n",
        "def predict_code(params, x, rng, temperature):\n",
        "  prompts = jnp.zeros((x[\"image\"].shape[0], seq_len), dtype=jnp.int32)\n",
        "  seqs, _, _ = decode.temperature_sampling(\n",
        "      params=params, model=lm_model, seed=rng,\n",
        "      inputs=x[\"image\"],\n",
        "      prompts=prompts,\n",
        "      temperature=temperature,\n",
        "      num_samples=1, eos_token=-1, prefill=False)\n",
        "  seqs = jnp.squeeze(seqs, axis=1)  # drop num_samples axis \n",
        "  return seqs - 1\n",
        "  \n",
        "@jax.jit\n",
        "def labels2code(params, x, ctx):\n",
        "  y, aux = oracle_model.apply(params, x, ctx=ctx, train=False, method=oracle_model.encode)\n",
        "  return aux[\"code\"]\n",
        "\n",
        "@jax.jit\n",
        "def code2labels(params, code, ctx):\n",
        "  logits, aux = oracle_model.apply(params, code, ctx=ctx, train=False, discrete_input=True, method=oracle_model.decode)\n",
        "  return task.predict_outputs(logits, config.oracle)"
      ],
      "metadata": {
        "id": "QzThueWDzc7I"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Load checkpoints\n",
        "!gsutil cp -n gs://big_vision/uvim/depth_stageI_params.npz gs://big_vision/uvim/depth_stageII_params.npz .\n",
        "\n",
        "oracle_params, oracle_state = vit.load(None, \"depth_stageI_params.npz\")\n",
        "oracle_params = jax.device_put({\"params\": oracle_params, \"state\": oracle_state})\n",
        "\n",
        "lm_params = vtt.load(None, \"depth_stageII_params.npz\")\n",
        "lm_params = jax.device_put({\"params\": lm_params})"
      ],
      "metadata": {
        "id": "AEjRgshLa6Fp"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Prepare dataset of images from NYU Depth V2:\n",
        "#  - https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html\n",
        "import os\n",
        "import h5py\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "\n",
        "if not os.path.exists(\"nyu_depth_v2_labeled.mat\"):\n",
        "  !wget --no-clobber http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/nyu_depth_v2_labeled.mat\n",
        "\n",
        "dataset_file = h5py.File(\"nyu_depth_v2_labeled.mat\", \"r\")\n",
        "\n",
        "def nyu_depth_examples():\n",
        "  for idx in range(dataset_file[\"images\"].shape[0]):\n",
        "    image = np.transpose(dataset_file[\"images\"][idx], (2, 1, 0))\n",
        "    yield {\"image\": image}\n",
        "\n",
        "dataset = tf.data.Dataset.from_generator(\n",
        "    nyu_depth_examples,\n",
        "    output_signature={\n",
        "        \"image\": tf.TensorSpec((480,640,3), tf.uint8),\n",
        "    }).map(preprocess_fn)"
      ],
      "metadata": {
        "id": "BKifDDRnH_Ll"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Run the model in a few examples:\n",
        "from matplotlib import pyplot as plt\n",
        "from matplotlib import patches\n",
        "\n",
        "num_examples = 4\n",
        "data = dataset.batch(1).take(num_examples).as_numpy_iterator()\n",
        "key = jax.random.PRNGKey(0)\n",
        "temperature = jnp.array(1e-7)\n",
        "\n",
        "def to_depth(x, nbins=256, mind=1e-3, maxd=10):\n",
        "  depth = x.astype(np.float32) + 0.5  # Undoes floor in expectation.\n",
        "  return depth/nbins * (maxd - mind) + mind\n",
        "\n",
        "def render_example(image, prediction, with_legend=True):\n",
        "  f, ax = plt.subplots(1, 2, figsize=(10, 10))\n",
        "  ax[0].imshow(image*0.5 + 0.5)\n",
        "  ax[0].axis(\"off\")\n",
        "  ax[1].imshow(to_depth(prediction))\n",
        "  ax[1].axis(\"off\")\n",
        "\n",
        "for idx, batch in enumerate(data):\n",
        "  subkey = jax.random.fold_in(key, idx)\n",
        "  code = predict_code(lm_params, batch, key, temperature)\n",
        "  aux_inputs = task.input_pp(batch, config.oracle)\n",
        "  prediction = code2labels(oracle_params, code, aux_inputs[\"ctx\"])\n",
        "  render_example(batch[\"image\"][0], prediction[\"depth\"][0])"
      ],
      "metadata": {
        "id": "TuevCy33nuv3"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}