{"cells":[{"cell_type":"markdown","source":["The code is composed on Google Colab.\n","\n","This Python notebook is devoted to task 1: In-class transfer on MNIST dataset."],"metadata":{"id":"YTyZGtDbJlbd"}},{"cell_type":"code","execution_count":1,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"1MibEOp7kTRn","outputId":"6442969d-607a-4e97-9669-f57171d2576c","executionInfo":{"status":"ok","timestamp":1758776900827,"user_tz":240,"elapsed":13709,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"outputs":[{"output_type":"stream","name":"stdout","text":["Mounted at /content/drive\n"]}],"source":["from google.colab import drive\n","drive.mount('/content/drive')"]},{"cell_type":"code","source":["cd/content/drive/MyDrive/OTHJ_FashionMNIST_MNIST"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"j4LJQeUS8LYL","executionInfo":{"status":"ok","timestamp":1758776901702,"user_tz":240,"elapsed":865,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}},"outputId":"60f0077b-4a4d-4124-bf6c-4dfa21354904"},"execution_count":2,"outputs":[{"output_type":"stream","name":"stdout","text":["/content/drive/MyDrive/OTHJ_FashionMNIST_MNIST\n"]}]},{"cell_type":"code","execution_count":3,"metadata":{"cellView":"form","id":"cEiUihg03b8r","executionInfo":{"status":"ok","timestamp":1758776920783,"user_tz":240,"elapsed":18220,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"outputs":[],"source":["# @title imports\n","import os, sys\n","import numpy as np\n","import tensorflow as tf\n","import torch\n","import torchvision\n","from torchvision import transforms, datasets\n","import torchvision.models as models\n","from torchvision.transforms import Compose, Resize, Normalize, ToTensor\n","from torch.utils.data.sampler import SubsetRandomSampler\n","import matplotlib.pyplot as plt\n","import pickle\n","import torch.nn as nn\n","import torch.nn.functional as F\n","\n","import pandas as pd\n","import matplotlib.pyplot as plt\n","from matplotlib.colors import ListedColormap\n","import seaborn as sns\n","import random\n","from PIL import Image\n","\n","from utils.general import mkdir_ifnotexists\n","\n","# Keras - Deep Learning API\n","import keras\n","from keras.datasets import mnist\n","from keras.datasets import fashion_mnist\n","from keras.layers import (\n","    Conv2D, Conv2DTranspose,\n","    Input, Flatten, Dense,\n","    Lambda, Reshape\n",")\n","from keras.models import Model\n","from keras.callbacks import (\n","    EarlyStopping, ModelCheckpoint\n",")\n","from keras import backend as K\n","from sklearn.manifold import TSNE\n","from keras.metrics import MeanSquaredError\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","\n"]},{"cell_type":"code","source":["# @title setup basic parameter\n","\n","dir = os.getcwd()\n","device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","\n","Dim = 10\n","beta_recon = 100.\n","beta_KL = 0.1\n","model_and_data_savepath = 'model_data_d{}_beta100_01'.format(Dim)\n","model_name = os.path.join(model_and_data_savepath, 'best_model_dim_{}_mnist.h5'.format(Dim))\n","example_savepath = 'example_d{}_beta100_01'.format(Dim)"],"metadata":{"id":"VD-lNU7RlpJO","cellView":"form","executionInfo":{"status":"ok","timestamp":1758776965227,"user_tz":240,"elapsed":2,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"execution_count":10,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"L961ddlbCWcl"},"source":["\n","\n","---\n","\n","## Load the decoder-encoder, compute and save latent data\n","\n","---\n","\n"]},{"cell_type":"markdown","metadata":{"id":"9iQoJ28TCWcm"},"source":["\n","VAE (using cnn)\n","\n","\n","\n"]},{"cell_type":"code","execution_count":5,"metadata":{"id":"MFxWd7lfCWcm","cellView":"form","executionInfo":{"status":"ok","timestamp":1758776920846,"user_tz":240,"elapsed":31,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"outputs":[],"source":["# @title VAE Loss\n","\n","import keras\n","from keras import layers\n","from keras import backend as K\n","\n","class VAELossLayer(layers.Layer):\n","    \"\"\"\n","    Custom Keras layer that calculates the loss (reconstruction + KL divergence)\n","    of the Variational AutoEncoder (VAE)\n","    \"\"\"\n","\n","    def __init__(self, beta=100, **kwargs):\n","        super().__init__(**kwargs)\n","        self.beta = beta\n","\n","    def compute_output_shape(self, input_shape):\n","        return input_shape[0]  # or return ()\n","\n","    def calculate_loss(self, original_input, reconstructed_output, mu, sigma):\n","        \"\"\"\n","        Calculates VAE loss, which is the sum of the reconstruction loss and KL-divergence loss\n","        \"\"\"\n","        original_input = tf.keras.layers.Flatten(data_format='channels_last')(original_input)\n","\n","        reconstructed_output = tf.keras.layers.Flatten(data_format='channels_last')(reconstructed_output)\n","        print(original_input.shape) #(32, 784)\n","        print(reconstructed_output.shape) #(32, 784)\n","        reconstruction_loss = tf.reduce_sum(tf.square(original_input - reconstructed_output), axis=[1])\n","        reconstruction_loss = self.beta * reconstruction_loss\n","        kl_loss = - tf.reduce_mean(1 + sigma - tf.square(mu) - tf.exp(sigma), axis=-1)\n","        return tf.reduce_mean(reconstruction_loss + kl_loss)\n","\n","    def call(self, inputs):\n","        \"\"\"\n","        Computes the loss and adds it to the layer's losses\n","        \"\"\"\n","        original_input, reconstructed_output, mu, sigma = inputs\n","\n","        loss = self.calculate_loss(original_input, reconstructed_output, mu, sigma)\n","        self.add_loss(loss)\n","        return original_input\n","\n"]},{"cell_type":"code","execution_count":6,"metadata":{"cellView":"form","colab":{"base_uri":"https://localhost:8080/","height":740},"collapsed":true,"id":"EgvxXsOACWcm","outputId":"853397d7-430c-43b5-b327-5dc26a72852c","executionInfo":{"status":"ok","timestamp":1758776922821,"user_tz":240,"elapsed":1942,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"outputs":[{"output_type":"stream","name":"stdout","text":["Output shape of encoder: (None, 64)\n"]},{"output_type":"display_data","data":{"text/plain":["\u001b[1mModel: \"encoder_model\"\u001b[0m\n"],"text/html":["<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"encoder_model\"</span>\n","</pre>\n"]},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n","┃\u001b[1m \u001b[0m\u001b[1mLayer (type)       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape     \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m   Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to     \u001b[0m\u001b[1m \u001b[0m┃\n","┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n","│ encoder_input_layer │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m1\u001b[0m) │          \u001b[38;5;34m0\u001b[0m │ -                 │\n","│ (\u001b[38;5;33mInputLayer\u001b[0m)        │                   │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d (\u001b[38;5;33mConv2D\u001b[0m)     │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m28\u001b[0m,    │      \u001b[38;5;34m3,328\u001b[0m │ encoder_input_la… │\n","│                     │ \u001b[38;5;34m128\u001b[0m)              │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d_1 (\u001b[38;5;33mConv2D\u001b[0m)   │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m28\u001b[0m,    │    \u001b[38;5;34m147,584\u001b[0m │ conv2d[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m]      │\n","│                     │ \u001b[38;5;34m128\u001b[0m)              │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d_2 (\u001b[38;5;33mConv2D\u001b[0m)   │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m14\u001b[0m,    │     \u001b[38;5;34m73,792\u001b[0m │ conv2d_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m]    │\n","│                     │ \u001b[38;5;34m64\u001b[0m)               │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d_3 (\u001b[38;5;33mConv2D\u001b[0m)   │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m14\u001b[0m,    │     \u001b[38;5;34m36,928\u001b[0m │ conv2d_2[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m]    │\n","│                     │ \u001b[38;5;34m64\u001b[0m)               │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d_4 (\u001b[38;5;33mConv2D\u001b[0m)   │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m14\u001b[0m,    │     \u001b[38;5;34m36,928\u001b[0m │ conv2d_3[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m]    │\n","│                     │ \u001b[38;5;34m64\u001b[0m)               │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d_5 (\u001b[38;5;33mConv2D\u001b[0m)   │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m14\u001b[0m,    │     \u001b[38;5;34m36,928\u001b[0m │ conv2d_4[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m]    │\n","│                     │ \u001b[38;5;34m64\u001b[0m)               │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d_6 (\u001b[38;5;33mConv2D\u001b[0m)   │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m14\u001b[0m,    │     \u001b[38;5;34m36,928\u001b[0m │ conv2d_5[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m]    │\n","│                     │ \u001b[38;5;34m64\u001b[0m)               │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ flatten (\u001b[38;5;33mFlatten\u001b[0m)   │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12544\u001b[0m)     │          \u001b[38;5;34m0\u001b[0m │ conv2d_6[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m]    │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ dense (\u001b[38;5;33mDense\u001b[0m)       │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m)        │    \u001b[38;5;34m802,880\u001b[0m │ flatten[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m]     │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ latent_mu (\u001b[38;5;33mDense\u001b[0m)   │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m10\u001b[0m)        │        \u001b[38;5;34m650\u001b[0m │ dense[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m]       │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ latent_sigma        │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m10\u001b[0m)        │        \u001b[38;5;34m650\u001b[0m │ dense[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m]       │\n","│ (\u001b[38;5;33mDense\u001b[0m)             │                   │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ z (\u001b[38;5;33mLambda\u001b[0m)          │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m10\u001b[0m)        │          \u001b[38;5;34m0\u001b[0m │ latent_mu[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m],  │\n","│                     │                   │            │ latent_sigma[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n","└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n"],"text/html":["<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n","┃<span style=\"font-weight: bold\"> Layer (type)        </span>┃<span style=\"font-weight: bold\"> Output Shape      </span>┃<span style=\"font-weight: bold\">    Param # </span>┃<span style=\"font-weight: bold\"> Connected to      </span>┃\n","┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n","│ encoder_input_layer │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">1</span>) │          <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ -                 │\n","│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>)        │                   │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>)     │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>,    │      <span style=\"color: #00af00; text-decoration-color: #00af00\">3,328</span> │ encoder_input_la… │\n","│                     │ <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>)              │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d_1 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>)   │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>,    │    <span style=\"color: #00af00; text-decoration-color: #00af00\">147,584</span> │ conv2d[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]      │\n","│                     │ <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>)              │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d_2 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>)   │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>,    │     <span style=\"color: #00af00; text-decoration-color: #00af00\">73,792</span> │ conv2d_1[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]    │\n","│                     │ <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)               │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d_3 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>)   │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>,    │     <span style=\"color: #00af00; text-decoration-color: #00af00\">36,928</span> │ conv2d_2[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]    │\n","│                     │ <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)               │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d_4 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>)   │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>,    │     <span style=\"color: #00af00; text-decoration-color: #00af00\">36,928</span> │ conv2d_3[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]    │\n","│                     │ <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)               │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d_5 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>)   │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>,    │     <span style=\"color: #00af00; text-decoration-color: #00af00\">36,928</span> │ conv2d_4[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]    │\n","│                     │ <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)               │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d_6 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>)   │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>,    │     <span style=\"color: #00af00; text-decoration-color: #00af00\">36,928</span> │ conv2d_5[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]    │\n","│                     │ <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)               │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ flatten (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Flatten</span>)   │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">12544</span>)     │          <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ conv2d_6[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]    │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ dense (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>)       │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)        │    <span style=\"color: #00af00; text-decoration-color: #00af00\">802,880</span> │ flatten[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]     │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ latent_mu (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>)   │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">10</span>)        │        <span style=\"color: #00af00; text-decoration-color: #00af00\">650</span> │ dense[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]       │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ latent_sigma        │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">10</span>)        │        <span style=\"color: #00af00; text-decoration-color: #00af00\">650</span> │ dense[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]       │\n","│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>)             │                   │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ z (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Lambda</span>)          │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">10</span>)        │          <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ latent_mu[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>],  │\n","│                     │                   │            │ latent_sigma[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n","└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n","</pre>\n"]},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["\u001b[1m Total params: \u001b[0m\u001b[38;5;34m1,176,596\u001b[0m (4.49 MB)\n"],"text/html":["<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">1,176,596</span> (4.49 MB)\n","</pre>\n"]},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m1,176,596\u001b[0m (4.49 MB)\n"],"text/html":["<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">1,176,596</span> (4.49 MB)\n","</pre>\n"]},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"],"text/html":["<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n","</pre>\n"]},"metadata":{}},{"output_type":"stream","name":"stdout","text":["None\n"]}],"source":["# @title Encoder\n","latent_space_dim = Dim #  50\n","\n","input_dimensions = (28, 28, 1)\n","\n","# Input Layer\n","# Defines the shape of the input data for the encoder.\n","encoder_input_layer = Input(shape=input_dimensions, name='encoder_input_layer')\n","\n","# Convolution Layers\n","# Applies convolution operations to extract features from the input image.\n","encoder_layer = Conv2D(128, 5, padding='same', activation='relu')(encoder_input_layer)\n","encoder_layer = Conv2D(128, 3, padding='same', activation='relu')(encoder_layer)\n","encoder_layer = Conv2D(64, 3, padding='same', strides=2, activation='relu')(encoder_layer)\n","encoder_layer = Conv2D(64, 3, padding='same', activation='relu')(encoder_layer)\n","encoder_layer = Conv2D(64, 3, padding='same', activation='relu')(encoder_layer)\n","encoder_layer = Conv2D(64, 3, padding='same', activation='relu')(encoder_layer)\n","encoder_layer = Conv2D(64, 3, padding='same', activation='relu')(encoder_layer)\n","\n","enc_shape = tf.keras.backend.int_shape(encoder_layer)\n","\n","# Flattening Layer\n","# Converts the 3D output of convolution layers into a 1D tensor for dense layers.\n","encoder_layer = Flatten()(encoder_layer)\n","\n","# Dense Layer\n","# A fully connected layer that combines extracted features and performs further learning.\n","encoder_layer = Dense(64, activation='relu')(encoder_layer)\n","\n","# Output Layer for Encoder\n","# Prepares the encoded data for transition into the latent space, approximating a probability distribution.\n","mu_layer = Dense(latent_space_dim, name='latent_mu')(encoder_layer)\n","sigma_layer = Dense(latent_space_dim, name='latent_sigma')(encoder_layer)\n","\n","# Storing the output shape for use in the decoder\n","encoder_output_shape = (None, 64) # K.int_shape(encoder_layer)\n","print(f\"Output shape of encoder: {encoder_output_shape}\")\n","\n","\n","### Implementing the Reparameterization Trick\n","def sample_z(args):\n","    \"\"\"\n","    Generate a sample from the Gaussian distribution defined by args=(mu, sigma).\n","\n","    Args:\n","    mu_layer:    The mean of the Gaussian distribution.\n","    sigma_layer: The log standard deviation of the Gaussian distribution.\n","\n","    Returns:\n","    A sample from the Gaussian distribution.\n","    \"\"\"\n","    mu_layer, sigma_layer = args\n","    batch_size = tf.shape(mu_layer)[0]\n","    dim = tf.shape(mu_layer)[1]\n","\n","    # Generate a random sample from a standard normal distribution with the same shape\n","    # epsilon = K.random_normal(shape=(batch_size, dim)).\n","    epsilon = tf.random.normal(shape=(batch_size, dim))\n","\n","    # Scale and shift the sample by mu and sigma\n","    return mu_layer + tf.exp(sigma_layer / 2) * epsilon\n","\n","# Creating the 'z' layer using the Lambda layer to apply the reparameterization trick\n","z = Lambda(sample_z, output_shape=(latent_space_dim,), name='z')([mu_layer, sigma_layer])\n","\n","# Building the encoder model\n","encoder_model = Model(encoder_input_layer, [mu_layer, sigma_layer, z], name='encoder_model')\n","\n","# Display the model summary\n","print(encoder_model.summary())\n","\n"]},{"cell_type":"code","execution_count":7,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":625},"collapsed":true,"id":"ZY3QQRABCWcm","outputId":"419713ca-857c-4151-c18e-fcc10d966507","cellView":"form","executionInfo":{"status":"ok","timestamp":1758776922913,"user_tz":240,"elapsed":91,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"outputs":[{"output_type":"display_data","data":{"text/plain":["\u001b[1mModel: \"decoder_model\"\u001b[0m\n"],"text/html":["<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"decoder_model\"</span>\n","</pre>\n"]},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n","┃\u001b[1m \u001b[0m\u001b[1mLayer (type)                   \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape          \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      Param #\u001b[0m\u001b[1m \u001b[0m┃\n","┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n","│ decoder_input_layer             │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m10\u001b[0m)             │             \u001b[38;5;34m0\u001b[0m │\n","│ (\u001b[38;5;33mInputLayer\u001b[0m)                    │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ dense_1 (\u001b[38;5;33mDense\u001b[0m)                 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12544\u001b[0m)          │       \u001b[38;5;34m137,984\u001b[0m │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ reshape (\u001b[38;5;33mReshape\u001b[0m)               │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m64\u001b[0m)     │             \u001b[38;5;34m0\u001b[0m │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose                │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m64\u001b[0m)     │        \u001b[38;5;34m36,928\u001b[0m │\n","│ (\u001b[38;5;33mConv2DTranspose\u001b[0m)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_1              │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m64\u001b[0m)     │        \u001b[38;5;34m36,928\u001b[0m │\n","│ (\u001b[38;5;33mConv2DTranspose\u001b[0m)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_2              │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m64\u001b[0m)     │        \u001b[38;5;34m36,928\u001b[0m │\n","│ (\u001b[38;5;33mConv2DTranspose\u001b[0m)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_3              │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m64\u001b[0m)     │        \u001b[38;5;34m36,928\u001b[0m │\n","│ (\u001b[38;5;33mConv2DTranspose\u001b[0m)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_4              │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m64\u001b[0m)     │        \u001b[38;5;34m36,928\u001b[0m │\n","│ (\u001b[38;5;33mConv2DTranspose\u001b[0m)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_5              │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m128\u001b[0m)    │        \u001b[38;5;34m73,856\u001b[0m │\n","│ (\u001b[38;5;33mConv2DTranspose\u001b[0m)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_6              │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m128\u001b[0m)    │       \u001b[38;5;34m409,728\u001b[0m │\n","│ (\u001b[38;5;33mConv2DTranspose\u001b[0m)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_7              │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m1\u001b[0m)      │         \u001b[38;5;34m3,201\u001b[0m │\n","│ (\u001b[38;5;33mConv2DTranspose\u001b[0m)               │                        │               │\n","└─────────────────────────────────┴────────────────────────┴───────────────┘\n"],"text/html":["<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n","┃<span style=\"font-weight: bold\"> Layer (type)                    </span>┃<span style=\"font-weight: bold\"> Output Shape           </span>┃<span style=\"font-weight: bold\">       Param # </span>┃\n","┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n","│ decoder_input_layer             │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">10</span>)             │             <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │\n","│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>)                    │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ dense_1 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>)                 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">12544</span>)          │       <span style=\"color: #00af00; text-decoration-color: #00af00\">137,984</span> │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ reshape (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Reshape</span>)               │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)     │             <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose                │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)     │        <span style=\"color: #00af00; text-decoration-color: #00af00\">36,928</span> │\n","│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2DTranspose</span>)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_1              │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)     │        <span style=\"color: #00af00; text-decoration-color: #00af00\">36,928</span> │\n","│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2DTranspose</span>)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_2              │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)     │        <span style=\"color: #00af00; text-decoration-color: #00af00\">36,928</span> │\n","│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2DTranspose</span>)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_3              │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)     │        <span style=\"color: #00af00; text-decoration-color: #00af00\">36,928</span> │\n","│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2DTranspose</span>)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_4              │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)     │        <span style=\"color: #00af00; text-decoration-color: #00af00\">36,928</span> │\n","│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2DTranspose</span>)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_5              │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>)    │        <span style=\"color: #00af00; text-decoration-color: #00af00\">73,856</span> │\n","│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2DTranspose</span>)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_6              │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>)    │       <span style=\"color: #00af00; text-decoration-color: #00af00\">409,728</span> │\n","│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2DTranspose</span>)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_7              │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">1</span>)      │         <span style=\"color: #00af00; text-decoration-color: #00af00\">3,201</span> │\n","│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2DTranspose</span>)               │                        │               │\n","└─────────────────────────────────┴────────────────────────┴───────────────┘\n","</pre>\n"]},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["\u001b[1m Total params: \u001b[0m\u001b[38;5;34m809,409\u001b[0m (3.09 MB)\n"],"text/html":["<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">809,409</span> (3.09 MB)\n","</pre>\n"]},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m809,409\u001b[0m (3.09 MB)\n"],"text/html":["<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">809,409</span> (3.09 MB)\n","</pre>\n"]},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"],"text/html":["<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n","</pre>\n"]},"metadata":{}}],"source":["# @title Decoder\n","\n","# Decoder Input Layer\n","decoder_input_layer = Input(shape=(latent_space_dim,), name='decoder_input_layer')\n","\n","# Initial Dense Layer\n","# The number of units is derived from the last convolutional layer of the encoder\n","num_units = np.prod(enc_shape[1:]) # 14 * 14 * 64\n","decoder_dense_layer = Dense(num_units, activation='relu')(decoder_input_layer)\n","\n","# Reshaping Layer\n","# The dense layer's output is reshaped to match the last convolutional layer's output shape in the encoder\n","reshape_dims = (14, 14, 64)\n","decoder_layer = Reshape(reshape_dims)(decoder_dense_layer)\n","\n","decoder_layer = Conv2DTranspose(64, 3, padding='same', activation='relu')(decoder_layer)\n","decoder_layer = Conv2DTranspose(64, 3, padding='same', activation='relu')(decoder_layer)\n","decoder_layer = Conv2DTranspose(64, 3, padding='same', activation='relu')(decoder_layer)\n","decoder_layer = Conv2DTranspose(64, 3, padding='same', activation='relu')(decoder_layer)\n","decoder_layer = Conv2DTranspose(64, 3, strides=2, padding='same', activation='relu')(decoder_layer)\n","decoder_layer = Conv2DTranspose(128, 3, padding='same', activation='relu')(decoder_layer)\n","decoder_layer = Conv2DTranspose(128, 5, padding='same', activation='relu')(decoder_layer)\n","decoder_out_layer = Conv2DTranspose(1, 5, padding='same', activation='relu')(decoder_layer)\n","\n","# Building the Decoder Model\n","decoder_model = Model(decoder_input_layer, decoder_out_layer, name='decoder_model')\n","\n","decoder_model.summary()\n"]},{"cell_type":"code","execution_count":8,"metadata":{"cellView":"form","colab":{"base_uri":"https://localhost:8080/"},"collapsed":true,"id":"0sJfV7oPCWcm","outputId":"53a369d8-95ba-4837-ee86-55d2758a005f","executionInfo":{"status":"ok","timestamp":1758776922929,"user_tz":240,"elapsed":15,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"outputs":[{"output_type":"stream","name":"stdout","text":["The type  of 'reconstructed_output' is: <class 'keras.src.backend.common.keras_tensor.KerasTensor'>\n","The shape of 'reconstructed_output' is: (None, 28, 28, 1)\n","The type of 'vae_loss_output' is: <class 'keras.src.backend.common.keras_tensor.KerasTensor'>\n","The shape of 'vae_loss_output' is: (None, 28, 28, 1)\n"]}],"source":["# @title set VAE model\n","from tensorflow.keras.models import load_model\n","\n","\n","# Reconstructed Output from Decoder\n","reconstructed_output = decoder_model(z)\n","\n","reconstructed_model = Model(encoder_input_layer, reconstructed_output, name='reconstructed_model')\n","\n","print(f\"The type  of 'reconstructed_output' is: {type(reconstructed_output)}\")\n","print(f\"The shape of 'reconstructed_output' is: {reconstructed_output.shape}\")\n","\n","\n","# Adding the loss computation layer to the model\n","vae_loss_output = VAELossLayer()([encoder_input_layer, reconstructed_output, mu_layer, sigma_layer])\n","\n","print(f\"The type of 'vae_loss_output' is: {type(vae_loss_output)}\")\n","print(f\"The shape of 'vae_loss_output' is: {vae_loss_output.shape}\")\n","\n","\n","# loaded_model = keras.saving.load_model(\"best_model_dim_15.h5\")\n","final_vae_model = Model(encoder_input_layer, vae_loss_output, name='vae_model')\n","\n"]},{"cell_type":"code","execution_count":11,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"collapsed":true,"id":"9BbmQa7eCWcn","outputId":"b07ee446-30da-4057-d6c3-1f5adda4d6ac","cellView":"form","executionInfo":{"status":"ok","timestamp":1758776969834,"user_tz":240,"elapsed":1092,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"outputs":[{"output_type":"stream","name":"stderr","text":["WARNING:absl:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n"]}],"source":["# @title load trained encoder and decoder\n","\n","from tensorflow.keras.models import load_model\n","\n","# loaded_model = keras.saving.load_model(\"best_model_dim_15.h5\")\n","trained_model = load_model(model_name, custom_objects={\"sample_z\": sample_z, \"VAELossLayer\": VAELossLayer})\n","\n","encoder_input = trained_model.input\n","\n","\n","# Encoder_mu\n","# Assume 'encoder_input' is the first input and 'z_mean' is an intermediate layer\n","encoder_input = trained_model.input\n","mu_layer = trained_model.get_layer('latent_mu')\n","mu_output = mu_layer.output\n","# Recreate the encoder model\n","encoder_mu = Model(inputs=encoder_input, outputs=mu_output)\n","\n","\n","# Encoder_sigma\n","# Assume 'encoder_input' is the first input and 'z_mean' is an intermediate layer\n","encoder_input = trained_model.input\n","sigma_layer = trained_model.get_layer('latent_sigma')\n","sigma_output = sigma_layer.output\n","# Recreate the encoder model\n","encoder_sigma = Model(inputs=encoder_input, outputs=sigma_output)\n","\n","\n","# Decoder\n","trained_decoder = trained_model.get_layer('decoder_model')\n","\n"]},{"cell_type":"code","execution_count":12,"metadata":{"collapsed":true,"id":"8qxGz_C5zkJb","colab":{"base_uri":"https://localhost:8080/"},"outputId":"dad6c3d5-7fb0-4a73-be08-89cb47ff7159","cellView":"form","executionInfo":{"status":"ok","timestamp":1758776998717,"user_tz":240,"elapsed":26033,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"outputs":[{"output_type":"stream","name":"stdout","text":["Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz\n","\u001b[1m11490434/11490434\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 0us/step\n","\u001b[1m216/216\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 5ms/step\n","\u001b[1m216/216\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 4ms/step\n","\u001b[1m247/247\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 4ms/step\n","\u001b[1m247/247\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step\n","\u001b[1m219/219\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 4ms/step\n","\u001b[1m219/219\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step\n","\u001b[1m224/224\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 2ms/step\n","\u001b[1m224/224\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 2ms/step\n","\u001b[1m214/214\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 5ms/step\n","\u001b[1m214/214\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 4ms/step\n","\u001b[1m198/198\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 4ms/step\n","\u001b[1m198/198\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 4ms/step\n","\u001b[1m215/215\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 4ms/step\n","\u001b[1m215/215\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 4ms/step\n","\u001b[1m228/228\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 4ms/step\n","\u001b[1m228/228\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step\n","\u001b[1m214/214\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step\n","\u001b[1m214/214\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 2ms/step\n","\u001b[1m218/218\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 2ms/step\n","\u001b[1m218/218\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 2ms/step\n"]}],"source":["# @title Load MNIST dataset & encode\n","\n","\n","# load the fashion mnist data\n","(x_train, y_train), (x_test, y_test) = mnist.load_data()\n","concatenated_x = tf.concat([x_train, x_test], axis=0)\n","concatenated_y = tf.concat([y_train, y_test], axis=0)\n","\n","list_of_images = []\n","\n","for i in range(10):\n","  list_of_images.append(concatenated_x[concatenated_y == i]) # (7000, 28, 28)\n","\n","for i in range(10):\n","    data = list_of_images[i].numpy()/255.\n","    data = tf.reshape(data, (list_of_images[i].shape[0], 28, 28, 1))\n","    encoded_mu_all = encoder_mu.predict(data)\n","    encoded_sigma_all = encoder_sigma.predict(data)\n","    epsilon_all = tf.random.normal(shape=(data.shape[0], latent_space_dim))\n","    encoded_images = encoded_mu_all + tf.exp(encoded_sigma_all / 2) * epsilon_all\n","    numpy_encoded_images = encoded_images.numpy()\n","    save_path = os.path.join(model_and_data_savepath, 'encoded_image_{}_mnist.npy'.format(i))\n","    with open(save_path, 'wb') as f:\n","        np.save(f, numpy_encoded_images)\n","\n"]},{"cell_type":"code","execution_count":13,"metadata":{"id":"QxDHnElk4HV6","cellView":"form","executionInfo":{"status":"ok","timestamp":1758776998795,"user_tz":240,"elapsed":77,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"outputs":[],"source":["# @title Load all digits\n","\n","\n","device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n","\n","tensor_encoded_list=[]\n","for i in range(10):\n","    encoded_data_path = os.path.join(model_and_data_savepath, 'encoded_image_{}_mnist.npy'.format(i))\n","    encoded_x = np.load(encoded_data_path)\n","    tensor_encoded_list.append(torch.tensor(encoded_x).to(device))\n"]},{"cell_type":"markdown","metadata":{"id":"dyDUTgV2wfjG"},"source":["\n","\n","---\n","## check accuracy on 28 $\\times$ 28 images\n","\n","\n","---\n","\n","\n"]},{"cell_type":"code","execution_count":14,"metadata":{"id":"h-820URmwlWe","executionInfo":{"status":"ok","timestamp":1758776999698,"user_tz":240,"elapsed":901,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"outputs":[],"source":["import tensorflow as tf\n","# Keras - Deep Learning API\n","import keras                                # High-level neural networks API\n","from keras.datasets import mnist            # MNIST dataset of hand-written digits\n","from keras.layers import (                  # Neural network layers\n","    Conv2D, Conv2DTranspose,\n","    Input, Flatten, Dense,\n","    Lambda, Reshape\n",")\n","from keras.models import Model              # Model definition and training\n","from keras.callbacks import (               # Training callbacks\n","    EarlyStopping, ModelCheckpoint\n",")\n","from keras import backend as K\n","from sklearn.manifold import TSNE\n","\n","import sys\n","sys.path.append('/Volumes/D/GitHub-Portfolio/DeepLearning-MNIST-VAE/src/')\n","\n","# Metrics\n","from keras.metrics import MeanSquaredError\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","\n","# load the fashion mnist data\n","(x_train, y_train), (x_test, y_test) = mnist.load_data()\n","assert x_train.shape == (60000, 28, 28)\n","assert x_test.shape == (10000, 28, 28)\n","assert y_train.shape == (60000,)\n","assert y_test.shape == (10000,)\n","\n","concatenated_x = tf.concat([x_train, x_test], axis=0)\n","concatenated_y = tf.concat([y_train, y_test], axis=0)\n","\n","list_of_images = []\n","for i in range(10):\n","  list_of_images.append(concatenated_x[concatenated_y == i]) # (7000, 28, 28)"]},{"cell_type":"code","execution_count":15,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"mg_24V7Ip3Q3","outputId":"745b8fdb-7e82-406e-b1e1-421b3e5963db","executionInfo":{"status":"ok","timestamp":1758777008720,"user_tz":240,"elapsed":9021,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"outputs":[{"output_type":"execute_result","data":{"text/plain":["<All keys matched successfully>"]},"metadata":{},"execution_count":15}],"source":["model = models.resnet18()\n","\n","model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n","model.fc =  nn.Linear(in_features=512, out_features=10, bias=True)\n","\n","save_path = os.getcwd()\n","file_name = os.path.join(save_path, 'model_classifier_28_by_28_MNIST.pt')\n","state_dict = torch.load(file_name)\n","model.load_state_dict(state_dict, strict=False)"]},{"cell_type":"code","execution_count":16,"metadata":{"id":"LLAqzPtQwnzQ","executionInfo":{"status":"ok","timestamp":1758777008723,"user_tz":240,"elapsed":4,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"outputs":[],"source":["def imshow(img):\n","    img = img / 2 + 0.5\n","    npimg = img.numpy()\n","    plt.imshow(np.transpose(npimg, (1, 2, 0)))\n","    plt.show()"]},{"cell_type":"code","execution_count":17,"metadata":{"id":"FkwLPqxywwXv","executionInfo":{"status":"ok","timestamp":1758777008726,"user_tz":240,"elapsed":1,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"outputs":[],"source":["def test_accuracy_fashion_MNIST(classifier, list_of_imgs, target_digit):\n","  with torch.no_grad():\n","    correct = 0\n","    output = classifier(list_of_imgs)\n","    _, predicted = torch.max(output, 1)\n","    total = list_of_imgs.size()[0]\n","    correct += (predicted == target_digit).sum().item()\n","    accuracy_rate = 100 * correct / total\n","  return accuracy_rate, predicted"]},{"cell_type":"markdown","metadata":{"id":"_WmQYYP-gcpj"},"source":["\n","\n","---\n","## preparing normalized latent samples\n","\n","\n","---\n","\n","\n"]},{"cell_type":"code","execution_count":18,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"collapsed":true,"id":"Jj3NHr-NVv43","outputId":"6555ed17-465f-4e6a-caf1-a2c2c776c8da","executionInfo":{"status":"ok","timestamp":1758777009068,"user_tz":240,"elapsed":341,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}},"cellView":"form"},"outputs":[{"output_type":"stream","name":"stdout","text":["effective dimension = 10\n","effective dimension using std = 10\n"]},{"output_type":"stream","name":"stderr","text":["/tmp/ipython-input-1238408412.py:25: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n","  U_reduced = torch.tensor(U[:, idx]).to(device)\n","/tmp/ipython-input-1238408412.py:26: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n","  V_reduced = torch.tensor(V[idx, :]).to(device)\n"]}],"source":["# @title compute mean value, SVD of the covariance matrix for the latent data\n","import pickle\n","\n","dim_latent = Dim\n","\n","encoded_x_tot = tensor_encoded_list[0]\n","for j in range(1, 10):\n","    encoded_x_tot = torch.cat((encoded_x_tot, tensor_encoded_list[j]), 0).to(device)\n","\n","encoded_x_tot_2 = tensor_encoded_list[5]\n","for j in range(6, 10):\n","    encoded_x_tot_2 = torch.cat((encoded_x_tot_2, tensor_encoded_list[j]), 0).to(device)\n","\n","encoded_x_tot_1 = tensor_encoded_list[0]\n","for j in range(0, 4):\n","    encoded_x_tot_1 = torch.cat((encoded_x_tot_1, tensor_encoded_list[j]), 0).to(device)\n","\n","\n","mean_encoded_x_tot = torch.mean(encoded_x_tot, dim=0, keepdims=True).to(device)\n","cov_encoded_x_tot = torch.matmul((encoded_x_tot - mean_encoded_x_tot).T, (encoded_x_tot - mean_encoded_x_tot)) / encoded_x_tot.size(0)\n","U, s, V = np.linalg.svd(cov_encoded_x_tot.cpu().detach().numpy(), full_matrices=True)\n","U = torch.tensor(U, dtype=torch.float32)\n","V = torch.tensor(V, dtype=torch.float32)\n","idx = s > 1e-4\n","U_reduced = torch.tensor(U[:, idx]).to(device)\n","V_reduced = torch.tensor(V[idx, :]).to(device)\n","s_reduced = torch.tensor(s[idx]).to(device)\n","\n","sqr = encoded_x_tot ** 2\n","std = torch.sqrt(torch.mean(sqr, dim=0, keepdims=True) - mean_encoded_x_tot ** 2)  # Trace of Covariance matrix\n","idx_std = std[0].cpu().numpy() > 0\n","\n","with open(os.path.join(model_and_data_savepath, 'SVD_U.pkl'), 'wb') as file:\n","    pickle.dump(U_reduced.cpu().detach().numpy(), file)\n","with open(os.path.join(model_and_data_savepath, 'SVD_V.pkl'), 'wb') as file:\n","    pickle.dump(V_reduced.cpu().detach().numpy(), file)\n","with open(os.path.join(model_and_data_savepath, 'SVD_s.pkl'), 'wb') as file:\n","    pickle.dump(s_reduced.cpu().detach().numpy(), file)\n","with open(os.path.join(model_and_data_savepath, 'mean_encoded_x_all.pkl'), 'wb') as file:\n","    pickle.dump(mean_encoded_x_tot.cpu().detach().numpy(), file)\n","\n","with open(os.path.join(model_and_data_savepath, 'std.pkl'), 'wb') as file:\n","    pickle.dump(std.cpu().detach().numpy(), file)\n","\n","effective_dim= sum(idx) # 15\n","print(\"effective dimension = {}\".format(effective_dim))\n","\n","effective_dim_using_std = sum(idx_std) # 15\n","print(\"effective dimension using std = {}\".format(effective_dim_using_std))\n","\n"]},{"cell_type":"code","source":["plt.plot(np.log(s)/np.log(10))\n","plt.savefig(model_and_data_savepath + '/singular_value.pdf')\n","plt.close()\n","print(s)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"rRrsqrgh8mhP","outputId":"fc5aced8-a9c7-4e91-d387-e26f347bd62d","executionInfo":{"status":"ok","timestamp":1758777010256,"user_tz":240,"elapsed":1187,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"execution_count":19,"outputs":[{"output_type":"stream","name":"stdout","text":["[10.867341   9.965664   8.227304   6.45979    5.206647   3.8749127\n","  3.2000525  2.4458566  2.1483896  1.31872  ]\n"]}]},{"cell_type":"code","execution_count":20,"metadata":{"cellView":"form","executionInfo":{"status":"ok","timestamp":1758777010257,"user_tz":240,"elapsed":6,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}},"id":"0CGmk186LDjZ"},"outputs":[],"source":["# @title overview of latent samples (func)\n","import numpy as np\n","import pickle\n","\n","\n","def normalize_latent_samples(digit, savepath=model_and_data_savepath, dataset='mnist'):\n","\n","    if digit == \"first_list\":\n","        encoded_x = encoded_x_tot_1.to(device)\n","    elif digit == \"second_list\":\n","        encoded_x = encoded_x_tot_2.to(device)\n","    else:\n","        encoded_x = np.load(os.path.join(savepath, 'encoded_image_{}_{}.npy').format(digit, dataset))\n","        encoded_x = torch.tensor(encoded_x).to(device)\n","    # normalize:\n","    encoded_x = torch.matmul((encoded_x - mean_encoded_x_tot), U_reduced)/torch.sqrt(s_reduced)\n","\n","    return encoded_x\n","\n","\n","norm_eps = 1e-6\n","def normalize_latent_samples_using_std(digit, savepath=model_and_data_savepath, dataset='mnist', load_test=False):\n","\n","    if digit == \"first_list\":\n","        encoded_x = encoded_x_tot_1.to(device)\n","    elif digit == \"second_list\":\n","        encoded_x = encoded_x_tot_2.to(device)\n","    elif load_test:\n","        encoded_x = np.load(os.path.join(savepath, 'encoded_image_{}_for_test_{}.npy').format(digit, dataset))\n","        encoded_x = torch.tensor(encoded_x).to(device)\n","    else:\n","        encoded_x = np.load(os.path.join(savepath, 'encoded_image_{}_{}.npy').format(digit, dataset))\n","        encoded_x = torch.tensor(encoded_x).to(device)\n","    # normalized:\n","    normalized_encoded_x = (encoded_x[:, idx_std] - mean_encoded_x_tot[:, idx_std])/(std[0, idx_std]) # use Trace of Cov of data to do normalization\n","\n","    return normalized_encoded_x\n","\n","\n","def recover_latent_samples_using_PCA(samples):  # samples:  N * d (d = reduced dimension 46)\n","\n","    recovered_samples = torch.matmul(samples * torch.sqrt(s_reduced), U_reduced.T) + mean_encoded_x_tot\n","\n","    return recovered_samples\n","\n","\n","def recover_latent_samples_using_std(samples):  # samples:  N * d (d = reduced dimension 46)\n","\n","    recovered_samples = torch.zeros(samples.size()[0], dim_latent).to(device)\n","    i = 0\n","    for d in range(dim_latent):\n","      if idx_std[d]:\n","        recovered_samples[:, d] = samples[:, i] * std[0, d] + mean_encoded_x_tot[0, d]\n","        i+=1\n","      else:\n","        recovered_samples[:, d] = mean_encoded_x_tot[0, d] * torch.ones(samples.size()[0]).to(mean_encoded_x_tot.device)\n","\n","    return recovered_samples\n","\n"]},{"cell_type":"code","execution_count":21,"metadata":{"collapsed":true,"id":"Y80G7m-h07Pa","cellView":"form","executionInfo":{"status":"ok","timestamp":1758777010347,"user_tz":240,"elapsed":89,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"outputs":[],"source":["# @title prepare data using SVD normalization\n","\n","data_firsthalf = normalize_latent_samples(\"first_list\")\n","init_data_gpu = data_firsthalf.to(device)\n","mean_init_data = torch.mean(init_data_gpu, dim=0).to(device)\n","cov_init_data = torch.matmul((init_data_gpu - mean_init_data).T, (init_data_gpu - mean_init_data)) / init_data_gpu.size()[0]\n","cov_init_data = cov_init_data.to(device)\n","\n","data_secondhalf = normalize_latent_samples(\"second_list\")\n","target_data_gpu = data_secondhalf.to(device)\n","mean_target_data = torch.mean(target_data_gpu, dim=0).to(device)\n","cov_target_data = torch.matmul((target_data_gpu - mean_target_data).T, (target_data_gpu - mean_target_data)) / target_data_gpu.size()[0]\n","cov_target_data = cov_target_data.to(device)\n","\n","\n","target_data_9_svd = normalize_latent_samples('9')[:, :].to(device)\n","target_data_8_svd = normalize_latent_samples('8')[:, :].to(device)\n","target_data_7_svd = normalize_latent_samples('7')[:, :].to(device)\n","target_data_6_svd = normalize_latent_samples('6')[:, :].to(device)\n","target_data_5_svd = normalize_latent_samples('5')[:, :].to(device)\n","target_data_list_svd = [target_data_5_svd, target_data_6_svd, target_data_7_svd, target_data_8_svd, target_data_9_svd]\n","\n","init_data_4_svd = normalize_latent_samples('4')[:, :].to(device)\n","init_data_3_svd = normalize_latent_samples('3')[:, :].to(device)\n","init_data_2_svd = normalize_latent_samples('2')[:, :].to(device)\n","init_data_1_svd = normalize_latent_samples('1')[:, :].to(device)\n","init_data_0_svd = normalize_latent_samples('0')[:, :].to(device)\n","init_data_list_svd = [init_data_0_svd, init_data_1_svd, init_data_2_svd, init_data_3_svd, init_data_4_svd]\n","\n"]},{"cell_type":"code","execution_count":22,"metadata":{"id":"Og88dGcbEFbk","executionInfo":{"status":"ok","timestamp":1758777012727,"user_tz":240,"elapsed":2379,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"outputs":[],"source":["# @title prepare data with normalization (using std, i.e., the trace of covariance to normalize)\n","\n","\n","data_firsthalf_std = normalize_latent_samples_using_std(\"first_list\")\n","init_data_gpu_std = data_firsthalf_std.to(device)\n","mean_init_data_std = torch.mean(init_data_gpu_std, dim=0).to(device)\n","cov_init_data_std = torch.matmul((init_data_gpu_std - mean_init_data_std).T, (init_data_gpu_std - mean_init_data_std)) / init_data_gpu_std.size()[0]\n","cov_init_data_std = cov_init_data_std.to(device)\n","\n","data_secondhalf_std = normalize_latent_samples_using_std(\"second_list\")\n","target_data_gpu_std = data_secondhalf_std.to(device)\n","mean_target_data_std = torch.mean(target_data_gpu_std, dim=0).to(device)\n","cov_target_data_std = torch.matmul((target_data_gpu_std - mean_target_data_std).T, (target_data_gpu_std - mean_target_data_std)) / target_data_gpu_std.size()[0]\n","cov_target_data_std = cov_target_data_std.to(device)\n","\n","target_data_9 = normalize_latent_samples_using_std('9')[:, :].to(device)\n","target_data_8 = normalize_latent_samples_using_std('8')[:, :].to(device)\n","target_data_7 = normalize_latent_samples_using_std('7')[:, :].to(device)\n","target_data_6 = normalize_latent_samples_using_std('6')[:, :].to(device)\n","target_data_5 = normalize_latent_samples_using_std('5')[:, :].to(device)\n","target_data_list_std = [target_data_5, target_data_6, target_data_7, target_data_8, target_data_9]\n","\n","init_data_4 = normalize_latent_samples_using_std('4')[:, :].to(device)\n","init_data_3 = normalize_latent_samples_using_std('3')[:, :].to(device)\n","init_data_2 = normalize_latent_samples_using_std('2')[:, :].to(device)\n","init_data_1 = normalize_latent_samples_using_std('1')[:, :].to(device)\n","init_data_0 = normalize_latent_samples_using_std('0')[:, :].to(device)\n","init_data_list_std = [init_data_0, init_data_1, init_data_2, init_data_3, init_data_4]\n","\n","\n","data_all_gpu_std_for_test = normalize_latent_samples_using_std(\"0\", load_test=True)\n","for i in range(1, 10):\n","    data_i_std_for_test = normalize_latent_samples_using_std(\"{}\".format(i), load_test=True)\n","    data_i_gpu_std_for_test = data_i_std_for_test.to(device)\n","    data_all_gpu_std_for_test = torch.cat((data_all_gpu_std_for_test, data_i_gpu_std_for_test), 0)\n","init_data_4_for_test = normalize_latent_samples_using_std('4', load_test=True).to(device)\n","init_data_3_for_test = normalize_latent_samples_using_std('3', load_test=True).to(device)\n","init_data_2_for_test = normalize_latent_samples_using_std('2', load_test=True).to(device)\n","init_data_1_for_test = normalize_latent_samples_using_std('1', load_test=True).to(device)\n","init_data_0_for_test = normalize_latent_samples_using_std('0', load_test=True).to(device)\n","init_data_list_std_for_test = [init_data_0_for_test, init_data_1_for_test, init_data_2_for_test, init_data_3_for_test, init_data_4_for_test]\n","\n","\n","data_all_gpu_std_for_test = normalize_latent_samples_using_std(\"0\", load_test=True)\n","for i in range(1, 10):\n","    data_i_std_for_test = normalize_latent_samples_using_std(\"{}\".format(i), load_test=True)\n","    data_i_gpu_std_for_test = data_i_std_for_test.to(device)\n","    data_all_gpu_std_for_test = torch.cat((data_all_gpu_std_for_test, data_i_gpu_std_for_test), 0)\n","target_data_9_for_test = normalize_latent_samples_using_std('9', load_test=True).to(device)\n","target_data_8_for_test = normalize_latent_samples_using_std('8', load_test=True).to(device)\n","target_data_7_for_test = normalize_latent_samples_using_std('7', load_test=True).to(device)\n","target_data_6_for_test = normalize_latent_samples_using_std('6', load_test=True).to(device)\n","target_data_5_for_test = normalize_latent_samples_using_std('5', load_test=True).to(device)\n","target_data_list_std_for_test = [target_data_5_for_test, target_data_6_for_test, target_data_7_for_test, target_data_8_for_test, target_data_9_for_test]\n","\n"]},{"cell_type":"markdown","metadata":{"id":"tKNgKlzzsnQU"},"source":["\n","\n","---\n","\n","## Define the Metric/Distance functions: MMD\n","\n","---\n","\n"]},{"cell_type":"code","execution_count":23,"metadata":{"cellView":"form","id":"oa9NUc3DXmFq","executionInfo":{"status":"ok","timestamp":1758777012729,"user_tz":240,"elapsed":1,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"outputs":[],"source":["# @title MMD\n","\n","# Assume the kernel function is symmetric: k(x1, x2) = k(x2, x1)\n","# MMD(p0, p1) = \\iint k(x1, x2)(p1(x1) - p0(x1))(p1(x2) - p0(x2)) dx1dx2\n","#             = \\iint k(x1, x2)p1(x1)p1(x2) dx1dx2 - 2 \\iint k(x1, x2)p0(x1)p1(x2) dx1dx2 + \\iint k(x1, x2)p0(x1)p0(x2) dx1dx2\n","# try RBF kernel: k(x1, x2) = \\exp( - |x1-x2|^2/(2\\sigma^2))\n","# size of x: N by dim x=[ x_1^t ]\n","#                       [ ..... ]\n","#                       [ x_N^t ]\n","def MMD(x, y, k_sigma=1):\n","\n","    N = x.size()[0]\n","    M = y.size()[0]\n","\n","    xsqr = torch.sum(x*x, 1).unsqueeze(1)\n","    # Xsqr_Xsqr = (|x_i|^2 + |x_j|^2)_{ij}\n","    Xsqr_Xsqr = xsqr + torch.transpose(xsqr, 0, 1)\n","    # X_dot_X = (x_i^t \\cdot x_j)_{ij}\n","    X_dot_X = torch.matmul(x, torch.transpose(x, 0, 1))\n","    # K_00 \\approx iint k(x1, x2) p0(x1)p0(x2) dx1dx2\n","    K_00 = torch.exp(-(Xsqr_Xsqr - 2 * X_dot_X)/(2 * k_sigma * k_sigma))\n","\n","    ysqr = torch.sum(y*y, 1).unsqueeze(1)\n","    # Ysqr_Ysqr = (|y_i|^2 + |y_j|^2)_{ij}\n","    Ysqr_Ysqr = ysqr + torch.transpose(ysqr, 0, 1)\n","    # Y_dot_Y = (y_i^t \\cdot y_j)_{ij}\n","    Y_dot_Y = torch.matmul(y, torch.transpose(y, 0, 1))\n","    # K_11 \\approx iint k(y1, y2) p1(y1)p0(y2) dy1dy2\n","    K_11 = torch.exp(-(Ysqr_Ysqr - 2 * Y_dot_Y)/(2 * k_sigma * k_sigma))\n","\n","    # Xsqr_Ysqr = (|x_i|^2 + |y_j|^2)_{ij}\n","    Xsqr_Ysqr = xsqr + torch.transpose(ysqr, 0, 1)\n","    # X_dot_Y = (x_i^t \\cdot y_j)_{ij}\n","    X_dot_Y = torch.matmul(x, torch.transpose(y, 0, 1))\n","    # K_01 \\approx iint k(x, y) p0(x)p1(y) dxdy\n","    K_01 = torch.exp(-(Xsqr_Ysqr - 2 * X_dot_Y) / (2 * k_sigma * k_sigma))\n","\n","    # use V-statistics, i.e., average value\n","    # cf. eq (5) https://jmlr.csail.mit.edu/papers/volume13/gretton12a/gretton12a.pdf\n","    MMDsqr = torch.sum(K_00)/(N*N) - 2 * torch.sum(K_01)/(N*M) + torch.sum(K_11)/(M*M)\n","\n","    return MMDsqr\n","\n"]},{"cell_type":"code","execution_count":24,"metadata":{"cellView":"form","id":"Q_dBnBxobgCp","executionInfo":{"status":"ok","timestamp":1758777012730,"user_tz":240,"elapsed":1,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"outputs":[],"source":["# @title MMD use k(z) = -sqrt{||z||^2+0.001}\n","# k(x, y) = k(x - y)\n","\n","# Assume the kernel function is symmetric: k(x1, x2) = k(x2, x1)\n","# MMD(p0, p1) = \\iint k(x1, x2)(p1(x1) - p0(x1))(p1(x2) - p0(x2)) dx1dx2\n","#             = \\iint k(x1, x2)p1(x1)p1(x2) dx1dx2 - 2 \\iint k(x1, x2)p0(x1)p1(x2) dx1dx2 + \\iint k(x1, x2)p0(x1)p0(x2) dx1dx2\n","# size of x: N by dim x=[ x_1^t ]\n","#                       [ ..... ]\n","#                       [ x_N^t ]\n","\n","\n","# x size:N * dim; Y size:M * dim\n","# the function returns N * M tensor whose i-j entry: ||x_i - y_j||\n","def dist_XY(X, Y):\n","\n","    eps = 0.001\n","\n","    N = X.size()[0]\n","    M = Y.size()[0]\n","    x = X.T\n","    y = Y.T\n","\n","    x = x.unsqueeze(2)\n","    x_repeat = x.repeat(1, 1, M)\n","    y = y.unsqueeze(2)\n","    y_repeat = y.repeat(1, 1, N)\n","    y_repeat_t = y_repeat.transpose(1, 2)\n","    x_y_matrix = x_repeat - y_repeat_t\n","    dist_matrix_x_y = torch.sqrt(torch.sum(x_y_matrix * x_y_matrix, 0) + eps)\n","    # dist_matrix_x_y = torch.sum(x_y_matrix * x_y_matrix, 0)\n","\n","    return dist_matrix_x_y\n","\n","\n","def MMD_negnorm(x, y, k_sigma=1):\n","\n","    M = x.size()[0]\n","    N = y.size()[0]\n","\n","    dist_x_y = dist_XY(x, y)\n","    B = -torch.sum(dist_x_y) / (M * N)\n","\n","    dist_x_x = dist_XY(x, x)\n","    A = -torch.sum(dist_x_x) / (M * M)\n","\n","    dist_y_y = dist_XY(y, y)\n","    C = -torch.sum(dist_y_y) / (N * N)\n","\n","    return A/2 - B + C/2\n","\n","\n"]},{"cell_type":"code","execution_count":25,"metadata":{"cellView":"form","id":"x19mrwfUSOfM","executionInfo":{"status":"ok","timestamp":1758777012730,"user_tz":240,"elapsed":0,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"outputs":[],"source":["# @title  sliced MMD (func)\n","\n","def slicedMMD(x, y, num):\n","  sum_mmd0 = 0\n","  for d in range(x.size()[1]-1):\n","    sum_mmd0 += MMD(x[:num, [d, d+1]], y[:num, [d, d+1]])\n","\n","  return sum_mmd0\n","\n"]},{"cell_type":"markdown","metadata":{"id":"IBGEhneOs0MU"},"source":["\n","\n","---\n","\n","\n","## Plotting function\n","\n","\n","---\n","\n"]},{"cell_type":"code","execution_count":26,"metadata":{"cellView":"form","id":"3WtiwdHsUKYi","executionInfo":{"status":"ok","timestamp":1758777012731,"user_tz":240,"elapsed":0,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"outputs":[],"source":["# @title plot samples on latent space (func)\n","\n","\n","def plot_on_latent_space(encoded_x_numpy, transported_samples, dim_0, dim_1, scale, iter, exp_dir=dir):\n","\n","    plt.figure(figsize=(12,12))\n","    plt.scatter( encoded_x_numpy[:1024, dim_0], encoded_x_numpy[:1024, dim_1], color='magenta', s=10, alpha=1)\n","    plt.scatter( transported_samples[:1024, dim_0], transported_samples[:1024, dim_1], color='blue', s=1, alpha=1)\n","    plt.title(\"plot on latent space on {}-{} coordinate plane\".format(dim_0, dim_1))\n","    filename = os.path.join(exp_dir, 'plot on latent space on {}-{} coordinate plane rescale={}iter={}.png'.format(dim_0, dim_1, scale, iter))\n","    # filename = 'plot on latent space on {}-{} coordinate plane rescale={}iter={}.png'.format(dim_0, dim_1, scale, iter)\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n"]},{"cell_type":"code","execution_count":27,"metadata":{"id":"RTOmJP570IVd","cellView":"form","executionInfo":{"status":"ok","timestamp":1758777012746,"user_tz":240,"elapsed":8,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"outputs":[],"source":["# @title plot MNIST figure (func)\n","\n","import math\n","\n","def plot_mnist(images, output, n_plots, scale, iter, ttl, exp_dir=example_savepath):\n","\n","    images = images.cpu().detach().numpy()[:n_plots, :]\n","    output = output[:n_plots, :]\n","    plot_num = n_plots\n","    output = output.view(plot_num, 28, 28)\n","    output = output.cpu().detach().numpy()\n","\n","    # plot the first ten input images and then reconstructed images\n","    fig, axes = plt.subplots(nrows=2, ncols=plot_num, sharex=True, sharey=True, figsize=(25,4))\n","    # input images on top row, reconstructions on bottom\n","    for images, row in zip([images, output], axes):\n","        for img, ax in zip(images, row):\n","            ax.imshow(np.squeeze(img), cmap='gray')\n","            ax.get_xaxis().set_visible(False)\n","            ax.get_yaxis().set_visible(False)\n","    plt.title(ttl)\n","    # filename = ttl+'plot_mnist_iter_{}.png'.format(iter)\n","    filename = os.path.join(exp_dir, ttl+'plot_mnist_iter_{}.png'.format(iter))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","\n","def plot_mnist_tab(rownum, colnum, x, iter, ttl, exp_dir=example_savepath):\n","    nex = rownum * colnum\n","    fig, axs = plt.subplots(rownum, nex//rownum)\n","    fig.set_size_inches(4*rownum, 4*colnum)\n","\n","    for i in range(rownum):\n","      for j in range(colnum):\n","        # axs[i, j].imshow(x[i*colnum+j,:, :, 0], cmap='gray')\n","        axs[i, j].imshow(x[i*colnum+j,:, :, 0])\n","\n","    fig.suptitle(ttl, fontsize=16)\n","\n","    for i in range(axs.shape[0]):\n","        for j in range(axs.shape[1]):\n","            axs[i, j].get_yaxis().set_visible(False)\n","            axs[i, j].get_xaxis().set_visible(False)\n","            axs[i ,j].set_aspect('equal')\n","    filename = os.path.join(exp_dir, ttl+'plot_mnist_iter_{}.png'.format(iter))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","\n","import random\n","def generate_random_list(start, end, length):\n","  random_list = []\n","  for _ in range(length):\n","    random_list.append(random.randint(start, end))\n","  return random_list\n","\n","def plot_mnist_init_target(x_a, x_b, x_c, x_d, x_e, iter, row_num, ttl, flag = '0_to_T',exp_dir=example_savepath):\n","\n","    column = 5\n","    # row_num = 20\n","\n","    nex = row_num * column\n","    fig, axs = plt.subplots(row_num, column)\n","    fig.set_size_inches( 4*column, 4*row_num )\n","\n","    random_integers = generate_random_list(0, x_a.shape[0], row_num)\n","\n","    for j in range(column):\n","        for i in range(row_num):\n","\n","          if j == 0:\n","              axs[i, j].imshow(x_a[random_integers[i], :, :, 0], cmap='gray')\n","              if flag == '0_to_T':\n","                  axs[i, j].set_title(f'5')\n","              else:\n","                  axs[i, j].set_title(f'0')\n","          if j == 1:\n","              axs[i, j].imshow(x_b[random_integers[i], :, :, 0], cmap='gray')\n","              if flag == '0_to_T':\n","                  axs[i, j].set_title(f'6')\n","              else:\n","                  axs[i, j].set_title(f'1')\n","          if j == 2:\n","              axs[i, j].imshow(x_c[random_integers[i], :, :, 0], cmap='gray')\n","              if flag == '0_to_T':\n","                  axs[i, j].set_title(f'7')\n","              else:\n","                  axs[i, j].set_title(f'2')\n","          if j == 3:\n","              axs[i, j].imshow(x_d[random_integers[i], :, :, 0], cmap='gray')\n","              if flag == '0_to_T':\n","                  axs[i, j].set_title(f'8')\n","              else:\n","                  axs[i, j].set_title(f'3')\n","          if j == 4:\n","              axs[i, j].imshow(x_e[random_integers[i], :, :, 0], cmap='gray')\n","              if flag == '0_to_T':\n","                  axs[i, j].set_title(f'9')\n","              else:\n","                  axs[i, j].set_title(f'4')\n","\n","    fig.suptitle(ttl, fontsize=16)\n","\n","    for i in range(axs.shape[0]):\n","        for j in range(axs.shape[1]):\n","            axs[i, j].get_yaxis().set_visible(False)\n","            axs[i, j].get_xaxis().set_visible(False)\n","            axs[i ,j].set_aspect('equal')\n","    filename = os.path.join(exp_dir, ttl+'plot_mnist_iter_{}.png'.format(iter))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","\n","\n","\n","def plot_mnist_one_row(x, iter, ttl, exp_dir=example_savepath):\n","\n","    # assume square image\n","    s = int(math.sqrt(x.shape[1]))\n","\n","    nex = 8\n","    fig, axs = plt.subplots(2, nex//2)\n","    fig.set_size_inches(18, 9)\n","\n","    for i in range(nex//2):\n","        # axs[0, i].imshow(x[i,:].reshape(s,s).detach().cpu().numpy(), cmap='gray')\n","        axs[0, i].imshow(x[i,:].reshape(s,s).detach().cpu().numpy())\n","        # axs[1, i].imshow(x[ nex//2 + i , : ].reshape(s,s).detach().cpu().numpy(), cmap='gray')\n","        axs[1, i].imshow(x[ nex//2 + i , : ].reshape(s,s).detach().cpu().numpy())\n","\n","    fig.suptitle(ttl, fontsize=16)\n","\n","    for i in range(axs.shape[0]):\n","        for j in range(axs.shape[1]):\n","            axs[i, j].get_yaxis().set_visible(False)\n","            axs[i, j].get_xaxis().set_visible(False)\n","            axs[i ,j].set_aspect('equal')\n","    filename = os.path.join(exp_dir, ttl+'plot_mnist_iter_{}.png'.format(iter))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","\n"]},{"cell_type":"code","execution_count":28,"metadata":{"id":"q8pssibfep3x","cellView":"form","executionInfo":{"status":"ok","timestamp":1758777012781,"user_tz":240,"elapsed":8,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"outputs":[],"source":["# @title plot MNIST latent samples (func)\n","import matplotlib.pyplot as plt\n","\n","\n","\n","def overview_of_latent_samples(digit='All', flag=1, dir=model_and_data_savepath):\n","\n","    if digit != \"All\":\n","        # with open('latent_MNIST_{}_trained_on_all_digits.pkl'.format(digit), 'rb') as file:\n","        #     loaded_data = pickle.load(file)\n","        encoded_x = np.load(os.path.join(dir, 'encoded_image_{}.npy'.format(digit)))\n","        encoded_x = torch.tensor(encoded_x).cpu()\n","        mean_encoded_x = torch.mean(encoded_x, dim=0).cpu()\n","\n","    else:\n","        if flag == 1:\n","            loaded_data = encoded_x_tot_1\n","        else:\n","            loaded_data = encoded_x_tot_2\n","        encoded_x = torch.tensor(loaded_data)\n","        mean_encoded_x = torch.mean(encoded_x, dim=0)\n","\n","    cov_encoded_x = torch.matmul((encoded_x - mean_encoded_x).T, (encoded_x - mean_encoded_x)) / encoded_x.size(0)\n","    cov_encoded_x = cov_encoded_x.cpu()\n","    sqrt_cov_matrix = sqrtm(cov_encoded_x.detach().numpy())\n","    sqrt_cov_matrix = torch.tensor(sqrt_cov_matrix, dtype=torch.float32)\n","    condition_number = np.linalg.cond(cov_encoded_x.detach().numpy())\n","    condition_number_sqrt_cov  =  np.linalg.cond(sqrt_cov_matrix)\n","    singularv_cov = np.linalg.svdvals(cov_encoded_x.detach().numpy())\n","    U, _, V = np.linalg.svd(cov_encoded_x.detach().numpy(), full_matrices=True)\n","    U = torch.tensor(U, dtype=torch.float32)\n","    V = torch.tensor(V, dtype=torch.float32)\n","\n","    return encoded_x, mean_encoded_x, U, singularv_cov, V\n","\n","\n","# Using SVD to normalize, projected to PCA directions\n","# flag = 1: 0,1,2,3,4\n","# flag = 2: 5,6,7,8,9\n","def plot_in_latent_spc_along_PCA_directions_with_common_normalization(list_of_digits, dim_0, dim_1,  save_dir, dir=model_and_data_savepath):\n","    mean_encoded_x_all = mean_encoded_x_tot.cpu()\n","    U_all = U_reduced.cpu()\n","    s_all = s_reduced.cpu()\n","    V_all = V_reduced.cpu()\n","\n","\n","    mean_encoded_x_all = mean_encoded_x_all.cpu()\n","    U_all = U_all.cpu()\n","\n","    plt.figure(figsize=(12,12))\n","    for i in list_of_digits:\n","        # encoded_x, mean_encoded_x, U, singularv_cov, V = overview_of_latent_samples(i, flag)\n","\n","        encoded_x = np.load(os.path.join(dir, 'encoded_image_{}.npy'.format(i)))\n","        encoded_x = torch.tensor(encoded_x).cpu()\n","\n","        centered_pca_encoded_x = torch.matmul((encoded_x - mean_encoded_x_all), U_all)/torch.sqrt(s_all)\n","        plt.scatter(centered_pca_encoded_x[:1000, dim_0], centered_pca_encoded_x[:1000, dim_1],  s=5, alpha=1, label='{}'.format(i))\n","\n","    plt.legend()\n","    plt.title(\"plot digits({}) on latent space {}-{} PCA dimensions (using SVD on whole MNIST dataset)\".format(list_of_digits, dim_0, dim_1))\n","    filename = os.path.join(save_dir,\"plot digits({}) on latent space {}-{} PCA dimensions\".format(list_of_digits, dim_0, dim_1))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","\n","# Using standard deviation to normalize\n","# flag = 1: 0,1,2,3,4\n","# flag = 2: 5,6,7,8,9\n","def plot_in_latent_spc_std_with_common_normalization(list_of_digits, dim_0, dim_1, save_dir, dir=model_and_data_savepath):\n","    mean_encoded_x = mean_encoded_x_tot.cpu()\n","    std_all = std.cpu()\n","\n","    plt.figure(figsize=(12,12))\n","    for i in list_of_digits:\n","        encoded_x = np.load(os.path.join(dir, 'encoded_image_{}.npy'.format(i ) ) )\n","        encoded_x = torch.tensor(encoded_x).cpu()\n","        normalized_encoded_x = (encoded_x[:, idx_std] - mean_encoded_x[:, idx_std])/(std_all[0, idx_std]) # use Trace of Cov of data to do normalization\n","        plt.scatter(normalized_encoded_x[:1000, dim_0], normalized_encoded_x[:1000, dim_1],  s=5, alpha=1, label='{}'.format(i))\n","\n","    plt.legend()\n","    plt.title(\"plot digits({}) on latent space {}-{} dimensions (using SVD on whole MNIST dataset\".format(list_of_digits, dim_0, dim_1))\n","    filename = os.path.join(save_dir,\"plot digits({}) on std latent space {}-{} dimensions\".format(list_of_digits, dim_0, dim_1))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","\n","def plot_normalized_MNIST_data_in_latent_spc_along_PCA_directions(list_of_digits, dim_0, dim_1, save_dir, dir=model_and_data_savepath):\n","    plt.figure(figsize=(12,12))\n","    for i in list_of_digits:\n","        encoded_x = normalize_latent_samples(i).cpu()\n","        plt.scatter(encoded_x[:1000, dim_0], encoded_x[:1000, dim_1], s=5, alpha=1, label='{}'.format(i))\n","    plt.legend()\n","    plt.title(\"plot digits({}) on latent space {}-{} PCA dimensions (normalized)\".format(list_of_digits, dim_0, dim_1))\n","    filename = os.path.join(save_dir,\"[PCA] plot digits({}) on latent space {}-{} PCA dimensions\".format(list_of_digits, dim_0, dim_1))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","\n","# Normalize entrywise, projected to ordinary directions\n","def plot_std_normalized_MNIST_data_in_latent_spc_ordinary_coord(list_of_digits, dim_0, dim_1, save_dir, dir=model_and_data_savepath):\n","    plt.figure(figsize=(12,12))\n","    for i in list_of_digits:\n","        encoded_x = normalize_latent_samples_using_std(i).cpu()\n","        plt.scatter(encoded_x[:1000, dim_0], encoded_x[:1000, dim_1], s=5, alpha=1, label='{}'.format(i))\n","    plt.legend()\n","    plt.title(\"plot digits({}) on latent space {}-{} dimensions (normalize using std)\".format(list_of_digits, dim_0, dim_1))\n","    filename = os.path.join(save_dir,\"[std] plot digits({}) on latent space {}-{} dimensions\".format(list_of_digits, dim_0, dim_1))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","\n","# plot raw latent samples\n","def plot_raw_MNIST_data_in_latent_spc_ordinary_coord(list_of_digits, dim_0, dim_1, save_dir, dir=model_and_data_savepath):\n","    plt.figure(figsize=(12,12))\n","    for i in list_of_digits:\n","        encoded_x = np.load(os.path.join(dir, 'encoded_image_{}.npy'.format(i )))\n","        encoded_x = torch.tensor(encoded_x).cpu()\n","        plt.scatter(encoded_x[:1000, dim_0], encoded_x[:1000, dim_1], s=5, alpha=1, label='{}'.format(i))\n","    plt.legend()\n","    plt.title(\"plot digits({}) on latent space {}-{} dimensions (raw data)\".format(list_of_digits, dim_0, dim_1))\n","    filename = os.path.join(save_dir,\"[RAW] plot digits({}) on latent space {}-{} dimensions\".format(list_of_digits, dim_0, dim_1))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","\n","def plot_in_latent_spc_along_PCA_directions_with_common_normalization_cmp_groups(list_of_digits_1, list_of_digits_2, dim_0, dim_1, save_dir, dir=model_and_data_savepath):\n","    mean_encoded_x_all = mean_encoded_x_tot.cpu()\n","    U_all = U_reduced.cpu()\n","    s_all = s_reduced.cpu()\n","    V_all = V_reduced.cpu()\n","    mean_encoded_x_all = mean_encoded_x_all.cpu()\n","    U_all = U_all.cpu()\n","\n","    plt.figure(figsize=(12,12))\n","    l = 0\n","    for i in list_of_digits_1:\n","        encoded_x = np.load(os.path.join(dir, 'encoded_image_{}.npy'.format(i)))\n","        encoded_x = torch.tensor(encoded_x).cpu()\n","        centered_pca_encoded_x = torch.matmul((encoded_x - mean_encoded_x_all), U_all)/torch.sqrt(s_all)\n","        if l == 0:\n","            plt.scatter(centered_pca_encoded_x[:1000, dim_0], centered_pca_encoded_x[:1000, dim_1], color=\"blue\", s=5, alpha=1, label='{}'.format(list_of_digits_1))\n","        else:\n","            plt.scatter(centered_pca_encoded_x[:1000, dim_0], centered_pca_encoded_x[:1000, dim_1], color=\"blue\", s=5, alpha=1)\n","        l=l+1\n","    l = 0\n","    for i in list_of_digits_2:\n","        encoded_x = np.load(os.path.join(dir, 'encoded_image_{}.npy'.format(i)))\n","        encoded_x = torch.tensor(encoded_x).cpu()\n","        centered_pca_encoded_x = torch.matmul((encoded_x - mean_encoded_x_all), U_all)/torch.sqrt(s_all)\n","        if l == 0:\n","            plt.scatter(centered_pca_encoded_x[:1000, dim_0], centered_pca_encoded_x[:1000, dim_1], color=\"green\", s=5, alpha=1, label='{}'.format(list_of_digits_1))\n","        else:\n","            plt.scatter(centered_pca_encoded_x[:1000, dim_0], centered_pca_encoded_x[:1000, dim_1], color=\"green\", s=5, alpha=1)\n","        l=l+1\n","    plt.legend()\n","    plt.title(\"plot digits({}) and digits({}) on latent space {}-{} PCA dimensions (using SVD on whole MNIST dataset)\".format(list_of_digits_1, list_of_digits_2, dim_0, dim_1))\n","    filename = os.path.join(save_dir,\"plot digits({}) and ({}) on latent space {}-{} PCA dimensions\".format(list_of_digits_1, list_of_digits_2, dim_0, dim_1))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","\n","# Using standard deviation to normalize\n","# flag = 1: 0,1,2,3,4\n","# flag = 2: 5,6,7,8,9\n","def plot_in_latent_spc_std_with_common_normalization_cmp_groups(list_of_digits_1, list_of_digits_2, dim_0, dim_1, save_dir, dir=model_and_data_savepath):\n","    mean_encoded_x = mean_encoded_x_tot.cpu()\n","    std_all = std.cpu()\n","\n","    plt.figure(figsize=(12,12))\n","    l = 0\n","    for i in list_of_digits_1:\n","        encoded_x = np.load(os.path.join(dir, 'encoded_image_{}.npy'.format(i )))\n","        encoded_x = torch.tensor(encoded_x).cpu()\n","        normalized_encoded_x = (encoded_x[:, idx_std] - mean_encoded_x[:, idx_std])/(std_all[0, idx_std]) # use Trace of Cov of data to do normalization\n","        if l == 0:\n","            plt.scatter(normalized_encoded_x[:1000, dim_0], normalized_encoded_x[:1000, dim_1], color='b', s=5, alpha=1, label='{}'.format(list_of_digits_1))\n","        else:\n","            plt.scatter(normalized_encoded_x[:1000, dim_0], normalized_encoded_x[:1000, dim_1], color='b', s=5, alpha=1)\n","        l=l+1\n","\n","    for i in list_of_digits_2:\n","        encoded_x = np.load(os.path.join(dir, 'encoded_image_{}.npy'.format(i ) ))\n","        encoded_x = torch.tensor(encoded_x).cpu()\n","        normalized_encoded_x = (encoded_x[:, idx_std] - mean_encoded_x[:, idx_std])/(std_all[0, idx_std]) # use Trace of Cov of data to do normalization\n","        if l == 0:\n","            plt.scatter(normalized_encoded_x[:1000, dim_0], normalized_encoded_x[:1000, dim_1], color='green', s=5, alpha=1, label='{}'.format(list_of_digits_2))\n","        else:\n","            plt.scatter(normalized_encoded_x[:1000, dim_0], normalized_encoded_x[:1000, dim_1], color='green', s=5, alpha=1)\n","        l=l+1\n","\n","    plt.legend()\n","    plt.title(\"plot digits({}) and digits ({}) on latent space {}-{} dimensions (using SVD on whole MNIST dataset\".format(list_of_digits_1, list_of_digits_2, dim_0, dim_1))\n","    filename = os.path.join(save_dir,\"plot digits({}) and ({}) on std latent space {}-{} dimensions\".format(list_of_digits_1, list_of_digits_2, dim_0, dim_1))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","\n"]},{"cell_type":"code","execution_count":29,"metadata":{"cellView":"form","id":"9i9qoulVQVEu","executionInfo":{"status":"ok","timestamp":1758777012813,"user_tz":240,"elapsed":2,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"outputs":[],"source":["# @title plot OT map and samples (fiunc)\n","\n","def plot_latent_samples_n_OT_map(target_data, init_pnts, transported_pnts, dim_0, dim_1, iter, normalization='PCA', flag='0_to_T', digits='05', exp_dir=example_savepath):\n","\n","    target_data = target_data.detach().cpu().numpy()\n","    init_pnts = init_pnts.detach().cpu().numpy()\n","    transported_pnts = transported_pnts.detach().cpu().numpy()\n","\n","    plt.style.use('dark_background')\n","    plt.figure(figsize=(12,12))\n","    if flag == '0_to_T':\n","        plt.scatter(transported_pnts[:300, dim_0], transported_pnts[:300, dim_1], color='blue', s=5, alpha=1, label='transported to T')\n","        plt.scatter(init_pnts[:300, dim_0], init_pnts[:300, dim_1], color='mediumspringgreen', s=5, alpha=1, label='0')\n","    else:\n","        plt.scatter(transported_pnts[:300, dim_0], transported_pnts[:300, dim_1], color='blue', s=5, alpha=1, label='transported to 0')\n","        plt.scatter(init_pnts[:300, dim_0], init_pnts[:300, dim_1], color='mediumspringgreen', s=5, alpha=1, label='T')\n","    # plt.quiver(init_pnts[:100, dim_0], init_pnts[:100, dim_1], transported_pnts[:100, dim_0], transported_pnts[:100, dim_1], color='cyan', width=0.002, headwidth=0.004)\n","    for k in range(300):\n","        # arrow_direction = transported_pnts[k, :] - init_pnts[k, :]\n","        # plt.quiver(init_pnts[k, dim_0], init_pnts[k, dim_1], arrow_direction[dim_0], arrow_direction[dim_1], color='cyan', scale=10, scale_units='xy')\n","        # plt.annotate(\"\", xy=(init_pnts[k, dim_0], init_pnts[k, dim_1]), xytext=(arrow_direction[dim_0], arrow_direction[dim_1]),\n","                # arrowprops=dict(arrowstyle=\"-|>\", color=\"cyan\", lw=2,\n","                #                 connectionstyle=\"arc3\"))\n","        plt.plot([init_pnts[k, dim_0], transported_pnts[k, dim_0]], [init_pnts[k, dim_1], transported_pnts[k, dim_1]], c='cyan', alpha=0.5, linewidth=0.8)\n","    plt.legend(fontsize=20)\n","    if normalization == 'PCA':\n","        plt.title(\"[{}] Plot on {}-{} PCA direction\".format(flag, dim_0, dim_1), fontsize=40)\n","    else:\n","        plt.title(\"[{}] Plot on {}-{} dimensional plane\".format(flag, dim_0, dim_1)+flag, fontsize=40)\n","    if normalization == 'PCA':\n","        filename = os.path.join(exp_dir, '[{}, {}] plot OT maps on {}-{} PCA directions iter={}.png'.format(flag, digits, dim_0, dim_1, iter))\n","    else:\n","        filename = os.path.join(exp_dir, '[{}, {}] plot OT maps on {}-{} dims (normalized using std) iter={}.png'.format(flag, digits, dim_0, dim_1, iter))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","    plt.style.use('dark_background')\n","    plt.figure(figsize=(12,12))\n","    plt.scatter(target_data[:1024, dim_0], target_data[:1024, dim_1], color='magenta', s=1, alpha=1, label='target')\n","    plt.scatter(transported_pnts[:1024, dim_0], transported_pnts[:1024, dim_1], color='blue', s=1, alpha=1, label='computed')\n","    # plt.scatter(init_pnts[:1024, dim_0], init_pnts[:1024, dim_1], color='mediumspringgreen', s=1, alpha=1)\n","    plt.legend(fontsize=20)\n","    if normalization == 'PCA':\n","        plt.title(\"[{}, {}] Plot on {}-{} PCA direction\".format(flag, digits, dim_0, dim_1), fontsize=40)\n","    else:\n","        plt.title(\"[{}, {}] Plot on {}-{} dimensional plane\".format(flag, digits, dim_0, dim_1), fontsize=40)\n","    if normalization == 'PCA':\n","        filename = os.path.join(exp_dir, '[{}, {}] plot samples on {}-{} PCA directions iter={}.png'.format(flag, digits, dim_0, dim_1, iter))\n","    else:\n","        filename = os.path.join(exp_dir, '[{}, {}] plot samples on {}-{} dims (normalized using std) iter={}.png'.format(flag, digits, dim_0, dim_1, iter))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","    plt.style.use('dark_background')\n","    plt.figure(figsize=(12,12))\n","    # plt.scatter(target_data[:1024, dim_0], target_data[:1024, dim_1], color='magenta', s=1, alpha=1, label='target')\n","    plt.scatter(transported_pnts[:1024, dim_0], transported_pnts[:1024, dim_1], color='deepskyblue', s=1, alpha=1, label='computed')\n","    # plt.scatter(init_pnts[:1024, dim_0], init_pnts[:1024, dim_1], color='mediumspringgreen', s=1, alpha=1)\n","    plt.legend(fontsize=20)\n","    if normalization == 'PCA':\n","        plt.title(\"[{}, {}] plot on {}-{} PCA direction\".format(flag, digits, dim_0, dim_1) + flag, fontsize=40)\n","    else:\n","        plt.title(\"[{}, {}] plot on {}-{} dimsional plane\".format(flag, digits, dim_0, dim_1) + flag, fontsize=40)\n","    if normalization == 'PCA':\n","        filename = os.path.join(exp_dir, '[{}, {}] plot samples (ONLY COMPUTED) on {}-{} PCA directions iter={}.png'.format(flag, digits, dim_0, dim_1, iter))\n","    else:\n","        filename = os.path.join(exp_dir, '[{}, {}] plot samples (ONLY COMPUTED) on {}-{} dims (normalized using std) iter={}.png'.format(flag, digits, dim_0, dim_1, iter))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","\n","\n","\n","def plot_3D_latent_samples_n_OT_map(target_data, init_pnts, transported_pnts, dim_0, dim_1, dim_2, iter, normalization='PCA', exp_dir=example_savepath):\n","\n","    target_data = target_data.detach().cpu().numpy()\n","    init_pnts = init_pnts.detach().cpu().numpy()\n","    transported_pnts = transported_pnts.detach().cpu().numpy()\n","\n","    plt.style.use('dark_background')\n","\n","    # 3D plot\n","    fig = plt.figure(figsize=(20, 20))\n","    ax = fig.add_subplot(projection='3d')\n","\n","    ax.scatter(target_data[:1024, dim_0], target_data[:1024, dim_1], target_data[:1024, dim_2], color='magenta', s=10, label='target')\n","    ax.scatter(transported_pnts[:1024, dim_0], transported_pnts[:1024, dim_1], transported_pnts[:1024, dim_2], color='deepskyblue', s=10, label='computed')\n","\n","    ax.set_xlabel('dim {}'.format(dim_0))\n","    ax.set_ylabel('dim {}'.format(dim_1))\n","    ax.set_zlabel('dim {}'.format(dim_2))\n","\n","    plt.legend(fontsize=20)\n","    if normalization == 'PCA':\n","        plt.title(\"plot on {}-{}-{} PCA direction\".format(dim_0, dim_1, dim_2), fontsize=40)\n","    else:\n","        plt.title(\"plot on {}-{}-{} dimsional plane\".format(dim_0, dim_1, dim_2), fontsize=40)\n","    if normalization == 'PCA':\n","        filename = os.path.join(exp_dir, 'plot 3D samples on {}-{}-{} PCA directions iter={}.png'.format(dim_0, dim_1, dim_2, iter))\n","    else:\n","        filename = os.path.join(exp_dir, 'plot 3D samples on {}-{}-{} dims (normalized using std) iter={}.png'.format(dim_0, dim_1, dim_2, iter))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","\n","    # 3D plot with OT map\n","    fig = plt.figure(figsize=(20, 20))\n","    ax = fig.add_subplot(projection='3d')\n","\n","    ax.scatter(transported_pnts[:200, dim_0], transported_pnts[:200, dim_1], transported_pnts[:200, dim_2], color='blue', s=5, label='transported')\n","    ax.scatter(init_pnts[:200, dim_0], init_pnts[:200, dim_1], init_pnts[:200, dim_2], color='mediumspringgreen', s=5, label='initial')\n","    for k in range(200):\n","        plt.plot([init_pnts[k, dim_0], transported_pnts[k, dim_0]], [init_pnts[k, dim_1], transported_pnts[k, dim_1]], [init_pnts[k, dim_2], transported_pnts[k, dim_2]], c='cyan', alpha=0.5, linewidth=0.8)\n","    plt.legend(fontsize=20)\n","    if normalization == 'PCA':\n","        plt.title(\"plot on {}-{}-{} PCA direction\".format(dim_0, dim_1, dim_2), fontsize=40)\n","    else:\n","        plt.title(\"plot on {}-{}-{} dimensional plane\".format(dim_0, dim_1, dim_2), fontsize=40)\n","    if normalization == 'PCA':\n","        filename = os.path.join(exp_dir, 'plot 3D OT maps on {}-{}-{} PCA directions iter={}.png'.format(dim_0, dim_1, dim_2, iter))\n","    else:\n","        filename = os.path.join(exp_dir, 'plot 3D OT maps on {}-{}-{} dims (normalized using std) iter={}.png'.format(dim_0, dim_1, dim_2, iter))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"SWq4QpZHrv66"},"source":["\n","\n","---\n","## Main algorithm\n","---\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"collapsed":true,"id":"p8A-AsutqF40","cellView":"form"},"outputs":[],"source":["# @title OT HJ implicit solver\n","import torch\n","import numpy as np\n","import torch.nn.functional\n","import math\n","import os\n","from datetime import datetime\n","from models_Resnet import gradient, ImplicitNet\n","import utils.general as utils\n","import matplotlib.pyplot as plt\n","device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n","\n","\n","normalization_type = 'std'\n","\n","if normalization_type == 'PCA':\n","    # effective_dim_1 should EQUAL to effective_dim_2\n","    effective_dim = effective_dim\n","    init_data_entire = init_data_gpu\n","    init_data_list = init_data_list_svd\n","    cov_init = cov_init_data\n","    mean_init = mean_init_data\n","    target_data_entire = target_data_gpu\n","    target_data_list = target_data_list_svd\n","    cov_target = cov_target_data\n","    mean_target = mean_target_data\n","else:\n","    effective_dim = effective_dim_using_std\n","    init_data_entire = init_data_gpu_std\n","    init_data_list = init_data_list_std\n","    cov_init = cov_init_data_std\n","    mean_init = mean_init_data_std\n","    target_data_entire = target_data_gpu_std\n","    target_data_list = target_data_list_std\n","    cov_target = cov_target_data_std\n","    mean_target = mean_target_data_std\n","\n","def random_sampler(N, dim=effective_dim, T=1, fix_T=False, T0=0):\n","    if fix_T:\n","        ts = T0 * torch.ones(N,1)\n","    else:\n","        ts = T * torch.rand(N,1)\n","    xs = torch.randn(N, dim)\n","    pnts = torch.cat((ts, xs), 1)\n","    return torch.tensor(pnts, dtype=torch.float32, requires_grad=True)\n","\n","## -------------------------------------------------------------------------------\n","## Configurations\n","## -------------------------------------------------------------------------------\n","gpu_id = 0\n","\n","dim_latent = Dim\n","\n","\n","iter0 = 0\n","\n","regularizer_type = ['implicithjxt0t', 'mmd_negnorm_0t']\n","regularizer_coord = [1, 500]\n","\n","with_OTloss = False\n","weight_OTloss = 0.01\n","\n","N = 2000 # collocation pnts for HJ\n","batch_size = 2000 # batch size for  sample transport\n","if target_data_list[0].shape[0] < batch_size:\n","    batch_size = target_data_list[0].shape[0]\n","\n","\n","epochs = 100000\n","val_frequency  = 2000\n","plot_detail_ot =  epochs\n","plot_detail_mnist = 10000\n","\n","num_mnist_digits_to_plt = 10\n","\n","dim = effective_dim # dimension of spatial domain\n","T = 1 # terminal time\n","\n","fixing_T_and_0_HJ_loss = True\n","\n","batch_size_OT = 4000\n","\n","NN_dims = [128, 128, 128, 128]\n","network_sol = ImplicitNet(d_in=dim+1, dims=NN_dims).to(device)\n","optimizer = torch.optim.Adam(params=network_sol.parameters(), lr=1E-4)\n","\n","loss_min = 1E10\n","\n","save_model = True\n","\n","exp_dir = os.path.join(example_savepath, 'Classed_OT_transfer')\n","utils.mkdir_ifnotexists(exp_dir)\n","\n","# If need to load existing model, uncomment the following two lines:\n","# state_dict = torch.load(os.path.join(exp_dir, 'weight_sol_ep{}.pth'.format(iter_0)))\n","# network_sol.load_state_dict(state_dict[\"state_dict\"], strict=False)\n","\n","timestamp = 'Normalization_type={}, batchsize={}, NNsize={}, HJ_loss_num_colocation_pnts={}  {}'.format(normalization_type, batch_size, NN_dims, N, datetime.now().strftime('%Y_%m_%d_%H_%M_%S'))\n","exp_dir = os.path.join(exp_dir, timestamp)\n","utils.mkdir_ifnotexists(exp_dir)\n","regularizer_type = list(map(str.lower,regularizer_type))\n","assert len(regularizer_type) == len(regularizer_coord), 'match regularizer coordinates'\n","regularizer_index = {t:i for i, t in enumerate(regularizer_type)}\n","\n","if save_model:\n","    utils.mkdir_ifnotexists(os.path.join(exp_dir,'model'))\n","if torch.cuda.is_available() and gpu_id > -1:\n","    device = torch.device(gpu_id)\n","else:\n","    device = torch.device('cpu')\n","utils.set_random_seed(5884)\n","\n","\n","MMD_loss_list_0T = []\n","MMD_loss_list_T0 = []\n","# OT_dist_0T_list = []\n","# OT_dist_T0_list = []\n","Implicit_HJ_loss_list = []\n","accuracy_0T_list = []\n","accuracy_T0_list = []\n","for epoch in range(iter0, epochs+iter0):\n","\n","    if epoch % 100 == 0:\n","        print(\"Iteration: {}\".format(epoch))\n","\n","    ## -------------------------------------------------------------------------------\n","    ## Compute losses\n","    ## -------------------------------------------------------------------------------\n","    loss = torch.tensor(0.).to(device)\n","    losses = {'Train loss' : None}\n","\n","    if 'implicithjxt0t' in regularizer_type:\n","        if fixing_T_and_0_HJ_loss:\n","            pnts = random_sampler(N, fix_T=True, T0=T).to(device)\n","        else:\n","            pnts = random_sampler(N, dim, T=T).to(device)\n","        pred_sol = network_sol(pnts)\n","        grad_pred_sol = gradient(pnts,pred_sol)[:,1:]\n","        init_x = pnts[:,1:] - pnts[:,[0]]*grad_pred_sol\n","        init_xt = torch.cat((torch.zeros((N,1)).to(device), init_x), 1)\n","        fwd_loss_implicithj = ((pred_sol - 0.5*pnts[:,[0]]*torch.sum(grad_pred_sol*grad_pred_sol,dim=1,keepdim=True) - network_sol(init_xt))**2).mean()\n","\n","        if fixing_T_and_0_HJ_loss:\n","            pnts = random_sampler(N, fix_T=True, T0=0).to(device)\n","        else:\n","            pnts = random_sampler(N, dim, T=T).to(device)\n","        pred_sol = network_sol(pnts)\n","        grad_pred_sol = gradient(pnts,pred_sol)[:,1:]\n","        terminal_x = pnts[:,1:] + (T - pnts[:,[0]])*grad_pred_sol\n","        terminal_xt = torch.cat((T * torch.ones((N,1)).to(device), terminal_x), 1)\n","        bckwd_loss_implicithj = ((pred_sol + 0.5*(T - pnts[:,[0]])*torch.sum(grad_pred_sol*grad_pred_sol,dim=1,keepdim=True) - network_sol(terminal_xt))**2).mean()\n","\n","        losses['Implicit HJ loss'] = fwd_loss_implicithj.item() + bckwd_loss_implicithj.item()\n","        loss_implicithj = fwd_loss_implicithj + bckwd_loss_implicithj\n","        loss += regularizer_coord[regularizer_index['implicithjxt0t']] * loss_implicithj\n","        Implicit_HJ_loss_list.append(loss_implicithj.cpu().detach().numpy())\n","\n","    if 'mmd_negnorm_0t' in regularizer_type:\n","        sub_batch_size = int(batch_size / 5)\n","        for i in range(5):\n","            init_data = init_data_list[i]\n","            target_data = target_data_list[i]\n","            init_data_indices_for_0T = torch.tensor(np.random.choice(init_data.shape[0], sub_batch_size, False))\n","            init_data_pnts_for_0T = torch.tensor(init_data[init_data_indices_for_0T,:], dtype=torch.float32, requires_grad=True).to(device)\n","            init_spatialtemporal_pnts_for_0T = torch.tensor(torch.cat((torch.zeros((sub_batch_size,1)).to(device), init_data_pnts_for_0T), 1), requires_grad=True)\n","            pred_sol_for_0T = network_sol(init_spatialtemporal_pnts_for_0T)\n","            transported_pnts_to_T = init_spatialtemporal_pnts_for_0T[:, 1:] + T * gradient(init_spatialtemporal_pnts_for_0T, pred_sol_for_0T)[:,1:]\n","            data_indices_for_0T = torch.tensor(np.random.choice(target_data.shape[0], sub_batch_size, False))\n","            data_pnts_for_0T = torch.tensor(target_data[data_indices_for_0T,:],dtype=torch.float32, requires_grad=True).to(device)\n","            loss_MMD_T = MMD_negnorm(transported_pnts_to_T, data_pnts_for_0T)\n","            if epoch % 10 == 0:\n","                MMD_loss_list_0T.append(loss_MMD_T.item())\n","\n","            data_indices_for_T0 = torch.tensor(np.random.choice(target_data.shape[0], sub_batch_size, False))\n","            data_pnts_for_T0 = torch.tensor(target_data[data_indices_for_T0,:], dtype=torch.float32, requires_grad=True).to(device)\n","            spatialtemporal_pnts = torch.tensor(torch.cat((T * torch.ones((sub_batch_size,1)).to(device), data_pnts_for_T0), 1), requires_grad=True)\n","            pred_sol_for_T0 = network_sol(spatialtemporal_pnts)\n","            transported_pnts_to_0 = spatialtemporal_pnts[:,1:] - T * gradient(spatialtemporal_pnts, pred_sol_for_T0)[:,1:]\n","            init_data_indices_for_T0 = torch.tensor(np.random.choice(init_data.shape[0], sub_batch_size, False))\n","            init_data_pnts_for_T0 = torch.tensor(init_data[init_data_indices_for_T0,:],dtype=torch.float32, requires_grad=True).to(device)\n","            loss_MMD_0 = MMD_negnorm(transported_pnts_to_0, init_data_pnts_for_T0)\n","            if epoch % 10 == 0:\n","                MMD_loss_list_T0.append(loss_MMD_0.item())\n","            loss_MMD = loss_MMD_0 + loss_MMD_T\n","            loss += regularizer_coord[regularizer_index['mmd_negnorm_0t']] * loss_MMD\n","\n","    # if with_OTloss:\n","    #     # OT loss functional\n","    #     displacement_0T = (transported_pnts_to_T - init_data_pnts_for_0T)/T\n","    #     OT_loss_0T = torch.sum(displacement_0T * displacement_0T, 1).mean() / 2\n","    #     # OT_dist_0T_list.append(OT_loss_0T.cpu().detach().numpy())\n","    #     displacement_T0 = (transported_pnts_to_0 - init_data_pnts_for_T0)/T\n","    #     OT_loss_T0 = torch.sum(displacement_T0 * displacement_T0, 1).mean() / 2\n","    #     # OT_dist_T0_list.append(OT_loss_T0.cpu().detach().numpy())\n","    #     loss += weight_OTloss * (OT_loss_0T + OT_loss_T0)\n","\n","    ## -------------------------------------------------------------------------------\n","    ## Update parameters\n","    ## -------------------------------------------------------------------------------\n","    optimizer.zero_grad()\n","    loss.backward()\n","    optimizer.step()\n","\n","    ## -------------------------------------------------------------------------------\n","    ## Comput OT distance. T * \\int |\\nabla u(x)|^2/2 \\rho_0(x) dx\n","    ## -------------------------------------------------------------------------------\n","    # OT_loss_0T = torch.tensor([0.0]).to(device)\n","    # for i in range(5):\n","    #     init_data = init_data_list[i]\n","    #     init_data_indices_for_0T = torch.tensor(np.random.choice(init_data.shape[0], batch_size_OT, False))\n","    #     init_data_pnts_for_0T = torch.tensor(init_data[init_data_indices_for_0T,:], dtype=torch.float32, requires_grad=True)\n","    #     init_spatialtemporal_pnts_for_0T = torch.tensor(torch.cat((torch.zeros((batch_size_OT,1)).to(device), init_data_pnts_for_0T), 1), requires_grad=True).to(device)\n","    #     pred_sol_for_0T = network_sol(init_spatialtemporal_pnts_for_0T)\n","    #     displacement_0T = gradient(init_spatialtemporal_pnts_for_0T, pred_sol_for_0T)[:,1:]\n","    #     OT_loss_0T = OT_loss_0T + T * torch.sum(displacement_0T * displacement_0T, 1).mean() / 2\n","    # OT_dist_0T_list.append(OT_loss_0T.cpu().detach())\n","\n","    # OT_loss_T0 = torch.tensor([0.0]).to(device)\n","    # for i in range(5):\n","    #     target_data = target_data_list[i]\n","    #     data_indices_for_T0 = torch.tensor(np.random.choice(target_data.shape[0], batch_size_OT, False))\n","    #     data_pnts_for_T0 = torch.tensor(target_data[data_indices_for_T0,:], dtype=torch.float32, requires_grad=True)\n","    #     spatialtemporal_pnts = torch.tensor(torch.cat((T * torch.ones((batch_size_OT,1)).to(device), data_pnts_for_T0), 1), requires_grad=True).to(device)\n","    #     pred_sol_for_T0 = network_sol(spatialtemporal_pnts)\n","    #     displacement_T0 = gradient(spatialtemporal_pnts, pred_sol_for_T0)[:,1:]\n","    #     OT_loss_T0 = OT_loss_T0 + T * torch.sum(displacement_T0 * displacement_T0, 1).mean() / 2\n","    # OT_dist_T0_list.append(OT_loss_T0.cpu().detach())\n","\n","    ## -------------------------------------------------------------------------------\n","    ## Validation\n","    ## -------------------------------------------------------------------------------\n","    if (epoch+1) % val_frequency == 0:\n","      # accuracy & plots 0 ------> T\n","      plot_latent_samples_list = []\n","      for i in range(5):\n","          init_data = init_data_list_std_for_test[i]\n","          sample_size = init_data.size()[0]\n","          init_data_indices = torch.tensor(np.random.choice(init_data.shape[0], sample_size, False))\n","          init_data_pnts = torch.tensor(init_data[init_data_indices,:], dtype=torch.float32, requires_grad=True)\n","          init_spatialtemporal_pnts = torch.tensor(torch.cat((torch.zeros((sample_size,1)).to(device), init_data_pnts), 1), requires_grad=True).to(device)\n","          pred_sol = network_sol(init_spatialtemporal_pnts)\n","          transported_pnts_to_T = init_spatialtemporal_pnts[:, 1:] + T * gradient(init_spatialtemporal_pnts, pred_sol)[:,1:]\n","          if normalization_type == 'PCA':\n","              recovered_generated_target_latent_samples = recover_latent_samples_using_PCA(transported_pnts_to_T)\n","          else:\n","              recovered_generated_target_latent_samples = recover_latent_samples_using_std(transported_pnts_to_T)\n","          plot_latent_samples_list.append(recovered_generated_target_latent_samples)\n","\n","          target_data = target_data_list[i]\n","          data_indices = torch.tensor(np.random.choice(target_data.shape[0], sample_size, False))\n","          data_pnts = target_data[data_indices,:]\n","\n","          #plot in first 6 PCA directions\n","          digits_string = '{}{}'.format(i, i+5)\n","          if (epoch+1) % plot_detail_ot == 0:\n","              for index in range(6):\n","                  plot_latent_samples_n_OT_map(torch.tensor(data_pnts), torch.tensor(init_data_pnts), transported_pnts_to_T, index, index+1, epoch, normalization_type, '0_to_T', digits_string, exp_dir)\n","\n","      plot_img_list = []\n","      for latent_sample in plot_latent_samples_list:\n","          decoded_generated_target = trained_decoder(latent_sample.cpu().detach().numpy())\n","          plot_img_list.append(decoded_generated_target)\n","      if (epoch+1) % plot_detail_mnist == 0:\n","          plot_mnist_init_target(plot_img_list[0], plot_img_list[1], plot_img_list[2], plot_img_list[3], plot_img_list[4], epoch, num_mnist_digits_to_plt, \"[0 to T] generated MNIST digits (conditioned on class)\", '0_to_T', exp_dir)\n","\n","      # check accuracy\n","      ave_accuracy = 0.0\n","      model.eval()\n","      for idx in range(5):\n","          reconstructed_img_vae = plot_img_list[idx]\n","          reconstructed_img_vae = 2.0 * reconstructed_img_vae - 1.0\n","          reconstructed_img_vae = np.array(reconstructed_img_vae)\n","          digit_idx = torch.from_numpy(reconstructed_img_vae)\n","          digit_idx = digit_idx.squeeze()\n","          digit_idx = digit_idx.unsqueeze(1)\n","          accuracy, predicted = test_accuracy_fashion_MNIST(model, digit_idx, idx+5)\n","          print(accuracy)\n","          ave_accuracy += accuracy\n","      ave_accuracy /= 5\n","      print(\"=================================================================================================\")\n","      print(\"[iter {}] average accuracy on transporting 0,1,2,3,4 to 5,6,7,8,9: {}\".format(epoch, ave_accuracy))\n","      accuracy_0T_list.append(ave_accuracy)\n","      print(\"=================================================================================================\")\n","\n","      # accuracy & plots T ------> 0\n","      plot_latent_samples_list = []\n","      for i in range(5):\n","          target_data = target_data_list_std_for_test[i]\n","          sample_size = target_data.size()[0]\n","          target_data_indices = torch.tensor(np.random.choice(target_data.shape[0], sample_size, False))\n","          target_data_pnts = torch.tensor(target_data[target_data_indices,:], dtype=torch.float32, requires_grad=True)\n","          target_spatialtemporal_pnts = torch.tensor(torch.cat((T*torch.ones((sample_size,1)).to(device), target_data_pnts), 1), requires_grad=True).to(device)\n","          pred_sol = network_sol(target_spatialtemporal_pnts)\n","          transported_pnts_to_0 = target_spatialtemporal_pnts[:, 1:] - T * gradient(target_spatialtemporal_pnts, pred_sol)[:,1:]\n","          if normalization_type == 'PCA':\n","              recovered_generated_init_latent_samples = recover_latent_samples_using_PCA(transported_pnts_to_0)\n","          else:\n","              recovered_generated_init_latent_samples = recover_latent_samples_using_std(transported_pnts_to_0)\n","          plot_latent_samples_list.append(recovered_generated_init_latent_samples)\n","\n","          init_data = init_data_list[i]\n","          init_indices = torch.tensor(np.random.choice(init_data.shape[0], sample_size, False))\n","          init_pnts = init_data[init_indices,:]\n","          #plot in first 6 PCA directions\n","          digits_string = '{}{}'.format(i+5, i)\n","          if (epoch+1) % plot_detail_ot == 0:\n","              for index in range(6):\n","                  plot_latent_samples_n_OT_map(torch.tensor(init_pnts), torch.tensor(target_data_pnts), transported_pnts_to_0, index, index+1, epoch, normalization_type, 'T_to_0', digits_string, exp_dir)\n","\n","      plot_img_list = []\n","      for latent_sample in plot_latent_samples_list:\n","          decoded_generated_init = trained_decoder(latent_sample.cpu().detach().numpy())\n","          plot_img_list.append(decoded_generated_init)\n","      if (epoch+1) % plot_detail_mnist == 0:\n","          plot_mnist_init_target(plot_img_list[0], plot_img_list[1], plot_img_list[2], plot_img_list[3], plot_img_list[4], epoch, num_mnist_digits_to_plt, \"[T to 0] generated MNIST digits (conditioned on class)\", 'T_to_0', exp_dir)\n","\n","      # check accuracy\n","      ave_accuracy = 0.0\n","      model.eval()\n","      for idx in range(5):\n","          reconstructed_img_vae = plot_img_list[idx]\n","          reconstructed_img_vae = 2.0 * reconstructed_img_vae - 1.0\n","          reconstructed_img_vae = np.array(reconstructed_img_vae)\n","          digit_idx = torch.from_numpy(reconstructed_img_vae)\n","          digit_idx = digit_idx.squeeze()\n","          digit_idx = digit_idx.unsqueeze(1)\n","          accuracy, predicted = test_accuracy_fashion_MNIST(model, digit_idx, idx)\n","          print(accuracy)\n","          ave_accuracy += accuracy\n","      ave_accuracy /= 5\n","      print(\"=================================================================================================\")\n","      print(\"[iter {}] average accuracy on transporting 5,6,7,8,9 to 0,1,2,3,4: {}\".format(epoch, ave_accuracy))\n","      accuracy_T0_list.append(ave_accuracy)\n","      print(\"=================================================================================================\")\n","\n","      if save_model:\n","            torch.save({'state_dict': network_sol.state_dict(),}, os.path.join(exp_dir, f'model/weight_sol_ep{epoch}.pth'))\n","\n","plt.figure(figsize=(10,10))\n","plt.plot(MMD_loss_list_0T, label=\"OT from 0 to T\")\n","plt.plot(MMD_loss_list_T0, label=\"OT from T to 0\")\n","plt.legend( fontsize = 12 )\n","plt.title(\"MMD - iter\", fontsize=18)\n","plt.savefig(os.path.join(exp_dir,f'MMD_loss_list'), dpi=200, bbox_inches='tight', pad_inches=0)\n","plt.close()\n","\n","# plt.figure(figsize=(10,10))\n","# plt.plot(torch.tensor(OT_dist_0T_list).cpu())\n","# plt.title(\"OT_dist_0toT - iter\", fontsize=18)\n","# plt.savefig(os.path.join(exp_dir,f'OT_dist_0T_list'), dpi=200, bbox_inches='tight', pad_inches=0)\n","# plt.close()\n","\n","# plt.figure(figsize=(10,10))\n","# plt.plot(torch.tensor(OT_dist_T0_list).cpu())\n","# plt.title(\"OT_dist_Tto0 - iter\", fontsize=18)\n","# plt.savefig(os.path.join(exp_dir,f'OT_dist_T0_list'), dpi=200, bbox_inches='tight', pad_inches=0)\n","# plt.close()\n","\n","plt.figure(figsize=(10,10))\n","plt.plot(Implicit_HJ_loss_list)\n","plt.title(\"Implicit HJ loss - iter\", fontsize=18)\n","plt.savefig(os.path.join(exp_dir,f'HJloss_list'), dpi=200, bbox_inches='tight', pad_inches=0)\n","plt.close()\n","\n","plt.figure(figsize=(10,10))\n","plt.plot(accuracy_0T_list, label='Accuracy @ T')\n","plt.plot(accuracy_T0_list, label='Accuracy @ 0')\n","plt.title(\"Average_accuracy_for_OT\", fontsize=18)\n","plt.legend(fontsize=15)\n","plt.savefig(os.path.join(exp_dir,f'accuracy'), dpi=200, bbox_inches='tight', pad_inches=0)\n","plt.close()\n","\n"]}],"metadata":{"accelerator":"GPU","colab":{"collapsed_sections":["L961ddlbCWcl","dyDUTgV2wfjG","_WmQYYP-gcpj","tKNgKlzzsnQU","IBGEhneOs0MU","SWq4QpZHrv66"],"gpuType":"T4","machine_shape":"hm","provenance":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":0}