{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_AoxDl52uYVd"
      },
      "source": [
        "# Transformation Coding [Double-Bump World]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tdtcedbh2ot2"
      },
      "source": [
        "## Imports"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qsn5NtvYldkJ"
      },
      "source": [
        "import os\n",
        "import io\n",
        "import cv2\n",
        "import glob\n",
        "import time\n",
        "import math\n",
        "import json\n",
        "import imageio\n",
        "import colorsys\n",
        "import multiprocessing as mp\n",
        "from os import path\n",
        "from datetime import datetime\n",
        "from functools import partial\n",
        "\n",
        "import numpy as np\n",
        "from PIL import Image\n",
        "from tqdm import tqdm\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "import torch.nn.functional as F\n",
        "from torch import autograd\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "from mpl_toolkits.mplot3d import Axes3D\n",
        "%matplotlib inline"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a12JiTew2q9q"
      },
      "source": [
        "## Prepare Experiment Folders"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "x_FH7sPhIQVT"
      },
      "source": [
        "timestamp = str(datetime.now()).replace(' ', '_')\n",
        "EXPERIMENT_DIR = 'doublebumpworld_%s' % timestamp\n",
        "os.makedirs(EXPERIMENT_DIR)\n",
        "%cd $EXPERIMENT_DIR\n",
        "!mkdir gifs plots videos saved_models figs_final figs_progress"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hoIkwOnZKsP9"
      },
      "source": [
        "## Logging"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "eCx_rhuiKr9v"
      },
      "source": [
        "try:\n",
        "    LOG_STR.close()\n",
        "    LOG_FILE.close()\n",
        "except:\n",
        "    pass\n",
        "\n",
        "timestamp = str(datetime.now())\n",
        "LOG_STR = io.StringIO()\n",
        "LOG_FILE = open('log.txt', 'w')\n",
        "\n",
        "try:\n",
        "    old_print\n",
        "except NameError:\n",
        "    old_print = print\n",
        "\n",
        "def print(*args, **kwargs):\n",
        "    kwargs['flush'] = True\n",
        "    old_print(*args, **kwargs)\n",
        "    kwargs['file'] = LOG_STR\n",
        "    old_print(*args, **kwargs)\n",
        "    kwargs['file'] = LOG_FILE\n",
        "    old_print(*args, **kwargs)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EVOc_tr3K4CU"
      },
      "source": [
        "## Arguments"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BEjSsrvSKiwM"
      },
      "source": [
        "class Args():\n",
        "    def __init__(self):\n",
        "        self.signal_len = 64\n",
        "        self.bump_len = 16\n",
        "        self.learning_rate = 1e-3\n",
        "        self.weight_decay = 1e-7\n",
        "        self.batch_size = 64\n",
        "        self.num_epochs = 100\n",
        "        self.steps_per_epoch = 100\n",
        "        self.num_actions = 16\n",
        "        self.barrier_type = 'log' # 'inv' or 'log' or 'hinge'\n",
        "        self.barrier_coef = 1e-7\n",
        "        self.cosine_sim = False\n",
        "        self.conformal_map = True\n",
        "        self.mlp_hidden_dims = 128\n",
        "        self.code_size = 4\n",
        "        self.proj_size = 3\n",
        "        self.decompositions = 1\n",
        "        self.decomposed_action = False\n",
        "\n",
        "\n",
        "args = Args()\n",
        "assert args.code_size % args.decompositions == 0 # TODO: the size could actually be different for different groups\n",
        "subcode_size = args.code_size // args.decompositions\n",
        "print(json.dumps(vars(args), sort_keys=True, indent=4))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-RN2xmEDwHmp"
      },
      "source": [
        "## Transformation Sampling"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZYehiaT5mRoN"
      },
      "source": [
        "def triangle(n, w):\n",
        "    x1 = np.linspace(start=0, stop=1, num=w // 2, endpoint=False)\n",
        "    x2 = np.linspace(start=1, stop=0, num=w - w // 2, endpoint=False)    \n",
        "    x = np.concatenate([x1, x2, np.zeros(n - w)])\n",
        "    assert len(x) == n\n",
        "    return x\n",
        "\n",
        "\n",
        "def rectangle(n, w):\n",
        "    x = np.ones(w)\n",
        "    x = np.concatenate([x, np.zeros(n - w)])\n",
        "    assert len(x) == n\n",
        "    return x\n",
        "\n",
        "\n",
        "def half_circle(n, w):\n",
        "    x = np.sin(np.linspace(start=0, stop=np.pi, num=w, endpoint=False))\n",
        "    x = np.concatenate([x, np.zeros(n - w)])\n",
        "    assert len(x) == n\n",
        "    return x\n",
        "\n",
        "\n",
        "def cyclic_shift(x, d):\n",
        "    x = np.concatenate([x[d:], x[:d]])\n",
        "    return x\n",
        "\n",
        "\n",
        "def random_cyclic_shift(x):\n",
        "    n = len(x)\n",
        "    d = np.random.randint(n)\n",
        "    x = cyclic_shift(x, d)\n",
        "    return x"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Tue5khQzx0sD"
      },
      "source": [
        "def sample(batch_size, num_actions, action_type=None):\n",
        "    assert num_actions >= 2\n",
        "    assert action_type is None or action_type in [0, 1]\n",
        "    delta1_list = [np.random.randint(args.signal_len) if action_type is None or action_type == 0 else 0 for _ in range(num_actions)]\n",
        "    delta2_list = [np.random.randint(args.signal_len) if action_type is None or action_type == 1 else 0 for _ in range(num_actions)]\n",
        "\n",
        "    x_list = [np.empty((batch_size, args.signal_len)) for _ in range(num_actions)]\n",
        "\n",
        "    for i in range(batch_size):\n",
        "        sig1 = 0.5 * random_cyclic_shift(rectangle(args.signal_len, args.bump_len))\n",
        "        sig2 = 0.5 * random_cyclic_shift(triangle(args.signal_len, args.bump_len))\n",
        "        for j in range(num_actions):\n",
        "            # NOTE: if this part changes, visualization might also need to change!\n",
        "            x_list[j][i] = cyclic_shift(sig1, delta1_list[j]) + cyclic_shift(sig2, delta2_list[j])\n",
        "\n",
        "    return x_list"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5dvE1VCgnTvC"
      },
      "source": [
        "## Test Transformation Sampling"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EUiZMIaw1DB2"
      },
      "source": [
        "x1, x2 = sample(args.batch_size, num_actions=2)\n",
        "print(x1.shape, x1.dtype)\n",
        "print(x2.shape, x2.dtype)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "e2tXK14umWhN"
      },
      "source": [
        "plt.style.use('default')\n",
        "\n",
        "fig = plt.figure(figsize=(6, 3))\n",
        "idx = 2\n",
        "\n",
        "plt.subplot(1, 2, 1)\n",
        "plt.ylim([-0.1, 1.1])\n",
        "plt.plot(x1[idx])\n",
        "\n",
        "plt.subplot(1, 2, 2)\n",
        "plt.ylim([-0.1, 1.1])\n",
        "plt.plot(x2[idx])\n",
        "\n",
        "plt.show()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sJidR4-HkNvT"
      },
      "source": [
        "## Visualization"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TjXBdXCnkN1B"
      },
      "source": [
        "viz_dataset = np.empty((args.signal_len, args.signal_len, args.signal_len))\n",
        "\n",
        "for i in tqdm(range(args.signal_len)):\n",
        "    for j in range(args.signal_len):\n",
        "        viz_dataset[i, j] = 0.5 * cyclic_shift(rectangle(args.signal_len, args.bump_len), i) + 0.5 * cyclic_shift(triangle(args.signal_len, args.bump_len), j)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "IeDXpcqMYryq"
      },
      "source": [
        "if subcode_size > args.proj_size:\n",
        "    t = np.random.randn(subcode_size, subcode_size)\n",
        "    PROJ_MAT = np.linalg.svd(t)[0][:, :args.proj_size]\n",
        "    PROJ_MAT /= np.sum(PROJ_MAT, axis=0, keepdims=True)\n",
        "else:\n",
        "    PROJ_MAT = np.eye(args.proj_size)\n",
        "\n",
        "with np.printoptions(linewidth=300, formatter={'float': lambda x: '%12f' % x}):\n",
        "    print('PROJ_MAT =\\n', PROJ_MAT)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "nJRwJCZYyHvz"
      },
      "source": [
        "def scatter_3d(ax, x, y, z, c, elev=30, azim=-60, **kwargs):\n",
        "    elev_rad = elev / 180 * np.pi\n",
        "    azim_rad = azim / 180 * np.pi\n",
        "    ax.view_init(elev, azim)\n",
        "    idx_sort = np.argsort(np.cos(azim_rad) * x + np.sin(azim_rad) * y + np.sin(elev_rad) * z, kind='heapsort')\n",
        "    ax.scatter(x[idx_sort], y[idx_sort], z[idx_sort], c=c[idx_sort], **kwargs)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DYxIDwtKkiQx"
      },
      "source": [
        "def visualize(ax, enc, data, s=1, **kwargs):\n",
        "    signal_len = data.shape[-1]\n",
        "    device = get_device(enc)\n",
        "    with torch.no_grad():\n",
        "        subcode_size = enc(torch.tensor(data[0, :1], dtype=torch.float32, device=device)).cpu().numpy().shape[-1]\n",
        "\n",
        "    code = np.empty((signal_len, signal_len, args.proj_size))\n",
        "\n",
        "    for i in range(signal_len):\n",
        "        with torch.no_grad():\n",
        "            code[i] = enc(torch.tensor(data[i], dtype=torch.float32, device=device)).cpu().numpy() @ PROJ_MAT\n",
        "\n",
        "    colors = np.empty((signal_len, signal_len, 3))\n",
        "\n",
        "    for i in range(signal_len):\n",
        "        for j in range(signal_len):\n",
        "            hls = (j / signal_len, 0.5, 1)\n",
        "            rgb = colorsys.hls_to_rgb(*hls)\n",
        "            colors[i, j] = rgb\n",
        "\n",
        "    code = code.reshape(-1, args.proj_size)\n",
        "    colors = colors.reshape(-1, 3)\n",
        "\n",
        "    if args.proj_size == 3:\n",
        "        scatter_3d(ax, code[:, 0], code[:, 1], code[:, 2], c=colors, s=s, **kwargs)\n",
        "    elif args.proj_size == 2:\n",
        "        ax.scatter(code[:, 0], code[:, 1], c=colors, s=s, **kwargs)\n",
        "    else:\n",
        "        assert False"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "I40eSpQk4Msy"
      },
      "source": [
        "## Encoder Network"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BvTiDiQayKFp"
      },
      "source": [
        "def count_parameters(net):\n",
        "    return sum(p.numel() for p in net.parameters() if p.requires_grad)\n",
        "\n",
        "\n",
        "def get_device(net):\n",
        "    '''\n",
        "    Returns the `torch.device` on which the network resides.\n",
        "    This method only makes sense when all module parameters reside on the **same** device.\n",
        "    '''\n",
        "    return list(net.parameters())[0].device"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8GGdOd71_scf"
      },
      "source": [
        "enc = nn.Sequential(\n",
        "    nn.Linear(args.signal_len, args.mlp_hidden_dims),\n",
        "    nn.ReLU(),\n",
        "    nn.Linear(args.mlp_hidden_dims, args.mlp_hidden_dims),\n",
        "    nn.ReLU(),\n",
        "    nn.Linear(args.mlp_hidden_dims, args.mlp_hidden_dims),\n",
        "    nn.ReLU(),\n",
        "    nn.Linear(args.mlp_hidden_dims, args.code_size, bias=False)\n",
        ")\n",
        "\n",
        "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
        "enc.to(DEVICE)\n",
        "print(enc)\n",
        "print('Device:', DEVICE)\n",
        "print('%d parameters' % count_parameters(enc))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "tLiqP3cGzI3F"
      },
      "source": [
        "class SubEncoder(nn.Module):\n",
        "    def __init__(self, enc, le, ri):\n",
        "        super().__init__()\n",
        "        self.enc = enc\n",
        "        self.le = le\n",
        "        self.ri = ri\n",
        "\n",
        "    def forward(self, x):\n",
        "        return self.enc(x)[:, self.le: self.ri]"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UiRpYgiHiPOJ"
      },
      "source": [
        "## Optimizer"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Yh2gKlJhfTaU"
      },
      "source": [
        "opt = optim.Adam(enc.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)\n",
        "scheduler = optim.lr_scheduler.StepLR(opt, step_size=1000, gamma=0.5, verbose=True) # does nothing for now"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vy_YGvhWCYTO"
      },
      "source": [
        "## Loss Function"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TgRR3TW3m0Qg"
      },
      "source": [
        "def cdist_mean(x1, x2, p=2.0, *args, **kwargs):\n",
        "    dim = x1.shape[1]\n",
        "    return torch.cdist(x1.contiguous(), x2.contiguous(), p, *args, **kwargs) * (dim ** (-1 / p))\n",
        "\n",
        "\n",
        "def loss_fn(enc, x_list, barrier_type, cosine_sim=False, conformal_map=False, decompositions=1, action_type=None):\n",
        "    device = get_device(enc)\n",
        "    num_actions = len(x_list)\n",
        "    assert num_actions >= 2\n",
        "    x_list = [torch.tensor(x, dtype=torch.float32, device=device) for x in x_list]\n",
        "    z_list = [enc(x) for x in x_list]\n",
        "    code_size = z_list[0].shape[1]\n",
        "\n",
        "    #-- symmetry loss --#\n",
        "    subcode_size = code_size // decompositions\n",
        "    loss_equiv = 0\n",
        "    for k in range(decompositions):\n",
        "        h_list = [z[:, k * subcode_size: (k + 1) * subcode_size] for z in z_list]\n",
        "\n",
        "        if conformal_map:\n",
        "            h_list = [(h[:, None, :] - h[None, :, :]).view(-1, subcode_size) for h in h_list]\n",
        "\n",
        "        if cosine_sim or conformal_map:\n",
        "            D_list = [1.0 - F.cosine_similarity(h[:, None, :], h[None, :, :], dim=2) for h in h_list]\n",
        "        else:\n",
        "            D_list = [cdist_mean(h, h, p=2) for h in h_list]\n",
        "\n",
        "        L_equiv = torch.zeros(num_actions, num_actions)\n",
        "        for i in range(num_actions):\n",
        "            for j in range(i + 1, num_actions):\n",
        "                if action_type is None or k == action_type:\n",
        "                    L_equiv[i, j] = torch.mean((D_list[i] - D_list[j]) ** 2)\n",
        "                else:\n",
        "                    L_equiv[i, j] = torch.mean((h_list[i] - h_list[j]) ** 2)\n",
        "        cur_loss_equiv = torch.sum(L_equiv) / (num_actions * (num_actions - 1) / 2)\n",
        "        loss_equiv += cur_loss_equiv / decompositions\n",
        "\n",
        "    #-- barrier loss --#\n",
        "    z_all = torch.cat(z_list, dim=0)\n",
        "    D_all = cdist_mean(z_all, z_all, p=2)\n",
        "    mask = torch.eye(D_all.shape[0], dtype=torch.bool)\n",
        "    if barrier_type == 'log':\n",
        "        loss_barrier = torch.mean(-torch.log(D_all[~mask] + 1e-9))\n",
        "    elif barrier_type == 'inv':\n",
        "        loss_barrier = torch.mean(1.0 / (D_all[~mask] + 1e-9))\n",
        "    elif barrier_type == 'hinge':\n",
        "        loss_barrier = torch.mean(torch.maximum(torch.zeros(1, device=device), 5.0 - D_all[~mask]))\n",
        "    else:\n",
        "        assert False, 'Unknown `barrier_type`'\n",
        "\n",
        "    return loss_equiv, loss_barrier"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PKZOSNYX_f-x"
      },
      "source": [
        "## Train"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-prwPRYR_kFt"
      },
      "source": [
        "def check_weights_nan(parameters):\n",
        "    for p in parameters:\n",
        "        if torch.isnan(p).any():\n",
        "            return True\n",
        "    return False\n",
        "\n",
        "\n",
        "def get_weights_norm(parameters, norm_type=2.0):\n",
        "    with torch.no_grad():\n",
        "        return torch.norm(torch.stack([torch.norm(p, norm_type) for p in parameters]), norm_type).item()\n",
        "\n",
        "\n",
        "def get_grads_norm(parameters, norm_type=2.0):\n",
        "    with torch.no_grad():\n",
        "        return torch.norm(torch.stack([torch.norm(p.grad, norm_type) for p in parameters]), norm_type).item()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VJcYn9Vn_kK7"
      },
      "source": [
        "torch.save(enc.state_dict(), path.join('saved_models', 'model_epoch_%03d.tar' % 0))\n",
        "\n",
        "fig = plt.figure(figsize=(6 * args.decompositions, 6))\n",
        "enc.eval()\n",
        "for i in range(args.decompositions):\n",
        "    ax = fig.add_subplot(100 + (args.decompositions) * 10 + (i + 1), projection='3d' if args.proj_size == 3 else None)\n",
        "    visualize(ax, SubEncoder(enc, le=i * subcode_size, ri=(i + 1) * subcode_size), viz_dataset)\n",
        "fig.savefig(path.join('figs_progress', 'codes_epoch_%03d.png' % 0), bbox_inches='tight', dpi=200)\n",
        "plt.show()\n",
        "plt.close()\n",
        "\n",
        "avg_loss_list = []\n",
        "avg_loss_equiv_list = []\n",
        "avg_loss_barrier_list = []\n",
        "for t in range(args.num_epochs):\n",
        "    enc.train()\n",
        "    loss_list = []\n",
        "    loss_equiv_list = []\n",
        "    loss_barrier_list = []\n",
        "\n",
        "    time_start = time.time()\n",
        "    progress = tqdm(range(args.steps_per_epoch), desc='Loss: None | Loss Equiv: None | Loss barrier: None | L2 Weights: %12g | L2 Grads: 0' % (\n",
        "        get_weights_norm(enc.parameters(), norm_type=2.0)\n",
        "    ), total=args.steps_per_epoch, position=0, leave=True)\n",
        "    for k in progress:\n",
        "        if not args.decomposed_action:# or k % 2 == 0:\n",
        "            x_list = sample(args.batch_size, args.num_actions)\n",
        "            loss_equiv, loss_barrier = loss_fn(enc, x_list, args.barrier_type, args.cosine_sim, args.conformal_map, args.decompositions)\n",
        "        else:\n",
        "            loss_equiv = 0\n",
        "            loss_barrier = 0\n",
        "            for action_type in [0, 1]:\n",
        "                x_list = sample(args.batch_size, args.num_actions // 2, action_type)\n",
        "                loss_equiv_cur, loss_barrier_cur = loss_fn(enc, x_list, args.barrier_type, args.cosine_sim, args.conformal_map, args.decompositions, action_type)\n",
        "                loss_equiv += loss_equiv_cur\n",
        "                loss_barrier += loss_barrier_cur\n",
        "            loss_equiv = loss_equiv / 2\n",
        "            loss_barrier = loss_barrier / 2\n",
        "            # x_list = sample(args.batch_size, args.num_actions, action_type=k % 2)\n",
        "            # loss_equiv, loss_barrier = loss_fn(enc, x_list, args.barrier_type, args.cosine_sim, args.conformal_map, args.decompositions, action_type=k % 2)\n",
        "        loss = loss_equiv + args.barrier_coef * loss_barrier\n",
        "        opt.zero_grad()\n",
        "        loss.backward()\n",
        "        opt.step()\n",
        "        # I multiply the losses by 1e4 to get nicer numbers\n",
        "        loss_list.append(1e4 * loss.item())\n",
        "        loss_equiv_list.append(1e4 * loss_equiv.item())\n",
        "        loss_barrier_list.append(1e4 * args.barrier_coef * loss_barrier.item())\n",
        "        progress.set_description('Loss: %12g | Loss Equiv: %12g | Loss barrier: %12g | L2 Weights: %12g | L2 Grads: %12g' % (\n",
        "            loss_list[-1],\n",
        "            loss_equiv_list[-1],\n",
        "            loss_barrier_list[-1],\n",
        "            get_weights_norm(enc.parameters(), norm_type=2.0),\n",
        "            get_grads_norm(enc.parameters(), norm_type=2.0)\n",
        "        ))\n",
        "    time_end = time.time()\n",
        "\n",
        "    torch.save(enc.state_dict(), path.join('saved_models', 'model_epoch_%03d.tar' % (t + 1)))\n",
        "\n",
        "    avg_loss = np.mean(loss_list)\n",
        "    avg_loss_equiv = np.mean(loss_equiv_list)\n",
        "    avg_loss_barrier = np.mean(loss_barrier_list)\n",
        "    avg_loss_list.append(avg_loss)\n",
        "    avg_loss_equiv_list.append(avg_loss_equiv)\n",
        "    avg_loss_barrier_list.append(avg_loss_barrier)\n",
        "    print('\\nEpoch %3d | Loss: %12g | Loss Equiv: %12g | Loss barrier: %12g | Time: %6.1f sec' % (\n",
        "        t + 1, avg_loss, avg_loss_equiv, avg_loss_barrier, time_end - time_start))\n",
        "\n",
        "    scheduler.step()\n",
        "\n",
        "    fig = plt.figure(figsize=(6 * args.decompositions, 6))\n",
        "    enc.eval()\n",
        "    for i in range(args.decompositions):\n",
        "        ax = fig.add_subplot(100 + (args.decompositions) * 10 + (i + 1), projection='3d' if args.proj_size == 3 else None)\n",
        "        visualize(ax, SubEncoder(enc, le=i * subcode_size, ri=(i + 1) * subcode_size), viz_dataset)\n",
        "    fig.savefig(path.join('figs_progress', 'codes_epoch_%03d.png' % (t + 1)), bbox_inches='tight', dpi=200)\n",
        "    plt.show()\n",
        "    plt.close()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "RDumfR4N_kPg"
      },
      "source": [
        "save_path = path.join('saved_models', 'model_final.tar')\n",
        "print('saving model to %s' % save_path)\n",
        "torch.save(enc.state_dict(), save_path)\n",
        "# enc.load_state_dict(torch.load('/content/symreg_mountaincar_2021-05-16 06 44 28.301811_epoch_100.tar'))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "W6jvbi3p_kVo"
      },
      "source": [
        "## Error Curve"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "d_WcS1xg_kaf"
      },
      "source": [
        "fig = plt.figure(figsize=(5, 5))\n",
        "plt.plot(avg_loss_list)\n",
        "plt.xlabel('Epoch')\n",
        "plt.ylabel('Loss')\n",
        "fig.savefig(path.join('plots', 'error_curve.png'), bbox_inches='tight', dpi=300)\n",
        "plt.show()\n",
        "plt.close()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oPJFA0mLNr48"
      },
      "source": [
        "## Visualize Progress"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3QY7b5HnNZwO"
      },
      "source": [
        "image_list = []\n",
        "files = sorted(glob.glob('figs_progress/*.png'))\n",
        "for file in files:\n",
        "    image_list.append(cv2.cvtColor(cv2.imread(file), cv2.COLOR_BGR2RGB))\n",
        "imageio.mimsave(path.join('gifs', 'codes_progress.gif'), image_list, fps=16)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "C1pm-Nt5O5A-"
      },
      "source": [
        "! ffmpeg -v quiet -r 16 -i ./figs_progress/codes_epoch_%03d.png -y videos/codes_progress.webm"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2pxdU9EZ_kgN"
      },
      "source": [
        "## Visualize Final Codes"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "g-myzU_3N8B2"
      },
      "source": [
        "fig = plt.figure(figsize=(8 * args.decompositions, 8))\n",
        "enc.eval()\n",
        "for i in range(args.decompositions):\n",
        "    ax = fig.add_subplot(100 + (args.decompositions) * 10 + (i + 1), projection='3d' if args.proj_size == 3 else None)\n",
        "    visualize(ax, SubEncoder(enc, le=i * subcode_size, ri=(i + 1) * subcode_size), viz_dataset, s=5)\n",
        "fig.savefig(path.join('plots', 'codes_final.png'), bbox_inches='tight', dpi=600)\n",
        "plt.show()\n",
        "plt.close()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sgHFyL32lXMK"
      },
      "source": [
        "! rm -rf figs_final\n",
        "! mkdir figs_final\n",
        "\n",
        "enc.eval()\n",
        "\n",
        "if args.proj_size == 3:\n",
        "    for t in tqdm(range(72)):\n",
        "        fig = plt.figure(figsize=(6 * args.decompositions, 6))\n",
        "        for i in range(args.decompositions):\n",
        "            ax = fig.add_subplot(100 + (args.decompositions) * 10 + (i + 1), projection='3d')\n",
        "            ax.axis('off')\n",
        "            visualize(ax, SubEncoder(enc, le=i * subcode_size, ri=(i + 1) * subcode_size), viz_dataset, azim=t * 5)\n",
        "        fig.savefig(path.join('figs_final', 'final_codes_angle_%03d.png' % t), bbox_inches='tight', dpi=200)\n",
        "        plt.close()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "iemaX8DzmbmF"
      },
      "source": [
        "if args.proj_size == 3:\n",
        "    image_list = []\n",
        "    files = sorted(glob.glob('figs_final/*.png'))\n",
        "    for file in files:\n",
        "        image_list.append(cv2.cvtColor(cv2.imread(file), cv2.COLOR_BGR2RGB))\n",
        "    imageio.mimsave(path.join('gifs', 'codes_final.gif'), image_list, fps=16)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qtVbIAb5PIkC"
      },
      "source": [
        "! ffmpeg -v quiet -r 16 -i ./figs_final/final_codes_angle_%03d.png -y videos/final_codes.webm"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9lVoqE2oFyh5"
      },
      "source": [
        "## Zip Results"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Y79gBkAARNvm"
      },
      "source": [
        "!zip -r9v $(basename \"$PWD\").zip figs_progress/ figs_final/ saved_models/ log.txt plots/ gifs/ videos/"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}