{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0WzFpOKDmURO"
      },
      "source": [
        "## PyTorch/TPU MNIST Demo\n",
        "\n",
        "This colab example corresponds to the implementation under [test_train_mp_mnist.py](https://github.com/pytorch/xla/blob/master/test/test_train_mp_mnist.py)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xOp9jBEumdvC"
      },
      "source": [
        "<h3>  &nbsp;&nbsp;Use Colab Cloud TPU&nbsp;&nbsp; <a href=\"https://cloud.google.com/tpu/\"><img valign=\"middle\" src=\"https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png\" width=\"50\"></a></h3>\n",
        "\n",
        "* On the main menu, click Runtime and select **Change runtime type**. Set \"TPU\" as the hardware accelerator.\n",
        "* The cell below makes sure you have access to a TPU on Colab.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Hx4YVNHametU"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YofXQrnxmf5r"
      },
      "source": [
        "### [RUNME] Install Colab TPU compatible PyTorch/TPU wheels and dependencies"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OApBOAe1fpH_"
      },
      "outputs": [],
      "source": [
        "!pip install cloud-tpu-client==0.10 torch==2.0.0 torchvision==0.15.1 https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp39-cp39-linux_x86_64.whl"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nfSCdVlA8jFg"
      },
      "source": [
        "### If you're using GPU with this colab notebook, run the below commented code to install GPU compatible PyTorch wheel and dependencies"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "J1Vfg-rH8bF4"
      },
      "outputs": [],
      "source": [
        "#!pip install cloud-tpu-client==0.10 torch==2.0.0 torchvision==0.15.1 https://storage.googleapis.com/tpu-pytorch/wheels/cuda/117/torch_xla-2.0-cp38-cp38-linux_x86_64.whl --force-reinstall "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cPrij_iPfqTV"
      },
      "source": [
        "### Only run the below commented cell if you would like a nightly release"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vJZrkoejQhxK"
      },
      "outputs": [],
      "source": [
        "# VERSION = \"1.13\"  #@param [\"1.13\", \"nightly\", \"20220315\"]  # or YYYYMMDD format\n",
        "# !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py\n",
        "# !python pytorch-xla-env-setup.py --version $VERSION\n",
        "# import os \n",
        "# os.environ['LD_LIBRARY_PATH']='/usr/local/lib'\n",
        "# !echo $LD_LIBRARY_PATH\n",
        "\n",
        "# !sudo ln -s /usr/local/lib/libmkl_intel_lp64.so /usr/local/lib/libmkl_intel_lp64.so.1\n",
        "# !sudo ln -s /usr/local/lib/libmkl_intel_thread.so /usr/local/lib/libmkl_intel_thread.so.1\n",
        "# !sudo ln -s /usr/local/lib/libmkl_core.so /usr/local/lib/libmkl_core.so.1\n",
        "\n",
        "# !ldconfig\n",
        "# !ldd /usr/local/lib/python3.7/dist-packages/torch/lib/libtorch.so"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "rMjTZHp-FJbY"
      },
      "outputs": [],
      "source": [
        "#@title PSGD code\n",
        "\"\"\"Created in May, 2018\n",
        "Pytorch functions for preconditioned SGD\n",
        "\n",
        "Updated in Dec, 2020: \n",
        "Wrapped Kronecker product preconditioner for easy use: the code will select the proper Kronecker product  \n",
        "preconditioner based on the formats of input left and right preconditioners.\n",
        "Add torch.jit.script decorator by default\n",
        "\"\"\"\n",
        "\n",
        "import torch\n",
        "\n",
        "\n",
        "###############################################################################\n",
        "@torch.jit.script\n",
        "def update_precond_dense(Q, dxs, dgs, step=0.01, _tiny=1.2e-38):\n",
        "    # type: (Tensor, List[Tensor], List[Tensor], float, float) -> Tensor\n",
        "    \"\"\"\n",
        "    update dense preconditioner P = Q^T*Q\n",
        "    Q: Cholesky factor of preconditioner with positive diagonal entries \n",
        "    dxs: list of perturbations of parameters\n",
        "    dgs: list of perturbations of gradients\n",
        "    step: update step size normalized to range [0, 1] \n",
        "    _tiny: an offset to avoid division by zero \n",
        "    \"\"\"\n",
        "    dx = torch.cat([torch.reshape(x, [-1, 1]) for x in dxs])\n",
        "    dg = torch.cat([torch.reshape(g, [-1, 1]) for g in dgs])\n",
        "    \n",
        "    a = Q.mm(dg)\n",
        "    #b = torch.triangular_solve(dx, Q, upper=True, transpose=True)[0]\n",
        "    b = torch.linalg.solve_triangular(Q.t(), dx, upper=False)\n",
        "\n",
        "    grad = torch.triu(a.mm(a.t()) - b.mm(b.t()))\n",
        "    step0 = step/(grad.abs().max() + _tiny)        \n",
        "        \n",
        "    return Q - step0*grad.mm(Q)\n",
        "\n",
        "@torch.jit.script\n",
        "def precond_grad_dense(Q, grads):\n",
        "    # type: (Tensor, List[Tensor]) -> List[Tensor]\n",
        "    \"\"\"\n",
        "    return preconditioned gradient using dense preconditioner\n",
        "    Q: Cholesky factor of preconditioner\n",
        "    grads: list of gradients\n",
        "    \"\"\"\n",
        "    grad = [torch.reshape(g, [-1, 1]) for g in grads]\n",
        "    lens = [g.shape[0] for g in grad]\n",
        "    grad = torch.cat(grad)\n",
        "    grad = Q.t().mm(Q.mm(grad))\n",
        "    \n",
        "    pre_grads = []\n",
        "    idx = 0\n",
        "    for i in range(len(grads)):\n",
        "        pre_grads.append(torch.reshape(grad[idx : idx + lens[i]], grads[i].shape))\n",
        "        idx = idx + lens[i]\n",
        "        \n",
        "    return pre_grads\n",
        "\n",
        "\n",
        "###############################################################################\n",
        "def update_precond_kron(Ql, Qr, dX, dG, step=0.01, _tiny=1.2e-38):\n",
        "    \"\"\"\n",
        "    Update Kronecker product preconditioner P = kron_prod(Qr^T*Qr, Ql^T*Ql)\n",
        "    Either Ql or Qr can be sparse, and the code can choose the right update rule.\n",
        "    dX: perturbation of (matrix) parameter\n",
        "    dG: perturbation of (matrix) gradient\n",
        "    step: update step size\n",
        "    _tiny: an offset to avoid division by zero \n",
        "    \"\"\"\n",
        "    m, n = Ql.shape\n",
        "    p, q = Qr.shape\n",
        "    if m==n: # left is dense\n",
        "        if p==q: #(dense, dense) format\n",
        "            return _update_precond_dense_dense(Ql, Qr, dX, dG, step, _tiny)\n",
        "        elif p==2: # (dense, normalization) format\n",
        "            return _update_precond_norm_dense(Qr, Ql, dX.t(), dG.t(), step, _tiny)[::-1]\n",
        "        elif p==1: # (dense, scaling) format\n",
        "            return _update_precond_dense_scale(Ql, Qr, dX, dG, step, _tiny)\n",
        "        else:\n",
        "            raise Exception('Unknown Kronecker product preconditioner')\n",
        "    elif m==2: # left is normalization\n",
        "        if p==q: # (normalization, dense) format\n",
        "            return _update_precond_norm_dense(Ql, Qr, dX, dG, step, _tiny)\n",
        "        elif p==1: # (normalization, scaling) format\n",
        "            return _update_precond_norm_scale(Ql, Qr, dX, dG, step, _tiny)\n",
        "        else:\n",
        "            raise Exception('Unknown Kronecker product preconditioner')\n",
        "    elif m==1: # left is scaling\n",
        "        if p==q: # (scaling, dense) format\n",
        "            return _update_precond_dense_scale(Qr, Ql, dX.t(), dG.t(), step, _tiny)[::-1]\n",
        "        elif p==2: # (scaling, normalization) format\n",
        "            return _update_precond_norm_scale(Qr, Ql, dX.t(), dG.t(), step, _tiny)[::-1]\n",
        "        else:\n",
        "            raise Exception('Unknown Kronecker product preconditioner')\n",
        "    else:\n",
        "        raise Exception('Unknown Kronecker product preconditioner')\n",
        " \n",
        "       \n",
        "def precond_grad_kron(Ql, Qr, Grad):\n",
        "    \"\"\"\n",
        "    return preconditioned gradient using Kronecker product preconditioner P = kron_prod(Qr^T*Qr, Ql^T*Ql)\n",
        "    Either Ql or Qr can be sparse, and the code can choose the right way to precondition the gradient\n",
        "    Grad: (matrix) gradient\n",
        "    \"\"\"\n",
        "    m, n = Ql.shape\n",
        "    p, q = Qr.shape\n",
        "    if m==n: # left is dense\n",
        "        if p==q: #(dense, dense) format\n",
        "            return _precond_grad_dense_dense(Ql, Qr, Grad)\n",
        "        elif p==2: # (dense, normalization) format\n",
        "            return _precond_grad_norm_dense(Qr, Ql, Grad.t()).t()\n",
        "        elif p==1: # (dense, scaling) format\n",
        "            return _precond_grad_dense_scale(Ql, Qr, Grad)\n",
        "        else:\n",
        "            raise Exception('Unknown Kronecker product preconditioner')\n",
        "    elif m==2: # left is normalization\n",
        "        if p==q: # (normalization, dense) format\n",
        "            return _precond_grad_norm_dense(Ql, Qr, Grad)\n",
        "        elif p==1: # (normalization, scaling) format\n",
        "            return _precond_grad_norm_scale(Ql, Qr, Grad)\n",
        "        else:\n",
        "            raise Exception('Unknown Kronecker product preconditioner')\n",
        "    elif m==1: # left is scaling\n",
        "        if p==q: # (scaling, dense) format\n",
        "            return _precond_grad_dense_scale(Qr, Ql, Grad.t()).t()\n",
        "        elif p==2: # (scaling, normalization) format\n",
        "            return _precond_grad_norm_scale(Qr, Ql, Grad.t()).t()\n",
        "        else:\n",
        "            raise Exception('Unknown Kronecker product preconditioner')\n",
        "    else:\n",
        "        raise Exception('Unknown Kronecker product preconditioner')\n",
        "        \n",
        "\n",
        "###############################################################################\n",
        "@torch.jit.script\n",
        "def _update_precond_dense_dense(Ql, Qr, dX, dG, step=0.01, _tiny=1.2e-38):\n",
        "    # type: (Tensor, Tensor, Tensor, Tensor, float, float) -> Tuple[Tensor, Tensor]\n",
        "    \"\"\"\n",
        "    update Kronecker product preconditioner P = kron_prod(Qr^T*Qr, Ql^T*Ql)\n",
        "    Ql: (left side) Cholesky factor of preconditioner with positive diagonal entries\n",
        "    Qr: (right side) Cholesky factor of preconditioner with positive diagonal entries\n",
        "    dX: perturbation of (matrix) parameter\n",
        "    dG: perturbation of (matrix) gradient\n",
        "    step: update step size normalized to range [0, 1] \n",
        "    _tiny: an offset to avoid division by zero \n",
        "    \"\"\"\n",
        "    max_l = torch.max(torch.diag(Ql))\n",
        "    max_r = torch.max(torch.diag(Qr))\n",
        "    \n",
        "    rho = torch.sqrt(max_l/max_r)\n",
        "    Ql /= rho\n",
        "    Qr *= rho\n",
        "    \n",
        "    #A = Ql.mm( dG.mm( Qr.t() ) )\n",
        "    #Bt = torch.triangular_solve((torch.triangular_solve(dX.t(), Qr, upper=True, transpose=True))[0].t(), \n",
        "    #                 Ql, upper=True, transpose=True)[0]\n",
        "    A = torch.linalg.multi_dot([Ql, dG, Qr.t()])\n",
        "    Bt = torch.linalg.solve_triangular(Ql.t(), torch.linalg.solve_triangular(Qr, dX, upper=True, left=False), upper=False)\n",
        "    \n",
        "    grad1 = torch.triu(A.mm(A.t()) - Bt.mm(Bt.t()))\n",
        "    grad2 = torch.triu(A.t().mm(A) - Bt.t().mm(Bt))\n",
        "    \n",
        "    step1 = step/(torch.max(torch.abs(grad1)) + _tiny)\n",
        "    step2 = step/(torch.max(torch.abs(grad2)) + _tiny)\n",
        "        \n",
        "    return Ql - step1*grad1.mm(Ql), Qr - step2*grad2.mm(Qr)\n",
        "    \n",
        "@torch.jit.script\n",
        "def _precond_grad_dense_dense(Ql, Qr, Grad):\n",
        "    # type: (Tensor, Tensor, Tensor) -> Tensor\n",
        "    \"\"\"\n",
        "    return preconditioned gradient using Kronecker product preconditioner\n",
        "    Ql: (left side) Cholesky factor of preconditioner\n",
        "    Qr: (right side) Cholesky factor of preconditioner\n",
        "    Grad: (matrix) gradient\n",
        "    \"\"\"\n",
        "    #return torch.chain_matmul(Ql.t(), Ql, Grad, Qr.t(), Qr)\n",
        "    return torch.linalg.multi_dot([Ql.t(), Ql, Grad, Qr.t(), Qr])\n",
        "    \n",
        "\n",
        "###############################################################################\n",
        "# (normalization, dense) format Kronecker product preconditioner\n",
        "@torch.jit.script\n",
        "def _update_precond_norm_dense(ql, Qr, dX, dG, step=0.01, _tiny=1.2e-38):\n",
        "    # type: (Tensor, Tensor, Tensor, Tensor, float, float) -> Tuple[Tensor, Tensor]\n",
        "    \"\"\"\n",
        "    update (normalization, dense) Kronecker product preconditioner P = kron_prod(Qr^T*Qr, Ql^T*Ql), where\n",
        "    dX and dG have shape (M, N)\n",
        "    ql has shape (2, M)\n",
        "    Qr has shape (N, N)\n",
        "    ql[0] is the diagonal part of Ql\n",
        "    ql[1,0:-1] is the last column of Ql, excluding the last entry\n",
        "    dX is perturbation of (matrix) parameter\n",
        "    dG is perturbation of (matrix) gradient\n",
        "    step: update step size normalized to range [0, 1] \n",
        "    _tiny: an offset to avoid division by zero  \n",
        "    \"\"\"\n",
        "    # make sure that Ql and Qr have similar dynamic range\n",
        "    max_l = torch.max(ql[0])\n",
        "    max_r = torch.max(torch.diag(Qr))  \n",
        "    rho = torch.sqrt(max_l/max_r)\n",
        "    ql /= rho\n",
        "    Qr *= rho\n",
        "    \n",
        "    # refer to https://arxiv.org/abs/1512.04202 for details\n",
        "    A = ql[0:1].t()*dG + ql[1:].t().mm( dG[-1:] ) # Ql*dG \n",
        "    A = A.mm(Qr.t())\n",
        "    \n",
        "    Bt = dX/ql[0:1].t()\n",
        "    Bt[-1:] -= (ql[1:]/(ql[0:1]*ql[0,-1])).mm(dX)\n",
        "    #Bt = torch.triangular_solve(Bt.t(), Qr, upper=True, transpose=True)[0].t()\n",
        "    Bt = torch.linalg.solve_triangular(Qr, Bt, upper=True, left=False)\n",
        "    \n",
        "    grad1_diag = torch.sum(A*A, dim=1) - torch.sum(Bt*Bt, dim=1)\n",
        "    grad1_bias = A[:-1].mm(A[-1:].t()) - Bt[:-1].mm(Bt[-1:].t()) \n",
        "    grad1_bias = torch.cat([torch.squeeze(grad1_bias), grad1_bias.new_zeros(1)])  \n",
        "\n",
        "    step1 = step/(torch.max(torch.max(torch.abs(grad1_diag)), \n",
        "                            torch.max(torch.abs(grad1_bias))) + _tiny)\n",
        "    new_ql0 = ql[0] - step1*grad1_diag*ql[0]\n",
        "    new_ql1 = ql[1] - step1*(grad1_diag*ql[1] + ql[0,-1]*grad1_bias)\n",
        "    \n",
        "    grad2 = torch.triu(A.t().mm(A) - Bt.t().mm(Bt))\n",
        "    step2 = step/(torch.max(torch.abs(grad2)) + _tiny)\n",
        "    \n",
        "    return torch.stack((new_ql0, new_ql1)), Qr - step2*grad2.mm(Qr)\n",
        "\n",
        "@torch.jit.script\n",
        "def _precond_grad_norm_dense(ql, Qr, Grad):\n",
        "    # type: (Tensor, Tensor, Tensor) -> Tensor\n",
        "    \"\"\"\n",
        "    return preconditioned gradient using (normalization, dense) Kronecker product preconditioner \n",
        "    Suppose Grad has shape (M, N)\n",
        "    ql[0] is the diagonal part of Ql\n",
        "    ql[1, 0:-1] is the last column of Ql, excluding the last entry\n",
        "    Qr: shape (N, N), Cholesky factor of right preconditioner\n",
        "    Grad: (matrix) gradient\n",
        "    \"\"\"\n",
        "    preG = ql[0:1].t()*Grad + ql[1:].t().mm(Grad[-1:]) # Ql*Grad \n",
        "    #preG = torch.chain_matmul(preG, Qr.t(), Qr)\n",
        "    preG = torch.linalg.multi_dot([preG, Qr.t(), Qr])\n",
        "    add_last_row = ql[1:].mm(preG) # use it to modify the last row\n",
        "    preG *= ql[0:1].t()\n",
        "    preG[-1:] += add_last_row\n",
        "    \n",
        "    return preG\n",
        "\n",
        "\n",
        "###############################################################################\n",
        "# (normalization, scaling) Kronecker product preconditioner \n",
        "# the left one is a normalization preconditioner; the right one is a scaling preconditioner\n",
        "@torch.jit.script\n",
        "def _update_precond_norm_scale(ql, qr, dX, dG, step=0.01, _tiny=1.2e-38):\n",
        "    # type: (Tensor, Tensor, Tensor, Tensor, float, float) -> Tuple[Tensor, Tensor]\n",
        "    \"\"\"\n",
        "    update (normalization, scaling) preconditioner P = kron_prod(Qr^T*Qr, Ql^T*Ql), where\n",
        "    dX and dG have shape (M, N)\n",
        "    ql has shape (2, M)\n",
        "    qr has shape (1, N)\n",
        "    ql[0] is the diagonal part of Ql\n",
        "    ql[1, 0:-1] is the last column of Ql, excluding the last entry\n",
        "    qr is the diagonal part of Qr\n",
        "    dX is perturbation of (matrix) parameter\n",
        "    dG is perturbation of (matrix) gradient\n",
        "    step: update step size\n",
        "    _tiny: an offset to avoid division by zero  \n",
        "    \"\"\"\n",
        "    # make sure that Ql and Qr have similar dynamic range\n",
        "    max_l = torch.max(ql[0])\n",
        "    max_r = torch.max(qr) # qr always is positive\n",
        "    rho = torch.sqrt(max_l/max_r)\n",
        "    ql /= rho\n",
        "    qr *= rho\n",
        "    \n",
        "    # refer to https://arxiv.org/abs/1512.04202 for details\n",
        "    A = ql[0:1].t()*dG + ql[1:].t().mm( dG[-1:] ) # Ql*dG \n",
        "    A *= qr # Ql*dG*Qr \n",
        "    \n",
        "    Bt = dX/ql[0:1].t()\n",
        "    Bt[-1:] -= (ql[1:]/(ql[0:1]*ql[0,-1])).mm(dX)\n",
        "    Bt /= qr # Ql^(-T)*dX*Qr^(-1) \n",
        "    \n",
        "    grad1_diag = torch.sum(A*A, dim=1) - torch.sum(Bt*Bt, dim=1)\n",
        "    grad1_bias = A[:-1].mm(A[-1:].t()) - Bt[:-1].mm(Bt[-1:].t()) \n",
        "    grad1_bias = torch.cat([torch.squeeze(grad1_bias), grad1_bias.new_zeros(1)])  \n",
        "\n",
        "    step1 = step/(torch.max(torch.max(torch.abs(grad1_diag)), \n",
        "                            torch.max(torch.abs(grad1_bias))) + _tiny)\n",
        "    new_ql0 = ql[0] - step1*grad1_diag*ql[0]\n",
        "    new_ql1 = ql[1] - step1*(grad1_diag*ql[1] + ql[0,-1]*grad1_bias)\n",
        "    \n",
        "    grad2 = torch.sum(A*A, dim=0, keepdim=True) - torch.sum(Bt*Bt, dim=0, keepdim=True)\n",
        "    step2 = step/(torch.max(torch.abs(grad2)) + _tiny)\n",
        "    \n",
        "    return torch.stack((new_ql0, new_ql1)), qr - step2*grad2*qr\n",
        "\n",
        "@torch.jit.script\n",
        "def _precond_grad_norm_scale(ql, qr, Grad):\n",
        "    # type: (Tensor, Tensor, Tensor) -> Tensor\n",
        "    \"\"\"\n",
        "    return preconditioned gradient using (normalization, scaling) Kronecker product preconditioner\n",
        "    Suppose Grad has shape (M, N)\n",
        "    ql has shape (2, M) \n",
        "    qr has shape (1, N) \n",
        "    ql[0] is the diagonal part of Ql\n",
        "    ql[1, 0:-1] is the last column of Ql, excluding the last entry\n",
        "    qr is the diagonal part of Qr\n",
        "    Grad: (matrix) gradient\n",
        "    \"\"\"\n",
        "    preG = ql[0:1].t()*Grad + ql[1:].t().mm(Grad[-1:]) # Ql*Grad \n",
        "    preG *= (qr*qr) # Ql*Grad*Qr^T*Qr\n",
        "    add_last_row = ql[1:].mm(preG) # use it to modify the last row\n",
        "    preG *= ql[0:1].t()\n",
        "    preG[-1:] += add_last_row\n",
        "    \n",
        "    return preG\n",
        "\n",
        "\n",
        "###############################################################################\n",
        "@torch.jit.script\n",
        "def _update_precond_dense_scale(Ql, qr, dX, dG, step=0.01, _tiny=1.2e-38):\n",
        "    # type: (Tensor, Tensor, Tensor, Tensor, float, float) -> Tuple[Tensor, Tensor]\n",
        "    \"\"\"\n",
        "    update (dense, scaling) preconditioner P = kron_prod(Qr^T*Qr, Ql^T*Ql), where\n",
        "    dX and dG have shape (M, N)\n",
        "    Ql has shape (M, M)\n",
        "    qr has shape (1, N)\n",
        "    qr is the diagonal part of Qr\n",
        "    dX is perturbation of (matrix) parameter\n",
        "    dG is perturbation of (matrix) gradient\n",
        "    step: update step size\n",
        "    _tiny: an offset to avoid division by zero \n",
        "    \"\"\"\n",
        "    max_l = torch.max(torch.diag(Ql))\n",
        "    max_r = torch.max(qr)\n",
        "    \n",
        "    rho = torch.sqrt(max_l/max_r)\n",
        "    Ql /= rho\n",
        "    qr *= rho\n",
        "    \n",
        "    A = Ql.mm( dG*qr )\n",
        "    #Bt = torch.triangular_solve(dX/qr, Ql, upper=True, transpose=True)[0]\n",
        "    Bt = torch.linalg.solve_triangular(Ql.t(), dX/qr, upper=False)\n",
        "    \n",
        "    grad1 = torch.triu(A.mm(A.t()) - Bt.mm(Bt.t()))\n",
        "    grad2 = torch.sum(A*A, dim=0, keepdim=True) - torch.sum(Bt*Bt, dim=0, keepdim=True)\n",
        "    \n",
        "    step1 = step/(torch.max(torch.abs(grad1)) + _tiny)\n",
        "    step2 = step/(torch.max(torch.abs(grad2)) + _tiny)\n",
        "        \n",
        "    return Ql - step1*grad1.mm(Ql), qr - step2*grad2*qr\n",
        "    \n",
        "@torch.jit.script\n",
        "def _precond_grad_dense_scale(Ql, qr, Grad):\n",
        "    # type: (Tensor, Tensor, Tensor) -> Tensor\n",
        "    \"\"\"\n",
        "    return preconditioned gradient using (dense, scaling) Kronecker product preconditioner\n",
        "    Suppose Grad has shape (M, N)\n",
        "    Ql: shape (M, M), (left side) Cholesky factor of preconditioner\n",
        "    qr: shape (1, N), defines a diagonal matrix for output feature scaling\n",
        "    Grad: (matrix) gradient\n",
        "    \"\"\"\n",
        "    #return torch.chain_matmul(Ql.t(), Ql, Grad*(qr*qr))\n",
        "    return torch.linalg.multi_dot([Ql.t(), Ql, Grad*(qr*qr)])\n",
        "\n",
        "\n",
        "\n",
        "###############################################################################   \n",
        "@torch.jit.script                     \n",
        "def update_precond_splu(L12, l3, U12, u3, dxs, dgs, step=0.01, _tiny=1.2e-38):\n",
        "    # type: (Tensor,Tensor,Tensor,Tensor, List[Tensor],List[Tensor], float,float) -> Tuple[Tensor,Tensor,Tensor,Tensor]\n",
        "    \"\"\"\n",
        "    update sparse LU preconditioner P = Q^T*Q, where \n",
        "    Q = L*U,\n",
        "    L12 = [L1; L2]\n",
        "    U12 = [U1, U2]\n",
        "    L = [L1, 0; L2, diag(l3)]\n",
        "    U = [U1, U2; 0, diag(u3)]\n",
        "    l3 and u3 are column vectors\n",
        "    dxs: a list of random perturbation on parameters\n",
        "    dgs: a list of resultant perturbation on gradients\n",
        "    step: update step size normalized to range [0, 1] \n",
        "    _tiny: an offset to avoid division by zero \n",
        "    \"\"\"\n",
        "    # make sure that L and U have similar dynamic range\n",
        "    max_l = torch.max(torch.max(torch.diag(L12)), torch.max(l3))\n",
        "    max_u = torch.max(torch.max(torch.diag(U12)), torch.max(u3))\n",
        "    rho = torch.sqrt(max_l/max_u)\n",
        "    L12 /= rho\n",
        "    l3 /= rho\n",
        "    U12 *= rho\n",
        "    u3 *= rho\n",
        "    \n",
        "    # extract the blocks\n",
        "    r = U12.shape[0]\n",
        "    L1 = L12[:r]\n",
        "    L2 = L12[r:]\n",
        "    U1 = U12[:, :r]\n",
        "    U2 = U12[:, r:]\n",
        "    \n",
        "    dx = torch.cat([torch.reshape(x, [-1, 1]) for x in dxs]) # a tall column vector\n",
        "    dg = torch.cat([torch.reshape(g, [-1, 1]) for g in dgs]) # a tall column vector\n",
        "    \n",
        "    # U*dg\n",
        "    Ug1 = U1.mm(dg[:r]) + U2.mm(dg[r:])\n",
        "    Ug2 = u3*dg[r:]\n",
        "    # Q*dg\n",
        "    Qg1 = L1.mm(Ug1)\n",
        "    Qg2 = L2.mm(Ug1) + l3*Ug2\n",
        "    # inv(U^T)*dx\n",
        "    #iUtx1 = torch.triangular_solve(dx[:r], U1, upper=True, transpose=True)[0]\n",
        "    iUtx1 = torch.linalg.solve_triangular(U1.t(), dx[:r], upper=False)\n",
        "    iUtx2 = (dx[r:] - U2.t().mm(iUtx1))/u3\n",
        "    # inv(Q^T)*dx\n",
        "    iQtx2 = iUtx2/l3\n",
        "    #iQtx1 = torch.triangular_solve(iUtx1 - L2.t().mm(iQtx2), L1, upper=False, transpose=True)[0]\n",
        "    iQtx1 = torch.linalg.solve_triangular(L1.t(), iUtx1 - L2.t().mm(iQtx2), upper=True)\n",
        "    # L^T*Q*dg\n",
        "    LtQg1 = L1.t().mm(Qg1) + L2.t().mm(Qg2)\n",
        "    LtQg2 = l3*Qg2\n",
        "    # P*dg\n",
        "    Pg1 = U1.t().mm(LtQg1)\n",
        "    Pg2 = U2.t().mm(LtQg1) + u3*LtQg2\n",
        "    # inv(L)*inv(Q^T)*dx\n",
        "    #iLiQtx1 = torch.triangular_solve(iQtx1, L1, upper=False)[0]\n",
        "    iLiQtx1 = torch.linalg.solve_triangular(L1, iQtx1, upper=False)\n",
        "    iLiQtx2 = (iQtx2 - L2.mm(iLiQtx1))/l3\n",
        "    # inv(P)*dx\n",
        "    iPx2 = iLiQtx2/u3\n",
        "    #iPx1 = torch.triangular_solve(iLiQtx1 - U2.mm(iPx2), U1, upper=True)[0]\n",
        "    iPx1 = torch.linalg.solve_triangular(U1, iLiQtx1 - U2.mm(iPx2), upper=True)\n",
        "    \n",
        "    # update L\n",
        "    grad1 = Qg1.mm(Qg1.t()) - iQtx1.mm(iQtx1.t())\n",
        "    grad1 = torch.tril(grad1)\n",
        "    grad2 = Qg2.mm(Qg1.t()) - iQtx2.mm(iQtx1.t())\n",
        "    grad3 = Qg2*Qg2 - iQtx2*iQtx2\n",
        "    max_abs_grad = torch.max(torch.abs(grad1))\n",
        "    max_abs_grad = torch.max(max_abs_grad, torch.max(torch.abs(grad2)))\n",
        "    max_abs_grad = torch.max(max_abs_grad, torch.max(torch.abs(grad3)))\n",
        "    step0 = step/(max_abs_grad + _tiny)\n",
        "    newL1 = L1 - step0*grad1.mm(L1)\n",
        "    newL2 = L2 - step0*grad2.mm(L1) - step0*grad3*L2\n",
        "    newl3 = l3 - step0*grad3*l3\n",
        "\n",
        "    # update U\n",
        "    grad1 = Pg1.mm(dg[:r].t()) - dx[:r].mm(iPx1.t())\n",
        "    grad1 = torch.triu(grad1)\n",
        "    grad2 = Pg1.mm(dg[r:].t()) - dx[:r].mm(iPx2.t())\n",
        "    grad3 = Pg2*dg[r:] - dx[r:]*iPx2\n",
        "    max_abs_grad = torch.max(torch.abs(grad1))\n",
        "    max_abs_grad = torch.max(max_abs_grad, torch.max(torch.abs(grad2)))\n",
        "    max_abs_grad = torch.max(max_abs_grad, torch.max(torch.abs(grad3)))\n",
        "    step0 = step/(max_abs_grad + _tiny)\n",
        "    newU1 = U1 - U1.mm(step0*grad1)\n",
        "    newU2 = U2 - U1.mm(step0*grad2) - step0*grad3.t()*U2\n",
        "    newu3 = u3 - step0*grad3*u3\n",
        "\n",
        "    return torch.cat([newL1, newL2], dim=0), newl3, torch.cat([newU1, newU2], dim=1), newu3\n",
        "\n",
        "@torch.jit.script\n",
        "def precond_grad_splu(L12, l3, U12, u3, grads):\n",
        "    # type: (Tensor,Tensor,Tensor,Tensor, List[Tensor]) -> List[Tensor]\n",
        "    \"\"\"\n",
        "    return preconditioned gradient with sparse LU preconditioner\n",
        "    where P = Q^T*Q, \n",
        "    Q = L*U,\n",
        "    L12 = [L1; L2]\n",
        "    U12 = [U1, U2]\n",
        "    L = [L1, 0; L2, diag(l3)]\n",
        "    U = [U1, U2; 0, diag(u3)]\n",
        "    l3 and u3 are column vectors\n",
        "    grads: a list of gradients to be preconditioned\n",
        "    \"\"\"\n",
        "    grad = [torch.reshape(g, [-1, 1]) for g in grads] # a list of column vector\n",
        "    lens = [g.shape[0] for g in grad] # length of each column vector\n",
        "    grad = torch.cat(grad)  # a tall column vector\n",
        "    \n",
        "    r = U12.shape[0]\n",
        "    L1 = L12[:r]\n",
        "    L2 = L12[r:]\n",
        "    U1 = U12[:, :r]\n",
        "    U2 = U12[:, r:]    \n",
        "    \n",
        "    # U*g\n",
        "    Ug1 = U1.mm(grad[:r]) + U2.mm(grad[r:])\n",
        "    Ug2 = u3*grad[r:]\n",
        "    # Q*g\n",
        "    Qg1 = L1.mm(Ug1)\n",
        "    Qg2 = L2.mm(Ug1) + l3*Ug2\n",
        "    # L^T*Q*g\n",
        "    LtQg1 = L1.t().mm(Qg1) + L2.t().mm(Qg2)\n",
        "    LtQg2 = l3*Qg2\n",
        "    # P*g\n",
        "    pre_grad = torch.cat([U1.t().mm(LtQg1),\n",
        "                          U2.t().mm(LtQg1) + u3*LtQg2])\n",
        "    \n",
        "    pre_grads = [] # restore pre_grad to its original shapes\n",
        "    idx = 0\n",
        "    for i in range(len(grads)):\n",
        "        pre_grads.append(torch.reshape(pre_grad[idx : idx + lens[i]], grads[i].shape))\n",
        "        idx = idx + lens[i]\n",
        "    \n",
        "    return pre_grads\n",
        "\n",
        "\n",
        "\n",
        "##############################################################################\n",
        "#\n",
        "# The low-rank approximation (UVd) preconditioner is defined by\n",
        "#\n",
        "#   Q = (I + U*V')*diag(d)\n",
        "#\n",
        "# which, after reparameterization, is equivalent to form\n",
        "#\n",
        "#   diag(d) + U*V'\n",
        "# \n",
        "# It relates to the LM-BFGS and conjugate gradient methods. \n",
        "#\n",
        "# The JIT decorator can be enabled if helps. \n",
        "# \n",
        "\n",
        "#@torch.jit.script\n",
        "def IpUVtmatvec(U, V, x):\n",
        "    # type: (Tensor, Tensor, Tensor) -> Tensor\n",
        "    \"\"\"\n",
        "    Returns (I + U*V')*x. All variables are either matrices or column vectors. \n",
        "    \"\"\"\n",
        "    return x + U.mm(V.t().mm(x))\n",
        "\n",
        "# def IpUVtsolve(U, V, x):\n",
        "#     \"\"\"\n",
        "#     Returns inv(I + U*V')*x. All variables are either matrices or column vectors.\n",
        "#     \"\"\"\n",
        "#     VtU = V.t().mm(U)\n",
        "#     I = torch.eye(VtU.size(dim=0), dtype=VtU.dtype, device=VtU.device)\n",
        "#     return x - U.mm(torch.linalg.solve(I + VtU, V.t().mm(x))) # torch.solve is slow\n",
        "\n",
        "# def norm_UVt(U, V):\n",
        "#     \"\"\"\n",
        "#     Returns ||U*V'||_fro = sqrt(tr(U'*U*V'*V)) = sqrt(sum((U'*U)*(V'*V))) \n",
        "#     \"\"\"\n",
        "#     return torch.sqrt(torch.abs(torch.sum( (U.t().mm(U))*(V.t().mm(V)) )))\n",
        "\n",
        "#@torch.jit.script\n",
        "def update_precond_UVd_math_(U, V, d, v, h, step, tiny):\n",
        "    # type: (Tensor, Tensor, Tensor, Tensor, Tensor, float, float) -> None\n",
        "    \"\"\"\n",
        "    Update preconditioner Q = (I + U*V')*diag(d) with (vector, Hessian-vector product) = (v, h).\n",
        "    State variables U, V and d are updated inplace. \n",
        "                               \n",
        "    U, V, d, v, and h are either matrices or column vectors.  \n",
        "    \"\"\"\n",
        "    # balance the numerical dynamic ranges of U and V; optional \n",
        "    if torch.rand([]) < 0.01:\n",
        "        normU = torch.linalg.vector_norm(U)\n",
        "        normV = torch.linalg.vector_norm(V)\n",
        "        rho = torch.sqrt(normU/normV)\n",
        "        U.div_(rho)\n",
        "        V.mul_(rho)\n",
        "\n",
        "    Qh = IpUVtmatvec(U, V, d*h)\n",
        "    Ph = d*IpUVtmatvec(V, U, Qh)\n",
        "    \n",
        "    # invQtv = IpUVtsolve(V, U, v/d)\n",
        "    # invPv = IpUVtsolve(U, V, invQtv)/d\n",
        "    VtU = V.t().mm(U)\n",
        "    I = torch.eye(VtU.size(dim=0), dtype=VtU.dtype, device=VtU.device)\n",
        "    IpVtU = I + VtU\n",
        "    invQtv = v/d\n",
        "    # torch's linalg.solve is slow for small matrix\n",
        "    invQtv = invQtv - V.mm(torch.linalg.solve(IpVtU.t(), U.t().mm(invQtv)))  \n",
        "    invPv  = invQtv - U.mm(torch.linalg.solve(IpVtU,     V.t().mm(invQtv)))\n",
        "    invPv = invPv/d\n",
        "\n",
        "    nablaD = Ph*h - v*invPv\n",
        "    mu = step/(torch.max(torch.abs(nablaD)) + tiny)\n",
        "    #d = d - mu*d*nablaD\n",
        "    d.sub_(mu*d*nablaD)\n",
        "    \n",
        "    # update either U or V, not both at the same time\n",
        "    a, b = Qh, invQtv\n",
        "    if torch.rand([]) < 0.5:\n",
        "        # nablaU = Qh.mm(Qh.t().mm(V)) - invQtv.mm(invQtv.t().mm(V))\n",
        "        # mu = step/(norm_UVt(nablaU, V) + _tiny)\n",
        "        # U = U - mu*(nablaU + nablaU.mm(V.t().mm(U)))\n",
        "        atV = a.t().mm(V)\n",
        "        atVVt = atV.mm(V.t())\n",
        "        btV = b.t().mm(V)\n",
        "        btVVt = btV.mm(V.t())\n",
        "        norm = torch.sqrt(torch.abs( (a.t().mm(a))*(atVVt.mm(atVVt.t())) # abs to avoid sqrt(-0.0) \n",
        "                                    +(b.t().mm(b))*(btVVt.mm(btVVt.t())) \n",
        "                                  -2*(a.t().mm(b))*(atVVt.mm(btVVt.t())) ))\n",
        "        mu = step/(norm + tiny)\n",
        "        # U = U - mu*( a.mm(atV.mm(IpVtU)) \n",
        "        #             -b.mm(btV.mm(IpVtU)) )\n",
        "        U.sub_(mu*( a.mm(atV.mm(IpVtU)) \n",
        "                   -b.mm(btV.mm(IpVtU)) ))\n",
        "    else:\n",
        "        # nablaV = Qh.mm(Qh.t().mm(U)) - invQtv.mm(invQtv.t().mm(U))\n",
        "        # mu = step/(norm_UVt(U, nablaV) + _tiny)\n",
        "        # V = V - mu*(nablaV + V.mm(U.t().mm(nablaV)))\n",
        "        atU = a.t().mm(U)\n",
        "        btU = b.t().mm(U)\n",
        "        UUta = U.mm(atU.t())\n",
        "        UUtb = U.mm(btU.t())\n",
        "        norm = torch.sqrt(torch.abs( (UUta.t().mm(UUta))*(a.t().mm(a)) # abs to avoid sqrt(-0.0)\n",
        "                                    +(UUtb.t().mm(UUtb))*(b.t().mm(b))\n",
        "                                  -2*(UUta.t().mm(UUtb))*(a.t().mm(b)) ))\n",
        "        mu = step/(norm + tiny)\n",
        "        # V = V - mu*( (a + V.mm(atU.t())).mm(atU) \n",
        "        #             -(b + V.mm(btU.t())).mm(btU) )\n",
        "        V.sub_(mu*( (a + V.mm(atU.t())).mm(atU) \n",
        "                   -(b + V.mm(btU.t())).mm(btU) ))\n",
        "\n",
        "    # return [U, V, d]\n",
        "\n",
        "#@torch.jit.script\n",
        "def precond_grad_UVd_math(U, V, d, g):\n",
        "    # type: (Tensor, Tensor, Tensor, Tensor) -> Tensor\n",
        "    \"\"\"\n",
        "    Preconditioning gradient g with Q = (I + U*V')*diag(d).\n",
        "                                         \n",
        "    All variables here are either matrices or column vectors. \n",
        "    \"\"\"\n",
        "    g = IpUVtmatvec(U, V, d*g)\n",
        "    g = d*IpUVtmatvec(V, U, g)\n",
        "    return g\n",
        "\n",
        "\n",
        "class UVd:\n",
        "    \"\"\"\n",
        "    Implements the low-rank approximation (UVd) preconditioner, Q = (I + U*V')*diag(d), as a class.\n",
        "\n",
        "    Args for initialization:\n",
        "        params_with_grad: a list of parameters or variables requiring gradients;\n",
        "        rank_of_approximation: rank of approximation, i.e., rank of U or V;\n",
        "        preconditioner_init_scale: initial scale of Q, or roughly, Q = preconditioner_init_scale*eye();\n",
        "        lr_params: normalized learning rate for parameters in range [0, 1];\n",
        "        lr_preconditioner: normalized learning rate for preconditioner in range [0, 1];\n",
        "        momentum: momentum factor in range [0,1);\n",
        "        grad_clip_max_norm: maximum allowable gradient norm after clipping, None for no clipping;\n",
        "        preconditioner_update_probability: probability on updating Q, 1 for updating at every step, and 0 for never;\n",
        "        exact_hessian_vector_product: True for exact Hessian-vector product via 2nd derivative,\n",
        "                                    and False for approximate one via finite-difference formulae.\n",
        "\n",
        "    Notes:\n",
        "        Note 1: The Hessian-vector product can be approximated using the finite-difference formulae by setting \n",
        "        exact_hessian_vector_product = False when the 2nd derivatives is not available.\n",
        "        In this case, make sure that the closure produces the same outputs given the same inputs, \n",
        "        except for numerical errors due to non-deterministic behaviors.\n",
        "        Random numbers, if any, used inside the closure should be generated starting from the same state, where the rng state can be\n",
        "        read and set by, e.g., `torch.cuda.get_rng_state' and `torch.cuda.set_rng_state', respectively.\n",
        "        \n",
        "        Note 2: Momentum here is the moving average of gradient so that its setting is decoupled from the learning rate.\n",
        "        This is necessary as the learning rate in PSGD is normalized. \n",
        "\n",
        "        Note 3: `torch.linalg.solve' is called twice in function `update_precond_UVd_math_'.\n",
        "        Certain solver could be orders of magnitude faster than others, especially for small matrices (see the pdf file).\n",
        "        Considering replace it with faster ones if the default solver is too slow.\n",
        "\n",
        "        Note 4: Currently, no support of sparse and mixed-precision gradients. \n",
        "        Half precision is supported except that torch.linalg.solve (v1.12) requires casting float16 to float32.    \n",
        "        \n",
        "        Note 5: lr_params, lr_preconditioner, momentum, grad_clip_max_norm, preconditioner_update_probability, and \n",
        "        exact_hessian_vector_product (bool) all can be reset on the fly. \n",
        "    \"\"\"\n",
        "    def __init__(self,  params_with_grad, rank_of_approximation:int=10, preconditioner_init_scale=1.0,\n",
        "                        lr_params=0.01, lr_preconditioner=0.01, momentum=0.0,\n",
        "                        grad_clip_max_norm=None, preconditioner_update_probability=1.0,\n",
        "                        exact_hessian_vector_product:bool=True):\n",
        "        # mutable members\n",
        "        self.lr_params = lr_params\n",
        "        self.lr_preconditioner = lr_preconditioner\n",
        "        self.momentum = momentum if (0<momentum<1) else 0.0\n",
        "        self.grad_clip_max_norm = grad_clip_max_norm\n",
        "        self.preconditioner_update_probability = preconditioner_update_probability\n",
        "        self.exact_hessian_vector_product = exact_hessian_vector_product\n",
        "        # protected members\n",
        "        params_with_grad = [params_with_grad,] if isinstance(params_with_grad, torch.Tensor) else params_with_grad\n",
        "        self._params_with_grad = [param for param in params_with_grad if param.requires_grad] # double check requires_grad flag\n",
        "        dtype, device = self._params_with_grad[0].dtype, self._params_with_grad[0].device\n",
        "        self._tiny = torch.finfo(dtype).tiny\n",
        "        self._delta_param_scale = torch.finfo(dtype).eps**0.5\n",
        "        self._param_sizes = [torch.numel(param) for param in self._params_with_grad]\n",
        "        self._param_cumsizes = torch.cumsum(torch.tensor(self._param_sizes), 0)\n",
        "        num_params = self._param_cumsizes[-1]\n",
        "        self._U = torch.randn(num_params, rank_of_approximation, dtype=dtype, device=device) / (num_params*rank_of_approximation)**0.5\n",
        "        self._V = torch.randn(num_params, rank_of_approximation, dtype=dtype, device=device) / (num_params*rank_of_approximation)**0.5\n",
        "        self._d = torch.ones( num_params, 1, dtype=dtype, device=device) * preconditioner_init_scale\n",
        "        self._m = None # momentum buffer \n",
        "\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def step(self, closure):\n",
        "        \"\"\"\n",
        "        Performs a single step of PSGD with low-rank approximation (UVd) preconditioner, i.e., \n",
        "        updating the trainable parameters once, and returning what closure returns.\n",
        "\n",
        "        Args:\n",
        "            closure (callable): a closure that evaluates the function of self._params_with_grad,\n",
        "                                and returns the loss, or an iterable with the first one being loss.\n",
        "                                Random numbers, if any, used inside the closure should be generated starting \n",
        "                                from the same rng state if self.exact_hessian_vector_product = False; otherwise doesn't matter. \n",
        "        \"\"\"\n",
        "        if torch.rand([]) < self.preconditioner_update_probability:\n",
        "            # evaluates gradients, Hessian-vector product, and updates the preconditioner\n",
        "            if self.exact_hessian_vector_product:\n",
        "                # exact Hessian-vector product\n",
        "                with torch.enable_grad():\n",
        "                    closure_returns = closure()\n",
        "                    loss = closure_returns if isinstance(closure_returns, torch.Tensor) else closure_returns[0]\n",
        "                    grads = torch.autograd.grad(loss, self._params_with_grad, create_graph=True)\n",
        "                    vs = [torch.randn_like(param) for param in self._params_with_grad]\n",
        "                    Hvs = torch.autograd.grad(grads, self._params_with_grad, vs)\n",
        "            else:\n",
        "                # approximate Hessian-vector product via finite-difference formulae. Use it with cautions.\n",
        "                with torch.enable_grad():\n",
        "                    closure_returns = closure()\n",
        "                    loss = closure_returns if isinstance(closure_returns, torch.Tensor) else closure_returns[0]\n",
        "                    grads = torch.autograd.grad(loss, self._params_with_grad)\n",
        "                vs = [self._delta_param_scale * torch.randn_like(param) for param in self._params_with_grad]\n",
        "                [param.add_(v) for (param, v) in zip(self._params_with_grad, vs)]\n",
        "                with torch.enable_grad():\n",
        "                    perturbed_returns = closure()\n",
        "                    perturbed_loss = perturbed_returns if isinstance(perturbed_returns, torch.Tensor) else perturbed_returns[0]\n",
        "                    perturbed_grads = torch.autograd.grad(perturbed_loss, self._params_with_grad)\n",
        "                Hvs = [perturbed_g - g for (perturbed_g, g) in zip(perturbed_grads, grads)]\n",
        "            # update preconditioner\n",
        "            v = torch.cat([torch.flatten(v) for v in vs])\n",
        "            h = torch.cat([torch.flatten(h) for h in Hvs])\n",
        "            if self.exact_hessian_vector_product:\n",
        "                update_precond_UVd_math_(self._U, self._V, self._d,\n",
        "                                         v[:,None], h[:,None], step=self.lr_preconditioner, tiny=self._tiny)\n",
        "            else: # compensate the levels of v and h; helpful to reduce numerical errors in half-precision training\n",
        "                update_precond_UVd_math_(self._U, self._V, self._d,\n",
        "                                         v[:,None]/self._delta_param_scale, h[:,None]/self._delta_param_scale, step=self.lr_preconditioner, tiny=self._tiny)\n",
        "        else:\n",
        "            # only evaluates the gradients\n",
        "            with torch.enable_grad():\n",
        "                closure_returns = closure()\n",
        "                loss = closure_returns if isinstance(closure_returns, torch.Tensor) else closure_returns[0]\n",
        "                grads = torch.autograd.grad(loss, self._params_with_grad)\n",
        "            vs = None # no vs and Hvs\n",
        "\n",
        "        # preconditioned gradients; momentum is optional\n",
        "        grad = torch.cat([torch.flatten(g) for g in grads])\n",
        "        if self.momentum > 0:\n",
        "            if self._m is None:\n",
        "                self._m = (1 - self.momentum)*grad\n",
        "            else:\n",
        "                self._m.mul_(self.momentum).add_((1 - self.momentum)*grad)\n",
        "            pre_grad = precond_grad_UVd_math(self._U, self._V, self._d, self._m[:, None])\n",
        "        else:\n",
        "            self._m = None # clean the buffer when momentum is set to zero \n",
        "            pre_grad = precond_grad_UVd_math(self._U, self._V, self._d, grad[:, None])\n",
        "            \n",
        "        # gradient clipping is optional\n",
        "        if self.grad_clip_max_norm is None:\n",
        "            lr = self.lr_params\n",
        "        else:\n",
        "            grad_norm = torch.linalg.vector_norm(pre_grad) + self._tiny\n",
        "            lr = self.lr_params * min(self.grad_clip_max_norm/grad_norm, 1.0)\n",
        "            \n",
        "        # update the parameters\n",
        "        if self.exact_hessian_vector_product or (vs is None):\n",
        "            [param.subtract_(lr * pre_grad[j - i:j].view_as(param))\n",
        "             for (param, i, j) in zip(self._params_with_grad, self._param_sizes, self._param_cumsizes)]\n",
        "        else: # in this case, do not forget to remove the perturbation on parameters\n",
        "            [param.subtract_(lr * pre_grad[j - i:j].view_as(param) + v)\n",
        "             for (param, i, j, v) in zip(self._params_with_grad, self._param_sizes, self._param_cumsizes, vs)]\n",
        "        # return whatever closure returns\n",
        "        return closure_returns\n",
        "\n",
        "################## end of UVd preconditioner #################################\n",
        "\n",
        "\n",
        "##############################################################################\n",
        "# An Xmat (X-matrix) preconditioner is defined by\n",
        "#\n",
        "#   Q = diag(a) + adiag(b)\n",
        "#\n",
        "# where adiag means anti-diagonal.\n",
        "# It's slightly more complicated than a diagonal preconditioner, but performs better.\n",
        "#\n",
        "\n",
        "#@torch.jit.script\n",
        "def update_precond_Xmat_math_(a, b, v, h, step, tiny):\n",
        "    # type: (Tensor, Tensor, Tensor, Tensor, float, float) -> None\n",
        "    \"\"\"\n",
        "    Update preconditioner Q = diag(a) + adiag(b) with (vector, Hessian-vector product) = (v, h).\n",
        "    State variables a and b are updated inplace.\n",
        "    \"\"\"\n",
        "    Qh = a*h + b*torch.flip(h, [0])\n",
        "    aflip, bflip = torch.flip(a, [0]), torch.flip(b, [0])\n",
        "    invQtv = (aflip*v - bflip*torch.flip(v, [0]))/(a*aflip - b*bflip)\n",
        "    nablaA = Qh*Qh - invQtv*invQtv\n",
        "    nablaB = Qh*torch.flip(Qh, [0]) - invQtv*torch.flip(invQtv, [0])\n",
        "    q, r = divmod(len(nablaB), 2)\n",
        "    if r == 1:\n",
        "        nablaB[q] = 0\n",
        "\n",
        "    mu = step/(torch.maximum(torch.max(torch.abs(nablaA)), torch.max(torch.abs(nablaB))) + tiny)\n",
        "    a.sub_(mu*(nablaA*a + nablaB*bflip))\n",
        "    b.sub_(mu*(nablaA*b + nablaB*aflip))\n",
        "\n",
        "#@torch.jit.script\n",
        "def precond_grad_Xmat_math(a, b, g):\n",
        "    # type: (Tensor, Tensor, Tensor) -> Tensor\n",
        "    \"\"\"\n",
        "    Preconditioning gradient g with Q = diag(a) + adiag(b).\n",
        "    \"\"\"\n",
        "    ab = a * b\n",
        "    return (a*a + torch.flip(b*b, [0]))*g + (ab + torch.flip(ab, [0]))*torch.flip(g, [0])\n",
        "\n",
        "from torch.optim.optimizer import Optimizer\n",
        "class XMat(Optimizer):\n",
        "    \"\"\"\n",
        "    Implements the Xmat preconditioner, Q = diag(a) + adiag(b), as a class.\n",
        "    Args for initialization:\n",
        "        params_with_grad: a list of parameters or variables requiring gradients;\n",
        "        preconditioner_init_scale: initial scale of Q, i.e., Q = preconditioner_init_scale*eye();\n",
        "        lr_params: normalized learning rate for parameters in range [0, 1];\n",
        "        lr_preconditioner: normalized learning rate for preconditioner in range [0, 1];\n",
        "        momentum: momentum factor in range [0,1);\n",
        "        grad_clip_max_norm: maximum allowable gradient norm after clipping, None for no clipping;\n",
        "        preconditioner_update_probability: probability on updating Q, 1 for updating at every step, and 0 for never, i.e., SGD;\n",
        "        exact_hessian_vector_product: True for exact Hessian-vector product via 2nd derivative,\n",
        "                                    and False for approximate one via finite-difference formulae.\n",
        "    Notes:\n",
        "        Note 1: The Hessian-vector product can be approximated using the finite-difference formulae by setting\n",
        "        exact_hessian_vector_product = False when the 2nd derivatives is not available.\n",
        "        In this case, make sure that the closure produces the same outputs given the same inputs,\n",
        "        except for numerical errors due to non-deterministic behaviors.\n",
        "        Random numbers, if any, used inside the closure should be generated starting from the same state, where the rng state can be\n",
        "        read and set by, e.g., `torch.cuda.get_rng_state' and `torch.cuda.set_rng_state', respectively.\n",
        "        \n",
        "        Note 2: Momentum here is the moving average of gradient so that its setting is decoupled from the learning rate.\n",
        "        This is necessary as the learning rate in PSGD is normalized.\n",
        "\n",
        "        Note 3: Currently, no support of sparse and mixed-precision gradients.\n",
        "\n",
        "        Note 4: lr_params, lr_preconditioner, momentum, grad_clip_max_norm, preconditioner_update_probability, and\n",
        "        exact_hessian_vector_product (bool) all can be reset on the fly.\n",
        "    \"\"\"\n",
        "    def __init__(self, params_with_grad, preconditioner_init_scale=1.0,\n",
        "                 lr_params=0.01, lr_preconditioner=0.01, momentum=0.0, \n",
        "                 grad_clip_max_norm=None, preconditioner_update_probability=1.0,\n",
        "                 exact_hessian_vector_product: bool = True):\n",
        "        # mutable members\n",
        "        self.lr_params = lr_params\n",
        "        self.lr_preconditioner = lr_preconditioner\n",
        "        self.momentum = momentum if (0<momentum<1) else 0.0\n",
        "        self.grad_clip_max_norm = grad_clip_max_norm\n",
        "        self.preconditioner_update_probability = preconditioner_update_probability\n",
        "        self.exact_hessian_vector_product = exact_hessian_vector_product\n",
        "        # protected members\n",
        "        params_with_grad = [params_with_grad, ] if isinstance(params_with_grad, torch.Tensor) else params_with_grad\n",
        "        self._params_with_grad = [param for param in params_with_grad if param.requires_grad]  # double check requires_grad flag\n",
        "        dtype, device = self._params_with_grad[0].dtype, self._params_with_grad[0].device\n",
        "        self._tiny = torch.finfo(dtype).tiny\n",
        "        self._delta_param_scale = torch.finfo(dtype).eps ** 0.5\n",
        "        self._param_sizes = [torch.numel(param) for param in self._params_with_grad]\n",
        "        self._param_cumsizes = torch.cumsum(torch.tensor(self._param_sizes), 0)\n",
        "        num_params = self._param_cumsizes[-1]\n",
        "        self._a = torch.ones(num_params, dtype=dtype, device=device)*preconditioner_init_scale\n",
        "        self._b = torch.zeros(num_params, dtype=dtype, device=device)\n",
        "        self._m = None # buffer for momentum \n",
        "        defaults = dict(lr=lr_params)\n",
        "        super(XMat, self).__init__(self._params_with_grad, defaults)        \n",
        "\n",
        "    @torch.no_grad()\n",
        "    def step(self, closure):\n",
        "        \"\"\"\n",
        "        Performs a single step of PSGD with Xmat preconditioner, i.e.,\n",
        "        updating the trainable parameters once, and returning what closure returns.\n",
        "        Args:\n",
        "            closure (callable): a closure that evaluates the function of self._params_with_grad,\n",
        "                                and returns the loss, or an iterable with the first one being loss.\n",
        "                                Random numbers, if any, used inside the closure should be generated starting\n",
        "                                from the same rng state if self.exact_hessian_vector_product = False; otherwise doesn't matter.\n",
        "        \"\"\"\n",
        "        if torch.rand([]) < self.preconditioner_update_probability:\n",
        "            # evaluates gradients, Hessian-vector product, and updates the preconditioner\n",
        "            if self.exact_hessian_vector_product:\n",
        "                # exact Hessian-vector product\n",
        "                with torch.enable_grad():\n",
        "                    closure_returns = closure()\n",
        "                    loss = closure_returns if isinstance(closure_returns, torch.Tensor) else closure_returns[0]\n",
        "                    grads = torch.autograd.grad(loss, self._params_with_grad, create_graph=True)\n",
        "                    vs = [torch.randn_like(param) for param in self._params_with_grad]\n",
        "                    Hvs = torch.autograd.grad(grads, self._params_with_grad, vs)\n",
        "            else:\n",
        "                # approximate Hessian-vector product via finite-difference formulae. Use it with cautions.\n",
        "                with torch.enable_grad():\n",
        "                    closure_returns = closure()\n",
        "                    loss = closure_returns if isinstance(closure_returns, torch.Tensor) else closure_returns[0]\n",
        "                    grads = torch.autograd.grad(loss, self._params_with_grad)\n",
        "                vs = [self._delta_param_scale * torch.randn_like(param) for param in self._params_with_grad]\n",
        "                [param.add_(v) for (param, v) in zip(self._params_with_grad, vs)]\n",
        "                with torch.enable_grad():\n",
        "                    perturbed_returns = closure()\n",
        "                    perturbed_loss = perturbed_returns if isinstance(perturbed_returns, torch.Tensor) else perturbed_returns[0]\n",
        "                    perturbed_grads = torch.autograd.grad(perturbed_loss, self._params_with_grad)\n",
        "                Hvs = [perturbed_g - g for (perturbed_g, g) in zip(perturbed_grads, grads)]\n",
        "            # update preconditioner\n",
        "            v = torch.cat([torch.flatten(v) for v in vs])\n",
        "            h = torch.cat([torch.flatten(h) for h in Hvs])\n",
        "            if self.exact_hessian_vector_product:\n",
        "                update_precond_Xmat_math_(self._a, self._b,\n",
        "                                         v, h, step=self.lr_preconditioner, tiny=self._tiny)\n",
        "            else:  # compensate the levels of v and h; helpful to reduce numerical errors in half-precision training\n",
        "                update_precond_Xmat_math_(self._a, self._b,\n",
        "                                         v/self._delta_param_scale, h/self._delta_param_scale,\n",
        "                                         step=self.lr_preconditioner, tiny=self._tiny)\n",
        "        else:\n",
        "            # only evaluates the gradients\n",
        "            with torch.enable_grad():\n",
        "                closure_returns = closure()\n",
        "                loss = closure_returns if isinstance(closure_returns, torch.Tensor) else closure_returns[0]\n",
        "                grads = torch.autograd.grad(loss, self._params_with_grad)\n",
        "            vs = None  # no vs and Hvs\n",
        "\n",
        "        # preconditioned gradients; momentum is optional        \n",
        "        grad = torch.cat([torch.flatten(g) for g in grads])\n",
        "        if self.momentum > 0:\n",
        "            if self._m is None:\n",
        "                self._m = (1 - self.momentum)*grad\n",
        "            else:\n",
        "                self._m.mul_(self.momentum).add_((1 - self.momentum)*grad)\n",
        "            pre_grad = precond_grad_Xmat_math(self._a, self._b, self._m)\n",
        "        else:\n",
        "            self._m = None # clean the buffer when momentum is set to zero again \n",
        "            pre_grad = precond_grad_Xmat_math(self._a, self._b, grad)\n",
        "        \n",
        "        # gradient clipping is optional\n",
        "        if self.grad_clip_max_norm is None:\n",
        "            lr = self.lr_params\n",
        "        else:\n",
        "            grad_norm = torch.linalg.vector_norm(pre_grad) + self._tiny\n",
        "            lr = self.lr_params * min(self.grad_clip_max_norm / grad_norm, 1.0)\n",
        "\n",
        "        # update the parameters\n",
        "        if self.exact_hessian_vector_product or (vs is None):\n",
        "            [param.subtract_(lr * pre_grad[j - i:j].view_as(param))\n",
        "             for (param, i, j) in zip(self._params_with_grad, self._param_sizes, self._param_cumsizes)]\n",
        "        else:  # in this case, do not forget to remove the perturbation on parameters\n",
        "            [param.subtract_(lr * pre_grad[j - i:j].view_as(param) + v)\n",
        "             for (param, i, j, v) in zip(self._params_with_grad, self._param_sizes, self._param_cumsizes, vs)]\n",
        "        # return whatever closure returns\n",
        "        return closure_returns\n",
        "\n",
        "################## end of Xmat preconditioner #################################\n",
        "\n",
        "\n",
        "###############################################################################\n",
        "# The classic Newton–Raphson type preconditioner.\n",
        "# Clearly, it is applicable only to small scale problems \n",
        "#\n",
        "\n",
        "# @torch.jit.script\n",
        "def update_precond_newton_math_(Q, v, h, step, tiny):\n",
        "    # type: (Tensor, Tensor, Tensor, float, float) -> None\n",
        "    \"\"\"\n",
        "    Update the classic Newton–Raphson type preconditioner P = Q'*Q with (v, h).\n",
        "    \"\"\"\n",
        "    a = Q.mm(h)\n",
        "    b = torch.linalg.solve_triangular(Q.t(), v, upper=False)\n",
        "    grad = torch.triu(a.mm(a.t()) - b.mm(b.t()))\n",
        "    mu = step/(grad.abs().max() + tiny)      \n",
        "    Q.sub_(mu*grad.mm(Q))\n",
        "\n",
        "class Newton:\n",
        "    \"\"\"\n",
        "    Implements the classic Newton–Raphson type preconditioner for SGD as a class.\n",
        "    Args for initialization:\n",
        "        params_with_grad: a list of parameters or variables requiring gradients;\n",
        "        preconditioner_init_scale: initial scale of Q, i.e., Q = preconditioner_init_scale*eye();\n",
        "        lr_params: normalized learning rate for parameters in range [0, 1];\n",
        "        lr_preconditioner: normalized learning rate for preconditioner in range [0, 1];\n",
        "        momentum: momentum factor in range [0,1);\n",
        "        grad_clip_max_norm: maximum allowable gradient norm after clipping, None for no clipping;\n",
        "        preconditioner_update_probability: probability on updating Q, 1 for updating at every step, and 0 for never, i.e., SGD;\n",
        "        exact_hessian_vector_product: True for exact Hessian-vector product via 2nd derivative,\n",
        "                                    and False for approximate one via finite-difference formulae.\n",
        "    Notes:\n",
        "        Note 1: The Hessian-vector product can be approximated using the finite-difference formulae by setting\n",
        "        exact_hessian_vector_product = False when the 2nd derivatives is not available.\n",
        "        In this case, make sure that the closure produces the same outputs given the same inputs,\n",
        "        except for numerical errors due to non-deterministic behaviors.\n",
        "        Random numbers, if any, used inside the closure should be generated starting from the same state, where the rng state can be\n",
        "        read and set by, e.g., `torch.cuda.get_rng_state' and `torch.cuda.set_rng_state', respectively.\n",
        "        \n",
        "        Note 2: Momentum here is the moving average of gradient so that its setting is decoupled from the learning rate.\n",
        "        This is necessary as the learning rate in PSGD is normalized.\n",
        "        Note 3: Currently, no support of sparse and mixed-precision gradients.\n",
        "        Note 4: lr_params, lr_preconditioner, momentum, grad_clip_max_norm, preconditioner_update_probability, and\n",
        "        exact_hessian_vector_product (bool) all can be reset on the fly.\n",
        "    \"\"\"\n",
        "    def __init__(self, params_with_grad, preconditioner_init_scale=1.0,\n",
        "                 lr_params=0.01, lr_preconditioner=0.01, momentum=0.0, \n",
        "                 grad_clip_max_norm=None, preconditioner_update_probability=1.0,\n",
        "                 exact_hessian_vector_product: bool = True):\n",
        "        # mutable members\n",
        "        self.lr_params = lr_params\n",
        "        self.lr_preconditioner = lr_preconditioner\n",
        "        self.momentum = momentum if (0<momentum<1) else 0.0\n",
        "        self.grad_clip_max_norm = grad_clip_max_norm\n",
        "        self.preconditioner_update_probability = preconditioner_update_probability\n",
        "        self.exact_hessian_vector_product = exact_hessian_vector_product\n",
        "        # protected members\n",
        "        params_with_grad = [params_with_grad, ] if isinstance(params_with_grad, torch.Tensor) else params_with_grad\n",
        "        self._params_with_grad = [param for param in params_with_grad if param.requires_grad]  # double check requires_grad flag\n",
        "        dtype, device = self._params_with_grad[0].dtype, self._params_with_grad[0].device\n",
        "        self._tiny = torch.finfo(dtype).tiny\n",
        "        self._delta_param_scale = torch.finfo(dtype).eps ** 0.5\n",
        "        self._param_sizes = [torch.numel(param) for param in self._params_with_grad]\n",
        "        self._param_cumsizes = torch.cumsum(torch.tensor(self._param_sizes), 0)\n",
        "        num_params = self._param_cumsizes[-1]\n",
        "        self._Q = torch.eye(num_params, dtype=dtype, device=device)*preconditioner_init_scale\n",
        "        self._m = None # buffer for momentum \n",
        "\n",
        "    @torch.no_grad()\n",
        "    def step(self, closure):\n",
        "        \"\"\"\n",
        "        Performs a single step of PSGD with Newton–Raphson preconditioner, i.e.,\n",
        "        updating the trainable parameters once, and returning what closure returns.\n",
        "        Args:\n",
        "            closure (callable): a closure that evaluates the function of self._params_with_grad,\n",
        "                                and returns the loss, or an iterable with the first one being loss.\n",
        "                                Random numbers, if any, used inside the closure should be generated starting\n",
        "                                from the same rng state if self.exact_hessian_vector_product = False; otherwise doesn't matter.\n",
        "        \"\"\"\n",
        "        if torch.rand([]) < self.preconditioner_update_probability:\n",
        "            # evaluates gradients, Hessian-vector product, and updates the preconditioner\n",
        "            if self.exact_hessian_vector_product:\n",
        "                # exact Hessian-vector product\n",
        "                with torch.enable_grad():\n",
        "                    closure_returns = closure()\n",
        "                    loss = closure_returns if isinstance(closure_returns, torch.Tensor) else closure_returns[0]\n",
        "                    grads = torch.autograd.grad(loss, self._params_with_grad, create_graph=True)\n",
        "                    vs = [torch.randn_like(param) for param in self._params_with_grad]\n",
        "                    Hvs = torch.autograd.grad(grads, self._params_with_grad, vs)\n",
        "            else:\n",
        "                # approximate Hessian-vector product via finite-difference formulae. Use it with cautions.\n",
        "                with torch.enable_grad():\n",
        "                    closure_returns = closure()\n",
        "                    loss = closure_returns if isinstance(closure_returns, torch.Tensor) else closure_returns[0]\n",
        "                    grads = torch.autograd.grad(loss, self._params_with_grad)\n",
        "                vs = [self._delta_param_scale * torch.randn_like(param) for param in self._params_with_grad]\n",
        "                [param.add_(v) for (param, v) in zip(self._params_with_grad, vs)]\n",
        "                with torch.enable_grad():\n",
        "                    perturbed_returns = closure()\n",
        "                    perturbed_loss = perturbed_returns if isinstance(perturbed_returns, torch.Tensor) else perturbed_returns[0]\n",
        "                    perturbed_grads = torch.autograd.grad(perturbed_loss, self._params_with_grad)\n",
        "                Hvs = [perturbed_g - g for (perturbed_g, g) in zip(perturbed_grads, grads)]\n",
        "            # update preconditioner\n",
        "            v = torch.cat([torch.flatten(v) for v in vs])\n",
        "            h = torch.cat([torch.flatten(h) for h in Hvs])\n",
        "            if self.exact_hessian_vector_product:\n",
        "                update_precond_newton_math_(self._Q,\n",
        "                                            v[:,None], h[:,None], step=self.lr_preconditioner, tiny=self._tiny)\n",
        "            else:  # compensate the levels of v and h; helpful to reduce numerical errors in half-precision training\n",
        "                update_precond_newton_math_(self._Q,\n",
        "                                            v[:,None]/self._delta_param_scale, h[:,None]/self._delta_param_scale,\n",
        "                                            step=self.lr_preconditioner, tiny=self._tiny)\n",
        "        else:\n",
        "            # only evaluates the gradients\n",
        "            with torch.enable_grad():\n",
        "                closure_returns = closure()\n",
        "                loss = closure_returns if isinstance(closure_returns, torch.Tensor) else closure_returns[0]\n",
        "                grads = torch.autograd.grad(loss, self._params_with_grad)\n",
        "            vs = None  # no vs and Hvs\n",
        "\n",
        "        # preconditioned gradients; momentum is optional        \n",
        "        grad = torch.cat([torch.flatten(g) for g in grads])\n",
        "        if self.momentum > 0:\n",
        "            if self._m is None:\n",
        "                self._m = (1 - self.momentum)*grad\n",
        "            else:\n",
        "                self._m.mul_(self.momentum).add_((1 - self.momentum)*grad)\n",
        "            pre_grad = self._Q.t() @ (self._Q @ self._m)\n",
        "        else:\n",
        "            self._m = None # clean the buffer when momentum is set to zero again \n",
        "            pre_grad = self._Q.t() @ (self._Q @ grad)\n",
        "        \n",
        "        # gradient clipping is optional\n",
        "        if self.grad_clip_max_norm is None:\n",
        "            lr = self.lr_params\n",
        "        else:\n",
        "            grad_norm = torch.linalg.vector_norm(pre_grad) + self._tiny\n",
        "            lr = self.lr_params * min(self.grad_clip_max_norm / grad_norm, 1.0)\n",
        "\n",
        "        # update the parameters\n",
        "        if self.exact_hessian_vector_product or (vs is None):\n",
        "            [param.subtract_(lr * pre_grad[j - i:j].view_as(param))\n",
        "             for (param, i, j) in zip(self._params_with_grad, self._param_sizes, self._param_cumsizes)]\n",
        "        else:  # in this case, do not forget to remove the perturbation on parameters\n",
        "            [param.subtract_(lr * pre_grad[j - i:j].view_as(param) + v)\n",
        "             for (param, i, j, v) in zip(self._params_with_grad, self._param_sizes, self._param_cumsizes, vs)]\n",
        "        # return whatever closure returns\n",
        "        return closure_returns\n",
        "\n",
        "################## end of Newton–Raphson preconditioner #################################"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OHpYeQrOLcmZ"
      },
      "source": [
        "### Define Parameters and Helpers, and start Training\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LNLegt4jLkRD"
      },
      "outputs": [],
      "source": [
        "# Result Visualization Helper\n",
        "import math\n",
        "from matplotlib import pyplot as plt\n",
        "\n",
        "M, N = 4, 6\n",
        "RESULT_IMG_PATH = '/tmp/test_result.png'\n",
        "\n",
        "def plot_results(images, labels, preds):\n",
        "  images, labels, preds = images[:M*N], labels[:M*N], preds[:M*N]\n",
        "  inv_norm = transforms.Normalize((-0.1307/0.3081,), (1/0.3081,))\n",
        "\n",
        "  num_images = images.shape[0]\n",
        "  fig, axes = plt.subplots(M, N, figsize=(11, 9))\n",
        "  fig.suptitle('Correct / Predicted Labels (Red text for incorrect ones)')\n",
        "\n",
        "  for i, ax in enumerate(fig.axes):\n",
        "    ax.axis('off')\n",
        "    if i >= num_images:\n",
        "      continue\n",
        "    img, label, prediction = images[i], labels[i], preds[i]\n",
        "    img = inv_norm(img)\n",
        "    img = img.squeeze() # [1,Y,X] -> [Y,X]\n",
        "    label, prediction = label.item(), prediction.item()\n",
        "    if label == prediction:\n",
        "      ax.set_title(u'\\u2713', color='blue', fontsize=22)\n",
        "    else:\n",
        "      ax.set_title(\n",
        "          'X {}/{}'.format(label, prediction), color='red')\n",
        "    ax.imshow(img)\n",
        "  plt.savefig(RESULT_IMG_PATH, transparent=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kNh-oEmHmorI"
      },
      "outputs": [],
      "source": [
        "# Define Parameters\n",
        "FLAGS = {}\n",
        "FLAGS['datadir'] = \"/tmp/mnist\"\n",
        "FLAGS['batch_size'] = 128\n",
        "FLAGS['num_workers'] = 2\n",
        "FLAGS['learning_rate'] = 0.1\n",
        "FLAGS['momentum'] = 0.5\n",
        "FLAGS['num_epochs'] = 10\n",
        "FLAGS['num_cores'] = 8\n",
        "FLAGS['log_steps'] = 20\n",
        "FLAGS['metrics_debug'] = False"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pTmxZL5ymp8P"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import os\n",
        "import time\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torch.optim as optim\n",
        "import torch_xla\n",
        "import torch_xla.core.xla_model as xm\n",
        "import torch_xla.debug.metrics as met\n",
        "import torch_xla.distributed.parallel_loader as pl\n",
        "import torch_xla.distributed.xla_multiprocessing as xmp\n",
        "import torch_xla.utils.utils as xu\n",
        "from torchvision import datasets, transforms\n",
        "\n",
        "\n",
        "SERIAL_EXEC = xmp.MpSerialExecutor()\n",
        "\n",
        "class MNIST(nn.Module):\n",
        "\n",
        "  def __init__(self):\n",
        "    super(MNIST, self).__init__()\n",
        "    self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n",
        "    self.bn1 = nn.BatchNorm2d(10)\n",
        "    self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n",
        "    self.bn2 = nn.BatchNorm2d(20)\n",
        "    self.fc1 = nn.Linear(320, 50)\n",
        "    self.fc2 = nn.Linear(50, 10)\n",
        "\n",
        "  def forward(self, x):\n",
        "    x = F.relu(F.max_pool2d(self.conv1(x), 2))\n",
        "    x = self.bn1(x)\n",
        "    x = F.relu(F.max_pool2d(self.conv2(x), 2))\n",
        "    x = self.bn2(x)\n",
        "    x = torch.flatten(x, 1)\n",
        "    x = F.relu(self.fc1(x))\n",
        "    x = self.fc2(x)\n",
        "    return F.log_softmax(x, dim=1)\n",
        "\n",
        "# Only instantiate model weights once in memory.\n",
        "WRAPPED_MODEL = xmp.MpModelWrapper(MNIST())\n",
        "\n",
        "def train_mnist():\n",
        "  torch.manual_seed(1)\n",
        "  \n",
        "  def get_dataset():\n",
        "    norm = transforms.Normalize((0.1307,), (0.3081,))\n",
        "    train_dataset = datasets.MNIST(\n",
        "        FLAGS['datadir'],\n",
        "        train=True,\n",
        "        download=True,\n",
        "        transform=transforms.Compose(\n",
        "            [transforms.ToTensor(), norm]))\n",
        "    test_dataset = datasets.MNIST(\n",
        "        FLAGS['datadir'],\n",
        "        train=False,\n",
        "        download=True,\n",
        "        transform=transforms.Compose(\n",
        "            [transforms.ToTensor(), norm]))\n",
        "    \n",
        "    return train_dataset, test_dataset\n",
        "  \n",
        "  # Using the serial executor avoids multiple processes to\n",
        "  # download the same data.\n",
        "  train_dataset, test_dataset = SERIAL_EXEC.run(get_dataset)\n",
        "\n",
        "  train_sampler = torch.utils.data.distributed.DistributedSampler(\n",
        "    train_dataset,\n",
        "    num_replicas=xm.xrt_world_size(),\n",
        "    rank=xm.get_ordinal(),\n",
        "    shuffle=True)\n",
        "  train_loader = torch.utils.data.DataLoader(\n",
        "      train_dataset,\n",
        "      batch_size=FLAGS['batch_size'],\n",
        "      sampler=train_sampler,\n",
        "      num_workers=FLAGS['num_workers'],\n",
        "      drop_last=True)\n",
        "  test_loader = torch.utils.data.DataLoader(\n",
        "      test_dataset,\n",
        "      batch_size=FLAGS['batch_size'],\n",
        "      shuffle=False,\n",
        "      num_workers=FLAGS['num_workers'],\n",
        "      drop_last=True)\n",
        "\n",
        "  # Scale learning rate to world size\n",
        "  lr = FLAGS['learning_rate'] * xm.xrt_world_size()\n",
        "\n",
        "  # Get loss function, optimizer, and model\n",
        "  device = xm.xla_device()\n",
        "  model = WRAPPED_MODEL.to(device)\n",
        "  # optimizer = optim.SGD(model.parameters(), lr=lr, momentum=FLAGS['momentum'])\n",
        "  optimizer = XMat(model.parameters(),lr_params=lr,momentum=0.9,preconditioner_update_probability=0.1)\n",
        "  loss_fn = nn.NLLLoss()\n",
        "\n",
        "  def train_loop_fn(loader):\n",
        "    tracker = xm.RateTracker()\n",
        "    model.train()\n",
        "    for x, (data, target) in enumerate(loader):\n",
        "      output = model(data)\n",
        "      loss = loss_fn(output, target)\n",
        "      def closure():\n",
        "          return loss\n",
        "      # if using single TPU one can use mark_step after standatd optimizer.step(closure)\n",
        "      #xm.mark_step()\n",
        "      xm.optimizer_step(optimizer, optimizer_args={'closure':closure})\n",
        "      tracker.add(FLAGS['batch_size'])\n",
        "      if x % FLAGS['log_steps'] == 0:\n",
        "        print('[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format(\n",
        "            xm.get_ordinal(), x, loss.item(), tracker.rate(),\n",
        "            tracker.global_rate(), time.asctime()), flush=True)\n",
        "\n",
        "  def test_loop_fn(loader):\n",
        "    total_samples = 0\n",
        "    correct = 0\n",
        "    model.eval()\n",
        "    data, pred, target = None, None, None\n",
        "    for data, target in loader:\n",
        "      output = model(data)\n",
        "      pred = output.max(1, keepdim=True)[1]\n",
        "      correct += pred.eq(target.view_as(pred)).sum().item()\n",
        "      total_samples += data.size()[0]\n",
        "\n",
        "    accuracy = 100.0 * correct / total_samples\n",
        "    print('[xla:{}] Accuracy={:.2f}%'.format(\n",
        "        xm.get_ordinal(), accuracy), flush=True)\n",
        "    return accuracy, data, pred, target\n",
        "\n",
        "  # Train and eval loops\n",
        "  accuracy = 0.0\n",
        "  data, pred, target = None, None, None\n",
        "  for epoch in range(1, FLAGS['num_epochs'] + 1):\n",
        "    para_loader = pl.ParallelLoader(train_loader, [device])\n",
        "    train_loop_fn(para_loader.per_device_loader(device))\n",
        "    xm.master_print(\"Finished training epoch {}\".format(epoch))\n",
        "\n",
        "    para_loader = pl.ParallelLoader(test_loader, [device])\n",
        "    accuracy, data, pred, target  = test_loop_fn(para_loader.per_device_loader(device))\n",
        "    if FLAGS['metrics_debug']:\n",
        "      xm.master_print(met.metrics_report(), flush=True)\n",
        "\n",
        "  return accuracy, data, pred, target"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Afwo4H7kSd8P",
        "outputId": "3711fb7a-213e-49a2-83ed-8ecdbe17ce19"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[xla:0](0) Loss=2.33622 Rate=360.25 GlobalRate=360.23 Time=Tue Mar 21 22:05:38 2023\n",
            "[xla:2](0) Loss=2.36655 Rate=170.40 GlobalRate=170.39 Time=Tue Mar 21 22:05:40 2023\n",
            "[xla:3](0) Loss=2.35576 Rate=77.56 GlobalRate=77.56 Time=Tue Mar 21 22:05:41 2023\n",
            "[xla:4](0) Loss=2.38184 Rate=64.69 GlobalRate=64.69 Time=Tue Mar 21 22:05:42 2023\n",
            "[xla:0](20) Loss=0.36951 Rate=519.97 GlobalRate=605.15 Time=Tue Mar 21 22:05:42 2023\n",
            "[xla:1](0) Loss=2.36546 Rate=44.22 GlobalRate=44.22 Time=Tue Mar 21 22:05:43 2023\n",
            "[xla:7](0) Loss=2.33757 Rate=36.86 GlobalRate=36.86 Time=Tue Mar 21 22:05:44 2023\n",
            "[xla:6](0) Loss=2.39576 Rate=29.96 GlobalRate=29.96 Time=Tue Mar 21 22:05:46 2023\n",
            "[xla:2](20) Loss=0.38797 Rate=241.12 GlobalRate=279.08 Time=Tue Mar 21 22:05:49 2023\n",
            "[xla:5](0) Loss=2.35329 Rate=27.54 GlobalRate=27.54 Time=Tue Mar 21 22:05:50 2023\n",
            "[xla:3](20) Loss=0.40883 Rate=162.62 GlobalRate=201.76 Time=Tue Mar 21 22:05:53 2023\n",
            "[xla:4](20) Loss=0.19045 Rate=148.03 GlobalRate=184.71 Time=Tue Mar 21 22:05:54 2023\n",
            "[xla:0](40) Loss=0.27515 Rate=328.07 GlobalRate=304.52 Time=Tue Mar 21 22:05:55 2023\n",
            "[xla:1](20) Loss=0.20308 Rate=134.13 GlobalRate=167.11 Time=Tue Mar 21 22:05:56 2023\n",
            "[xla:7](20) Loss=0.29599 Rate=132.59 GlobalRate=162.84 Time=Tue Mar 21 22:05:57 2023\n",
            "[xla:6](20) Loss=0.17262 Rate=138.84 GlobalRate=164.09 Time=Tue Mar 21 22:05:59 2023\n",
            "[xla:2](40) Loss=0.19393 Rate=232.29 GlobalRate=250.63 Time=Tue Mar 21 22:06:00 2023\n",
            "[xla:5](20) Loss=0.33367 Rate=154.72 GlobalRate=175.27 Time=Tue Mar 21 22:06:01 2023\n",
            "Finished training epoch 1\n",
            "[xla:3](40) Loss=0.21777 Rate=227.97 GlobalRate=230.68 Time=Tue Mar 21 22:06:02 2023\n",
            "[xla:4](40) Loss=0.08850 Rate=225.24 GlobalRate=220.46 Time=Tue Mar 21 22:06:04 2023\n",
            "[xla:1](40) Loss=0.14228 Rate=240.80 GlobalRate=216.03 Time=Tue Mar 21 22:06:05 2023\n",
            "[xla:7](40) Loss=0.14473 Rate=252.79 GlobalRate=216.89 Time=Tue Mar 21 22:06:05 2023\n",
            "[xla:6](40) Loss=0.18559 Rate=252.49 GlobalRate=217.04 Time=Tue Mar 21 22:06:06 2023\n",
            "[xla:5](40) Loss=0.14162 Rate=246.82 GlobalRate=221.98 Time=Tue Mar 21 22:06:09 2023\n",
            "[xla:0] Accuracy=95.95%\n",
            "[xla:0](0) Loss=0.18850 Rate=43.19 GlobalRate=43.19 Time=Tue Mar 21 22:06:44 2023\n",
            "[xla:2] Accuracy=96.64%\n",
            "[xla:3] Accuracy=93.13%\n",
            "[xla:4] Accuracy=95.04%\n",
            "[xla:2](0) Loss=0.10780 Rate=61.48 GlobalRate=61.48 Time=Tue Mar 21 22:06:49 2023\n",
            "[xla:3](0) Loss=0.24743 Rate=60.59 GlobalRate=60.59 Time=Tue Mar 21 22:06:50 2023\n",
            "[xla:1] Accuracy=94.62%\n",
            "[xla:7] Accuracy=95.84%\n",
            "[xla:4](0) Loss=0.15213 Rate=56.46 GlobalRate=56.46 Time=Tue Mar 21 22:06:51 2023\n",
            "[xla:6] Accuracy=95.15%\n",
            "[xla:0](20) Loss=0.13302 Rate=225.15 GlobalRate=259.64 Time=Tue Mar 21 22:06:52 2023\n",
            "[xla:1](0) Loss=0.21882 Rate=62.83 GlobalRate=62.82 Time=Tue Mar 21 22:06:53 2023\n",
            "[xla:7](0) Loss=0.21765 Rate=46.04 GlobalRate=46.04 Time=Tue Mar 21 22:06:53 2023\n",
            "[xla:5] Accuracy=96.33%\n",
            "[xla:6](0) Loss=0.17167 Rate=32.39 GlobalRate=32.39 Time=Tue Mar 21 22:06:56 2023\n",
            "[xla:2](20) Loss=0.12297 Rate=153.15 GlobalRate=191.59 Time=Tue Mar 21 22:07:00 2023\n",
            "[xla:5](0) Loss=0.14098 Rate=26.22 GlobalRate=26.22 Time=Tue Mar 21 22:07:01 2023\n",
            "[xla:3](20) Loss=0.14321 Rate=147.93 GlobalRate=185.00 Time=Tue Mar 21 22:07:02 2023\n",
            "[xla:4](20) Loss=0.09682 Rate=131.42 GlobalRate=164.10 Time=Tue Mar 21 22:07:05 2023\n",
            "[xla:0](40) Loss=0.16615 Rate=201.29 GlobalRate=217.20 Time=Tue Mar 21 22:07:06 2023\n",
            "[xla:1](20) Loss=0.05411 Rate=134.70 GlobalRate=167.42 Time=Tue Mar 21 22:07:07 2023\n",
            "[xla:7](20) Loss=0.08718 Rate=130.34 GlobalRate=162.87 Time=Tue Mar 21 22:07:07 2023\n",
            "[xla:6](20) Loss=0.02695 Rate=138.96 GlobalRate=166.52 Time=Tue Mar 21 22:07:08 2023\n",
            "[xla:2](40) Loss=0.03967 Rate=217.87 GlobalRate=220.15 Time=Tue Mar 21 22:07:10 2023\n",
            "[xla:5](20) Loss=0.18459 Rate=168.21 GlobalRate=183.86 Time=Tue Mar 21 22:07:10 2023\n",
            "Finished training epoch 2\n",
            "[xla:3](40) Loss=0.13654 Rate=228.49 GlobalRate=222.36 Time=Tue Mar 21 22:07:11 2023\n",
            "[xla:4](40) Loss=0.03001 Rate=248.31 GlobalRate=216.61 Time=Tue Mar 21 22:07:13 2023\n",
            "[xla:1](40) Loss=0.06658 Rate=251.89 GlobalRate=220.38 Time=Tue Mar 21 22:07:14 2023\n",
            "[xla:7](40) Loss=0.03073 Rate=251.44 GlobalRate=216.76 Time=Tue Mar 21 22:07:15 2023\n",
            "[xla:6](40) Loss=0.04376 Rate=246.72 GlobalRate=217.06 Time=Tue Mar 21 22:07:16 2023\n",
            "[xla:5](40) Loss=0.03663 Rate=234.31 GlobalRate=220.35 Time=Tue Mar 21 22:07:20 2023\n",
            "[xla:0] Accuracy=97.80%\n",
            "[xla:0](0) Loss=0.08700 Rate=50.21 GlobalRate=50.21 Time=Tue Mar 21 22:07:52 2023\n",
            "[xla:2] Accuracy=97.83%\n",
            "[xla:3] Accuracy=96.95%\n",
            "[xla:4] Accuracy=96.88%\n",
            "[xla:2](0) Loss=0.06367 Rate=56.78 GlobalRate=56.78 Time=Tue Mar 21 22:07:56 2023\n",
            "[xla:3](0) Loss=0.05326 Rate=57.95 GlobalRate=57.95 Time=Tue Mar 21 22:07:58 2023\n",
            "[xla:7] Accuracy=96.45%\n",
            "[xla:1] Accuracy=97.33%\n",
            "[xla:4](0) Loss=0.06497 Rate=46.38 GlobalRate=46.38 Time=Tue Mar 21 22:07:59 2023\n",
            "[xla:6] Accuracy=97.05%\n",
            "[xla:0](20) Loss=0.11849 Rate=231.57 GlobalRate=273.95 Time=Tue Mar 21 22:07:59 2023\n",
            "[xla:7](0) Loss=0.12760 Rate=54.93 GlobalRate=54.93 Time=Tue Mar 21 22:08:00 2023\n",
            "[xla:1](0) Loss=0.04367 Rate=41.00 GlobalRate=41.00 Time=Tue Mar 21 22:08:02 2023\n",
            "[xla:5] Accuracy=97.52%[xla:6](0) Loss=0.12088 Rate=41.61 GlobalRate=41.61 Time=Tue Mar 21 22:08:02 2023\n",
            "\n",
            "[xla:2](20) Loss=0.04601 Rate=182.56 GlobalRate=226.57 Time=Tue Mar 21 22:08:06 2023\n",
            "[xla:5](0) Loss=0.05954 Rate=34.74 GlobalRate=34.74 Time=Tue Mar 21 22:08:06 2023\n",
            "[xla:3](20) Loss=0.13148 Rate=161.44 GlobalRate=201.83 Time=Tue Mar 21 22:08:09 2023\n",
            "[xla:4](20) Loss=0.06540 Rate=161.35 GlobalRate=198.87 Time=Tue Mar 21 22:08:09 2023\n",
            "[xla:0](40) Loss=0.09040 Rate=215.37 GlobalRate=235.06 Time=Tue Mar 21 22:08:12 2023\n",
            "[xla:7](20) Loss=0.03470 Rate=142.31 GlobalRate=178.08 Time=Tue Mar 21 22:08:13 2023\n",
            "[xla:1](20) Loss=0.01737 Rate=139.65 GlobalRate=172.48 Time=Tue Mar 21 22:08:14 2023\n",
            "[xla:6](20) Loss=0.03595 Rate=145.59 GlobalRate=179.34 Time=Tue Mar 21 22:08:14 2023\n",
            "[xla:2](40) Loss=0.02014 Rate=221.31 GlobalRate=236.17 Time=Tue Mar 21 22:08:16 2023\n",
            "[xla:5](20) Loss=0.07198 Rate=161.11 GlobalRate=190.39 Time=Tue Mar 21 22:08:17 2023\n",
            "Finished training epoch 3\n",
            "[xla:3](40) Loss=0.05570 Rate=231.30 GlobalRate=232.93 Time=Tue Mar 21 22:08:18 2023\n",
            "[xla:4](40) Loss=0.01860 Rate=236.75 GlobalRate=233.91 Time=Tue Mar 21 22:08:18 2023\n",
            "[xla:7](40) Loss=0.02674 Rate=263.94 GlobalRate=233.10 Time=Tue Mar 21 22:08:20 2023\n",
            "[xla:1](40) Loss=0.03008 Rate=258.80 GlobalRate=226.67 Time=Tue Mar 21 22:08:22 2023\n",
            "[xla:6](40) Loss=0.02904 Rate=256.24 GlobalRate=230.73 Time=Tue Mar 21 22:08:22 2023\n",
            "[xla:5](40) Loss=0.01557 Rate=241.53 GlobalRate=230.25 Time=Tue Mar 21 22:08:25 2023\n",
            "[xla:0] Accuracy=97.90%\n",
            "[xla:0](0) Loss=0.07209 Rate=64.40 GlobalRate=64.40 Time=Tue Mar 21 22:09:04 2023\n",
            "[xla:2] Accuracy=97.85%\n",
            "[xla:2](0) Loss=0.05928 Rate=55.63 GlobalRate=55.62 Time=Tue Mar 21 22:09:07 2023\n",
            "[xla:4] Accuracy=97.34%\n",
            "[xla:3] Accuracy=97.07%\n",
            "[xla:7] Accuracy=97.48%\n",
            "[xla:4](0) Loss=0.04034 Rate=51.65 GlobalRate=51.65 Time=Tue Mar 21 22:09:10 2023\n",
            "[xla:3](0) Loss=0.01812 Rate=54.48 GlobalRate=54.48 Time=Tue Mar 21 22:09:10 2023\n",
            "[xla:1] Accuracy=97.99%\n",
            "[xla:6] Accuracy=97.57%\n",
            "[xla:0](20) Loss=0.06191 Rate=220.40 GlobalRate=272.09 Time=Tue Mar 21 22:09:12 2023\n",
            "[xla:7](0) Loss=0.09014 Rate=34.26 GlobalRate=34.26 Time=Tue Mar 21 22:09:14 2023\n",
            "[xla:5] Accuracy=97.49%\n",
            "[xla:1](0) Loss=0.02456 Rate=32.49 GlobalRate=32.49 Time=Tue Mar 21 22:09:15 2023\n",
            "[xla:6](0) Loss=0.04913 Rate=31.73 GlobalRate=31.73 Time=Tue Mar 21 22:09:16 2023\n",
            "[xla:2](20) Loss=0.04067 Rate=168.61 GlobalRate=210.07 Time=Tue Mar 21 22:09:18 2023\n",
            "[xla:5](0) Loss=0.05233 Rate=33.90 GlobalRate=33.90 Time=Tue Mar 21 22:09:20 2023\n",
            "[xla:4](20) Loss=0.03176 Rate=148.53 GlobalRate=185.50 Time=Tue Mar 21 22:09:22 2023\n",
            "[xla:3](20) Loss=0.10608 Rate=144.70 GlobalRate=181.05 Time=Tue Mar 21 22:09:23 2023\n",
            "[xla:0](40) Loss=0.06689 Rate=215.23 GlobalRate=238.90 Time=Tue Mar 21 22:09:24 2023\n",
            "[xla:7](20) Loss=0.03182 Rate=151.41 GlobalRate=180.52 Time=Tue Mar 21 22:09:25 2023\n",
            "[xla:1](20) Loss=0.00348 Rate=155.76 GlobalRate=182.88 Time=Tue Mar 21 22:09:26 2023[xla:6](20) Loss=0.01131 Rate=158.71 GlobalRate=184.70 Time=Tue Mar 21 22:09:26 2023\n",
            "\n",
            "[xla:2](40) Loss=0.00976 Rate=214.99 GlobalRate=226.15 Time=Tue Mar 21 22:09:28 2023\n",
            "[xla:5](20) Loss=0.02958 Rate=185.75 GlobalRate=211.72 Time=Tue Mar 21 22:09:28 2023\n",
            "Finished training epoch 4\n",
            "[xla:4](40) Loss=0.03273 Rate=253.14 GlobalRate=234.09 Time=Tue Mar 21 22:09:30 2023\n",
            "[xla:3](40) Loss=0.07277 Rate=251.83 GlobalRate=230.52 Time=Tue Mar 21 22:09:31 2023\n",
            "[xla:7](40) Loss=0.03693 Rate=249.77 GlobalRate=228.09 Time=Tue Mar 21 22:09:33 2023\n",
            "[xla:6](40) Loss=0.02676 Rate=240.75 GlobalRate=226.03 Time=Tue Mar 21 22:09:35 2023\n",
            "[xla:1](40) Loss=0.01250 Rate=235.66 GlobalRate=222.77 Time=Tue Mar 21 22:09:35 2023\n",
            "[xla:5](40) Loss=0.00496 Rate=232.01 GlobalRate=233.92 Time=Tue Mar 21 22:09:38 2023\n",
            "[xla:0] Accuracy=98.03%\n",
            "[xla:0](0) Loss=0.02842 Rate=54.06 GlobalRate=54.06 Time=Tue Mar 21 22:10:12 2023\n",
            "[xla:2] Accuracy=98.06%\n",
            "[xla:4] Accuracy=97.46%\n",
            "[xla:2](0) Loss=0.02857 Rate=53.25 GlobalRate=53.25 Time=Tue Mar 21 22:10:16 2023\n",
            "[xla:3] Accuracy=97.29%\n",
            "[xla:4](0) Loss=0.00612 Rate=47.93 GlobalRate=47.92 Time=Tue Mar 21 22:10:19 2023[xla:7] Accuracy=97.52%\n",
            "\n",
            "[xla:3](0) Loss=0.01115 Rate=46.93 GlobalRate=46.93 Time=Tue Mar 21 22:10:19 2023\n",
            "[xla:6] Accuracy=97.38%\n",
            "[xla:1] Accuracy=97.89%\n",
            "[xla:0](20) Loss=0.01925 Rate=191.95 GlobalRate=236.09 Time=Tue Mar 21 22:10:21 2023\n",
            "[xla:7](0) Loss=0.04916 Rate=37.35 GlobalRate=37.35 Time=Tue Mar 21 22:10:22 2023\n",
            "[xla:6](0) Loss=0.05503 Rate=33.86 GlobalRate=33.86 Time=Tue Mar 21 22:10:23 2023\n",
            "[xla:5] Accuracy=97.60%\n",
            "[xla:1](0) Loss=0.01686 Rate=32.18 GlobalRate=32.18 Time=Tue Mar 21 22:10:26 2023\n",
            "[xla:2](20) Loss=0.03497 Rate=167.68 GlobalRate=208.42 Time=Tue Mar 21 22:10:27 2023\n",
            "[xla:5](0) Loss=0.01175 Rate=30.01 GlobalRate=30.01 Time=Tue Mar 21 22:10:28 2023\n",
            "[xla:4](20) Loss=0.01793 Rate=138.76 GlobalRate=173.25 Time=Tue Mar 21 22:10:31 2023\n",
            "[xla:3](20) Loss=0.08693 Rate=134.89 GlobalRate=168.47 Time=Tue Mar 21 22:10:32 2023\n",
            "[xla:0](40) Loss=0.06259 Rate=197.14 GlobalRate=217.33 Time=Tue Mar 21 22:10:34 2023\n",
            "[xla:7](20) Loss=0.01855 Rate=138.24 GlobalRate=169.23 Time=Tue Mar 21 22:10:35 2023\n",
            "[xla:6](20) Loss=0.01825 Rate=141.92 GlobalRate=170.72 Time=Tue Mar 21 22:10:35 2023\n",
            "[xla:1](20) Loss=0.00893 Rate=151.99 GlobalRate=178.98 Time=Tue Mar 21 22:10:37 2023\n",
            "[xla:2](40) Loss=0.01231 Rate=208.71 GlobalRate=221.05 Time=Tue Mar 21 22:10:37 2023\n",
            "[xla:5](20) Loss=0.02507 Rate=168.25 GlobalRate=190.69 Time=Tue Mar 21 22:10:38 2023\n",
            "[xla:4](40) Loss=0.00263 Rate=239.72 GlobalRate=220.01 Time=Tue Mar 21 22:10:40 2023\n",
            "Finished training epoch 5\n",
            "[xla:3](40) Loss=0.03834 Rate=240.84 GlobalRate=217.09 Time=Tue Mar 21 22:10:41 2023\n",
            "[xla:7](40) Loss=0.02768 Rate=246.25 GlobalRate=219.33 Time=Tue Mar 21 22:10:43 2023\n",
            "[xla:6](40) Loss=0.01049 Rate=238.08 GlobalRate=216.71 Time=Tue Mar 21 22:10:44 2023\n",
            "[xla:1](40) Loss=0.00548 Rate=233.77 GlobalRate=219.59 Time=Tue Mar 21 22:10:45 2023\n",
            "[xla:5](40) Loss=0.00476 Rate=218.52 GlobalRate=216.38 Time=Tue Mar 21 22:10:48 2023\n",
            "[xla:0] Accuracy=98.11%\n",
            "[xla:0](0) Loss=0.03750 Rate=50.32 GlobalRate=50.32 Time=Tue Mar 21 22:11:21 2023\n",
            "[xla:2] Accuracy=97.13%\n",
            "[xla:2](0) Loss=0.02862 Rate=64.55 GlobalRate=64.55 Time=Tue Mar 21 22:11:24 2023\n",
            "[xla:4] Accuracy=97.88%\n",
            "[xla:3] Accuracy=97.32%\n",
            "[xla:4](0) Loss=0.00991 Rate=52.77 GlobalRate=52.77 Time=Tue Mar 21 22:11:27 2023\n",
            "[xla:3](0) Loss=0.01191 Rate=46.68 GlobalRate=46.68 Time=Tue Mar 21 22:11:27 2023\n",
            "[xla:6] Accuracy=97.40%\n",
            "[xla:7] Accuracy=97.64%\n",
            "[xla:0](20) Loss=0.01501 Rate=208.38 GlobalRate=251.15 Time=Tue Mar 21 22:11:29 2023\n",
            "[xla:1] Accuracy=98.16%\n",
            "[xla:6](0) Loss=0.05222 Rate=33.92 GlobalRate=33.92 Time=Tue Mar 21 22:11:32 2023\n",
            "[xla:5] Accuracy=97.40%\n",
            "[xla:7](0) Loss=0.01621 Rate=33.59 GlobalRate=33.59 Time=Tue Mar 21 22:11:32 2023\n",
            "[xla:1](0) Loss=0.00675 Rate=35.39 GlobalRate=35.39 Time=Tue Mar 21 22:11:33 2023\n",
            "[xla:2](20) Loss=0.01816 Rate=177.64 GlobalRate=222.15 Time=Tue Mar 21 22:11:34 2023\n",
            "[xla:5](0) Loss=0.00972 Rate=30.12 GlobalRate=30.12 Time=Tue Mar 21 22:11:36 2023\n",
            "[xla:4](20) Loss=0.00903 Rate=151.90 GlobalRate=189.71 Time=Tue Mar 21 22:11:38 2023\n",
            "[xla:3](20) Loss=0.08010 Rate=143.14 GlobalRate=178.21 Time=Tue Mar 21 22:11:40 2023\n",
            "[xla:0](40) Loss=0.02207 Rate=216.04 GlobalRate=235.56 Time=Tue Mar 21 22:11:41 2023\n",
            "[xla:6](20) Loss=0.01248 Rate=154.50 GlobalRate=183.21 Time=Tue Mar 21 22:11:43 2023\n",
            "[xla:1](20) Loss=0.01204 Rate=168.76 GlobalRate=198.35 Time=Tue Mar 21 22:11:43 2023\n",
            "[xla:7](20) Loss=0.00779 Rate=155.82 GlobalRate=184.14 Time=Tue Mar 21 22:11:43 2023[xla:2](40) Loss=0.00253 Rate=232.96 GlobalRate=243.11 Time=Tue Mar 21 22:11:43 2023\n",
            "\n",
            "[xla:5](20) Loss=0.03686 Rate=178.41 GlobalRate=199.36 Time=Tue Mar 21 22:11:45 2023\n",
            "Finished training epoch 6\n",
            "[xla:4](40) Loss=0.00277 Rate=244.96 GlobalRate=233.16 Time=Tue Mar 21 22:11:47 2023\n",
            "[xla:3](40) Loss=0.01157 Rate=240.15 GlobalRate=223.50 Time=Tue Mar 21 22:11:48 2023\n",
            "[xla:6](40) Loss=0.00877 Rate=237.21 GlobalRate=224.00 Time=Tue Mar 21 22:11:51 2023\n",
            "[xla:1](40) Loss=0.01079 Rate=232.21 GlobalRate=229.40 Time=Tue Mar 21 22:11:52 2023\n",
            "[xla:7](40) Loss=0.01102 Rate=224.75 GlobalRate=218.17 Time=Tue Mar 21 22:11:53 2023\n",
            "[xla:5](40) Loss=0.00438 Rate=208.51 GlobalRate=212.62 Time=Tue Mar 21 22:11:57 2023\n",
            "[xla:0] Accuracy=98.40%\n",
            "[xla:0](0) Loss=0.03031 Rate=56.65 GlobalRate=56.64 Time=Tue Mar 21 22:12:28 2023\n",
            "[xla:2] Accuracy=97.52%\n",
            "[xla:2](0) Loss=0.04642 Rate=54.60 GlobalRate=54.60 Time=Tue Mar 21 22:12:30 2023\n",
            "[xla:4] Accuracy=98.21%\n",
            "[xla:3] Accuracy=97.52%\n",
            "[xla:4](0) Loss=0.00290 Rate=60.69 GlobalRate=60.69 Time=Tue Mar 21 22:12:33 2023\n",
            "[xla:3](0) Loss=0.00214 Rate=61.75 GlobalRate=61.75 Time=Tue Mar 21 22:12:34 2023\n",
            "[xla:0](20) Loss=0.00943 Rate=232.81 GlobalRate=280.91 Time=Tue Mar 21 22:12:35 2023\n",
            "[xla:6] Accuracy=97.94%\n",
            "[xla:7] Accuracy=97.92%\n",
            "[xla:1] Accuracy=98.18%\n",
            "[xla:5] Accuracy=97.69%\n",
            "[xla:2](20) Loss=0.00921 Rate=201.24 GlobalRate=246.47 Time=Tue Mar 21 22:12:39 2023\n",
            "[xla:6](0) Loss=0.02555 Rate=36.71 GlobalRate=36.71 Time=Tue Mar 21 22:12:39 2023\n",
            "[xla:7](0) Loss=0.02138 Rate=35.49 GlobalRate=35.49 Time=Tue Mar 21 22:12:41 2023\n",
            "[xla:1](0) Loss=0.00708 Rate=35.91 GlobalRate=35.91 Time=Tue Mar 21 22:12:41 2023\n",
            "[xla:5](0) Loss=0.05943 Rate=33.76 GlobalRate=33.76 Time=Tue Mar 21 22:12:43 2023\n",
            "[xla:4](20) Loss=0.01013 Rate=166.21 GlobalRate=207.87 Time=Tue Mar 21 22:12:44 2023\n",
            "[xla:3](20) Loss=0.08120 Rate=150.95 GlobalRate=188.77 Time=Tue Mar 21 22:12:47 2023\n",
            "[xla:0](40) Loss=0.00481 Rate=220.16 GlobalRate=242.30 Time=Tue Mar 21 22:12:47 2023\n",
            "[xla:2](40) Loss=0.01732 Rate=224.76 GlobalRate=243.49 Time=Tue Mar 21 22:12:50 2023\n",
            "[xla:6](20) Loss=0.00557 Rate=155.49 GlobalRate=186.73 Time=Tue Mar 21 22:12:50 2023\n",
            "[xla:7](20) Loss=0.05074 Rate=163.38 GlobalRate=193.34 Time=Tue Mar 21 22:12:51 2023\n",
            "[xla:1](20) Loss=0.00102 Rate=159.52 GlobalRate=190.01 Time=Tue Mar 21 22:12:51 2023\n",
            "Finished training epoch 7\n",
            "[xla:5](20) Loss=0.00985 Rate=173.21 GlobalRate=200.46 Time=Tue Mar 21 22:12:53 2023\n",
            "[xla:4](40) Loss=0.00138 Rate=233.32 GlobalRate=237.06 Time=Tue Mar 21 22:12:53 2023\n",
            "[xla:3](40) Loss=0.00386 Rate=261.11 GlobalRate=239.73 Time=Tue Mar 21 22:12:54 2023\n",
            "[xla:6](40) Loss=0.00614 Rate=233.01 GlobalRate=224.39 Time=Tue Mar 21 22:12:59 2023\n",
            "[xla:7](40) Loss=0.00439 Rate=228.02 GlobalRate=224.80 Time=Tue Mar 21 22:13:00 2023\n",
            "[xla:1](40) Loss=0.00650 Rate=226.49 GlobalRate=222.49 Time=Tue Mar 21 22:13:01 2023\n",
            "[xla:5](40) Loss=0.00394 Rate=234.10 GlobalRate=230.89 Time=Tue Mar 21 22:13:02 2023\n",
            "[xla:0] Accuracy=98.35%\n",
            "[xla:0](0) Loss=0.01925 Rate=48.34 GlobalRate=48.34 Time=Tue Mar 21 22:13:33 2023\n",
            "[xla:2] Accuracy=97.84%\n",
            "[xla:2](0) Loss=0.01460 Rate=61.19 GlobalRate=61.18 Time=Tue Mar 21 22:13:35 2023\n",
            "[xla:4] Accuracy=98.11%\n",
            "[xla:3] Accuracy=98.04%\n",
            "[xla:4](0) Loss=0.00113 Rate=63.34 GlobalRate=63.34 Time=Tue Mar 21 22:13:39 2023\n",
            "[xla:3](0) Loss=0.00097 Rate=54.88 GlobalRate=54.87 Time=Tue Mar 21 22:13:40 2023\n",
            "[xla:0](20) Loss=0.00551 Rate=227.84 GlobalRate=268.41 Time=Tue Mar 21 22:13:40 2023\n",
            "[xla:6] Accuracy=97.91%\n",
            "[xla:7] Accuracy=97.92%\n",
            "[xla:2](20) Loss=0.00232 Rate=253.83 GlobalRate=305.84 Time=Tue Mar 21 22:13:42 2023\n",
            "[xla:1] Accuracy=98.24%\n",
            "[xla:6](0) Loss=0.01625 Rate=44.88 GlobalRate=44.88 Time=Tue Mar 21 22:13:44 2023\n",
            "[xla:5] Accuracy=97.65%\n",
            "[xla:7](0) Loss=0.01750 Rate=42.87 GlobalRate=42.87 Time=Tue Mar 21 22:13:45 2023\n",
            "[xla:1](0) Loss=0.00513 Rate=33.44 GlobalRate=33.44 Time=Tue Mar 21 22:13:46 2023\n",
            "[xla:5](0) Loss=0.01159 Rate=33.93 GlobalRate=33.92 Time=Tue Mar 21 22:13:48 2023\n",
            "[xla:4](20) Loss=0.00176 Rate=186.02 GlobalRate=232.12 Time=Tue Mar 21 22:13:49 2023\n",
            "[xla:3](20) Loss=0.02186 Rate=173.74 GlobalRate=215.87 Time=Tue Mar 21 22:13:50 2023\n",
            "[xla:0](40) Loss=0.00163 Rate=248.92 GlobalRate=265.73 Time=Tue Mar 21 22:13:50 2023\n",
            "[xla:2](40) Loss=0.00136 Rate=225.31 GlobalRate=247.56 Time=Tue Mar 21 22:13:54 2023\n",
            "[xla:6](20) Loss=0.00634 Rate=148.47 GlobalRate=183.85 Time=Tue Mar 21 22:13:56 2023\n",
            "[xla:7](20) Loss=0.01028 Rate=149.93 GlobalRate=184.69 Time=Tue Mar 21 22:13:57 2023\n",
            "[xla:1](20) Loss=0.00112 Rate=152.20 GlobalRate=180.50 Time=Tue Mar 21 22:13:57 2023\n",
            "Finished training epoch 8\n",
            "[xla:4](40) Loss=0.00975 Rate=230.90 GlobalRate=245.29 Time=Tue Mar 21 22:13:58 2023\n",
            "[xla:5](20) Loss=0.00381 Rate=165.96 GlobalRate=194.05 Time=Tue Mar 21 22:13:58 2023\n",
            "[xla:3](40) Loss=0.00848 Rate=228.36 GlobalRate=237.25 Time=Tue Mar 21 22:14:00 2023\n",
            "[xla:6](40) Loss=0.00527 Rate=272.24 GlobalRate=240.33 Time=Tue Mar 21 22:14:03 2023\n",
            "[xla:7](40) Loss=0.00300 Rate=261.24 GlobalRate=236.55 Time=Tue Mar 21 22:14:04 2023\n",
            "[xla:1](40) Loss=0.00230 Rate=257.75 GlobalRate=231.25 Time=Tue Mar 21 22:14:05 2023\n",
            "[xla:5](40) Loss=0.00231 Rate=254.27 GlobalRate=238.25 Time=Tue Mar 21 22:14:07 2023\n",
            "[xla:0] Accuracy=98.38%\n",
            "[xla:2] Accuracy=98.11%\n",
            "[xla:0](0) Loss=0.00121 Rate=40.39 GlobalRate=40.39 Time=Tue Mar 21 22:14:36 2023\n",
            "[xla:2](0) Loss=0.00297 Rate=37.52 GlobalRate=37.52 Time=Tue Mar 21 22:14:38 2023\n",
            "[xla:4] Accuracy=98.17%\n",
            "[xla:3] Accuracy=97.99%\n",
            "[xla:4](0) Loss=0.01881 Rate=59.95 GlobalRate=59.95 Time=Tue Mar 21 22:14:43 2023\n",
            "[xla:0](20) Loss=0.00377 Rate=209.04 GlobalRate=241.45 Time=Tue Mar 21 22:14:44 2023\n",
            "[xla:3](0) Loss=0.00238 Rate=54.86 GlobalRate=54.85 Time=Tue Mar 21 22:14:44 2023\n",
            "[xla:6] Accuracy=97.84%\n",
            "[xla:2](20) Loss=0.00170 Rate=229.96 GlobalRate=254.61 Time=Tue Mar 21 22:14:45 2023\n",
            "[xla:7] Accuracy=97.93%\n",
            "[xla:1] Accuracy=98.39%\n",
            "[xla:6](0) Loss=0.00295 Rate=56.03 GlobalRate=56.03 Time=Tue Mar 21 22:14:47 2023\n",
            "[xla:5] Accuracy=97.56%\n",
            "[xla:7](0) Loss=0.02591 Rate=58.66 GlobalRate=58.66 Time=Tue Mar 21 22:14:48 2023\n",
            "[xla:1](0) Loss=0.00127 Rate=52.91 GlobalRate=52.91 Time=Tue Mar 21 22:14:49 2023\n",
            "[xla:4](20) Loss=0.00300 Rate=212.04 GlobalRate=260.90 Time=Tue Mar 21 22:14:51 2023\n",
            "[xla:5](0) Loss=0.00081 Rate=38.16 GlobalRate=38.16 Time=Tue Mar 21 22:14:51 2023\n",
            "[xla:0](40) Loss=0.00203 Rate=263.55 GlobalRate=266.81 Time=Tue Mar 21 22:14:52 2023\n",
            "[xla:3](20) Loss=0.02036 Rate=197.53 GlobalRate=242.57 Time=Tue Mar 21 22:14:53 2023\n",
            "[xla:2](40) Loss=0.00128 Rate=249.37 GlobalRate=258.31 Time=Tue Mar 21 22:14:55 2023\n",
            "[xla:6](20) Loss=0.00097 Rate=160.68 GlobalRate=200.69 Time=Tue Mar 21 22:14:58 2023\n",
            "[xla:7](20) Loss=0.00525 Rate=152.19 GlobalRate=190.45 Time=Tue Mar 21 22:15:00 2023\n",
            "Finished training epoch 9\n",
            "[xla:1](20) Loss=0.00042 Rate=148.11 GlobalRate=185.14 Time=Tue Mar 21 22:15:01 2023\n",
            "[xla:4](40) Loss=0.00613 Rate=215.07 GlobalRate=237.52 Time=Tue Mar 21 22:15:03 2023\n",
            "[xla:5](20) Loss=0.00079 Rate=147.22 GlobalRate=179.27 Time=Tue Mar 21 22:15:03 2023\n",
            "[xla:3](40) Loss=0.00113 Rate=222.06 GlobalRate=240.53 Time=Tue Mar 21 22:15:03 2023\n",
            "[xla:6](40) Loss=0.00075 Rate=242.50 GlobalRate=238.42 Time=Tue Mar 21 22:15:07 2023\n",
            "[xla:7](40) Loss=0.00449 Rate=260.13 GlobalRate=240.48 Time=Tue Mar 21 22:15:08 2023\n",
            "[xla:1](40) Loss=0.00041 Rate=258.07 GlobalRate=235.93 Time=Tue Mar 21 22:15:09 2023\n",
            "[xla:5](40) Loss=0.00110 Rate=278.15 GlobalRate=238.55 Time=Tue Mar 21 22:15:10 2023\n",
            "[xla:0] Accuracy=98.38%\n",
            "[xla:2] Accuracy=98.22%\n",
            "[xla:0](0) Loss=0.00096 Rate=40.21 GlobalRate=40.21 Time=Tue Mar 21 22:15:36 2023\n",
            "[xla:2](0) Loss=0.00254 Rate=35.64 GlobalRate=35.64 Time=Tue Mar 21 22:15:40 2023\n",
            "[xla:3] Accuracy=98.03%\n",
            "[xla:4] Accuracy=98.29%\n",
            "[xla:0](20) Loss=0.00299 Rate=163.37 GlobalRate=197.47 Time=Tue Mar 21 22:15:47 2023\n",
            "[xla:3](0) Loss=0.00063 Rate=52.72 GlobalRate=52.72 Time=Tue Mar 21 22:15:47 2023\n",
            "[xla:4](0) Loss=0.00078 Rate=56.86 GlobalRate=56.86 Time=Tue Mar 21 22:15:47 2023\n",
            "[xla:7] Accuracy=97.55%\n",
            "[xla:2](20) Loss=0.00160 Rate=186.86 GlobalRate=215.21 Time=Tue Mar 21 22:15:49 2023\n",
            "[xla:6] Accuracy=98.00%\n",
            "[xla:1] Accuracy=98.47%\n",
            "[xla:7](0) Loss=0.00890 Rate=64.95 GlobalRate=64.95 Time=Tue Mar 21 22:15:51 2023\n",
            "[xla:6](0) Loss=0.00848 Rate=58.78 GlobalRate=58.78 Time=Tue Mar 21 22:15:52 2023\n",
            "[xla:5] Accuracy=97.90%\n",
            "[xla:1](0) Loss=0.00130 Rate=57.16 GlobalRate=57.16 Time=Tue Mar 21 22:15:52 2023\n",
            "[xla:4](20) Loss=0.01603 Rate=258.71 GlobalRate=306.82 Time=Tue Mar 21 22:15:54 2023\n",
            "[xla:0](40) Loss=0.00113 Rate=279.39 GlobalRate=252.45 Time=Tue Mar 21 22:15:54 2023\n",
            "[xla:3](20) Loss=0.00318 Rate=240.31 GlobalRate=284.92 Time=Tue Mar 21 22:15:54 2023\n",
            "[xla:5](0) Loss=0.00074 Rate=55.31 GlobalRate=55.31 Time=Tue Mar 21 22:15:55 2023\n",
            "[xla:2](40) Loss=0.00078 Rate=270.77 GlobalRate=258.19 Time=Tue Mar 21 22:15:57 2023\n",
            "[xla:7](20) Loss=0.01852 Rate=183.25 GlobalRate=229.01 Time=Tue Mar 21 22:16:01 2023\n",
            "[xla:6](20) Loss=0.00022 Rate=179.25 GlobalRate=223.25 Time=Tue Mar 21 22:16:02 2023\n",
            "Finished training epoch 10\n",
            "[xla:1](20) Loss=0.00035 Rate=171.61 GlobalRate=213.92 Time=Tue Mar 21 22:16:03 2023\n",
            "[xla:4](40) Loss=0.00038 Rate=241.21 GlobalRate=263.54 Time=Tue Mar 21 22:16:05 2023\n",
            "[xla:3](40) Loss=0.00089 Rate=229.28 GlobalRate=250.27 Time=Tue Mar 21 22:16:06 2023\n",
            "[xla:5](20) Loss=0.00131 Rate=153.05 GlobalRate=191.37 Time=Tue Mar 21 22:16:06 2023\n",
            "[xla:7](40) Loss=0.00130 Rate=229.78 GlobalRate=243.49 Time=Tue Mar 21 22:16:11 2023\n",
            "[xla:6](40) Loss=0.00162 Rate=225.04 GlobalRate=237.93 Time=Tue Mar 21 22:16:12 2023\n",
            "[xla:1](40) Loss=0.00026 Rate=239.55 GlobalRate=243.49 Time=Tue Mar 21 22:16:12 2023\n",
            "[xla:5](40) Loss=0.00070 Rate=251.46 GlobalRate=237.25 Time=Tue Mar 21 22:16:15 2023\n",
            "[xla:0] Accuracy=98.52%\n",
            "[xla:2] Accuracy=98.31%\n",
            "[xla:3] Accuracy=98.11%\n",
            "[xla:4] Accuracy=98.26%\n",
            "[xla:7] Accuracy=97.97%\n",
            "[xla:6] Accuracy=98.32%\n",
            "[xla:1] Accuracy=98.55%\n",
            "[xla:5] Accuracy=98.02%\n"
          ]
        }
      ],
      "source": [
        "# Start training processes\n",
        "def _mp_fn(rank, flags):\n",
        "  global FLAGS\n",
        "  FLAGS = flags\n",
        "  torch.set_default_tensor_type('torch.FloatTensor')\n",
        "  accuracy, data, pred, target = train_mnist()\n",
        "  if rank == 0:\n",
        "    # Retrieve tensors that are on TPU core 0 and plot.\n",
        "    plot_results(data.cpu(), pred.cpu(), target.cpu())\n",
        "\n",
        "xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS['num_cores'],\n",
        "          start_method='fork')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MznTE72_mthI"
      },
      "source": [
        "## Visualize Predictions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 665
        },
        "id": "X9VAwyUnI7Sb",
        "outputId": "7da921b7-9a58-4eb2-884c-7669befd0ec8"
      },
      "outputs": [
        {
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAxgAAAKICAYAAADzbFyTAABn9UlEQVR4nO3dd3hU1dbH8bWT0HvvHQIoCspFwd6wYcGCHcVyLVhAwXrtvWDFXq79qi/Xgr1esYIoglKlI733Ekhy3j8yrnVOnJC2k0wy38/z5Hl+Z86ZMzuzM5PsnDV7uyAIBAAAAAB8SCnrBgAAAACoOBhgAAAAAPCGAQYAAAAAbxhgAAAAAPCGAQYAAAAAbxhgAAAAAPCGAQYAAAAAbxhgAAAAAPCGAQYAAAAAbxhgAAAAAPCGAQYAAAAAbxhgAAAAAPCGAQaACsk519Y5Fzjn0mLbnzjnzimFx73VOfea53NGvpdSvO8uzrlfnHOusPctwLm9P09F5Zyr5pz7wDm33jk3ysP59nfO/eGjbTDOufHOuV3Luh0A8scAAygHnHNnxP7Q2+ScWxr7Y3m/BGjXfOfcYQU47nrn3N1xbh/knMuKfV8bnHOTnHPHlERbgyA4KgiClwvQ1gJ9T0XhnDvIObeoJM5dQu4QkRFBEAQi+txsjfXXMufcS865mr4f1OfzVMBznSwiTUSkQRAEA4r7mEEQfBcEQefinqesFGdQWsJGiMjtZd0IAPljgAEkOOfcVSLyiIjcLTl/BLUWkSdF5PginOtvfzCU0h8R/UTk4zz2jQ2CoKaI1BWRF0Tk/5xz9XIflIB/7FRozrlmInKwiLyXa9exsf7qISJ7iMj1pduyEtFGRGYGQZBZ2Dsmys9l7na4HBXtd/z7InKwc65pWTcEwM5VtDcfoEJxztWRnP/YXRoEwTtBEGwOgmBHEAQfBEFwdeyYKs65R5xzS2JfjzjnqsT2HeScW+Scu9Y5t0xEXoyVpvzXOfeac26DiAxyztVxzr0Quzqy2Dl3p3MuNdSOfzrnpjvnNjrnpjnn9nTOvSo5g50PYv/RviaP76GeiKSLyNidfa9BEGSLyL9FpJqIdChsO51zqc65Ec65Vc65uZIzqAm3Y4xz7oKifE/Oud7OuR+dc+ucc7855w4Knaedc+6b2Hm+EJGG+fVrHs9TP+fcxNiVnIXOuVvjHHZerI+XOueGh+6b4py7zjk3xzm32jn3f865+nk8ziDn3NxYe+c5587Mo0l9ReTXIAi2xdsZBMEyEflMcgYaf5272M+Tc66GiHwiIs1jfbDJOdd8Z9+jc+4p59zboXPc55z7Kq9z5Xq820TkZhE5Nbb//Nhj3eicW+CcW+GceyX2Wgz/d/9859yfIvK/ON9D5KqJy7nyM9w597vLKcN6yzlXNbT/eJdz9W5D7Ps7MnZ7c+fc+865Nc652c65f4buE+/1McY5d5dz7gcR2SIi7Z1zXZxzX8TO8Ydz7pTQOao55x6MfZ/rnXPfO+eqici3sUPWxZ6TPnG+x4K87wyLPX9LnXPn5rrvCOfcn8655c65p2OPK865hs65D2M/Q2ucc9+52EAp9rM4QUSOiPezAyCBBEHAF198JeiXiBwpIpkikraTY24XkXEi0lhEGonIjyJyR2zfQbH73yciVSTnj/dbRWSHiPSXnH8yVBORd0XkGRGpETvPeBG5KHaOASKyWER6iYgTkY4i0ia2b76IHJbP93CaiLyRx75BIvJ9LKeJyBAR2SgidYrQzotFZIaItBKR+iLytYgEfz13IjJGRC4o7PckIi1EZLWIHB1rR9/YdqPY/rEi8lDs+T0g1v7X8vh+DxKRRTvZt1vsMXYXkeUi0j+2r23se3kj9r3vJiIr/2pn7HkbJyItY+145q/nPHTftNh9N4hI59i+ZiKyax7teUBEnsh12/zQY7YUkcki8mhpPE/5fI/VRWSm5Pw87S8iq0SkZX7Peejct4bbIiLnichsEWkvIjVF5B0ReTXX8/lK7Pmsll/7Y8/beBFpLjk/m9NF5OLYvr1EZH3s+UqJPY9dYvu+lZyrlVUlZyC3UkQOCbU59+tjjIj8KSK7xvq7jogsFJFzY9t7xJ6bXWLneCJ2nxYikioi+8Se27++x+K+79wuIpViPxNbRKRebP/DknM1or6I1BKRD0Tknti+e0Tk6dj9KsX604Ue9zEReags35f54ouv/L/KvAF88cVX3l8icqaILMvnmDkicnRo+wgRmR/LB4nIdhGpGtp/q4h8G9puIiIZ4T+UROR0Efk6lj8TkSF5PPZ8yX+A8aqIDMxj36DYHyLrYn/4jBP7A7aw7fyfxP5oi20fLnkPMAr8PYnItRL74zJ022cico7kXO3IFJEaoX3/kSIMMOIc+4iIPBzLbWPfS5fQ/vtF5IVYni4ih4b2NZOcPz7T5O8DjHUicpLE+cM41+M/JyL3xnluNknO4CAQka9EpG5pPE87+x5j23uLyBoRWSAipxfmOZe/DzC+EpHBoe3OcZ7P9js5X+QxY8/bWbn67ulYfuavfs51jlYikiUitUK33SMiL8V7fYR+xm8PbZ8qIt/lOuYZEblFcgYlW0Wke5zH1p+ZnXyP+b3vbA3fX0RWiEhvyRnQbxaRDqF9fURkXizfLiKjRaRjHo97l4j8uyCvIb744qvsviiRAhLbahFp6HZe591ccv6o+suC2G1/WRn8vcxlYSi3kZz/FC6NlSWsk5w/QhrH9reSnD8mCi1W2tBXRD7dyWHjgiCoGwRBwyAIegdB8GUR29k81/Hh5yS3wnxPbURkwF+PGXvc/STnD9zmIrI2CILNBXzcPDnn9nbOfe2cW+mcWy85V2RylxHl/v7+6uc2IvJuqH3TJeeP0ybhO8faeWrs3Eudcx8557rk0aS1kvPf5dz6B0FQS3L+iOwSamNJP087/R6DIPhJROZKzh+w/1fIc+cW7zWVJtHnc6EUzrJQ3iI5V0ZE8v5ZbC4ia4Ig2JirHS3yaUPu18zeufrkTBFpKjn9VjWPxy6I/N53VgfRz7T89T03kpwrThNCbfo0drtIzpWz2SLyucsp5bsu1+PWkpxBMoAExgADSGxjJee/9v13cswSyflD4i+tY7f9JYhzn/BtC2OP0TD2h37dIAhqB0Gwa2h/hzweO965w3qJyIIgCFbmc1xeCtPOpZLzx9pfWu/kvIX5nhZKzn/m64a+agRBcG/sMevFav0L8rg78x/JKRtpFQRBHckpE8k9PWzu7++vfl4oIkflamPVIAgW/+2bC4LPgiDoKzl/+M+QnCsV8fwuOZ+diSsIgm9E5CXJmdnnrzb4ep7i/Vzt9Ht0zl0qOeU9S0TkmnzOlZ94r6lMySlbK85548nrZ3GJiNR3zoUHea0lp7RvZ23I/Zr5JtdzVjMIgksk54rhtjweuyDfW37vO3lZJTlXN3YNtalOkDNxgARBsDEIgmFBELQXkeNE5Crn3KGh+3cVkd8K8DgAyhADDCCBBUGwXnI+gPqEc66/c666c66Sc+4o59z9scPeEJEbnXONnHMNY8cXeH2BIAiWisjnIvKgc6527AOuHZxzB8YOeV5EhjvnerocHZ1zf/1hsVxy6tTzcrSIfFTw77hY7fw/EbnCOdfS5XywPPd/PsMK8z29JiLHOueOcDkfJK8a+xBryyAIFojILyJym3OussuZOvjY/L6X2DnCX05y/jO7JgiCbc65vUTkjDh3vSn2M7Cr5NTVvxW7/WkRueuv7yH2s3B8nMdt4nI+UFxDcgZrm0QkO49mfiEie7rQh5HjeERE+jrnuovf52m5iDRwsQ9W5/c9OufSReROETlLRAaKyDXOuR47OVd+3hCRK13OB9NrSs4Mbm8FRZhlqgBeEJFznXOHxn6mWzjnugRBsFByPtdwT+y53F1EzpdCvLZF5EMRSXfODYy9b1RyzvVyznUNbFKFh1zOh8lTnXN9XM4HtVdKzs/Fzl7bRXrfiT3ucyLysHOusYhI7Hs+IpaPib0eneR8NiUr1haJ/Sz2lJyfTQAJjAEGkOCCIHhQRK4SkRsl5xf/QhG5TGz60Dsl54+33yXnQ7e/xm4rjLNFpLKITJOc0pj/Ss5/uCUIglGSU/f8H8mpvX9Pcj6cKZJTE35jrNRhuPzdzqanLYo82yk5f7R8Jjn/3fxVcj6YG1dhvqfYH3rHi8gNYs//1WLvn2eI1f/fIjkf/t2ZFpLzH9zwVwcRGSwitzvnNkrOH2vxyny+kZzyka8kZ32Kz2O3Pyo5Vz8+j91/XKxNuaVIzs/Sklh7DxSRS+I1MgiC5ZLzuZa/DVRCx6yUnO/3Zp/PUxAEMyTnD9i5sX5ontf36HLKB18TkfuCIPgtCIJZsTa86pyrkse58vNvyfns0LciMk9y/tN/eQHuV2hBEIyXnMHiw5LzB/U3YlcGTpecz0MskZwJDm7JVUKY37k3Ss5nkU6LnWOZ2IQPIiLDJec942fJ6Zf7RCQlCIItkvP6+CH2nPWOc/rivO9cKzk/x+NczgxYX0rO51xERDrFtjdJzhXcJ4Mg+Dq271gRGRMEQUGulAAoQy4IfF3lBQDjnGsiIhNFpEXAG0255JzbRUReFpG96EOUNefcTyJyfhAEU8q6LQB2jgEGgBIRK1vpGQTBG2XdFgAAUHoYYAAAAADwhs9gAAAAAPCGAQYAAAAAbxhgAAAAAPCGAQYAAAAAbxhgAAAAAPCGAQYAAAAAbxhgAAAAAPCGAQYAAAAAbxhgAAAAAPCGAQYAAAAAbxhgAAAAAPCGAQYAAAAAbxhgAAAAAPCGAQYAAAAAbxhgAAAAAPCGAQYAAAAAbxhgAAAAAPCGAQYAAAAAbxhgAAAAAPCGAQYAAAAAbxhgAAAAAPCGAQYAAAAAbxhgAAAAAPCGAQYAAAAAbxhgAAAAAPCGAQYAAAAAbxhgAAAAAPCGAQYAAAAAbxhgAAAAAPCGAQYAAAAAbxhgAAAAAPCGAQYAAAAAbxhgAAAAAPCGAQYAAAAAbxhgAAAAAPCGAQYAAAAAbxhgAAAAAPCGAQYAAAAAbxhgAAAAAPCGAQYAAAAAbxhgAAAAAPCGAQa8cE7OcU4+cU6uKOu2ID76KPHRR+UD/ZT46KPERx9VbC4IgrJuA8o552SYiIwQkXki0jMIZG0ZNwm50EeJjz4qH+inxEcfJT76qOJjgIFicU76isgnIrJDRPYJAplYxk1CLvRR4qOPygf6KfHRR4mPPkoOaWXdAJRfzkkzEXldRFJF5J+8SSQe+ijx0UflA/2U+OijxEcfJQ8+g4EicU6ciLwmIo1E5NkgkBfLuEnIhT5KfPRR+UA/JT76KPHRR8mFAQaKaqiIHCIiP4vwAa0ENVToo0Q3VOij8mCo0E+JbqjQR4luqNBHSYPPYKDQnJPOIjJJcuonuweBzCvbFiE3+ijx0UflA/2U+OijxEcfJR+uYCBfzslhzskC52Rg7BLn8yJSVUSu4k0iMdBHiY8+Kh/op8RHHyU++gh8yBsFcbaItJacy5v1RWQ/Efk4COT5smwUIuijxEcflQ/0U+KjjxIffZTkKJFCvpyTc0TkpdBNa0SkWxDI0rJpEXKjjxIffVQ+0E+Jjz5KfPQRKJFCQYwWkYzQ9mDeJBIOfZT46KPygX5KfPRR4qOPkhwDDOQrCGSdiDwqIstFZGQQyFtl2yLkRh8lPvqofKCfEh99lPjoI1AiBQAAAMAbrmAAAAAA8IYBBgAAAABvGGAAAAAA8IYBBgAAAABvGGAAAAAA8IYBBgAAAABvGGAAAAAA8CatrBsA0zdlAIuSlKAvske54p6DPipZ9FHi89FHIvRTSeO1lPjoo8Tn6/0uGXEFAwAAAIA3DDAAAAAAeMMAAwAAAIA3DDAAAAAAeMMAAwAAAIA3DDAAAAAAeMMAAwAAAIA3DDAAAAAAeMMAAwAAAIA3DDAAAAAAeMMAAwAAAIA3DDAAAAAAeJNW1g0AUPp2HNZTc5s7Z0b2XdR4jOZeVZzlOy/V3OipsSXXOPxNWptWke1Fj9bU3L3JYs3L+2wotTYBpSmtVUvNK/ra62H7ces0/9rrdc2pLv7/T7OC7Mj2TSt6aP7vZ/tqbncd73FAcXAFAwAAAIA3DDAAAAAAeEOJFFCBhcsKZo+or3ncPiM1106pGrnPuX8eHPdc2xq6uLejZKQ2sP6q859NkX2vtX5L855fXq45XSaUfMOAUrDlxL0j23eNeEbzvlWycx/+N7lLofJyR+NJmm886xfN/cZYSWjlT38u0LlQdlylyprPm/qH5lQJND+b3r5U25TsuIIBAAAAwBsGGAAAAAC8oUQKqGA2DbDSgqvuekNz/xrrNB8782TNWdc3jJ5g3O9xz9tKfvTTQOQptXZtzW0/3aL50eY/RI47dMqZmtPPpSwKFc+2utH/fxakLOrTrdU1Xzf5RM2VPqujufqq6HmWHGB51klPaW7wr3maN36af3tRtv4Y2UPzKTXHa/6/TXXiHI3SwBUMAAAAAN4wwAAAAADgDSVS8C6lRg3NG4/sFtm35Pgdmv+970ua96+aqTm88NHnT9jCRw2eZ+GjCGezOq2+oLfm9256QPOCTCsZ6PHwZZpbjAyV1WQsKaEGorDmPNdW8+jmL2p+YX3ryHE1z96sOavEW5W8Np9s5YZL+2/XfGzXyZofbDpe8rI2e6vd55phmmu/Mc5XEyus+i9Gn6P9tg7WvKmV/W+09VsLNWctXqq5eea0Aj1O2sm7F7WJ2ImU7l01Z/82vUQeY8sJ9vr88qgHQ3tsIdLrx1upXEeZWCLtQHxcwQAAAADgDQMMAAAAAN5QIoUiS6lqC7QtvXBPzY8NfVJznyrfRu6zNnub5msXH6n56lXNLKd/pvmjY3a1Oz9fvPZWNJtO3kvz2Fsf13zlkkM1z7o0XXPz8TYLlC09hLK27Rjrx097P6T53c0tNL/fu0PkPlkbVpR8wyqY8MKFW3I9nwtOsFfE4we9pvnIar8W6zHrpVTTPOIue1+8c/R+mrO3bBHEEUTfpcJlZbVDt2dKAYTKSeff0Tuy69f9Htb8Z6adbcUIW5StmqwqyKMkvX1+szLCk+v8W/M1fU7QnLl0mbfH63LdFM0dKllZ1NPr7L2z85AFmiknLV1cwQAAAADgDQMMAAAAAN5QIoVCCfbtobndIzM0v998pOZXNtjlycHP2gwOIiKt31muOWvmHM31ZaPmF7sdpbleK7vsiajTbrXVn0au7aR5bl8rXZN1k6Ww0lq11Jy5cFHRGoedCj/HVzz0puZ2ocv8A4edprnGhp9Kp2HllEuzX2Up7WzGrRn/qqf5ur0/0Xx+7S8L/RgrsqyU6YQpgzSvXFPLzrt7dDHKaxuUzOw5ySg13cralh3SWHPduTYz4br2leLeNzhqreZp/3gi197Kmvr+cLHmDqPznh0smaXWs9fU7Cejs9t93NBmvpu6w8rSMls1soOKWSK18TQrcXu85cjQHuv7J58/XnOz1SwQW1a4ggEAAADAGwYYAAAAALyhRApxuSpVNM+91WaI+v6sEZo3Z9ssH13eGq45/SYry2mxOXp5siCzOGRPsdKrKlN2cmCSu7yezY5x9B9Ha85al//Ceal162jesVv7yL5g6briNw47Ne1fzTUfV8PKNzp8dYHmTu/ZYojM+vV34bKoJf+1EsGJvV4v1nmHLOmjeexz9t7X5M2pmutsmG05dN/3Poou2hYukbpy+qma622ZVaw2JoNwKY6IyGnvf6P5zFo2i9rWwGYuquYqS2E9uraj5vRbNmhmxiGT1tLKnqfdbO9d8w58LteR9j/r0567SnOr8cUrU0qpbgvG3nv305qrOCuLOnF2X83NHqIsKhFwBQMAAACANwwwAAAAAHjDAAMAAACAN3wGA/F90lDj1M62SvSNK/bVPOFKq0/uOMZWWc0u4aYhx9kLDtA8qMUPml9utb/myDSzva0+/KQXP9fct8ZH0fNedKXmKrPneWkrRFI7W633T/1s9eDXN7bT3OZl+5/PrJd203xU52mRc306xl576XfYZwOyNmyQZLFkiK2APrHX43GPCdfnP7janrNXJu8dOa7zPTYFbfbs+ZobZozVnFdNftCnu+YPd38ysm9rkKp5y3c2VWc94TMY+XFVq0S2j63xZ2jLpuIuyucuwuqkWt8H1Yp3ropq2g32GYx5xzyb53E3rrD3rHb/nqu5QKut78Tys+01dkBV+3zFqqzNmjeG2pgiK4v5iPCBKxgAAAAAvGGAAQAAAMAbSqQgIiJz/9Mjsj2ry0ua95poKwo3PG+95tTlv5Z0s7ATqy5opvnGm9pqrntoDc0Nv7dp/I56waZ5HFTbprLtNeH8yHkbffKzz2YmNVfJSi5mXG8TmtZLsRKPE2qGpht+8bG4x+T28BlWJpBe/yLL5/9S9MaWM63etvK/m87qoblDVZvC9JVrjtNc9QNbmbmjTIycqzhTkq7Z1abQbJBSLbJv0nYrDmlxL1NnFkZmrhWfD7tjmOb/3fyQ5pouWkr1l+4/DdRc473amlfsGy3Y+enoRzSPHtlDc8aBhWpuhTP3XpuueV7/p+Iec9/qTpHtice20Zy5dFHuwwssPC2uiMjdw/8d97jDJ56nudH3k4r8eCgZXMEAAAAA4A0DDAAAAADeUCKVxLIOsllVfjngici+lzbYpc5GF2zUHIRW7956vM3isnoX+1Hatd8fms9parMbiYjc9a9Bmmu9NU5QdFlT7XluZ1VsMvu1PTT/9/ZXNbdItVKO3Z66THOrOyndKCkp7Vppfnq/V+IeUz00C859K63vPnx1P831Z+yI3Ofge+x1NfPIZzT3b2klQZmLFhehxeVH5nybVWjCHva/sgnSVHNVGS8lIVz6ts/FeZelvb3uHyXy+Mmo4bM2o9fpb/ezHc7FPb7F6tDMa4H93qqb62W4733DNf925qOajz30Es1pX00obHPLpbRWLTWPOvWR0B4rQ5u+3Wbd+vaYLpH7Zy5c6KUdcy9oE9k+snpG3OOaDdlmj+3lkeETVzAAAAAAeMMAAwAAAIA3lEglscrLrfRpyvboTBwDa9kMHmf+El2ILT8pYpesv9sW/RGrO9EWwCnOzC3I2wEdZ2sOl0UNWWKzglAWVTrmn9pE86HVwpf57TUybJmVGs463o5vtijvPvpxVU/NKe9O0byph82+UrWCl0iVpcVXWunTR81skb/FWVsix/3vwX001xFKQn3JWr3G27nSn7TZjsaebLOAzetvv7s6feXt4RLazEutpLNHFfubYGnmJs3n3HK15noLrGytuIJ9bDG9j8+9P9femprav22z5nVakDyz5pVHXMEAAAAA4A0DDAAAAADeUCKVxLKmz9J81zGnRfZt7FJP8+pdUjVvbWlzNbQZbTNzDHjwU81n1Z6p+cqHbbYiEZHGMynNKQmr/2nlTx+3shnBxoWqcu5qNkbz6bucozlrmvUXiselRd9Sdz9yhuZssdfLkCX7ap67n90eZBSsrClIs/8NbQriz7CCkrNlt61xb39x7V6R7TqvURaV6DIX2MxH766x0rc9e8zRvFGSQ91uq+PefsOSozTXe7l4ZVEptWppdlVtMdGZJ1g5b7tKNSP3WRsqPezy1FrNQVUr48reEi1PRNnjCgYAAAAAbxhgAAAAAPCGEimIyN/LZKqH1iiq/k7+9894oJLmL7fYTDhNX/wtclx20ZqHOFKbNNZ8ztCPNfeacLrmptfb8ee/+4nmtQ/aHF617eo3imnbEXtEtl9t+3Tc4+ad31ZzkDEj7jE7M/tUm+3msy2hheU+LJmF5SCS2qiR5lt6fRD3mFH/OSiy3UIoCUX5cV36p3Fv/2ZmJ81d6s3TnLV2beS49Wf11rx6N5spr1kPm5Xyorbfaj6zVrgk64s821UvNBvix1/+n+aT5xymeeP+lEglGq5gAAAAAPCGAQYAAAAAbyiRQpFtON0uhw6p96Tm9FGDNXfczCwqJeXPp6xkY89qdmn7k+utX7Km/qH55snHap7U+xXNx4gt2obiWXB83vt6PHG55paTCzcTS1rb1pHtL/qP0HzYR1dpThdKpErKjJvbaz6z1mea71m9i+ZWIydF7kNJaDmQYrMkpqUk9/KvU7a21HxSTauTnnvYvzXPnLRZ87bAnjsRkd0rTyq5xsVM326lUJnZqTs5EmWNKxgAAAAAvGGAAQAAAMAbBhgAAAAAvOEzGCiyXYZM0RxeTTj9X5M1U4Psz7Ih+0S2f9v7cc1dX7lUc7up8ev7G71gU/2l9HZxj0HxjDjorcj2qixb8bntm0s0ZwaBFMbGZ6K1xu9s7K55l7sW2XkLdVYURpduC+PePmquTU3cdMv00mpOuRP+zF69T+x5ylq3viyaoxZfs7fmj5vae+qpcw8vi+aUqdc+OVDz0LMmaK6TYtNip1eqUejzrsiyz23UdDalffWUynGPP3TacZHt4F77vGGVZXauYNY8QeLiCgYAAAAAbxhgAAAAAPCGEikUyppz+2h+u+XDmv/xwjDNbTazem1JuOLi6JLqN6zYU3P7W3/VXJDim+wCHYWCyDzUpvntX+PXyL4z59ul/sy58wt13lkvW/8+1/HlyL57B52tOWXxxEKdFwW34jIrS/xfR5saeGtg/5tr9Gg1QXyznrDyoz/6P6G5e3ubsrntE1M1l0W51NZdrYwxI7Aiw7W3ttGcJqtKtU1lpf11Vl57+ivnaJ5/YkPNDfdfqnnzO03zPFfKdsv1X7Tzbv7Upnv+fnf7nXbjit00Vz4u+nxnb/nTcp6PiETDFQwAAAAA3jDAAAAAAOANJVLIV2rDBprPv/p9zY+s6aG5/UhbMTq510L1a9MAKzEYVPupyL6jzw3NcpLxe9z7p7WylVmb3zxD8xPrOnhqIdZ0rZLnvpmvdtbcSOLP7uUq2UwqK99pq/m93Z7UfMUll4fvIlW++7mwzUQRbD/YSnZqp1TV/MIGe12lfh0ti4Nxta1OJkVs5rrJg222pn6HHqs5uKmd3ffH36InK+TMa2EpNWzmoyWvt47s+6WnlW4dOfVMzdW+miDJLGvaTM2tQjmsmhRsFqeMfr00f9TtMc1rs6xPf7r8H5pTtlD2WRFwBQMAAACANwwwAAAAAHhDiRTyNfO6TprPr/O55v2uu0xz3VXxyz9QPNnn2Wwaf5v5aVz8sijpvbvGPs+O13xYTZut5fbDTw7dgcWKSkratvi3bz7JSt/qXm4zpDzR+g3Nl141RHP1T37y3zjk6/u9ngttWYnUvV9aWU8noW/ykn6fzdD0aR9b6PPIals0f9T5A7vD/1k8csbxkXPNmdZcc7Nv7fbsSlZ6tfmU+LNQvdjdZmHrUTn6Z89NK2wBwNSHGoT28L7oS5WrbOap8KJ9Hf7vYs0dvxtXqm1CyeMKBgAAAABvGGAAAAAA8IYSKcQV9Omu+cdTH9Tc+b2rNHd6lbKo0hSehUVEZMVgWwRswz5WijDz4Bc0HzvzGM1j+3XUnLWQy/++pG7Le3abg4fYa2TWeY00/1/7RzTftOwAzVfdcKnmWu9QMlAWVl1oi4nWTrEZoqbusBmRujy1VjOz5uUte4rNXPfkfgdqvvKRJpr/t4/NltYs1cqoPu0yOnqyLqF8YmFbYn/q/JAR/b/qN3dbf9f8jHI3X1Ib1Nc8tM1nlpfabFFd7rXfQ7bEISoKrmAAAAAA8IYBBgAAAABvKJFCXGe9+JHm/22xRaW63r9EM5c0S17Nu2tr/vnlaCnOL/+yxarGZdjtfUJlNvXftBKPICN0ELxp8LyVQXXuMjiy77tTRmheWc/ebvccdaXmLiNsFqlaiymLKmudzrFFQ1Od/Q/uhO8u0dxxGguBFVbmsuWa251m+bx9bTbC+ZfZe9zxnaOz5N3bpHAL3123vKfm0Z/aTFGdnloYOa7mQsqiSsLs4bbI6G6VbYHeu2/eQ3OVZSwYWpFxBQMAAACANwwwAAAAAHhDiRRERCTz0J6R7TNrWWnNHg/ZJexmC34stTZBJOU7K8W4pX3PnRxp6omV7OQ9vxFKQofh0RKnQcP3i3tcR7HjKDUse2lNbVajQU3GaM4KsjVXn1xN4J/7YZLmdj/Y7bmXET1a9izkme3dr13oPZHXW+m4/LiPNR8+8hrNzT/ib4hkwRUMAAAAAN4wwAAAAADgDQMMAAAAAN7wGYwkltaiueZLnn4rsu/IGcdrbvYw0/gBqLiCOrU09622VfN7m+tqbvXiLM2s3g3s3Ie71tPcXPjcRTLiCgYAAAAAbxhgAAAAAPCGEqkkNufhBpr7VV8f2ffsyZs0Z2VTEAAg+Vz/9pma260cu5MjAQBhXMEAAAAA4A0DDAAAAADeUCKVxNqcMlnzMZJ7lei1pdsYACgjWX/M1nx0C1sxOrwCNACg4LiCAQAAAMAbBhgAAAAAvHFBEJR1GwAAAABUEFzBAAAAAOANAwwAAAAA3jDAAAAAAOANAwwAAAAA3jDAAAAAAOANAwwAAAAA3jDAAAAAAOANAwwAAAAA3jDAAAAAAOANAwwAAAAA3jDAAAAAAOANAwwAAAAA3jDAAAAAAOANAwwAAAAA3jDAAAAAAOANAwwAAAAA3jDAAAAAAOANAwwAAAAA3jDAAAAAAOANAwx44Zyc45x84pxcUdZtQXz0UeKjj8oH+inx0UeJjz6q2FwQBGXdBpRzzskwERkhIvNEpGcQyNoybhJyoY8SH31UPtBPiY8+Snz0UcXHAAPF4pz0FZFPRGSHiOwTBDKxjJuEXOijxEcflQ/0U+KjjxIffZQc0sq6ASi/nJNmIvK6iKSKyD95k0g89FHio4/KB/op8dFHiY8+Sh58BgNF4pw4EXlNRBqJyLNBIC+WcZOQC32U+Oij8oF+Snz0UeKjj5ILAwwU1VAROUREfhbhA1oJaqjQR4luqNBH5cFQoZ8S3VChjxLdUKGPkgafwUChOSedRWSS5NRPdg8CmVe2LUJu9FHio4/KB/op8dFHiY8+Sj5cwUC+nJPDnJMFzsnA2CXO50WkqohcxZtEYqCPEh99VD7QT4mPPkp89BH4kDcK4mwRaS05lzfri8h+IvJxEMjzZdkoRNBHiY8+Kh/op8RHHyU++ijJUSKFfDkn54jIS6Gb1ohItyCQpWXTIuRGHyU++qh8oJ8SH32U+OgjUCKFghgtIhmh7cG8SSQc+ijx0UflA/2U+OijxEcfJTkGGMhXEMg6EXlURJaLyMggkLfKtkXIjT5KfPRR+UA/JT76KPHRR6BECgAAAIA3XMEAAAAA4A0DDAAAAADeMMAAAAAA4A0DDAAAAADeMMAAAAAA4A0DDAAAAADeMMAAAAAA4E1aWTcApm/KABYlKUFfZI9yxT0HfVSy6KPE56OPROinksZrKfHRR4nP1/tdMuIKBgAAAABvGGAAAAAA8IYBBgAAAABvGGAAAAAA8IYBBgAAAABvGGAAAAAA8IYBBgAAAABvWAcDJctFp5BeeGMfzdMueVLzXtdforney2NLvl0AAAAoEVzBAAAAAOANAwwAAAAA3jDAAAAAAOANn8FAiQp/5kJEZPLFj2vekp2peUet6Gc1UML22k3jnAE1I7vOO+J/mm9o+IfmB9Z00DzmqK6aMxcuKokWAkDCSaleXfP8q3torjszO3Jc7TfGlVaTgITEFQwAAAAA3jDAAAAAAOANJVJJJq1lC83Za9Za3rLF22MsvGkfzb9dPDKyb2nWVs39HrxGc9PHf/T2+DCpjRppnjW8o+a3T31Y866VKud5/6zA8s/r2oZOzP8mgNxSO3eMbM8Y3FDzJYd+oXlovZmaV4XeEwdcNUxzjf/+VBJNRBGEy6JWj7Lfob/3sN9v72+uF7nPs2+0L/mGlXe9d9e4Zpcamrf02xA57PbdPtDcv8Y6zYdP7695/uTmmjtd/YvmINNKsVG6+CsBAAAAgDcMMAAAAAB4Q4lUkslctLhEzrvlhL01j7nwAc0ZQWrkuGPvDZVFPUlZVElYeYnN3PXvax7RvFvlSqGjrCyq8zfnRe5fY7yVA7T4aJnmrDkL7KDs1cVvKFCObDytt+Zab9oMQZmH9tR883PPRe7zjypZmlNC/8/LFptxqH5qFc1PjnhU83XTz9GcNdVmc0Ppm3WHlfJM7/GE5vXZ2zTf+vxZkfs0l+T+/ZZSq5bmlad103zo4LGaL2/wpOZmqfZ7p6A+7/qebdjEhtKx5kWa0//5c6HPCz+4ggEAAADAGwYYAAAAALyhRApFlnXQnppfeeRBzQ1S7FJnt2cvi9ynNWVRJSLj6F6aX732Ic1dKln5xS0ru2v+4Tor9+j45cTIucKzbmQJSlrGUdZ3C46zBSfvOWSU5gE18y5J6/7TQM0tTpzquXXJbfZre2j+7aBHLN9tJYbNU7+3nGavNxGRlVkZmk+eMkjzpm123Pi9XtTcuZKVlGbVip4rGbieu+a5L5hQuj/bqy+wUtMxAx4I7ammacCMMzQ3vz8Jf7e56AK56wba75Vjh3+t+foGT0g8m0Il1IMX76v5y5ldIsfV+b6q5fmh309V7fGH3P+m5n49ftc8K+/Wo4RxBQMAAACANwwwAAAAAHhDiRQKJbWeLSZ05tPva26dFr8sqvXtNmMESs7aCzdpDpdFdXz/Yrv9mumaK2+0mTVCa+mhmNLatIpsb29tC63NPsvebvt0m635qTaPaa7urPQmRezyf/ZOeimgA73aeKqVeXy4n5UbVnH2ugrPDiVit0/fni1hlw23hfPqvG0L59UJHTNxrv1cRM+bHNLat9V87ajXNadK9Lm8+9ATNGfOWyAlIa1ta82XDHtXc3iGo/Csex3OmFQi7Uhkqbt21rz1kW2RfT/uEr8U6qUNtgjenT8eo7nrQxs1h2dK6yjRst0829Kgvua6qZsLdB+UHq5gAAAAAPCGAQYAAAAAbyiRQqHMutZmdziz1leaT5h9tOa29/2qOZv6jVLx0G7/p/mDLbU1d715nuasjRslnqyD94xsV/59vu1bvcZTCyuWcJnAjIvran76qH9Hjju4mpUQ5F3yZGVRy7O2an5urS1euSkzOqPQvU1ZPMqntGZNNb9wn5VFta9UKd7hEVO2W1/e1G9gZF+NaT/lPlxEoov2dascLiNNjl/JKdWt5GjFY/bz3ydUIvbGxibRO2WWTPmYq2SPX+UVe/0Nqr1E8+TtOzR3eCxaupUMgj42A+G/XrdZz3rvZKKzcClZ5xvs90j6/F80F7dHl55mf48cVPVLzcNetcUvmyT5godliSsYAAAAALxhgAEAAADAm+S4HpvkUju205w1e95Ojoxv+5G2ENh7p1n5wNbASj4yz7fFh1xVu8y87KJo+U2LF6dYWzZsKHRbkL/6qTajVHYbKzNIrW9z16x52I4/pfWYyP2/OqabbVAiFdf0y+y5nHnck4W+/40r7BL+29N7aG7/uJXbuB9/07zp07bRE4RKpLb9WavQj5/scpcF7v6gzVrTsVL8uo9widvsHVYyc+XQoZqrTRtfoMfPqGPnCs8cdu6CQ+2gcb9LRbVgWA/Nv+0xMu4xb55wcGQ7a2HJLJm24UT7WfigQ3gWJOujM5+7UnPLcclXcrOptf1+71bZFo88de6xkePW32Cz6HX4wX5+M7P9lbdl79dD88fX3q95bbb1V7MX7L0z+QraEgdXMAAAAAB4wwADAAAAgDcMMAAAAAB4w2cwkkBRPnchKakad73DainDq0T3vulSzRnHW/3jnRd/prlf9a8jp337UlsJ/MUBNrVt9m/TBUX3z08v0Dyr/1OaWzw+X3O1VKsbP67GIs3vD9gvcq6s+X8Idq7SOnt9DAhN0fzbrOhK3m3etddFlY/iTy3bQSbFvT2tha1++1SX/0T2pYSmtm06VlBIK4dsjWzf2cQ+O5FXzfZtK61W/+u79tFcc3T8qWh3ZoN9LE6yQ4844dNdNLeuYNNrbjnBpl0edsY7mtdn21TOB4+8WnPz6aXz/d9814txb79ntfVF2xfnas4s8RYlnlpvjdPcf/0Vmit/Gn1PS5FVJfL4Ow7/h+YLR76tuZKz99ejbxyuud5m3hQTAVcwAAAAAHjDAAMAAACAN5RIIa55d++l+ePmNnXfYdNO0Lx6z9Cl/eNt3tPaKVU1P7c+WjLyzzoLNb/wUGhl6UMFxdDlkZWa3+5rZWjPtvpW8+psKws59ga7nFx3KpeTC6vd9fachYtt0mW5t8eYdXkbzV1zrSgdfl3Vfn+SZqZkzNvMp+w97eeeD+faW1niWZllU3L+ONzuX/PLwpdFhVeMPuSQSZonZtj/+dqOsNsrWl8uP8VKoc6uvVjz0+usFKn5/SVfFrXt2L0i23tXCT+mlQB/ccMBmqsuLdj0w8kgd1lUSUjt1D6yfdKjH2seUHO15l1/uFhzm5f5PZZouIIBAAAAwBsGGAAAAAC8oUQKIiKS1qplZPuz0x8IbVXX1KehzUj15S7vap6TaSsQH3mDzS5V/z1buVtE5J8zrESqd+hc4yRaAoLCcVlWULEuq0Zoz1pNs3bYaqwNv7VZpJJxVpRElVLLVuU++Ygf8jzugc9tBd2O28bleVyyC5da3HLwe5prpcQviRIRmb7dXkuXDR+muUYRyqLC5t9kq7ePbvGY5p8y7L0ve8uWYj1Gosno10vz6D6Pak4RK6N9/rl+mpuW0MxZqQ3qa77yoeiMbOGS3k7vXmL5w+L1NwonPMvYCXd8Edl3fm37fXXTih6a2w2ao7milRRWBFzBAAAAAOANAwwAAAAA3lAiBRER2fZidKzZOq163OPuaDxJ85ubGml+9cyjNNf9JTSbQ6jkI7dXfrLFqtKl5GemqGhSa9fWXOfV9ZrDl5O/2GplUX2r2XxHsy+wkri2N1nZGsrWzNt21fxh4yc1T94eLWTr/JQtaJVV8s0qt/64pY7mM2st1Zy7nCI8W9TpL12jufXb/kp23hz4SGjL3m8vfXKw5uYVbHG9pWfb89oxtEjrqXMP19zscZuhyQpt/Vp7RLrmftWj5TcTQiVxXZ60klJeVyUvrX1bzSfdaQv0Xl53buS4IUv6aJ57ls2gl71ldsk1DsXGFQwAAAAA3jDAAAAAAOANJVJJbOOpvTV/2eWxXHvj/2jcsrK75onHtNYcLJoS73DZekCXyPaGbLsM2vFV5i8qjukP2WX/2W2f1bzvb6dorn91quZt732kecCx32uecE+0jK2izWRTnnT/x5y4t5874srIduM/KlYpjU+pda0s6pZ/fKA5RVzoqOj/1k6eMkhz69uK/tyGH1tE5LRxUzX3qGIlQh2/PldzhxEVty/v2/OduLev2GLvOdUyV8U9prjSWjTXfOPtL+V53Gnf2GJtnaZNKJG2wGQdvKfms58erfnMWraAXrdxAyP3aXHi1NAWZVHlBVcwAAAAAHjDAAMAAACAN5RIJbHVJ1opTBWX94/C/Ew7bsI53TRnL5oe93iXZue6/OG3IvseWWMLL6V8N7HgjYWk1KgR2f7voTbL0KsbbVaoetfZ8581dYbmYR+dpXnmALvvoQdeFDlvlU+Y0as0uZ42c9Sz7azUbX5o8crm782P3IfiwrwtvMCez1Nrfak5O/T/tI+2REuZ6t5is60Vdiaj8Oty7RsNI/tODc1cdfIcm2kv/arFmivybEWdKq0Mbdnihu/s8prmj2e00XzL1yfa4eGKNpECdUyr9vZ4z3R+XXN4BqvcxhxiCwAum2fHpYYeMOtvjclx7q+DNLc8aWrcYyCy9hybBWr4v2yhw5Nq2KxdXb4/R3N4AT0RFtErr7iCAQAAAMAbBhgAAAAAvKFEKsm40Ewmb+31XGhPpchxmaEL96fefrXmBr+Nlfz88dQemo+o/n1k38M3nK65hvyU77lg/ri3W2S7R+XvNJ/18smaW/8ef1aaFmNCF5oHWPyzX/T/DJ0+KUYjUWiNH7eFDuukVNV8wLOXa261uOLONORDePamcwd9GveYjGCH5utfPTuyr/XPhXt+gz42m56728pyxqS/GTnu/AV9Na+5zUqBKi1PjtmKTnjjKs1Tz35cc/jn/PRayy0f95TmlFxlSdmFLF5LEXuMnd23WWq1UA7f3x4/fP93N9fXvHVxzUK1qaJzlawMbt5rNoPkmD4PaG6caov4dv/JZosKl0Uxk2HFwBUMAAAAAN4wwAAAAADgDQMMAAAAAN7wGYwks/5E+3zEbpXz/jxFl48Ga05/If/PXSy7ch/Nk496SPONy/eLHFfjv3zuoqhS6m/Pc1+d2flP5Le1Xmq+x6B0pDVtovnEhj/EPabxrzvi3o6/c3Vqa7603h9xj9ntY/tMS/rthf9MS3g64Q03b9I8Jt1WI56+Pfo6XHNBY82VknCV6HbX2++Ooz86X/Pyva0Of5cTbSrtOpW2ak510c9NZAXxp4oNG9bEpiVOr2TTB7++waYPHvHcKXnev9XbizRnzv8z38frlOSfI0xtUD+y3fGzjZo/avZSaI/1d4/xNl16eIVupqKteLiCAQAAAMAbBhgAAAAAvKFEKsksOyT+urGTt0fLMbr+a77mvFaa3XB6b83fDntQ85Isu9g565joyrYiywrUThRO/XE21WNe/bWme/ypGmvOo3SqtE2/vp3mY6vbvMB9Jp2mud5HrKheUHPPaZXvMa0+yr/EJrdg3x6a73vtGc1dK9v/5lZmZWi+bPiwyP1rTEvuEpqwlO8naW4Wmr187YOhXITzpna019Kkj5trnrXD+uWN0w63x56Ud3lcZhEeP9mkdmqvueoLGyP7Hm5mP++Ls2yq2ZNutanuW7z+q+bCTTyM8oYrGAAAAAC8YYABAAAAwBtKpCAiIgOfvDKy3XylXUZOqW4zQCwY1kPzDxeO0PzJ5haanxpuy0RXXTreZzORl0rxX8qbTrEyttuPGqV5wnYrpGr54YrIffIqsULxhGdcGXXcY5o3ZFthRo2RdUuzSRXGHkdM15ySx//Nqo3O+70orY2VWM05v6XmNwc+orlHlSqap2+38o+Lhthq1TVGUxJV2pY9VDnu7dc/dZ7mnZVFIX8pu9uq3Ae+brOhDa8fnbFtTqbNAnbZaZdqrj/WZhOjLCp5cAUDAAAAgDcMMAAAAAB4Q4lUkqk9pZJtHG1xc5toYcycB6205rIjP9V8eV2b/uPoGbZgUdrZNnNU1cWURZW2OWfZbF3Nv6urud+NX2s+uabN4LXXiCGam/5B+UBpqDnaZjHavbLN3HXMjBM1V/6UmaMKassJe2t+vbUt7pkt8UtmVl3UR/OGdtF9d5zwpuYTaoZLBu1/cDeu2E3z13fZwqI1KYsqU9vGN9D84o1HaW72O+9rxeEq2eto7b1WxhkuiwqXRImIXHmAzYLnFvxWgq1DecAVDAAAAADeMMAAAAAA4A0lUkmm2RNWvtRtv3M0zzrhqTzv8/rGxpr3HHGZ5qaPWmlAZjZzD5W0lIVVI9vZofk4pg16wnYMspgR2KXt3V61sqh2D1M+UNreaPeF5uzQ7dm3WHlbiiwqxRaVb4uOsJ//Winxy6LCxt38uObsSA9Ezd5hr5ljvrpcc5ehVhpScyNlUYmi1Z32XpZ3r6Kw5tzZU/OM7k/EPebDjbtFtv+4O7ywruXD022Wt5uafKk51VnZ6IwdNTTf099KrUREsn+fUbBGI6FwBQMAAACANwwwAAAAAHhDiVSSCTLt8n/rAZM1Hy17Fuj+TYXSmrLS7rqxke2uabaQ0Xen2KKHL6z7h+YxQ2y2m3ZfR++PkjfvHpu5KNVN0tzpS1sErNP3v5Zmk8q11Lp1NA/qYzPabczernnkmr0039BwUtzz/JRRKbJ99/x+mlMutlLE9Jm/aKb8Bslkj31n5nvMkHqzo9sHzc7jyDBbuHdtts1C9cuW9ppdRqag/OMKBgAAAABvGGAAAAAA8IYSKaCc6jB8nOZBw/eLe0yqUH5TmtLatIps/3iWla5lBdU0d7l3k91e8s2qMDK7ttV8XUObjabLpzZDWtdhszQf1+X8uOdJXbM5sh3MnKOZ/gBENp1r5YjpgwdrbrvbEs2fd30vcp9XNzbV/OD0wzRnZdn/srfPr2nn+tBKG1PHhH9XFaTUComOKxgAAAAAvGGAAQAAAMAbBhgAAAAAvOEzGADgyZ8Dop/BqJNiU552+uKfmtNnTRYUnhv7m+bjWvTSnC42nWzkMxTjfo97Hj5nAexc1qy5mjteOTfuMTub3r65TPPeJpQvXMEAAAAA4A0DDAAAAADeUCIFAJ5s6b41sv3GxiaaOw2aoDkotRYBAFD6uIIBAAAAwBsGGAAAAAC8oUQKADzpOHBiZPt1aVlGLQEAoOxwBQMAAACANwwwAAAAAHjjgoD5TAAAAAD4wRUMAAAAAN4wwAAAAADgDQMMAAAAAN4wwAAAAADgDQMMAAAAAN4wwAAAAADgDQMMAAAAAN4wwAAAAADgDQMMAAAAAN4wwAAAAADgDQMMAAAAAN4wwAAAAADgDQMMAAAAAN4wwAAAAADgDQMMAAAAAN4wwAAAAADgDQMMAAAAAN4wwAAAAADgDQMMAAAAAN4wwIAXzsk5zsknzskVZd0WxEcfJT76qHygnxIffZT46KOKzQVBUNZtQDnnnAwTkREiMk9EegaBrC3jJiEX+ijx0UflA/2U+OijxEcfVXwMMFAszklfEflERHaIyD5BIBPLuEnIhT5KfPRR+UA/JT76KPHRR8khrawbgPLLOWkmIq+LSKqI/JM3icRDHyU++qh8oJ8SH32U+Oij5MFnMFAkzokTkddEpJGIPBsE8mIZNwm50EeJjz4qH+inxEcfJT76KLkwwEBRDRWRQ0TkZxE+oJWghgp9lOiGCn1UHgwV+inRDRX6KNENFfooafAZDBSac9JZRCZJTv1k9yCQeWXbIuRGHyU++qh8oJ8SH32U+Oij5MMVDOTLOTnMOVngnAyMXeJ8XkSqishVvEkkBvoo8dFH5QP9lPjoo8RHH4EPeaMgzhaR1pJzebO+iOwnIh8HgTxflo1CBH2U+Oij8oF+Snz0UeKjj5IcJVLIl3Nyjoi8FLppjYh0CwJZWjYtQm70UeKjj8oH+inx0UeJjz4CJVIoiNEikhHaHsybRMKhjxIffVQ+0E+Jjz5KfPRRkmOAgXwFgawTkUdFZLmIjAwCeatsW4Tc6KPERx+VD/RT4qOPEh99BEqkAAAAAHjDFQwAAAAA3jDAAAAAAOANAwwAAAAA3jDAAAAAAOANAwwAAAAA3jDAAAAAAOANAwwAAAAA3qSVdQNg+qYMYFGSEvRF9ihX3HPQRyWLPkp8PvpIhH4qabyWEh99lPh8vd8lI65gAAAAAPCGAQYAAAAAbxhgAAAAAPCGAQYAAAAAbxhgAAAAAPCGAQYAAAAAbxhgAAAAAPCGAQYAAAAAbxhgAAAAAPCGAQYAAAAAbxhgAAAAAPCGAQYAAAAAb9LKugEAACSTzEN6aj7g4bGaf1zV3g46dFFpNikhrD6/j+aqA5ZH9n292yjNKeI0P7muneZHfztYc+P3q2qu9eY4r+1E4kmpXl3zH/fuprlhhzWa1/3eMHKfdtePFZQcrmAAAAAA8IYBBgAAAABvKJECgDK08MZ9NGfttklz24Z2af/jLu8X+rypzv5/tPv40zW3umSt5sylywp9XhRNaoP6mhvcMVfzjQ2naB6RskPz/6RG6TSsjKW1aqn5+mtf13xcjbWR47IlO7RlP9sX1p2t+eID7XnNPtCO77bnFZrbX0NZTEU054VOmqcf8Hj8g3pEN4+7vlfJNQhcwQAAAADgDwMMAAAAAN5QIgUApezPW60s6ucLHtJcxVXSHJ4pJ1uCQj/Gpuytmt/d4znNfe8ZornTIEqkcktr00pz5oKF3s67+Owumj9oayUcM3ZkaP7m1D1C95jp7bET2bb0JprDZVHhn38RkafXddS8dHuduOe6o/Gk0Jb9/3TGmU9o3v/XwZH7MMNU+bXkGnsfnXrASM3Z8Q4WkZ6PDYlst5AfS6JZiOEKBgAAAABvGGAAAAAA8IYSKRRKasMGmpcNSNccHBm6tJ1iFyjH7/lm9P6hmW2uXmblAL/vWfgSEBTO9iP+oXlLEyvF2XLC+shxHeqv1vxep880t3/nIs2dLvupJJpYoa0/s7fmxwc+ozlcFpWXFze0imy/9ufemldttNmGan5QS3O11Vmaa8yx12en6RMK2OLkMfc+W+DthQFPab77lLM0B79MkcJK7Woz2zx02TNxjzllwgWaW06bWujHKO+q/DxL80GTB2iudl/dyHGVJ9hsUVkbNsQ91/6nWfnTA3c/qXmvKvb7ZcWxGZH71Ir+ikIhhBe3E4nO5HRyl4maJ+zh73/Z247ZS/N7l96vuZKrqXlV1jbNBz92teYWD1ASVZq4ggEAAADAGwYYAAAAALyhRCrJpKZ30LxpFyt32tA6+qOwqfcWzW/1eVZzo9TtmlukRi+PFkRWYOVT2UF4lhBKpOIJL0K1sWfzyL4tjVI1rzvYZgx6dC+75t+l0irNTVJttpRqrnKBHj8r1C1PHvmS5oela4HuDzPwXx9pPqDq9rjHTN9hC60NfOQqzS3+HS3PqbZhnuZo8VR8WfkfklTWnd0nsj32jBGa66VU07yhg5Vd1Pql8I+zsF8jzQdVtb59b3Ndza3Pnq85r9lvKrJwuVPNI+OXPokU7Ge42gp7jlNDv1MiM7KtLth7H+ILl0XNfLZzZN+0A57NfbiIiBww0BY6rPtq8RY63O2W3zS3TKuieVXWZs0HPUFZVCLgCgYAAAAAbxhgAAAAAPCGEqkkEL6kueYRu1T8Q/enC3gG+zFZkWWlHbes7B736NfH22w5lVdEf8TGD3oo9+FJJdi3h+ZZA+1SvatqBQBP7/eK5sapNlvTbpXzn23o7+KXsYUX93pj3V6RfVc3HK+5prNL0A8uOFxzivhbgKwiW32BleIcX/OB0B4rw7luWS/NE27qqbnpx3ZpnxKn4gsvoHf+9aMj+8JlUZcvscW76n7+h+aC9EFKrVqR7c7Hx18s7+YXbXaqlpsp4fCl7d3WX3tUsYKzJ0OL9HW5cXrkPry2Cse1slLdaQfHL4kSEXl+fXvNDX9YqjmzCI8Z/r15ReMnQnvsd+jgBcdpbnEvr6lEwBUMAAAAAN4wwAAAAADgDSVSSSClXl3Ng9raDA5fbLWygGWZdSL3GTGtr+bsibavzbtr7PYpM+I+Xrr8rNn12i2yL3WQy314Uql813LNszt9tJMj/5J3WdTPGTZLyoztzTTf9qNdKk5bZffvMGqT5pR1NuNGUC06q8r6D+1npKZNVCUZI+0xqlEiVSCXDHtXc7M8Zl378lUrKQyXRaH4UqpW1TxjSAvN79eOlkiFX0t/XLer5rS1hVuUcH2/XSPbH7a3xd4eWmuLkLV+fLLmZJw5qrgW3GZlbC8NHKm5V5XQbFGh/58+NNZ+n6VvKMJ0YFBZT27L/yARGfHdkZrT5/68kyP/LrVB/ch2i4dsMcY2afFnAdt0PDNRJhquYAAAAADwhgEGAAAAAG8YYAAAAADwhs9gJIHMxUs0f9DvH5qDjVaTn7VqdeQ+LWRq3HMVqF44xQr36z6yOLIrvIL0+3/srrm9TCrImcu9zTvs+/8hI//x/TUzTtZc7dG6kX3V/1ihOXP+n5rTJX6NcbhCNTw1Y+aXrSPHhVdoP35WP801xxRuys5k5HpGa/AH1f5V84qsLZoP+fc1mts8btMCU0XsV/bu9rmHL08aEdoT/TzMGR8P1tzpq5+kqFb135Lnvmd+319zh42TivwYyWLr8dHps+teZe9xkzva5y6yQ7+Vwp+7CN9es4H1y7Zjo+et+sF4wc4tuN2m257axaaJzc71P+rBiw7Q3PWGOZoL+/tiyRldItujW44Mbdljpn90seXVhfucB0oeVzAAAAAAeMMAAwAAAIA3lEglmcx5C0r8MYK9u2l+ve2LkX0Zga3jmX7tKmtXibcqMVTua8//XdIj3+PryOw89xXnOUvp3lXzO13+neu89raw6b6Wmqus4xJ0flK27shzXyVnU2g2+3G75iAzWX76S9/Mi2wl+tZpVhb19uZ6keO63DZXc2HLObYfYWWnP+83MrLvu201NLd7UpCPxdfa9LO/XfF4ZF92qIAwRcLTnafke/ukvV6z8+wVLUS8+F8Hap5/Q2fNaf8r3BTFFVmPQ608NtwP2bmKpmffsovmysUoWRo8+L3IdvhxJoZKi9tEZ5tGguEKBgAAAABvGGAAAAAA8IYSKXgRXnnzutdfzvO43d+6QnOHheNKtE2IcpVsBqshb7+juaarEjlu399O0VznE8qiCiNYsDjPfdWdraq+tYG99ea9VjuKIq1NK82jDw2X2djP/0O3nR65T52VRX8vWnZBhubcr6Vzx5ynOf17VpAujOxcc6qFy2QuXHiI5u/mdMz3XEN6/M/uWzdadvp0q280j3/+W82D/nup5vbXjC1AiyuWRddbudqrrR8I7amq6cDfoq+j+j8XfeaoFYPt8c6tHS01DBdiXX7HZfZ4HyVfv5QnXMEAAAAA4A0DDAAAAADeUCIFL2bcYgta7VvlS825F5PrdPNkzQVatA/ebDm6h+a+1fIuCWkw2GY4Yn6jwgm6tM11y3ea1mfb81pzSYagZMy6xGY+2zVUFjg/0xZbq/9rdGHRwpZzpLVvq/n63T7J87gqtayfN5+0t+Yabxd9Mb+KrMV9P2reQy6P7Gvzus3Al7nIShE7yMR8z/uh2KxhL59/RWTfucM+1HxhnfmaXx/wmObrP7UF3ZJldqmDT7Tvs06KvY6+3molUvVOXR65T9bGjUV+vK1N8t739daamhv9ZK9dFnxNbFzBAAAAAOANAwwAAAAA3lAihSJL7WplUb+f+KjmjMDGrXedcUH0Tpt/L/F2waTWraP5vPvei3vMfr8PiGzX/nN+CbYouaS6lFC2RcAWH1BNc2gSGxRRWssWmtv0WhT3mP5PXKO5+fQf4x6TW+ahPTVvbWTzfbW+fKbmM2utyPP+U/e1GfUGND1C8+a3C/TwSS1cLiXir1yzwQvRmYc+GtVO846xqZovrWszIi043wp6O9iEVBVCeHbBWffvqfnD5k+EjrL3sSv+Y7/T22ws3ixO2QfuofmjQfdr3hFUjhx32Xvnau4wjdknywuuYAAAAADwhgEGAAAAAG8okUKRzb3NZpOo5uySZtfvBmluN46SqLK04dAumgfW+jruMfUv2BrZzsxmbo6iCiZMjWzfvaqz5qsbTNP8wQVWDnD5sydqzlqed7kNdiLV/lf2bMc3Qjuqa8qoawu3hRcRExHZtou9Bq79x2eaT61lC/XlXkSvIN7c1EjzqoesFKearCz0uVAysjZs0Pz8K0drvvyKx+MdXuGktG+tedoptsBdeJbH8CxObUfb8xVdCrFgwovybrp+veaWafb6+ikjuvxoy6/sd1LGUb00V//T2pI19Y8itAYliSsYAAAAALxhgAEAAADAG0qkUCiu566av+3ztOavttbV3PFSm8WFYpvSl9rIyjKOvmVM3GMOm3aC5spL48+6g+L78O6DNZ96ny1c1TrNZpFadkIHzY2epkSqKNb2sVmkWqdVj3vM9HOeiHv7zsUviwrPDjZ9uy3gd8HVV0aOq/OlzTZVbe34Ijw+StPmTrYYZnaRCoDKnw3dGuR7zAPzbQa0df+orXnbDbtGjju49ey4909xVnDVospCzUPrfxbvcOlTJfqXw2fPPxn3uG+22mv96SUHWbsusDZmzZwjKBtcwQAAAADgDQMMAAAAAN4wwAAAAADgDZ/BQL5cFatDnjPMpo9rkGJ15ENf/qfmVqsLtkouSsa6Q6ym/9oGVuO6NbD64uqDdmhmWtqSU2/cEs2fbu6q+cI68zXvcvZ0zcun2sq2Kd9NLNnGVSB1P7YpgEf8y6YGHl6/YFNXvr6xseZbfuhv551g0293ON0+T/FW+881nz1lkOb6o36KnJdXVuLbcHpvzW8cYp/TSRGnOXt1dGXpiqT98On5HvNxV1t6PuVm+790dmQy27ylSOHvUxAHVrPPP3Vo867mi5tcbo89U1BGuIIBAAAAwBsGGAAAAAC8oUQK+drat7vmGQfa1LRjtlm5VNtnbXo6ygJKX1rTJpr3vfanuMecMae/5szFS+IeA78y5/+p+f3zbMraC99+UfOLbb7SvOul52pu910JN64CCa/G/M3+zTV/tcs+8Q7/m9TJczWnb/zFdqSkWj69Xtz7BqPzn+YThZfWqmVke8EZtuJ0i6832o7xkwt13tXn94lsj73dVuwOl+8cMPlUzV1utDKiivb77bbmH2tOkWo7OfKvY1xkqyDyus/yrK2az5pxluZl45pFTxA4iady97V2yI/2+mz+HWXaiYArGAAAAAC8YYABAAAAwBtKpBBXWlu7HH3CfTZjyswd2zTfPuRizVWXs0ptWdrRwS4p39vkk7jHrHqqreZasqykm4RcUibbirJ3r9pN840Np2ge3v1Lza8fe4zmqh/w+iqorHXrNbsffyvQffKa12bt2Xtp/ri9zTB0wcIDNTd4gb4piAW3WbnaSwNHau5d1crQdgRWgJQiv0buH15Zu3PLwZrb1expB4Uqaeb1tz9vhh1qZUAX1rGSqJzHiV++U/PmGpqzNsyVimrQpVdp3vWW3/M9fsmWOpoXvda+QI9R77RFmj/s8o7m/T8fqjn9AitNbCPzC3ReJDauYAAAAADwhgEGAAAAAG8okUJcM66wmVjer/ue5q7fXqK53YeUBiSK2adWiXv75O22oF7t/9ol6CDewdi50IxC8+600pkh/T/M8y6poeKbf99ns9d8do/1wIX32+vo3NoLNb9z5WLNwQdFaC+KJtTPm1rGn73mm1920dwpO/6sbYiafIGVRYVna9oRejOKLsIW/f9neN+Vh3yq+cITZ4fuEX9Bt50t9NZ1zIWaG79v76O1xo+L+31UNFVDv8fn5P1WFmJl0g1keYEeY1l1K4/rMs/+huh6rZWeVbTZucAVDAAAAAAeMcAAAAAA4A0lUhARkbXn5Fp8aMAIzaM22YJHna5ZrTmz5JuFnUjp1kXzf455MrTHyjpOGWeX/9tlFmxGHcSX1ryp5qnnPL6TI014hppK172ruW3llZobpsZf3Oredjbbyg2dTo/sy5pVcWe1KWuukv1a7H/S95rf3NRIc6fLKIsqrEouPFtU/rdHZ3fKueUvg+vO05wduj2vGaFOmt1P8/oRrSWsAzO0lbimj9rCd01Dt1MWVbFxBQMAAACANwwwAAAAAHhDiVQyC82WUvWM6MJrm7PtWvVjt5yqudbC5JhZozzY2qaW5l5VrDRgU5Chudl/4s8uhcLLXr9B88i1nTRfXm9Wge5/du3F+R8U8uHG7prdpi2Fui+KLsiw18+0DbaA5Rtje2tOF8pqCmufq2xh1hXH2nM8/aDnNRd0FqnwvgN+P0Xz8lkNNbd/12bQS/3aFu2ryiKjQKngCgYAAAAAbxhgAAAAAPCGEqkktvpcWyzsp25PRPYdOPlszbXepCwqES04Pv7tr6y32aWqMkOKN9kbN2r+7Lx9NY/qeLjm1cflXcp0bKcpmu9t+nPcY15YbzPcfHHDAZqrLqUfy8LWA20hsfQCLiqG+MK/R2q9abcfIz2Ldd7aMiduBlC2uIIBAAAAwBsGGAAAAAC8YYABAAAAwBs+g5HE1uxp0/69vrFxZF+dS2yNTVbsTkyjDx8Z2qqs6ZHPj9LcUfj8TIkYP1lj7dDHI2r/J++7TAnlgtSdV2UqVABAOcUVDAAAAADeMMAAAAAA4A0lUkkmpUYNzRfsP0bz4/cOiBxXb97Y0moSiujt9VZmk1rnF82dn16lOUsAAABKF1cwAAAAAHjDAAMAAACAN5RIJZnszZs1f7N7Nc31hJKo8mZc90qWpU9oz+zSbwwAAEAMVzAAAAAAeMMAAwAAAIA3LgiCsm4DAAAAgAqCKxgAAAAAvGGAAQAAAMAbBhgAAAAAvGGAAQAAAMAbBhgAAAAAvGGAAQAAAMAbBhgAAAAAvGGAAQAAAMAbBhgAAAAAvGGAAQAAAMAbBhgAAAAAvGGAAQAAAMAbBhgAAAAAvGGAAQAAAMAbBhgAAAAAvGGAAQAAAMAbBhgAAAAAvGGAAQAAAMAbBhgAAAAAvGGAAS+ck3Ock0+ckyvKui2Ijz5KfPRR+UA/JT76KPHRRxWbC4KgrNuAcs45GSYiI0Rknoj0DAJZW8ZNQi70UeKjj8oH+inx0UeJjz6q+BhgoFick74i8omI7BCRfYJAJpZxk5ALfZT46KPygX5KfPRR4qOPkkNaWTcA5Zdz0kxEXheRVBH5J28SiYc+Snz0UflAPyU++ijx0UfJg89goEicEycir4lIIxF5NgjkxTJuEnKhjxIffVQ+0E+Jjz5KfPRRcmGAgaIaKiKHiMjPInxAK0ENFfoo0Q0V+qg8GCr0U6IbKvRRohsq9FHS4DMYKDTnpLOITJKc+snuQSDzyrZFyI0+Snz0UflAPyU++ijx0UfJhysYyJdzcphzssA5GRi7xPm8iFQVkat4k0gM9FHio4/KB/op8dFHiY8+Ah/yRkGcLSKtJefyZn0R2U9EPg4Ceb4sG4UI+ijx0UflA/2U+OijxEcfJTlKpJAv5+QcEXkpdNMaEekWBLK0bFqE3OijxEcflQ/0U+KjjxIffQRKpFAQo0UkI7Q9mDeJhEMfJT76qHygnxIffZT46KMkxwAD+QoCWScij4rIchEZGQTyVtm2CLnRR4mPPiof6KfERx8lPvoIlEgBAAAA8IYrGAAAAAC8YYABAAAAwBsGGAAAAAC8YYABAAAAwBsGGAAAAAC8YYABAAAAwBsGGAAAAAC8SSvrBsD0TRnAoiQl6IvsUa6456CPShZ9lPh89JEI/VTSeC0lPvoo8fl6v0tGXMEAAAAA4A0DDAAAAADeMMAAAAAA4A0DDAAAAADeMMAAAAAA4A0DDAAAAADeMMAAAAAA4A0DDAAAAADeMMAAAAAA4A0DDAAAAADeMMAAAAAA4A0DDAAAAADepJV1AwAAAAARkWDfHpHtfZ8cr/mTxbtobnDhVs2ZixaXeLtQOFzBAAAAAOANAwwAAAAA3lAihbhSGzXSvPqojpo3HrdRc5fGyzWP6vBZgc573fKemidcZ7nS578UqZ0omuz999C8ZOj2yL7Pej2juVlqdc2HnXeh5sqf0V9lafnl+2j++bqRmi9ceJAdc5S9vWetXVsq7QLKm3lv7q555gGvaD58wCDN7odJpdii5LTh9N6aR903IrIv/HvoxoZTNHcecqnmDldTIpVouIIBAAAAwBsGGAAAAAC8oUQqyaS1aql5xWGtNGccvy5y3OO7v6G5T5VP454rRZzmbAk0T9+xQ/OPWzpE7nNL47GaL7ylrubVn+fTcBRJau3amudcu6vmUWc+rLlrpUq57lVN032ru2qutCnTfwNRIEuH7RPZ/nDI/ZqzpYrmZ1uN0dz1ifM1dziDEqlEknF0L80retrrL62n9dOHez6n+b8brIxHROSzbrUFRbPplN6R7V/3e0TzjsD6YuUeVpbT+IcSb1ZSmvmcvQ6+PNzKouqnVI4ct/vjl2luedifdtwuq0qwdSgurmAAAAAA8IYBBgAAAABvKJFKMkd/9pvmC+uMzvO4HUGW5qVZNsvQGdPO1rz5w6aa6862sqiqi2ymqewpMyLnffydgzT/e4+XNN8kvQR+hGeIanL/HM0ftH5C87gMuwS91y9nRO5f5+lammtMWqjZLZ3ks5kohANOmxDZbpJaJY8jTcvXcpe+obBS69XTvOhcKxfc2GVHvMOlRevVke3bOsV/jz2gqvVnJZca95gdgZUqPjfqyMi+1vJjHi1GPGnN7HdV9rkrI/uqOHudvL2poeZmb83SnCXwZck1Vu753RFW6hkui9r33qGR+7R83H7eF1Tro7l2T0qkEhlXMAAAAAB4wwADAAAAgDeUSCWZ//xppUi12m7V/PKi6Cw1255trrnm/43TXEPmxs1h2Tt5/M2rq+9kL4pq9fl22fjVmx7U3LGSldLcvspmovn65n01Nx49Ps/zMm9U2dlwhs12c3bDkTs50hw9o7/mat9ZeeLOXpPI29IzrCzq9+FPejyzlUW9v9neE4d8dZbmzs9t1tx6AiVRxfHnme01v971oci+pVn2Lnf91wM0d9k8ReDHjsNsUd1JQx7XvDbbZqLs88BQzU0f5+e9IuAKBgAAAABvGGAAAAAA8IYSqSRT5xSbdeFN2UVz2oY/I8fVlOh2UaU2bBDZfvlQWzzq1vnHh/Ys8fJ4FZmrEp05aPZdNlvUpNNs4bwqzo7b5ZvQYmvnTtdcLSPvsiiUna3999L88j1W6tYmrXK8w/8mY2QzzSkbF+7kSOQlpbqVLFU9ZnncY5ZmbtK8INNme9ou0Rmhzvn6As1tR1k5iMu0orUqY62ULX2zvS5t6VIUReYhVpZz1jlfaM69sOgJs06wfY/aQodZW7aUYOsqvtQG9TV3u29i3GMOefhqzc0ezbssylWy97/9j/hd82+rmsc7HAmCKxgAAAAAvGGAAQAAAMAbSqSSTNaGDSX/IM5KAeY9Fb2E2aeKLVk0Y2orzZ0okYortWsnzb3enBbZ91FDm9Xmh4yqmu8cOEhz+x8maS5oyUX4cvQfT3TXnN5hqR106KICng2FkXqpleS0S6ua53HzMrdpvmDIlZqrvUfpW3GtPdFmWxvX/WnNW7JtwdEBVw7TXOPtn/I8V7r8ku/jMcOXP6mNGmmeeab9HrqqvpWhfb01+rpa+p+2mhtOH1tyjUsym/t01Hx/U3sdjcuwY1q9NV/zzmYsXHOmlbt91NIWjO048WLN9QSJhisYAAAAALxhgAEAAADAGwYYAAAAALzhMxjwr1c3jb/v81Jk18D5fTV3uXW25izBX1Jq1NDc9bU5mm9uODly3Ig1nTV/c9qemt3USYV6vNQmjSPbC59uqPmNUA36zQNtylsnfAajOFJr19a8/LRdNY/bxVa5zd5Jdf5FM8/QzOcu/Np8UvzPqXV/bYjm9m9Tq5+ItuzVVvPMI5+Oe8zNt5wf2W74On1ZEipfuTTu7VfdeqnmuosL9txvPX695qfXt9Hc+bJfNTOtc+LhCgYAAAAAbxhgAAAAAPCGEil4kVKrlubez03I87g/3uiiufGqvFfuTGabDrcSs3ub2lS0p849InLcluOssCxr7R+Feoy0Vi01b34h+jbwXudnNV94npWFpP2Qd7+ikJpZWdqPtzwW2hH/fz5zd+yIbGc+1VRzZVngtWnJbvv2+L8Wqy93cW9H2Urt1F5z1WHxpztfmrVVc60/M+Ieg+LLOshKdUelW7nn4iwr96z7Sv5lUesG9ols/9TrUc27fmwlVumZPxepnSgdXMEAAAAA4A0DDAAAAADeUCIFLxa/aiU3NzQcozn904six3V+2ma8YdaH+KqusRWDtwSWO9RcFTlu1B17a27xld1e8/MpmrM3b9YcXqF7zbNVNH+76/9FzrvfdVdrrvsVM6z4klK9uuYZlzYo1H0HPDU8st3iHcoLS8rv+70Q2qqkqcVn9vpj1rvEsX4PKzcc0/mJuMcc+qa9p7X/jve0krLsCis/q51iK6b3emOw5g4yLt/zBKdFf9fN3GF/LbT7L385lBdcwQAAAADgDQMMAAAAAN5QIoUiW/1Pm+nhy54jNK8NrQ+2y+3LI/fJzMws8XaVdynfTNR80YJ+ml9v+2XkuDtPCM3qdILFeZnbNN+86FjNac465oM2Vha1+4+DIudt/ToLt5WEGSNsdrAZxz++kyNz3LHSZmRpMyo6Ow6vIiSzbcfspfnZ+x8O7bEy0AmhyaI6/MveKymw8WfVRdHZnqb0fkrz1ct6au4wPP+yqGVD9tH8a4/o++M/7rpSc+PPKQ8tL7iCAQAAAMAbBhgAAAAAvKFECoUSLov65Ta7HLoo0y48n3+KzRghC34vlXZVVGv3XaO5X5Xo5ejVZ1gJzfY6tghYq/7zNI/u9JHmVGf/T8gK7PjaH9SMPmg2c+T44nruqvm+w97SnJLH/3ZuW9lD8699bTG9rJXzvbcN8f13kz3vXSov1ZzRzBYTTZtWqk1CLhsvXK85PTQ7XtiV/7IF2WrvyL9EBwXj0uzPxkan/hnZ995m+10y9bT2oT1z4p5r6/FW6jZmuJVZX7Nsv8hxTV+drDlbUF5wBQMAAACANwwwAAAAAHhDiRTy5Xrtpvnjm+0yZlZQTXP/220howbjWMioJAQZGZHt+i/a85xS1RY1ml3PSqeyO1np2g3Le2huV2Wl5qOHfRM5708fN9OctWp10RsMOfRlK804voYtHpXXZf4PXtxfc9OVzJZSFvpWt7KPURu7aK7y8yzNFBGWvpRaVqJ2cafv4h5z32orSaz/42LNzLrmz7LBVtY0oXN0tqduP56jufXMyRJP6i7pmh9/9DHNO0Lze/02vEf0Pht/LVJbUba4ggEAAADAGwYYAAAAALyhRApxZR+4h+ZLnx+luV6KleJ0ftNmi+r0mi0OxywPpW/hUCuLmnzBSM1Pr7OZPCYf3shyI7tMPeyDtyPnemfgQZqbPkyZTmGsGLxPZPucOg+EtuLPdhOetajpozzfZe1fSw7XfHHjrzV/XHt3O2jDhtJsEkRkyflWqntu7TGas0O/cT675UDN1Rf8VCrtSjb1jrXSs6+2Vonsa3Ov9UVeCxqe/96nmncNzQDW6w6b9avR15RZVwRcwQAAAADgDQMMAAAAAN4wwAAAAADgDZ/BQHw325Sa/arbqql7jh+oucMwm4KTz12UvtT0DprHXGq1/tN3pGp+Z5jVk1de+bPmlC1bNC/c0SBy3u21vTazwktt0ljzSRf/L7KvTkr8z11MzLD/7bx62pGhPVO9tg2FN25JG83PtbL+C2rXKIvmIKbe0Uvi3j544cGaq7/L5y5KQnjF7Y+72tSyh/w2MHJcvQn2/pWyu03x3Pc/4zX3r7FOc6evLtCc/sIEzXl9fgPlC1cwAAAAAHjDAAMAAACAN5RIQUREll4VnV7z5y6Pan5ozS6aW5w2RzOXMcvWjMsaaq4Tmj5435cv09z20/jT/W09wPr0zFrfRvbd56uBSWL6HVZSM7rBJ5F9eZUOnv7NhZrTJ07I4yggea26sE9k+8Mu4Smfq2ma/kg3zbVknMC/hcfYb/tqzsoGt3/aKHLctmPrab7tkec171/V1lJ/dG1HzV2uXqQ5a8d2P41FwuAKBgAAAABvGGAAAAAA8IYSqSSWcXQvzW9dMSKy7+utdqnz60E2g0SQEX+Wm9TaNvVQdkZG6PiMeIejCMKzFYmIzDrpKc0vbGilucP91kdZ4Ts4p3HBKVa8MzYjNXyUtH9hgeZMQTypDepr3n+3PzRXctHnckeojrDzu4M1p1/GbDfFldLDyvwyGlfXXOXbKZqzt20r9Hl7N7ef/wkZVrbhNmwu9LlQOC7N/iRpcea8yL6GqVYWdeUSK+mt/favminbLV19z42W4N7W2N7Xqjjry67fDdLc7rTfQ/dYUVJNQwLgCgYAAAAAbxhgAAAAAPCGEqkktuBku6DcsVKVyL5+o8/V3GlC/Jk5Vl1ks3yccOnXmhdus/Kqr77rHblPeHE+FE92qCDgwXeO19x2Q/yZo/682fpr5uEjNXd76bLIcW0Xxb8/QhpaidRzrd/QvCOI/s/m/c32Wkh/YZNmSjmKr9PzszQ/1twWkXxvc03Nm7Oj72sFsX81K5FamRV/oUSUjO2H9ND8QcdnIvu2BFau9u2bPTU32/FjibcrGaU1baL53cMeD+2ppOneJtEZ8OZnWh8d9vmlmjtfMlEz733JgysYAAAAALxhgAEAAADAG0qkkkz2fj00Tz38Sc1vb2oSOa7FGLuQueA2m7Hjy3Pv19ws1WbvyM7jwucumT3i3g6/aiy0nFrPynJmXd9F8/un2ExhZ847VnP7e23WHZG8F4dD/r7eWjWyffUnZ2juNJHyQJ9mb7RFvpZmWvlZ/xrhozZJ4VmJ1f3L97abM5lTraS5a/KeVWiP/1nJTacHKYsqaTOvbK+5TZrNR/jqRpux8PbPT4jcp3lozdb0/9qMUpRFJSeuYAAAAADwhgEGAAAAAG8okUo2qbbYWnhRsJNqroocdtLjT0p81eLe+vQ6u5z6wjP9NHd6fUbkuMjCbyiUrJWrI9tHz+iv+aXrHta85mpbdGzfql9qvmzxoZo3D7CZQLI3Rs+Lorvof4Mi2+lDKIsqKVkHL9F8bs+LNG9paTVSq7rZr7jtu24p9GPU+cbe7xouY3a1knZPh7c1T98R/W3R8m3+XClN9brZ3wR7fjhUc/ol4zV3EhYMRd64ggEAAADAGwYYAAAAALzhmmOSqTTFFpE6abaVMr3d8aM877Moc6vmfj9frLn+G1aKUOuTyZqbbLEZPiiJ8ig7+mymDq+tueuHVvIklXZo3OWb8zV3vC/DTrV0egk0MHnMP7lx3Ntbfezi3o6SFUyYqrlaaO2vVqPLoDHwYuCjV0W2m45m5qjSVK+fLWRZbyfHAXnhCgYAAAAAbxhgAAAAAPCGAQYAAAAAb/gMRpLJWr3G8oF2+zHSs0D3byVT4t7O6s+lL5hodefHtIjff+1lkmb6yJ9Wd1k9+HF39dJcTcbHOxxAAdzUzl5LTYXPXADlGVcwAAAAAHjDAAMAAACANwwwAAAAAHjDAAMAAACANwwwAAAAAHjDAAMAAACANwwwAAAAAHjDAAMAAACANy4IgrJuAwAAAIAKgisYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALxhgAEAAADAGwYYAAAAALz5f8bVRe6QvGRjAAAAAElFTkSuQmCC",
            "text/plain": [
              "<PIL.Image.Image image mode=RGBA size=792x648 at 0x7FA6BB6418E0>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "from google.colab.patches import cv2_imshow\n",
        "import cv2\n",
        "img = cv2.imread(RESULT_IMG_PATH, cv2.IMREAD_UNCHANGED)\n",
        "cv2_imshow(img)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "z8kp_R22QCJQ"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "accelerator": "TPU",
    "colab": {
      "machine_shape": "hm",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
