{"cells":[{"cell_type":"markdown","source":["\n","The code is composed on Google Colab.\n","\n","\n","This Python notebook is devoted to task 3: Inter-class transport between Fashion MNIST and MNIST\n","\n","\n"],"metadata":{"id":"syiVWVlCchCY"}},{"cell_type":"code","source":["from google.colab import drive\n","drive.mount('/content/drive')"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"Z42NR37-LeQU","executionInfo":{"status":"ok","timestamp":1758768377316,"user_tz":240,"elapsed":19954,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}},"outputId":"c1acab89-cb95-4d31-904e-047ae01b48a6"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Mounted at /content/drive\n"]}]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"ul47PBw-QQ18","outputId":"e0eb2760-7a63-4cff-f292-33d54435fa44","executionInfo":{"status":"ok","timestamp":1758769762353,"user_tz":240,"elapsed":43,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"outputs":[{"output_type":"stream","name":"stdout","text":["/content/drive/MyDrive/FashionMNIST_MNIST_OTHJ\n"]}],"source":["cd/content/drive/MyDrive/OTHJ_FashionMNIST_MNIST"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"cEiUihg03b8r"},"outputs":[],"source":["# @title imports\n","import os, sys\n","import numpy as np\n","import tensorflow as tf\n","import torch\n","import torchvision\n","from torchvision import transforms, datasets\n","import torchvision.models as models\n","from torchvision.transforms import Compose, Resize, Normalize, ToTensor\n","from torch.utils.data.sampler import SubsetRandomSampler\n","import matplotlib.pyplot as plt\n","import pickle\n","import torch.nn as nn\n","import torch.nn.functional as F\n","\n","import pandas as pd\n","import matplotlib.pyplot as plt\n","from matplotlib.colors import ListedColormap\n","import seaborn as sns\n","import random\n","from PIL import Image\n","\n","from utils.general import mkdir_ifnotexists\n","\n","# Keras - Deep Learning API\n","import keras\n","from keras.datasets import mnist\n","from keras.datasets import fashion_mnist\n","from keras.layers import (\n","    Conv2D, Conv2DTranspose,\n","    Input, Flatten, Dense,\n","    Lambda, Reshape\n",")\n","from keras.models import Model\n","from keras.callbacks import (\n","    EarlyStopping, ModelCheckpoint\n",")\n","from keras import backend as K\n","from sklearn.manifold import TSNE\n","from keras.metrics import MeanSquaredError\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"VD-lNU7RlpJO","cellView":"form"},"outputs":[],"source":["# @title setup basic parameter\n","\n","dir = os.getcwd()\n","device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","\n","Dim = 35\n","beta_recon = 100.\n","beta_KL = 0.1\n","model_and_data_savepath = 'model_data_d{}_beta100_01'.format(Dim)\n","model0_name = os.path.join(model_and_data_savepath, 'best_model_dim_{}_fmnist.h5'.format(Dim))\n","model1_name = os.path.join(model_and_data_savepath, 'best_model_dim_{}_mnist.h5'.format(Dim))\n","example_savepath = 'example_d{}_beta100_01'.format(Dim)\n","rho0_dataset = 'fmnist'\n","rho1_dataset = 'mnist'\n","rho0_name = 'Fashion MNIST'\n","rho1_name = 'MNIST'"]},{"cell_type":"markdown","metadata":{"id":"L961ddlbCWcl"},"source":["\n","\n","---\n","\n","## Load the decoder-encoder, compute and save latent data (rho0:  fMNIST)\n","\n","---\n","\n"]},{"cell_type":"markdown","metadata":{"id":"9iQoJ28TCWcm"},"source":["\n","Define VAE (using cnn)\n","\n","\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"MFxWd7lfCWcm"},"outputs":[],"source":["# @title VAE Loss\n","\n","import keras\n","from keras import layers\n","from keras import backend as K\n","\n","class VAELossLayer(layers.Layer):\n","    \"\"\"\n","    Custom Keras layer that calculates the loss (reconstruction + KL divergence)\n","    of the Variational AutoEncoder (VAE)\n","    \"\"\"\n","\n","    def __init__(self, beta=100, **kwargs):\n","        super().__init__(**kwargs)\n","        self.beta = beta\n","\n","    def compute_output_shape(self, input_shape):\n","        return input_shape[0]  # or return ()\n","\n","    def calculate_loss(self, original_input, reconstructed_output, mu, sigma):\n","        \"\"\"\n","        Calculates VAE loss, which is the sum of the reconstruction loss and KL-divergence loss\n","        \"\"\"\n","        original_input = tf.keras.layers.Flatten(data_format='channels_last')(original_input)\n","\n","        reconstructed_output = tf.keras.layers.Flatten(data_format='channels_last')(reconstructed_output)\n","        reconstruction_loss = tf.reduce_sum(tf.square(original_input - reconstructed_output), axis=[1])\n","        reconstruction_loss = self.beta * reconstruction_loss\n","        kl_loss = - tf.reduce_mean(1 + sigma - tf.square(mu) - tf.exp(sigma), axis=-1)\n","        return tf.reduce_mean(reconstruction_loss + kl_loss)\n","\n","    def call(self, inputs):\n","        \"\"\"\n","        Computes the loss and adds it to the layer's losses\n","        \"\"\"\n","        original_input, reconstructed_output, mu, sigma = inputs\n","\n","        loss = self.calculate_loss(original_input, reconstructed_output, mu, sigma)\n","        self.add_loss(loss)\n","        return original_input\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","colab":{"base_uri":"https://localhost:8080/","height":740},"collapsed":true,"id":"EgvxXsOACWcm","outputId":"38860a39-0a6d-46fd-d4d4-8a6aefe7c874","executionInfo":{"status":"ok","timestamp":1758768395088,"user_tz":240,"elapsed":1480,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"outputs":[{"output_type":"stream","name":"stdout","text":["Output shape of encoder: (None, 64)\n"]},{"output_type":"display_data","data":{"text/plain":["\u001b[1mModel: \"encoder_model\"\u001b[0m\n"],"text/html":["<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"encoder_model\"</span>\n","</pre>\n"]},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n","┃\u001b[1m \u001b[0m\u001b[1mLayer (type)       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape     \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m   Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to     \u001b[0m\u001b[1m \u001b[0m┃\n","┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n","│ encoder_input_layer │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m1\u001b[0m) │          \u001b[38;5;34m0\u001b[0m │ -                 │\n","│ (\u001b[38;5;33mInputLayer\u001b[0m)        │                   │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d (\u001b[38;5;33mConv2D\u001b[0m)     │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m28\u001b[0m,    │      \u001b[38;5;34m1,280\u001b[0m │ encoder_input_la… │\n","│                     │ \u001b[38;5;34m128\u001b[0m)              │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d_1 (\u001b[38;5;33mConv2D\u001b[0m)   │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m28\u001b[0m,    │    \u001b[38;5;34m147,584\u001b[0m │ conv2d[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m]      │\n","│                     │ \u001b[38;5;34m128\u001b[0m)              │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d_2 (\u001b[38;5;33mConv2D\u001b[0m)   │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m14\u001b[0m,    │     \u001b[38;5;34m73,792\u001b[0m │ conv2d_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m]    │\n","│                     │ \u001b[38;5;34m64\u001b[0m)               │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d_3 (\u001b[38;5;33mConv2D\u001b[0m)   │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m14\u001b[0m,    │     \u001b[38;5;34m36,928\u001b[0m │ conv2d_2[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m]    │\n","│                     │ \u001b[38;5;34m64\u001b[0m)               │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d_4 (\u001b[38;5;33mConv2D\u001b[0m)   │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m14\u001b[0m,    │     \u001b[38;5;34m36,928\u001b[0m │ conv2d_3[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m]    │\n","│                     │ \u001b[38;5;34m64\u001b[0m)               │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d_5 (\u001b[38;5;33mConv2D\u001b[0m)   │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m14\u001b[0m,    │     \u001b[38;5;34m36,928\u001b[0m │ conv2d_4[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m]    │\n","│                     │ \u001b[38;5;34m64\u001b[0m)               │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d_6 (\u001b[38;5;33mConv2D\u001b[0m)   │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m14\u001b[0m,    │     \u001b[38;5;34m36,928\u001b[0m │ conv2d_5[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m]    │\n","│                     │ \u001b[38;5;34m64\u001b[0m)               │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ flatten (\u001b[38;5;33mFlatten\u001b[0m)   │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12544\u001b[0m)     │          \u001b[38;5;34m0\u001b[0m │ conv2d_6[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m]    │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ dense (\u001b[38;5;33mDense\u001b[0m)       │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m)        │    \u001b[38;5;34m802,880\u001b[0m │ flatten[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m]     │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ latent_mu (\u001b[38;5;33mDense\u001b[0m)   │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m35\u001b[0m)        │      \u001b[38;5;34m2,275\u001b[0m │ dense[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m]       │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ latent_sigma        │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m35\u001b[0m)        │      \u001b[38;5;34m2,275\u001b[0m │ dense[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m]       │\n","│ (\u001b[38;5;33mDense\u001b[0m)             │                   │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ z (\u001b[38;5;33mLambda\u001b[0m)          │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m35\u001b[0m)        │          \u001b[38;5;34m0\u001b[0m │ latent_mu[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m],  │\n","│                     │                   │            │ latent_sigma[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n","└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n"],"text/html":["<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n","┃<span style=\"font-weight: bold\"> Layer (type)        </span>┃<span style=\"font-weight: bold\"> Output Shape      </span>┃<span style=\"font-weight: bold\">    Param # </span>┃<span style=\"font-weight: bold\"> Connected to      </span>┃\n","┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n","│ encoder_input_layer │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">1</span>) │          <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ -                 │\n","│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>)        │                   │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>)     │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>,    │      <span style=\"color: #00af00; text-decoration-color: #00af00\">1,280</span> │ encoder_input_la… │\n","│                     │ <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>)              │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d_1 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>)   │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>,    │    <span style=\"color: #00af00; text-decoration-color: #00af00\">147,584</span> │ conv2d[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]      │\n","│                     │ <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>)              │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d_2 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>)   │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>,    │     <span style=\"color: #00af00; text-decoration-color: #00af00\">73,792</span> │ conv2d_1[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]    │\n","│                     │ <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)               │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d_3 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>)   │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>,    │     <span style=\"color: #00af00; text-decoration-color: #00af00\">36,928</span> │ conv2d_2[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]    │\n","│                     │ <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)               │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d_4 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>)   │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>,    │     <span style=\"color: #00af00; text-decoration-color: #00af00\">36,928</span> │ conv2d_3[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]    │\n","│                     │ <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)               │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d_5 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>)   │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>,    │     <span style=\"color: #00af00; text-decoration-color: #00af00\">36,928</span> │ conv2d_4[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]    │\n","│                     │ <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)               │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d_6 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>)   │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>,    │     <span style=\"color: #00af00; text-decoration-color: #00af00\">36,928</span> │ conv2d_5[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]    │\n","│                     │ <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)               │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ flatten (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Flatten</span>)   │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">12544</span>)     │          <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ conv2d_6[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]    │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ dense (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>)       │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)        │    <span style=\"color: #00af00; text-decoration-color: #00af00\">802,880</span> │ flatten[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]     │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ latent_mu (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>)   │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">35</span>)        │      <span style=\"color: #00af00; text-decoration-color: #00af00\">2,275</span> │ dense[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]       │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ latent_sigma        │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">35</span>)        │      <span style=\"color: #00af00; text-decoration-color: #00af00\">2,275</span> │ dense[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]       │\n","│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>)             │                   │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ z (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Lambda</span>)          │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">35</span>)        │          <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ latent_mu[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>],  │\n","│                     │                   │            │ latent_sigma[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n","└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n","</pre>\n"]},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["\u001b[1m Total params: \u001b[0m\u001b[38;5;34m1,177,798\u001b[0m (4.49 MB)\n"],"text/html":["<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">1,177,798</span> (4.49 MB)\n","</pre>\n"]},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m1,177,798\u001b[0m (4.49 MB)\n"],"text/html":["<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">1,177,798</span> (4.49 MB)\n","</pre>\n"]},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"],"text/html":["<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n","</pre>\n"]},"metadata":{}},{"output_type":"stream","name":"stdout","text":["None\n"]}],"source":["# @title Encoder\n","latent_space_dim = Dim #  50\n","\n","input_dimensions = (28, 28, 1)\n","\n","# Input Layer\n","# Defines the shape of the input data for the encoder.\n","encoder_input_layer = Input(shape=input_dimensions, name='encoder_input_layer')\n","\n","# Convolution Layers\n","# Applies convolution operations to extract features from the input image.\n","encoder_layer = Conv2D(128, 3, padding='same', activation='relu')(encoder_input_layer)\n","encoder_layer = Conv2D(128, 3, padding='same', activation='relu')(encoder_layer)\n","encoder_layer = Conv2D(64, 3, padding='same', strides=2, activation='relu')(encoder_layer)\n","encoder_layer = Conv2D(64, 3, padding='same', activation='relu')(encoder_layer)\n","encoder_layer = Conv2D(64, 3, padding='same', activation='relu')(encoder_layer)\n","encoder_layer = Conv2D(64, 3, padding='same', activation='relu')(encoder_layer)\n","encoder_layer = Conv2D(64, 3, padding='same', activation='relu')(encoder_layer)\n","\n","enc_shape = tf.keras.backend.int_shape(encoder_layer)\n","\n","# Flattening Layer\n","# Converts the 3D output of convolution layers into a 1D tensor for dense layers.\n","encoder_layer = Flatten()(encoder_layer)\n","\n","# Dense Layer\n","# A fully connected layer that combines extracted features and performs further learning.\n","encoder_layer = Dense(64, activation='relu')(encoder_layer)\n","\n","# Output Layer for Encoder\n","# Prepares the encoded data for transition into the latent space, approximating a probability distribution.\n","mu_layer = Dense(latent_space_dim, name='latent_mu')(encoder_layer)\n","sigma_layer = Dense(latent_space_dim, name='latent_sigma')(encoder_layer)\n","\n","# Storing the output shape for use in the decoder\n","encoder_output_shape = (None, 64) # K.int_shape(encoder_layer)\n","print(f\"Output shape of encoder: {encoder_output_shape}\")\n","\n","\n","### Implementing the Reparameterization Trick\n","def sample_z(args):\n","    \"\"\"\n","    Generate a sample from the Gaussian distribution defined by args=(mu, sigma).\n","\n","    Args:\n","    mu_layer:    The mean of the Gaussian distribution.\n","    sigma_layer: The log standard deviation of the Gaussian distribution.\n","\n","    Returns:\n","    A sample from the Gaussian distribution.\n","    \"\"\"\n","    mu_layer, sigma_layer = args\n","    # batch_size = K.shape(mu_layer)[0]\n","    batch_size = tf.shape(mu_layer)[0]\n","    # dim = K.int_shape(mu_layer)[1]\n","    dim = tf.shape(mu_layer)[1]\n","\n","    # Generate a random sample from a standard normal distribution with the same shape\n","    # epsilon = K.random_normal(shape=(batch_size, dim)).\n","    epsilon = tf.random.normal(shape=(batch_size, dim))\n","\n","    # Scale and shift the sample by mu and sigma\n","    return mu_layer + tf.exp(sigma_layer / 2) * epsilon\n","\n","# Creating the 'z' layer using the Lambda layer to apply the reparameterization trick\n","z = Lambda(sample_z, output_shape=(latent_space_dim,), name='z')([mu_layer, sigma_layer])\n","\n","# Building the encoder model\n","encoder_model = Model(encoder_input_layer, [mu_layer, sigma_layer, z], name='encoder_model')\n","\n","# # Save the model for future use\n","# save_path=os.getcwd()\n","# encoder_model.save(save_path + '/encoder_model_MNIST_dim_15.h5')\n","\n","# Display the model summary\n","print(encoder_model.summary())\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","colab":{"base_uri":"https://localhost:8080/","height":625},"collapsed":true,"id":"ZY3QQRABCWcm","outputId":"861915f4-b32d-4ece-e3f0-c96c69db2641","executionInfo":{"status":"ok","timestamp":1758768395203,"user_tz":240,"elapsed":114,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"outputs":[{"output_type":"display_data","data":{"text/plain":["\u001b[1mModel: \"decoder_model\"\u001b[0m\n"],"text/html":["<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"decoder_model\"</span>\n","</pre>\n"]},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n","┃\u001b[1m \u001b[0m\u001b[1mLayer (type)                   \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape          \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      Param #\u001b[0m\u001b[1m \u001b[0m┃\n","┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n","│ decoder_input_layer             │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m35\u001b[0m)             │             \u001b[38;5;34m0\u001b[0m │\n","│ (\u001b[38;5;33mInputLayer\u001b[0m)                    │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ dense_1 (\u001b[38;5;33mDense\u001b[0m)                 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12544\u001b[0m)          │       \u001b[38;5;34m451,584\u001b[0m │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ reshape (\u001b[38;5;33mReshape\u001b[0m)               │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m64\u001b[0m)     │             \u001b[38;5;34m0\u001b[0m │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose                │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m64\u001b[0m)     │        \u001b[38;5;34m36,928\u001b[0m │\n","│ (\u001b[38;5;33mConv2DTranspose\u001b[0m)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_1              │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m64\u001b[0m)     │        \u001b[38;5;34m36,928\u001b[0m │\n","│ (\u001b[38;5;33mConv2DTranspose\u001b[0m)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_2              │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m64\u001b[0m)     │        \u001b[38;5;34m36,928\u001b[0m │\n","│ (\u001b[38;5;33mConv2DTranspose\u001b[0m)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_3              │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m64\u001b[0m)     │        \u001b[38;5;34m36,928\u001b[0m │\n","│ (\u001b[38;5;33mConv2DTranspose\u001b[0m)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_4              │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m64\u001b[0m)     │        \u001b[38;5;34m36,928\u001b[0m │\n","│ (\u001b[38;5;33mConv2DTranspose\u001b[0m)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_5              │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m128\u001b[0m)    │        \u001b[38;5;34m73,856\u001b[0m │\n","│ (\u001b[38;5;33mConv2DTranspose\u001b[0m)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_6              │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m128\u001b[0m)    │       \u001b[38;5;34m409,728\u001b[0m │\n","│ (\u001b[38;5;33mConv2DTranspose\u001b[0m)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_7              │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m1\u001b[0m)      │         \u001b[38;5;34m3,201\u001b[0m │\n","│ (\u001b[38;5;33mConv2DTranspose\u001b[0m)               │                        │               │\n","└─────────────────────────────────┴────────────────────────┴───────────────┘\n"],"text/html":["<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n","┃<span style=\"font-weight: bold\"> Layer (type)                    </span>┃<span style=\"font-weight: bold\"> Output Shape           </span>┃<span style=\"font-weight: bold\">       Param # </span>┃\n","┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n","│ decoder_input_layer             │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">35</span>)             │             <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │\n","│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>)                    │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ dense_1 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>)                 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">12544</span>)          │       <span style=\"color: #00af00; text-decoration-color: #00af00\">451,584</span> │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ reshape (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Reshape</span>)               │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)     │             <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose                │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)     │        <span style=\"color: #00af00; text-decoration-color: #00af00\">36,928</span> │\n","│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2DTranspose</span>)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_1              │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)     │        <span style=\"color: #00af00; text-decoration-color: #00af00\">36,928</span> │\n","│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2DTranspose</span>)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_2              │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)     │        <span style=\"color: #00af00; text-decoration-color: #00af00\">36,928</span> │\n","│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2DTranspose</span>)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_3              │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)     │        <span style=\"color: #00af00; text-decoration-color: #00af00\">36,928</span> │\n","│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2DTranspose</span>)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_4              │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)     │        <span style=\"color: #00af00; text-decoration-color: #00af00\">36,928</span> │\n","│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2DTranspose</span>)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_5              │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>)    │        <span style=\"color: #00af00; text-decoration-color: #00af00\">73,856</span> │\n","│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2DTranspose</span>)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_6              │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>)    │       <span style=\"color: #00af00; text-decoration-color: #00af00\">409,728</span> │\n","│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2DTranspose</span>)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_7              │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">1</span>)      │         <span style=\"color: #00af00; text-decoration-color: #00af00\">3,201</span> │\n","│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2DTranspose</span>)               │                        │               │\n","└─────────────────────────────────┴────────────────────────┴───────────────┘\n","</pre>\n"]},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["\u001b[1m Total params: \u001b[0m\u001b[38;5;34m1,123,009\u001b[0m (4.28 MB)\n"],"text/html":["<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">1,123,009</span> (4.28 MB)\n","</pre>\n"]},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m1,123,009\u001b[0m (4.28 MB)\n"],"text/html":["<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">1,123,009</span> (4.28 MB)\n","</pre>\n"]},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"],"text/html":["<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n","</pre>\n"]},"metadata":{}}],"source":["# @title Decoder\n","\n","# Decoder Input Layer\n","decoder_input_layer = Input(shape=(latent_space_dim,), name='decoder_input_layer')\n","\n","# Initial Dense Layer\n","# The number of units is derived from the last convolutional layer of the encoder\n","num_units = np.prod(enc_shape[1:]) # 14 * 14 * 64\n","decoder_dense_layer = Dense(num_units, activation='relu')(decoder_input_layer)\n","\n","# Reshaping Layer\n","# The dense layer's output is reshaped to match the last convolutional layer's output shape in the encoder\n","reshape_dims = (14, 14, 64)\n","decoder_layer = Reshape(reshape_dims)(decoder_dense_layer)\n","\n","decoder_layer = Conv2DTranspose(64, 3, padding='same', activation='relu')(decoder_layer)\n","decoder_layer = Conv2DTranspose(64, 3, padding='same', activation='relu')(decoder_layer)\n","decoder_layer = Conv2DTranspose(64, 3, padding='same', activation='relu')(decoder_layer)\n","decoder_layer = Conv2DTranspose(64, 3, padding='same', activation='relu')(decoder_layer)\n","decoder_layer = Conv2DTranspose(64, 3, strides=2, padding='same', activation='relu')(decoder_layer)\n","decoder_layer = Conv2DTranspose(128, 3, padding='same', activation='relu')(decoder_layer)\n","decoder_layer = Conv2DTranspose(128, 5, padding='same', activation='relu')(decoder_layer)\n","decoder_out_layer = Conv2DTranspose(1, 5, padding='same', activation='relu')(decoder_layer)\n","\n","# Building the Decoder Model\n","decoder_model = Model(decoder_input_layer, decoder_out_layer, name='decoder_model')\n","\n","decoder_model.summary()\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","colab":{"base_uri":"https://localhost:8080/"},"collapsed":true,"id":"0sJfV7oPCWcm","outputId":"a6704b4c-2b45-4acc-b605-eb98ebb67ff4","executionInfo":{"status":"ok","timestamp":1758768395209,"user_tz":240,"elapsed":5,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"outputs":[{"output_type":"stream","name":"stdout","text":["The type  of 'reconstructed_output' is: <class 'keras.src.backend.common.keras_tensor.KerasTensor'>\n","The shape of 'reconstructed_output' is: (None, 28, 28, 1)\n","The type of 'vae_loss_output' is: <class 'keras.src.backend.common.keras_tensor.KerasTensor'>\n","The shape of 'vae_loss_output' is: (None, 28, 28, 1)\n"]}],"source":["# @title set VAE model\n","from tensorflow.keras.models import load_model\n","\n","\n","# Reconstructed Output from Decoder\n","reconstructed_output = decoder_model(z)\n","\n","reconstructed_model = Model(encoder_input_layer, reconstructed_output, name='reconstructed_model')\n","\n","print(f\"The type  of 'reconstructed_output' is: {type(reconstructed_output)}\")\n","print(f\"The shape of 'reconstructed_output' is: {reconstructed_output.shape}\")\n","\n","\n","# Adding the loss computation layer to the model\n","vae_loss_output = VAELossLayer()([encoder_input_layer, reconstructed_output, mu_layer, sigma_layer])\n","\n","print(f\"The type of 'vae_loss_output' is: {type(vae_loss_output)}\")\n","print(f\"The shape of 'vae_loss_output' is: {vae_loss_output.shape}\")\n","\n","\n","# loaded_model = keras.saving.load_model(\"best_model_dim_15.h5\")\n","final_vae_model = Model(encoder_input_layer, vae_loss_output, name='vae_model')\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"collapsed":true,"id":"9BbmQa7eCWcn","outputId":"80f5e536-d49a-4c83-b09a-36d2782f5f0b","cellView":"form","executionInfo":{"status":"ok","timestamp":1758768397848,"user_tz":240,"elapsed":2639,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"outputs":[{"output_type":"stream","name":"stderr","text":["WARNING:absl:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n"]}],"source":["# @title load trained encoder and decoder\n","from tensorflow.keras.models import load_model\n","\n","trained_model0 = load_model(model0_name, custom_objects={\"sample_z\": sample_z, \"VAELossLayer\": VAELossLayer})\n","\n","# Encoder_mu\n","# Assume 'encoder_input' is the first input and 'z_mean' is an intermediate layer\n","encoder0_input = trained_model0.input\n","mu0_layer = trained_model0.get_layer('latent_mu')\n","mu0_output = mu0_layer.output\n","\n","# Recreate the encoder model\n","encoder0_mu = Model(inputs=encoder0_input, outputs=mu0_output)\n","\n","# Encoder_sigma\n","# Assume 'encoder_input' is the first input and 'z_mean' is an intermediate layer\n","encoder0_input = trained_model0.input\n","sigma0_layer = trained_model0.get_layer('latent_sigma')\n","sigma0_output = sigma0_layer.output\n","# Recreate the encoder model\n","encoder0_sigma = Model(inputs=encoder0_input, outputs=sigma0_output)\n","\n","# Decoder\n","trained_decoder0 = trained_model0.get_layer('decoder_model')\n"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"collapsed":true,"id":"8qxGz_C5zkJb","outputId":"413b515b-49fa-4090-d120-7935d4b255f8","cellView":"form","executionInfo":{"status":"ok","timestamp":1758768439625,"user_tz":240,"elapsed":41769,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"outputs":[{"output_type":"stream","name":"stdout","text":["Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz\n","\u001b[1m29515/29515\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 0us/step\n","Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz\n","\u001b[1m26421880/26421880\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 0us/step\n","Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz\n","\u001b[1m5148/5148\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 0us/step\n","Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz\n","\u001b[1m4422102/4422102\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 0us/step\n","\u001b[1m188/188\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step\n","\u001b[1m188/188\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 6ms/step\n","\u001b[1m188/188\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step\n","\u001b[1m188/188\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step\n","\u001b[1m188/188\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step\n","\u001b[1m188/188\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step\n","\u001b[1m188/188\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step\n","\u001b[1m188/188\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step\n","\u001b[1m188/188\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step\n","\u001b[1m188/188\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step\n","\u001b[1m188/188\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step\n","\u001b[1m188/188\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step\n","\u001b[1m188/188\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step\n","\u001b[1m188/188\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step\n","\u001b[1m188/188\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step\n","\u001b[1m188/188\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step\n","\u001b[1m188/188\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step\n","\u001b[1m188/188\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step\n","\u001b[1m188/188\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step\n","\u001b[1m188/188\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 18ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 11ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 5ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 5ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 5ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 5ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 5ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4ms/step\n"]}],"source":["# @title Load fMNIST dataset & encode\n","\n","\n","# load the fashion mnist data\n","(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()\n","\n","list_of_images = []\n","\n","for i in range(10):\n","  list_of_images.append(x_train[y_train == i])\n","\n","for i in range(10):\n","    data = list_of_images[i]/255.\n","    data = tf.reshape(data, (list_of_images[i].shape[0], 28, 28, 1))\n","    encoded0_mu_all = encoder0_mu.predict(data)\n","    encoded0_sigma_all = encoder0_sigma.predict(data)\n","    epsilon0_all = tf.random.normal(shape=(data.shape[0], latent_space_dim))\n","    encoded0_images = encoded0_mu_all + tf.exp(encoded0_sigma_all / 2) * epsilon0_all\n","    numpy_encoded0_images = encoded0_images.numpy()\n","    save_path = os.path.join(model_and_data_savepath, 'encoded_image_{}_fmnist.npy'.format(i))\n","    with open(save_path, 'wb') as f:\n","        np.save(f, numpy_encoded0_images)\n","\n","\n","list_of_images_for_test = []\n","\n","for i in range(10):\n","    list_of_images_for_test.append(x_test[y_test == i]) # (7000, 28, 28)\n","\n","for i in range(10):\n","    data = list_of_images_for_test[i]/255.\n","    data = tf.reshape(data, (list_of_images_for_test[i].shape[0], 28, 28, 1))\n","    encoded0_mu_all = encoder0_mu.predict(data)\n","    encoded0_sigma_all = encoder0_sigma.predict(data)\n","    epsilon0_all = tf.random.normal(shape=(data.shape[0], latent_space_dim))\n","    encoded0_images = encoded0_mu_all + tf.exp(encoded0_sigma_all / 2) * epsilon0_all\n","    numpy_encoded0_images = encoded0_images.numpy()\n","    save_path = os.path.join(model_and_data_savepath, 'encoded_image_{}_for_test_fmnist.npy'.format(i))\n","    with open(save_path, 'wb') as f:\n","        np.save(f, numpy_encoded0_images)\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"QxDHnElk4HV6"},"outputs":[],"source":["# @title save & load all digits\n","\n","\n","device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n","\n","tensor_encoded0_list=[]\n","for i in range(10):\n","    encoded0_data_path = os.path.join(model_and_data_savepath, 'encoded_image_{}_fmnist.npy'.format(i))\n","    encoded0_x = np.load(encoded0_data_path)\n","    tensor_encoded0_list.append(torch.tensor(encoded0_x).to(device))\n","\n","\n","tensor_encoded0_for_test_list=[]\n","for i in range(10):\n","    encoded0_data_path = os.path.join(model_and_data_savepath, 'encoded_image_{}_for_test_fmnist.npy'.format(i))\n","    encoded0_x_for_test = np.load(encoded0_data_path)\n","    tensor_encoded0_list.append(torch.tensor(encoded0_x_for_test).to(device))\n","\n"]},{"cell_type":"markdown","metadata":{"id":"AyQbnwjpGaLb"},"source":["\n","\n","---\n","\n","## Load the decoder-encoder, compute and save latent data (rho1:  MNIST)\n","\n","---\n","\n"]},{"cell_type":"markdown","metadata":{"id":"o6ZZj5BvGaLc"},"source":["\n","VAE (using cnn)\n","\n","\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"4ws2kdjzGaLc"},"outputs":[],"source":["# @title VAE Loss\n","\n","import keras\n","from keras import layers\n","from keras import backend as K\n","\n","class VAELossLayer(layers.Layer):\n","    \"\"\"\n","    Custom Keras layer that calculates the loss (reconstruction + KL divergence)\n","    of the Variational AutoEncoder (VAE)\n","    \"\"\"\n","\n","    def __init__(self, beta=100, **kwargs):\n","        super().__init__(**kwargs)\n","        self.beta = beta\n","\n","    def compute_output_shape(self, input_shape):\n","        return input_shape[0]  # or return ()\n","\n","    def calculate_loss(self, original_input, reconstructed_output, mu, sigma):\n","        \"\"\"\n","        Calculates VAE loss, which is the sum of the reconstruction loss and KL-divergence loss\n","        \"\"\"\n","        original_input = tf.keras.layers.Flatten(data_format='channels_last')(original_input)\n","\n","        reconstructed_output = tf.keras.layers.Flatten(data_format='channels_last')(reconstructed_output)\n","        print(original_input.shape) #(32, 784)\n","        print(reconstructed_output.shape) #(32, 784)\n","        reconstruction_loss = tf.reduce_sum(tf.square(original_input - reconstructed_output), axis=[1])\n","        reconstruction_loss = self.beta * reconstruction_loss\n","        kl_loss = - tf.reduce_mean(1 + sigma - tf.square(mu) - tf.exp(sigma), axis=-1)\n","        return tf.reduce_mean(reconstruction_loss + kl_loss)\n","\n","    def call(self, inputs):\n","        \"\"\"\n","        Computes the loss and adds it to the layer's losses\n","        \"\"\"\n","        original_input, reconstructed_output, mu, sigma = inputs\n","\n","        loss = self.calculate_loss(original_input, reconstructed_output, mu, sigma)\n","        self.add_loss(loss)\n","        return original_input\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","collapsed":true,"id":"y_2LoaqGGaLd","colab":{"base_uri":"https://localhost:8080/","height":740},"outputId":"b2d4b024-2e99-44c2-dff7-e71b2302e82f","executionInfo":{"status":"ok","timestamp":1758768439822,"user_tz":240,"elapsed":92,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"outputs":[{"output_type":"stream","name":"stdout","text":["Output shape of encoder: (None, 64)\n"]},{"output_type":"display_data","data":{"text/plain":["\u001b[1mModel: \"encoder_model\"\u001b[0m\n"],"text/html":["<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"encoder_model\"</span>\n","</pre>\n"]},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n","┃\u001b[1m \u001b[0m\u001b[1mLayer (type)       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape     \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m   Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to     \u001b[0m\u001b[1m \u001b[0m┃\n","┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n","│ encoder_input_layer │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m1\u001b[0m) │          \u001b[38;5;34m0\u001b[0m │ -                 │\n","│ (\u001b[38;5;33mInputLayer\u001b[0m)        │                   │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d_7 (\u001b[38;5;33mConv2D\u001b[0m)   │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m28\u001b[0m,    │      \u001b[38;5;34m3,328\u001b[0m │ encoder_input_la… │\n","│                     │ \u001b[38;5;34m128\u001b[0m)              │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d_8 (\u001b[38;5;33mConv2D\u001b[0m)   │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m28\u001b[0m,    │    \u001b[38;5;34m147,584\u001b[0m │ conv2d_7[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m]    │\n","│                     │ \u001b[38;5;34m128\u001b[0m)              │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d_9 (\u001b[38;5;33mConv2D\u001b[0m)   │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m14\u001b[0m,    │     \u001b[38;5;34m73,792\u001b[0m │ conv2d_8[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m]    │\n","│                     │ \u001b[38;5;34m64\u001b[0m)               │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d_10 (\u001b[38;5;33mConv2D\u001b[0m)  │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m14\u001b[0m,    │     \u001b[38;5;34m36,928\u001b[0m │ conv2d_9[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m]    │\n","│                     │ \u001b[38;5;34m64\u001b[0m)               │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d_11 (\u001b[38;5;33mConv2D\u001b[0m)  │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m14\u001b[0m,    │     \u001b[38;5;34m36,928\u001b[0m │ conv2d_10[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m]   │\n","│                     │ \u001b[38;5;34m64\u001b[0m)               │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d_12 (\u001b[38;5;33mConv2D\u001b[0m)  │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m14\u001b[0m,    │     \u001b[38;5;34m36,928\u001b[0m │ conv2d_11[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m]   │\n","│                     │ \u001b[38;5;34m64\u001b[0m)               │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d_13 (\u001b[38;5;33mConv2D\u001b[0m)  │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m14\u001b[0m,    │     \u001b[38;5;34m36,928\u001b[0m │ conv2d_12[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m]   │\n","│                     │ \u001b[38;5;34m64\u001b[0m)               │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ flatten_1 (\u001b[38;5;33mFlatten\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12544\u001b[0m)     │          \u001b[38;5;34m0\u001b[0m │ conv2d_13[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m]   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ dense_2 (\u001b[38;5;33mDense\u001b[0m)     │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m)        │    \u001b[38;5;34m802,880\u001b[0m │ flatten_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m]   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ latent_mu (\u001b[38;5;33mDense\u001b[0m)   │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m35\u001b[0m)        │      \u001b[38;5;34m2,275\u001b[0m │ dense_2[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m]     │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ latent_sigma        │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m35\u001b[0m)        │      \u001b[38;5;34m2,275\u001b[0m │ dense_2[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m]     │\n","│ (\u001b[38;5;33mDense\u001b[0m)             │                   │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ z (\u001b[38;5;33mLambda\u001b[0m)          │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m35\u001b[0m)        │          \u001b[38;5;34m0\u001b[0m │ latent_mu[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m],  │\n","│                     │                   │            │ latent_sigma[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n","└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n"],"text/html":["<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n","┃<span style=\"font-weight: bold\"> Layer (type)        </span>┃<span style=\"font-weight: bold\"> Output Shape      </span>┃<span style=\"font-weight: bold\">    Param # </span>┃<span style=\"font-weight: bold\"> Connected to      </span>┃\n","┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n","│ encoder_input_layer │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">1</span>) │          <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ -                 │\n","│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>)        │                   │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d_7 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>)   │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>,    │      <span style=\"color: #00af00; text-decoration-color: #00af00\">3,328</span> │ encoder_input_la… │\n","│                     │ <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>)              │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d_8 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>)   │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>,    │    <span style=\"color: #00af00; text-decoration-color: #00af00\">147,584</span> │ conv2d_7[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]    │\n","│                     │ <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>)              │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d_9 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>)   │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>,    │     <span style=\"color: #00af00; text-decoration-color: #00af00\">73,792</span> │ conv2d_8[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]    │\n","│                     │ <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)               │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d_10 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>)  │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>,    │     <span style=\"color: #00af00; text-decoration-color: #00af00\">36,928</span> │ conv2d_9[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]    │\n","│                     │ <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)               │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d_11 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>)  │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>,    │     <span style=\"color: #00af00; text-decoration-color: #00af00\">36,928</span> │ conv2d_10[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]   │\n","│                     │ <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)               │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d_12 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>)  │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>,    │     <span style=\"color: #00af00; text-decoration-color: #00af00\">36,928</span> │ conv2d_11[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]   │\n","│                     │ <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)               │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ conv2d_13 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>)  │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>,    │     <span style=\"color: #00af00; text-decoration-color: #00af00\">36,928</span> │ conv2d_12[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]   │\n","│                     │ <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)               │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ flatten_1 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Flatten</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">12544</span>)     │          <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ conv2d_13[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ dense_2 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>)     │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)        │    <span style=\"color: #00af00; text-decoration-color: #00af00\">802,880</span> │ flatten_1[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ latent_mu (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>)   │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">35</span>)        │      <span style=\"color: #00af00; text-decoration-color: #00af00\">2,275</span> │ dense_2[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]     │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ latent_sigma        │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">35</span>)        │      <span style=\"color: #00af00; text-decoration-color: #00af00\">2,275</span> │ dense_2[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]     │\n","│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>)             │                   │            │                   │\n","├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n","│ z (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Lambda</span>)          │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">35</span>)        │          <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ latent_mu[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>],  │\n","│                     │                   │            │ latent_sigma[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n","└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n","</pre>\n"]},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["\u001b[1m Total params: \u001b[0m\u001b[38;5;34m1,179,846\u001b[0m (4.50 MB)\n"],"text/html":["<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">1,179,846</span> (4.50 MB)\n","</pre>\n"]},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m1,179,846\u001b[0m (4.50 MB)\n"],"text/html":["<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">1,179,846</span> (4.50 MB)\n","</pre>\n"]},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"],"text/html":["<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n","</pre>\n"]},"metadata":{}},{"output_type":"stream","name":"stdout","text":["None\n"]}],"source":["# @title Encoder\n","latent_space_dim = Dim #  50\n","\n","input_dimensions = (28, 28, 1)\n","\n","# Input Layer\n","# Defines the shape of the input data for the encoder.\n","encoder_input_layer = Input(shape=input_dimensions, name='encoder_input_layer')\n","\n","# Convolution Layers\n","# Applies convolution operations to extract features from the input image.\n","encoder_layer = Conv2D(128, 5, padding='same', activation='relu')(encoder_input_layer)\n","encoder_layer = Conv2D(128, 3, padding='same', activation='relu')(encoder_layer)\n","encoder_layer = Conv2D(64, 3, padding='same', strides=2, activation='relu')(encoder_layer)\n","encoder_layer = Conv2D(64, 3, padding='same', activation='relu')(encoder_layer)\n","encoder_layer = Conv2D(64, 3, padding='same', activation='relu')(encoder_layer)\n","encoder_layer = Conv2D(64, 3, padding='same', activation='relu')(encoder_layer)\n","encoder_layer = Conv2D(64, 3, padding='same', activation='relu')(encoder_layer)\n","\n","enc_shape = tf.keras.backend.int_shape(encoder_layer)\n","\n","# Flattening Layer\n","# Converts the 3D output of convolution layers into a 1D tensor for dense layers.\n","encoder_layer = Flatten()(encoder_layer)\n","\n","# Dense Layer\n","# A fully connected layer that combines extracted features and performs further learning.\n","encoder_layer = Dense(64, activation='relu')(encoder_layer)\n","\n","# Output Layer for Encoder\n","# Prepares the encoded data for transition into the latent space, approximating a probability distribution.\n","mu_layer = Dense(latent_space_dim, name='latent_mu')(encoder_layer)\n","sigma_layer = Dense(latent_space_dim, name='latent_sigma')(encoder_layer)\n","\n","# Storing the output shape for use in the decoder\n","encoder_output_shape = (None, 64) # K.int_shape(encoder_layer)\n","print(f\"Output shape of encoder: {encoder_output_shape}\")\n","\n","\n","### Implementing the Reparameterization Trick\n","def sample_z(args):\n","    \"\"\"\n","    Generate a sample from the Gaussian distribution defined by args=(mu, sigma).\n","\n","    Args:\n","    mu_layer:    The mean of the Gaussian distribution.\n","    sigma_layer: The log standard deviation of the Gaussian distribution.\n","\n","    Returns:\n","    A sample from the Gaussian distribution.\n","    \"\"\"\n","    mu_layer, sigma_layer = args\n","    # batch_size = K.shape(mu_layer)[0]\n","    batch_size = tf.shape(mu_layer)[0]\n","    # dim = K.int_shape(mu_layer)[1]\n","    dim = tf.shape(mu_layer)[1]\n","\n","    # Generate a random sample from a standard normal distribution with the same shape\n","    # epsilon = K.random_normal(shape=(batch_size, dim)).\n","    epsilon = tf.random.normal(shape=(batch_size, dim))\n","\n","    # Scale and shift the sample by mu and sigma\n","    return mu_layer + tf.exp(sigma_layer / 2) * epsilon\n","\n","\n","# Creating the 'z' layer using the Lambda layer to apply the reparameterization trick\n","z = Lambda(sample_z, output_shape=(latent_space_dim,), name='z')([mu_layer, sigma_layer])\n","\n","# Building the encoder model\n","encoder_model = Model(encoder_input_layer, [mu_layer, sigma_layer, z], name='encoder_model')\n","\n","# # Save the model for future use\n","# save_path=os.getcwd()\n","# encoder_model.save(save_path + '/encoder_model_MNIST_dim_15.h5')\n","\n","# Display the model summary\n","print(encoder_model.summary())\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","collapsed":true,"id":"wYP4_rY8GaLd","colab":{"base_uri":"https://localhost:8080/","height":625},"outputId":"a4eaa92b-52b0-477e-ebad-5ec45b5e4adc","executionInfo":{"status":"ok","timestamp":1758768439931,"user_tz":240,"elapsed":106,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"outputs":[{"output_type":"display_data","data":{"text/plain":["\u001b[1mModel: \"decoder_model\"\u001b[0m\n"],"text/html":["<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"decoder_model\"</span>\n","</pre>\n"]},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n","┃\u001b[1m \u001b[0m\u001b[1mLayer (type)                   \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape          \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      Param #\u001b[0m\u001b[1m \u001b[0m┃\n","┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n","│ decoder_input_layer             │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m35\u001b[0m)             │             \u001b[38;5;34m0\u001b[0m │\n","│ (\u001b[38;5;33mInputLayer\u001b[0m)                    │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ dense_3 (\u001b[38;5;33mDense\u001b[0m)                 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12544\u001b[0m)          │       \u001b[38;5;34m451,584\u001b[0m │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ reshape_1 (\u001b[38;5;33mReshape\u001b[0m)             │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m64\u001b[0m)     │             \u001b[38;5;34m0\u001b[0m │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_8              │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m64\u001b[0m)     │        \u001b[38;5;34m36,928\u001b[0m │\n","│ (\u001b[38;5;33mConv2DTranspose\u001b[0m)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_9              │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m64\u001b[0m)     │        \u001b[38;5;34m36,928\u001b[0m │\n","│ (\u001b[38;5;33mConv2DTranspose\u001b[0m)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_10             │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m64\u001b[0m)     │        \u001b[38;5;34m36,928\u001b[0m │\n","│ (\u001b[38;5;33mConv2DTranspose\u001b[0m)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_11             │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m64\u001b[0m)     │        \u001b[38;5;34m36,928\u001b[0m │\n","│ (\u001b[38;5;33mConv2DTranspose\u001b[0m)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_12             │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m64\u001b[0m)     │        \u001b[38;5;34m36,928\u001b[0m │\n","│ (\u001b[38;5;33mConv2DTranspose\u001b[0m)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_13             │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m128\u001b[0m)    │        \u001b[38;5;34m73,856\u001b[0m │\n","│ (\u001b[38;5;33mConv2DTranspose\u001b[0m)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_14             │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m128\u001b[0m)    │       \u001b[38;5;34m409,728\u001b[0m │\n","│ (\u001b[38;5;33mConv2DTranspose\u001b[0m)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_15             │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m1\u001b[0m)      │         \u001b[38;5;34m3,201\u001b[0m │\n","│ (\u001b[38;5;33mConv2DTranspose\u001b[0m)               │                        │               │\n","└─────────────────────────────────┴────────────────────────┴───────────────┘\n"],"text/html":["<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n","┃<span style=\"font-weight: bold\"> Layer (type)                    </span>┃<span style=\"font-weight: bold\"> Output Shape           </span>┃<span style=\"font-weight: bold\">       Param # </span>┃\n","┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n","│ decoder_input_layer             │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">35</span>)             │             <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │\n","│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>)                    │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ dense_3 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>)                 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">12544</span>)          │       <span style=\"color: #00af00; text-decoration-color: #00af00\">451,584</span> │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ reshape_1 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Reshape</span>)             │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)     │             <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_8              │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)     │        <span style=\"color: #00af00; text-decoration-color: #00af00\">36,928</span> │\n","│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2DTranspose</span>)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_9              │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)     │        <span style=\"color: #00af00; text-decoration-color: #00af00\">36,928</span> │\n","│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2DTranspose</span>)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_10             │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)     │        <span style=\"color: #00af00; text-decoration-color: #00af00\">36,928</span> │\n","│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2DTranspose</span>)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_11             │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)     │        <span style=\"color: #00af00; text-decoration-color: #00af00\">36,928</span> │\n","│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2DTranspose</span>)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_12             │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)     │        <span style=\"color: #00af00; text-decoration-color: #00af00\">36,928</span> │\n","│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2DTranspose</span>)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_13             │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>)    │        <span style=\"color: #00af00; text-decoration-color: #00af00\">73,856</span> │\n","│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2DTranspose</span>)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_14             │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>)    │       <span style=\"color: #00af00; text-decoration-color: #00af00\">409,728</span> │\n","│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2DTranspose</span>)               │                        │               │\n","├─────────────────────────────────┼────────────────────────┼───────────────┤\n","│ conv2d_transpose_15             │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">1</span>)      │         <span style=\"color: #00af00; text-decoration-color: #00af00\">3,201</span> │\n","│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2DTranspose</span>)               │                        │               │\n","└─────────────────────────────────┴────────────────────────┴───────────────┘\n","</pre>\n"]},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["\u001b[1m Total params: \u001b[0m\u001b[38;5;34m1,123,009\u001b[0m (4.28 MB)\n"],"text/html":["<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">1,123,009</span> (4.28 MB)\n","</pre>\n"]},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m1,123,009\u001b[0m (4.28 MB)\n"],"text/html":["<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">1,123,009</span> (4.28 MB)\n","</pre>\n"]},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"],"text/html":["<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n","</pre>\n"]},"metadata":{}}],"source":["# @title Decoder\n","\n","# Decoder Input Layer\n","decoder_input_layer = Input(shape=(latent_space_dim,), name='decoder_input_layer')\n","\n","# Initial Dense Layer\n","# The number of units is derived from the last convolutional layer of the encoder\n","num_units = np.prod(enc_shape[1:]) # 14 * 14 * 64\n","decoder_dense_layer = Dense(num_units, activation='relu')(decoder_input_layer)\n","\n","# Reshaping Layer\n","# The dense layer's output is reshaped to match the last convolutional layer's output shape in the encoder\n","reshape_dims = (14, 14, 64)\n","decoder_layer = Reshape(reshape_dims)(decoder_dense_layer)\n","\n","decoder_layer = Conv2DTranspose(64, 3, padding='same', activation='relu')(decoder_layer)\n","decoder_layer = Conv2DTranspose(64, 3, padding='same', activation='relu')(decoder_layer)\n","decoder_layer = Conv2DTranspose(64, 3, padding='same', activation='relu')(decoder_layer)\n","decoder_layer = Conv2DTranspose(64, 3, padding='same', activation='relu')(decoder_layer)\n","decoder_layer = Conv2DTranspose(64, 3, strides=2, padding='same', activation='relu')(decoder_layer)\n","decoder_layer = Conv2DTranspose(128, 3, padding='same', activation='relu')(decoder_layer)\n","decoder_layer = Conv2DTranspose(128, 5, padding='same', activation='relu')(decoder_layer)\n","decoder_out_layer = Conv2DTranspose(1, 5, padding='same', activation='relu')(decoder_layer)\n","\n","# Building the Decoder Model\n","decoder_model = Model(decoder_input_layer, decoder_out_layer, name='decoder_model')\n","\n","decoder_model.summary()\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","collapsed":true,"id":"ivPNAYp8GaLd","colab":{"base_uri":"https://localhost:8080/"},"outputId":"2e313031-aa95-4252-bde1-38fadd8983b1","executionInfo":{"status":"ok","timestamp":1758768439934,"user_tz":240,"elapsed":3,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"outputs":[{"output_type":"stream","name":"stdout","text":["The type  of 'reconstructed_output' is: <class 'keras.src.backend.common.keras_tensor.KerasTensor'>\n","The shape of 'reconstructed_output' is: (None, 28, 28, 1)\n","The type of 'vae_loss_output' is: <class 'keras.src.backend.common.keras_tensor.KerasTensor'>\n","The shape of 'vae_loss_output' is: (None, 28, 28, 1)\n"]}],"source":["# @title set VAE model\n","from tensorflow.keras.models import load_model\n","\n","\n","# Reconstructed Output from Decoder\n","reconstructed_output = decoder_model(z)\n","\n","reconstructed_model = Model(encoder_input_layer, reconstructed_output, name='reconstructed_model')\n","\n","print(f\"The type  of 'reconstructed_output' is: {type(reconstructed_output)}\")\n","print(f\"The shape of 'reconstructed_output' is: {reconstructed_output.shape}\")\n","\n","\n","# Adding the loss computation layer to the model\n","vae_loss_output = VAELossLayer()([encoder_input_layer, reconstructed_output, mu_layer, sigma_layer])\n","\n","print(f\"The type of 'vae_loss_output' is: {type(vae_loss_output)}\")\n","print(f\"The shape of 'vae_loss_output' is: {vae_loss_output.shape}\")\n","\n","\n","# loaded_model = keras.saving.load_model(\"best_model_dim_15.h5\")\n","final_vae_model = Model(encoder_input_layer, vae_loss_output, name='vae_model')\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"collapsed":true,"id":"i8U3bUoBGaLe","colab":{"base_uri":"https://localhost:8080/"},"outputId":"f638ff18-1230-40aa-8e1b-e8625bbfbec5","executionInfo":{"status":"ok","timestamp":1758768449472,"user_tz":240,"elapsed":9535,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"outputs":[{"output_type":"stream","name":"stderr","text":["WARNING:absl:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n"]}],"source":["# @title load trained encoder and decoder\n","from tensorflow.keras.models import load_model\n","\n","\n","trained_model1 = load_model(model1_name, custom_objects={\"sample_z\": sample_z, \"VAELossLayer\": VAELossLayer})\n","\n","# Encoder_mu\n","# Assume 'encoder_input' is the first input and 'z_mean' is an intermediate layer\n","encoder1_input = trained_model1.input\n","mu1_layer = trained_model1.get_layer('latent_mu')\n","mu1_output = mu1_layer.output\n","# Recreate the encoder model\n","encoder1_mu = Model(inputs=encoder1_input, outputs=mu1_output)\n","\n","\n","# Encoder_sigma\n","# Assume 'encoder_input' is the first input and 'z_mean' is an intermediate layer\n","encoder1_input = trained_model1.input\n","sigma1_layer = trained_model1.get_layer('latent_sigma')\n","sigma1_output = sigma1_layer.output\n","# Recreate the encoder model\n","encoder1_sigma = Model(inputs=encoder1_input, outputs=sigma1_output)\n","\n","\n","# Decoder\n","trained_decoder1 = trained_model1.get_layer('decoder_model')\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","collapsed":true,"id":"-T6degIPGaLe","colab":{"base_uri":"https://localhost:8080/"},"outputId":"ccb1bf6f-e406-4def-bc08-1de8b51f7d5c","executionInfo":{"status":"ok","timestamp":1758768491183,"user_tz":240,"elapsed":41712,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"outputs":[{"output_type":"stream","name":"stdout","text":["Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz\n","\u001b[1m11490434/11490434\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 0us/step\n","\u001b[1m186/186\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 7ms/step\n","\u001b[1m186/186\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 5ms/step\n","\u001b[1m211/211\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 6ms/step\n","\u001b[1m211/211\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 4ms/step\n","\u001b[1m187/187\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 6ms/step\n","\u001b[1m187/187\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 5ms/step\n","\u001b[1m192/192\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 6ms/step\n","\u001b[1m192/192\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 4ms/step\n","\u001b[1m183/183\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 6ms/step\n","\u001b[1m183/183\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 5ms/step\n","\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 6ms/step\n","\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 5ms/step\n","\u001b[1m185/185\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 6ms/step\n","\u001b[1m185/185\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 5ms/step\n","\u001b[1m196/196\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 6ms/step\n","\u001b[1m196/196\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 4ms/step\n","\u001b[1m183/183\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 6ms/step\n","\u001b[1m183/183\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 5ms/step\n","\u001b[1m186/186\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 6ms/step\n","\u001b[1m186/186\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 5ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 12ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4ms/step\n","\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4ms/step\n"]}],"source":["# @title Load MNIST dataset & encode\n","\n","\n","# load the fashion mnist data\n","(x_train, y_train), (x_test, y_test) = mnist.load_data()\n","\n","\n","list_of_images = []\n","\n","for i in range(10):\n","  list_of_images.append(x_train[y_train == i]) # (7000, 28, 28)\n","\n","for i in range(10):\n","    data = list_of_images[i]/255.\n","    data = tf.reshape(data, (list_of_images[i].shape[0], 28, 28, 1))\n","    encoded1_mu_all = encoder1_mu.predict(data)\n","    encoded1_sigma_all = encoder1_sigma.predict(data)\n","    epsilon1_all = tf.random.normal(shape=(data.shape[0], latent_space_dim))\n","    encoded1_images = encoded1_mu_all + tf.exp(encoded1_sigma_all / 2) * epsilon1_all\n","    numpy_encoded1_images = encoded1_images.numpy()\n","    save_path = os.path.join(model_and_data_savepath, 'encoded_image_{}_mnist.npy'.format(i))\n","    with open(save_path, 'wb') as f:\n","        np.save(f, numpy_encoded1_images)\n","\n","for i in range(10):\n","    list_of_images_for_test.append(x_test[y_test == i]) # (7000, 28, 28)\n","\n","for i in range(10):\n","    data = list_of_images_for_test[i]/255.\n","    data = tf.reshape(data, (list_of_images_for_test[i].shape[0], 28, 28, 1))\n","    encoded1_mu_all = encoder1_mu.predict(data)\n","    encoded1_sigma_all = encoder1_sigma.predict(data)\n","    epsilon1_all = tf.random.normal(shape=(data.shape[0], latent_space_dim))\n","    encoded1_images = encoded1_mu_all + tf.exp(encoded1_sigma_all / 2) * epsilon1_all\n","    numpy_encoded1_images = encoded1_images.numpy()\n","    save_path = os.path.join(model_and_data_savepath, 'encoded_image_{}_for_test_mnist.npy'.format(i))\n","    with open(save_path, 'wb') as f:\n","        np.save(f, numpy_encoded1_images)\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"ck7NONCtGaLe"},"outputs":[],"source":["# @title save & load all digits\n","\n","\n","device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n","\n","tensor_encoded1_list=[]\n","for i in range(10):\n","    encoded1_data_path = os.path.join(model_and_data_savepath, 'encoded_image_{}_mnist.npy'.format(i))\n","    encoded1_x = np.load(encoded1_data_path)\n","    tensor_encoded1_list.append(torch.tensor(encoded1_x).to(device))\n","\n","tensor_encoded1_for_test_list=[]\n","for i in range(10):\n","    encoded1_data_path = os.path.join(model_and_data_savepath, 'encoded_image_{}_for_test_mnist.npy'.format(i))\n","    encoded1_x_for_test = np.load(encoded1_data_path)\n","    tensor_encoded1_for_test_list.append(torch.tensor(encoded1_x_for_test).to(device))\n"]},{"cell_type":"markdown","metadata":{"id":"FVqh5rEVMis9"},"source":["\n","\n","---\n","## preparing normalized latent samples (rho0:  fMNIST)\n","\n","\n","---\n","\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","collapsed":true,"id":"YI3cfLgTMis-","colab":{"base_uri":"https://localhost:8080/"},"outputId":"e411af8f-4e89-4475-9e30-8c6899adc503","executionInfo":{"status":"ok","timestamp":1758768491570,"user_tz":240,"elapsed":242,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"outputs":[{"output_type":"stream","name":"stdout","text":["effective dimension = 22\n","effective dimension using std = 35\n"]},{"output_type":"stream","name":"stderr","text":["/tmp/ipython-input-3725090064.py:24: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n","  U0_reduced = torch.tensor(U0[:, idx0]).to(device)\n","/tmp/ipython-input-3725090064.py:25: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n","  V0_reduced = torch.tensor(V0[idx0, :]).to(device)\n"]}],"source":["# @title compute mean value, SVD of the covariance matrix for the latent data\n","import pickle\n","\n","dim_latent = Dim\n","\n","encoded0_x_tot = tensor_encoded0_list[0]\n","for j in range(1, 10):\n","    encoded0_x_tot = torch.cat((encoded0_x_tot, tensor_encoded0_list[j]), 0).to(device)\n","\n","encoded0_x_tot_2 = tensor_encoded0_list[5]\n","for j in range(6, 10):\n","    encoded0_x_tot_2 = torch.cat((encoded0_x_tot_2, tensor_encoded0_list[j]), 0).to(device)\n","\n","encoded0_x_tot_1 = tensor_encoded0_list[0]\n","for j in range(0, 4):\n","    encoded0_x_tot_1 = torch.cat((encoded0_x_tot_1, tensor_encoded0_list[j]), 0).to(device)\n","\n","mean_encoded0_x_tot = torch.mean(encoded0_x_tot, dim=0, keepdims=True).to(device)\n","cov_encoded0_x_tot = torch.matmul((encoded0_x_tot - mean_encoded0_x_tot).T, (encoded0_x_tot - mean_encoded0_x_tot)) / encoded0_x_tot.size(0)\n","U0, s0, V0 = np.linalg.svd(cov_encoded0_x_tot.cpu().detach().numpy(), full_matrices=True)\n","U0 = torch.tensor(U0, dtype=torch.float32)\n","V0 = torch.tensor(V0, dtype=torch.float32)\n","idx0 = s0 > 1e-4\n","U0_reduced = torch.tensor(U0[:, idx0]).to(device)\n","V0_reduced = torch.tensor(V0[idx0, :]).to(device)\n","s0_reduced = torch.tensor(s0[idx0]).to(device)\n","\n","sqr0 = encoded0_x_tot ** 2\n","std0 = torch.sqrt(torch.mean(sqr0, dim=0, keepdims=True) - mean_encoded0_x_tot ** 2)  # diagonal of Covariance matrix\n","idx_std0 = std0[0].cpu().numpy() > 0\n","\n","effective_dim0= sum(idx0)\n","print(\"effective dimension = {}\".format(effective_dim0))\n","\n","effective_dim_using_std0 = sum(idx_std0)\n","print(\"effective dimension using std = {}\".format(effective_dim_using_std0))\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"h2l98ucCMis-","colab":{"base_uri":"https://localhost:8080/"},"outputId":"8e8a58b4-bf60-45bf-b3d4-a51816418579","executionInfo":{"status":"ok","timestamp":1758768492861,"user_tz":240,"elapsed":1290,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"outputs":[{"output_type":"stream","name":"stdout","text":["[2.3953995e+01 1.7651148e+01 1.0345057e+01 7.7660394e+00 6.1860533e+00\n"," 5.3431301e+00 4.3687687e+00 3.8708813e+00 3.4692185e+00 3.2967808e+00\n"," 2.6887658e+00 2.2653694e+00 1.9939806e+00 1.7932447e+00 1.5799804e+00\n"," 1.3917128e+00 1.2567911e+00 1.1261525e+00 9.7465360e-01 7.8684020e-01\n"," 7.0519114e-01 1.0056355e-04 9.7078053e-05 9.3978684e-05 8.9623849e-05\n"," 8.2004561e-05 8.1308834e-05 7.9811725e-05 7.9003039e-05 7.5538039e-05\n"," 6.9105838e-05 6.3406260e-05 6.2885250e-05 5.8906382e-05 5.5472690e-05]\n"]}],"source":["plt.plot(np.log(s0)/np.log(10))\n","plt.savefig(model_and_data_savepath + '/singular_value_fmnist.pdf')\n","plt.close()\n","print(s0)"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"IYoqaOL5Mis-"},"outputs":[],"source":["# @title overview of latent samples (func)\n","import numpy as np\n","import pickle\n","\n","\n","def normalize_latent_samples0(digit, savepath=model_and_data_savepath, dataset=rho0_dataset):\n","\n","    if digit == \"first_list\":\n","        encoded_x = encoded0_x_tot_1.to(device)\n","    elif digit == \"second_list\":\n","        encoded_x = encoded0_x_tot_2.to(device)\n","    else:\n","        encoded_x = np.load(os.path.join(savepath, 'encoded_image_{}_{}.npy').format(digit, dataset))\n","        encoded_x = torch.tensor(encoded_x).to(device)\n","    # normalize:\n","    encoded_x = torch.matmul((encoded_x - mean_encoded0_x_tot), U0_reduced)/torch.sqrt(s0_reduced)\n","\n","    return encoded_x\n","\n","\n","norm_eps = 1e-6\n","def normalize_latent_samples_using_std0(digit, savepath=model_and_data_savepath, dataset=rho0_dataset, load_test=False):\n","\n","    if digit == \"first_list\":\n","        encoded_x = encoded0_x_tot_1.to(device)\n","    elif digit == \"second_list\":\n","        encoded_x = encoded0_x_tot_2.to(device)\n","    elif load_test:\n","        encoded_x = np.load(os.path.join(savepath, 'encoded_image_{}_for_test_{}.npy').format(digit, dataset))\n","        encoded_x = torch.tensor(encoded_x).to(device)\n","    else:\n","        encoded_x = np.load(os.path.join(savepath, 'encoded_image_{}_{}.npy').format(digit, dataset))\n","        encoded_x = torch.tensor(encoded_x).to(device)\n","    # normalized:\n","    normalized_encoded_x = (encoded_x[:, idx_std0] - mean_encoded0_x_tot[:, idx_std0])/(std0[0, idx_std0]) # use Trace of Cov of data to do normalization\n","\n","    return normalized_encoded_x\n","\n","\n","def normalize_latent_samples0_input(encoded_x):\n","\n","    encoded_x = torch.tensor(encoded_x).to(device)\n","\n","    # normalized:\n","    normalized_encoded_x = torch.matmul((encoded_x - mean_encoded0_x_tot), U0_reduced)/torch.sqrt(s0_reduced) # use Trace of Cov of data to do normalization\n","\n","    return normalized_encoded_x\n","\n","\n","def normalize_latent_samples_using_std0_input(encoded_x):\n","\n","    encoded_x = torch.tensor(encoded_x).to(device)\n","\n","    # normalized:\n","    normalized_encoded_x = (encoded_x - mean_encoded0_x_tot) / std0[0, idx_std0]\n","\n","    return normalized_encoded_x\n","\n","\n","def recover_latent_samples_using_PCA0(samples):  # samples:  N * d (d = reduced dimension 46)\n","\n","    recovered_samples = torch.matmul(samples * torch.sqrt(s0_reduced), U0_reduced.T) + mean_encoded0_x_tot\n","\n","    return recovered_samples\n","\n","\n","def recover_latent_samples_using_std0(samples):  # samples:  N * d (d = reduced dimension 46)\n","\n","    recovered_samples = torch.zeros(samples.size()[0], dim_latent).to(device)\n","    i = 0\n","    for d in range(dim_latent):\n","      if idx_std0[d]:\n","        recovered_samples[:, d] = samples[:, i] * std0[0, d] + mean_encoded0_x_tot[0, d]\n","        i+=1\n","      else:\n","        recovered_samples[:, d] = mean_encoded0_x_tot[0, d] * torch.ones(samples.size()[0]).to(mean_encoded0_x_tot.device)\n","\n","    return recovered_samples\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","collapsed":true,"id":"7z04nGz4Mis-"},"outputs":[],"source":["# @title prepare initial data using SVD normalization\n","\n","data_all_gpu = normalize_latent_samples0(\"0\")\n","for i in range(1, 10):\n","    data_i = normalize_latent_samples0(\"{}\".format(i))\n","    data_i_gpu = data_i.to(device)\n","    data_all_gpu = torch.cat((data_all_gpu, data_i_gpu), 0)\n","\n","init_data_gpu = data_all_gpu\n","mean_init_data = torch.mean(init_data_gpu, dim=0).to(device)\n","# print(mean_init_data)\n","cov_init_data = torch.matmul((init_data_gpu - mean_init_data).T, (init_data_gpu - mean_init_data)) / init_data_gpu.size()[0]\n","cov_init_data = cov_init_data.to(device)\n","\n","init_data_9_svd = normalize_latent_samples0('9').to(device)\n","init_data_8_svd = normalize_latent_samples0('8').to(device)\n","init_data_7_svd = normalize_latent_samples0('7').to(device)\n","init_data_6_svd = normalize_latent_samples0('6').to(device)\n","init_data_5_svd = normalize_latent_samples0('5').to(device)\n","init_data_4_svd = normalize_latent_samples0('4').to(device)\n","init_data_3_svd = normalize_latent_samples0('3').to(device)\n","init_data_2_svd = normalize_latent_samples0('2').to(device)\n","init_data_1_svd = normalize_latent_samples0('1').to(device)\n","init_data_0_svd = normalize_latent_samples0('0').to(device)\n","init_data_list_svd = [init_data_0_svd, init_data_1_svd, init_data_2_svd, init_data_3_svd, init_data_4_svd, init_data_5_svd, init_data_6_svd, init_data_7_svd, init_data_8_svd, init_data_9_svd]\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"DHlpicitMis-","cellView":"form"},"outputs":[],"source":["# @title prepare initial data with normalization (using std, i.e., the trace of covariance to normalize)\n","\n","data_all_gpu_std = normalize_latent_samples_using_std0(\"0\")\n","for i in range(1, 10):\n","    data_i_std = normalize_latent_samples_using_std0(\"{}\".format(i))\n","    data_i_gpu_std = data_i_std.to(device)\n","    data_all_gpu_std = torch.cat((data_all_gpu_std, data_i_gpu_std), 0)\n","\n","init_data_gpu_std = data_all_gpu_std\n","mean_init_data_std = torch.mean(init_data_gpu_std, dim=0).to(device)\n","cov_init_data_std = torch.matmul((init_data_gpu_std - mean_init_data_std).T, (init_data_gpu_std - mean_init_data_std)) / init_data_gpu_std.size()[0]\n","cov_init_data_std = cov_init_data_std.to(device)\n","\n","init_data_9 = normalize_latent_samples_using_std0('9').to(device)\n","init_data_8 = normalize_latent_samples_using_std0('8').to(device)\n","init_data_7 = normalize_latent_samples_using_std0('7').to(device)\n","init_data_6 = normalize_latent_samples_using_std0('6').to(device)\n","init_data_5 = normalize_latent_samples_using_std0('5').to(device)\n","init_data_4 = normalize_latent_samples_using_std0('4').to(device)\n","init_data_3 = normalize_latent_samples_using_std0('3').to(device)\n","init_data_2 = normalize_latent_samples_using_std0('2').to(device)\n","init_data_1 = normalize_latent_samples_using_std0('1').to(device)\n","init_data_0 = normalize_latent_samples_using_std0('0').to(device)\n","init_data_list_std = [init_data_0, init_data_1, init_data_2, init_data_3, init_data_4, init_data_5, init_data_6, init_data_7, init_data_8, init_data_9]\n","\n","\n","\n","data_all_gpu_std_for_test = normalize_latent_samples_using_std0(\"0\", load_test=True)\n","for i in range(1, 10):\n","    data_i_std_for_test = normalize_latent_samples_using_std0(\"{}\".format(i), load_test=True)\n","    data_i_gpu_std_for_test = data_i_std_for_test.to(device)\n","    data_all_gpu_std_for_test = torch.cat((data_all_gpu_std_for_test, data_i_gpu_std_for_test), 0)\n","\n","init_data_9_for_test = normalize_latent_samples_using_std0('9', load_test=True).to(device)\n","init_data_8_for_test = normalize_latent_samples_using_std0('8', load_test=True).to(device)\n","init_data_7_for_test = normalize_latent_samples_using_std0('7', load_test=True).to(device)\n","init_data_6_for_test = normalize_latent_samples_using_std0('6', load_test=True).to(device)\n","init_data_5_for_test = normalize_latent_samples_using_std0('5', load_test=True).to(device)\n","init_data_4_for_test = normalize_latent_samples_using_std0('4', load_test=True).to(device)\n","init_data_3_for_test = normalize_latent_samples_using_std0('3', load_test=True).to(device)\n","init_data_2_for_test = normalize_latent_samples_using_std0('2', load_test=True).to(device)\n","init_data_1_for_test = normalize_latent_samples_using_std0('1', load_test=True).to(device)\n","init_data_0_for_test = normalize_latent_samples_using_std0('0', load_test=True).to(device)\n","init_data_list_std_for_test = [init_data_0_for_test, init_data_1_for_test, init_data_2_for_test, init_data_3_for_test, init_data_4_for_test, init_data_5_for_test, init_data_6_for_test, init_data_7_for_test, init_data_8_for_test, init_data_9_for_test]\n"]},{"cell_type":"markdown","metadata":{"id":"PsJ3AVfSbIpf"},"source":["\n","\n","---\n","## preparing normalized latent samples (rho1:  MNIST)\n","\n","\n","---\n","\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","collapsed":true,"id":"FqGe8wqHbIpg","colab":{"base_uri":"https://localhost:8080/"},"outputId":"871f2627-c72d-4626-89f8-4c31578d7787","executionInfo":{"status":"ok","timestamp":1758768493201,"user_tz":240,"elapsed":44,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"outputs":[{"output_type":"stream","name":"stdout","text":["effective dimension = 15\n","effective dimension using std = 35\n"]},{"output_type":"stream","name":"stderr","text":["/tmp/ipython-input-3460867625.py:24: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n","  U1_reduced = torch.tensor(U1[:, idx1]).to(device)\n","/tmp/ipython-input-3460867625.py:25: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n","  V1_reduced = torch.tensor(V1[idx1, :]).to(device)\n"]}],"source":["# @title compute mean value, SVD of the covariance matrix for the latent data\n","import pickle\n","\n","dim_latent = Dim\n","\n","encoded1_x_tot = tensor_encoded1_list[0]\n","for j in range(1, 10):\n","    encoded1_x_tot = torch.cat((encoded1_x_tot, tensor_encoded1_list[j]), 0).to(device)\n","\n","encoded1_x_tot_2 = tensor_encoded1_list[5]\n","for j in range(6, 10):\n","    encoded1_x_tot_2 = torch.cat((encoded1_x_tot_2, tensor_encoded1_list[j]), 0).to(device)\n","\n","encoded1_x_tot_1 = tensor_encoded1_list[0]\n","for j in range(0, 4):\n","    encoded1_x_tot_1 = torch.cat((encoded1_x_tot_1, tensor_encoded1_list[j]), 0).to(device)\n","\n","mean_encoded1_x_tot = torch.mean(encoded1_x_tot, dim=0, keepdims=True).to(device)\n","cov_encoded1_x_tot = torch.matmul((encoded1_x_tot - mean_encoded1_x_tot).T, (encoded1_x_tot - mean_encoded1_x_tot)) / encoded1_x_tot.size(0)\n","U1, s1, V1 = np.linalg.svd(cov_encoded1_x_tot.cpu().detach().numpy(), full_matrices=True)\n","U1 = torch.tensor(U1, dtype=torch.float32)\n","V1 = torch.tensor(V1, dtype=torch.float32)\n","idx1 = s1 > 1e-4\n","U1_reduced = torch.tensor(U1[:, idx1]).to(device)\n","V1_reduced = torch.tensor(V1[idx1, :]).to(device)\n","s1_reduced = torch.tensor(s1[idx1]).to(device)\n","\n","sqr1 = encoded1_x_tot ** 2\n","std1 = torch.sqrt(torch.mean(sqr1, dim=0, keepdims=True) - mean_encoded1_x_tot ** 2)  # Trace of Covariance matrix\n","idx_std1 = std1[0].cpu().numpy() > 0\n","\n","effective_dim1= sum(idx1)\n","print(\"effective dimension = {}\".format(effective_dim1))\n","\n","effective_dim_using_std1 = sum(idx_std1)\n","print(\"effective dimension using std = {}\".format(effective_dim_using_std1))\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"c3PvY6dDbIpg","colab":{"base_uri":"https://localhost:8080/"},"outputId":"77d65242-909d-4f3c-e1d4-30be0b517f67","executionInfo":{"status":"ok","timestamp":1758768493710,"user_tz":240,"elapsed":509,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"outputs":[{"output_type":"stream","name":"stdout","text":["[1.1046782e+01 9.0992794e+00 7.7155213e+00 6.6736026e+00 6.1251798e+00\n"," 4.7482872e+00 3.2156472e+00 3.0044379e+00 2.5746758e+00 2.2668998e+00\n"," 2.0362995e+00 1.6277325e+00 1.3591386e+00 1.2713200e+00 1.0821454e-02\n"," 5.9059224e-05 5.4488271e-05 5.0164508e-05 4.7222438e-05 3.7595746e-05\n"," 3.4917197e-05 3.4206376e-05 3.0716117e-05 3.0261694e-05 2.8496894e-05\n"," 2.6637783e-05 2.5255764e-05 2.4762865e-05 2.4423480e-05 2.3145323e-05\n"," 2.2492828e-05 1.9787576e-05 1.7793898e-05 1.7621804e-05 1.4801024e-05]\n"]}],"source":["plt.plot(np.log(s1)/np.log(10))\n","plt.savefig(model_and_data_savepath + '/singular_value_mnist.pdf')\n","plt.close()\n","print(s1)"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"i7HzCPhWbIph"},"outputs":[],"source":["# @title overview of latent samples (func)\n","import numpy as np\n","import pickle\n","\n","\n","def normalize_latent_samples1(digit, savepath=model_and_data_savepath, dataset=rho1_dataset):\n","\n","    if digit == \"first_list\":\n","        encoded_x = encoded1_x_tot_1.to(device)\n","    elif digit == \"second_list\":\n","        encoded_x = encoded1_x_tot_2.to(device)\n","    else:\n","        encoded_x = np.load(os.path.join(savepath, 'encoded_image_{}_{}.npy').format(digit, dataset))\n","        encoded_x = torch.tensor(encoded_x).to(device)\n","    # normalize:\n","    encoded_x = torch.matmul((encoded_x - mean_encoded1_x_tot), U1_reduced)/torch.sqrt(s1_reduced)\n","\n","    return encoded_x\n","\n","\n","norm_eps = 1e-6\n","def normalize_latent_samples_using_std1(digit, savepath=model_and_data_savepath, dataset=rho1_dataset, load_test=False):\n","\n","    if digit == \"first_list\":\n","        encoded_x = encoded1_x_tot_1.to(device)\n","    elif digit == \"second_list\":\n","        encoded_x = encoded1_x_tot_2.to(device)\n","    elif load_test:\n","        encoded_x = np.load(os.path.join(savepath, 'encoded_image_{}_for_test_{}.npy').format(digit, dataset))\n","        encoded_x = torch.tensor(encoded_x).to(device)\n","    else:\n","        encoded_x = np.load(os.path.join(savepath, 'encoded_image_{}_{}.npy').format(digit, dataset))\n","        encoded_x = torch.tensor(encoded_x).to(device)\n","    # normalized:\n","    normalized_encoded_x = (encoded_x[:, idx_std1] - mean_encoded1_x_tot[:, idx_std1])/(std1[0, idx_std1]) # use Trace of Cov of data to do normalization\n","\n","    return normalized_encoded_x\n","\n","\n","def normalize_latent_samples1_input(encoded_x):\n","\n","    encoded_x = torch.tensor(encoded_x).to(device)\n","\n","    # normalized:\n","    normalized_encoded_x = torch.matmul((encoded_x - mean_encoded1_x_tot), U1_reduced)/torch.sqrt(s1_reduced)\n","\n","    return normalized_encoded_x\n","\n","\n","def normalize_latent_samples_using_std1_input(encoded_x):\n","\n","    encoded_x = torch.tensor(encoded_x).to(device)\n","\n","    # normalized:\n","    normalized_encoded_x = (encoded_x[:, idx_std1] - mean_encoded1_x_tot[:, idx_std1])/(std1[0, idx_std1]) # use Trace of Cov of data to do normalization\n","\n","    return normalized_encoded_x\n","\n","\n","def recover_latent_samples_using_PCA1(samples):  # samples:  N * d (d = reduced dimension 46)\n","\n","    recovered_samples = torch.matmul(samples * torch.sqrt(s1_reduced), U1_reduced.T) + mean_encoded1_x_tot\n","\n","    return recovered_samples\n","\n","\n","def recover_latent_samples_using_std1(samples):  # samples:  N * d (d = reduced dimension 46)\n","\n","    recovered_samples = torch.zeros(samples.size()[0], dim_latent).to(device)\n","    i = 0\n","    for d in range(dim_latent):\n","      if idx_std1[d]:\n","        recovered_samples[:, d] = samples[:, i] * std1[0, d] + mean_encoded1_x_tot[0, d]\n","        i+=1\n","      else:\n","        recovered_samples[:, d] = mean_encoded1_x_tot[0, d] * torch.ones(samples.size()[0]).to(mean_encoded1_x_tot.device)\n","\n","    return recovered_samples\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","collapsed":true,"id":"JR3xr38-bIph"},"outputs":[],"source":["# @title prepare target data using SVD normalization\n","\n","data_all_gpu = normalize_latent_samples1(\"0\")\n","for i in range(1, 10):\n","    data_i = normalize_latent_samples1(\"{}\".format(i))\n","    data_i_gpu = data_i.to(device)\n","    data_all_gpu = torch.cat((data_all_gpu, data_i_gpu), 0)\n","\n","target_data_gpu = data_all_gpu\n","mean_target_data = torch.mean(target_data_gpu, dim=0).to(device)\n","# print(mean_target_data)\n","cov_target_data = torch.matmul((target_data_gpu - mean_target_data).T, (target_data_gpu - mean_target_data)) / target_data_gpu.size()[0]\n","cov_target_data = cov_target_data.to(device)\n","\n","\n","target_data_9_svd = normalize_latent_samples1('9').to(device)\n","target_data_8_svd = normalize_latent_samples1('8').to(device)\n","target_data_7_svd = normalize_latent_samples1('7').to(device)\n","target_data_6_svd = normalize_latent_samples1('6').to(device)\n","target_data_5_svd = normalize_latent_samples1('5').to(device)\n","target_data_4_svd = normalize_latent_samples1('4').to(device)\n","target_data_3_svd = normalize_latent_samples1('3').to(device)\n","target_data_2_svd = normalize_latent_samples1('2').to(device)\n","target_data_1_svd = normalize_latent_samples1('1').to(device)\n","target_data_0_svd = normalize_latent_samples1('0').to(device)\n","target_data_list_svd = [target_data_0_svd, target_data_1_svd, target_data_2_svd, target_data_3_svd, target_data_4_svd, target_data_5_svd, target_data_6_svd, target_data_7_svd, target_data_8_svd, target_data_9_svd]\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"KD_zPhShbIph"},"outputs":[],"source":["# @title prepare target data with normalization (using std, i.e., the trace of covariance to normalize)\n","\n","data_all_gpu_std = normalize_latent_samples_using_std1(\"0\")\n","for i in range(1, 10):\n","    data_i_std = normalize_latent_samples_using_std1(\"{}\".format(i))\n","    data_i_gpu_std = data_i_std.to(device)\n","    data_all_gpu_std = torch.cat((data_all_gpu_std, data_i_gpu_std), 0)\n","\n","\n","target_data_gpu_std = data_all_gpu_std\n","mean_target_data_std = torch.mean(target_data_gpu_std, dim=0).to(device)\n","cov_target_data_std = torch.matmul((target_data_gpu_std - mean_target_data_std).T, (target_data_gpu_std - mean_target_data_std)) / target_data_gpu_std.size()[0]\n","cov_target_data_std = cov_target_data_std.to(device)\n","\n","\n","target_data_9 = normalize_latent_samples_using_std1('9').to(device)\n","target_data_8 = normalize_latent_samples_using_std1('8').to(device)\n","target_data_7 = normalize_latent_samples_using_std1('7').to(device)\n","target_data_6 = normalize_latent_samples_using_std1('6').to(device)\n","target_data_5 = normalize_latent_samples_using_std1('5').to(device)\n","target_data_4 = normalize_latent_samples_using_std1('4').to(device)\n","target_data_3 = normalize_latent_samples_using_std1('3').to(device)\n","target_data_2 = normalize_latent_samples_using_std1('2').to(device)\n","target_data_1 = normalize_latent_samples_using_std1('1').to(device)\n","target_data_0 = normalize_latent_samples_using_std1('0').to(device)\n","target_data_list_std = [target_data_0, target_data_1, target_data_2, target_data_3, target_data_4, target_data_5, target_data_6, target_data_7, target_data_8, target_data_9]\n","\n","\n","data_all_gpu_std_for_test = normalize_latent_samples_using_std1(\"0\", load_test=True)\n","for i in range(1, 10):\n","    data_i_std_for_test = normalize_latent_samples_using_std1(\"{}\".format(i), load_test=True)\n","    data_i_gpu_std_for_test = data_i_std_for_test.to(device)\n","    data_all_gpu_std_for_test = torch.cat((data_all_gpu_std_for_test, data_i_gpu_std_for_test), 0)\n","\n","\n","target_data_9_for_test = normalize_latent_samples_using_std1('9', load_test=True).to(device)\n","target_data_8_for_test = normalize_latent_samples_using_std1('8', load_test=True).to(device)\n","target_data_7_for_test = normalize_latent_samples_using_std1('7', load_test=True).to(device)\n","target_data_6_for_test = normalize_latent_samples_using_std1('6', load_test=True).to(device)\n","target_data_5_for_test = normalize_latent_samples_using_std1('5', load_test=True).to(device)\n","target_data_4_for_test = normalize_latent_samples_using_std1('4', load_test=True).to(device)\n","target_data_3_for_test = normalize_latent_samples_using_std1('3', load_test=True).to(device)\n","target_data_2_for_test = normalize_latent_samples_using_std1('2', load_test=True).to(device)\n","target_data_1_for_test = normalize_latent_samples_using_std1('1', load_test=True).to(device)\n","target_data_0_for_test = normalize_latent_samples_using_std1('0', load_test=True).to(device)\n","target_data_list_std_for_test = [target_data_0_for_test, target_data_1_for_test, target_data_2_for_test, target_data_3_for_test, target_data_4_for_test, target_data_5_for_test, target_data_6_for_test, target_data_7_for_test, target_data_8_for_test, target_data_9_for_test]\n","\n"]},{"cell_type":"markdown","metadata":{"id":"ng6raSYzMis_"},"source":["\n","\n","---\n","## Plotting functions  \n","\n","\n","\n","---\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"dKnVo5RYMis_"},"outputs":[],"source":["# @title plot samples on latent space (func)\n","\n","\n","def plot_on_latent_space(encoded_x_numpy, transported_samples, dim_0, dim_1, scale, iter, exp_dir=dir):\n","\n","    plt.figure(figsize=(12,12))\n","    plt.scatter( encoded_x_numpy[:1024, dim_0], encoded_x_numpy[:1024, dim_1], color='magenta', s=10, alpha=1)\n","    plt.scatter( transported_samples[:1024, dim_0], transported_samples[:1024, dim_1], color='blue', s=1, alpha=1)\n","    plt.title(\"plot on latent space on {}-{} coordinate plane\".format(dim_0, dim_1))\n","    filename = os.path.join(exp_dir, 'plot on latent space on {}-{} coordinate plane rescale={}iter={}.png'.format(dim_0, dim_1, scale, iter))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"vBtNJC1yMVRK"},"outputs":[],"source":["# @title plot MNIST figure (func)\n","\n","import math\n","\n","\n","def plot_two_row_src_target(x_original, x_generated, iter, ttl, flag = '0_to_T', exp_dir=example_savepath):\n","\n","    column = 10\n","    row_num = 2\n","\n","    nex = row_num * column\n","    fig, axs = plt.subplots(row_num, column, figsize=(2*column, 2*row_num))\n","    plt.subplots_adjust(wspace=0.05, hspace=0.05)  # tighter spacing\n","    fig.suptitle(ttl, fontsize=30, y=1.02)\n","\n","    labelsize_y = 20\n","    for j in range(column):\n","        if j == 0:\n","          if flag == '0_to_T':\n","            axs[0, j].set_ylabel(\"Source\", rotation='vertical', fontsize = labelsize_y)\n","          else:\n","            axs[0, j].set_ylabel(\"Source\", rotation='vertical', fontsize = labelsize_y)\n","        x0 = x_original[j]\n","        axs[0, j].imshow(x0[:, :, 0], cmap='Greys')\n","        if j == 0:\n","          if flag == '0_to_T':\n","            axs[1, j].set_ylabel(\"Transported\", rotation='vertical', fontsize = labelsize_y)\n","          else:\n","            axs[1, j].set_ylabel(\"Transported\", rotation='vertical', fontsize = labelsize_y)\n","        x1 = x_generated[j]\n","        axs[1, j].imshow(x1[:, :, 0], cmap='Greys')\n","    for i in range(axs.shape[0]):\n","        for j in range(axs.shape[1]):\n","            axs[i, j].tick_params(left=False, labelleft=False, bottom=False, labelbottom=False)\n","            axs[i, j].set_aspect('equal')\n","    filename = os.path.join(exp_dir, ttl+'_iter_{}.png'.format(iter))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","\n","def plot_mnist(images, output, n_plots, scale, iter, ttl, exp_dir=example_savepath):\n","\n","    images = images.cpu().detach().numpy()[:n_plots, :]\n","    output = output[:n_plots, :]\n","    plot_num = n_plots\n","    output = output.view(plot_num, 28, 28)\n","    output = output.cpu().detach().numpy()\n","\n","    # plot the first ten input images and then reconstructed images\n","    fig, axes = plt.subplots(nrows=2, ncols=plot_num, sharex=True, sharey=True, figsize=(25,4))\n","    # input images on top row, reconstructions on bottom\n","    for images, row in zip([images, output], axes):\n","        for img, ax in zip(images, row):\n","            ax.imshow(np.squeeze(img), cmap='gray')\n","            ax.get_xaxis().set_visible(False)\n","            ax.get_yaxis().set_visible(False)\n","    plt.title(ttl)\n","    filename = os.path.join(exp_dir, ttl+'plot_mnist_iter_{}.png'.format(iter))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","\n","def plot_mnist_tab(rownum, colnum, x, iter, ttl, exp_dir=example_savepath):\n","    nex = rownum * colnum\n","    fig, axs = plt.subplots(rownum, nex//rownum)\n","    fig.set_size_inches(4*rownum, 4*colnum)\n","\n","    for i in range(rownum):\n","      for j in range(colnum):\n","        axs[i, j].imshow(x[i*colnum+j,:, :, 0])\n","\n","    fig.suptitle(ttl, fontsize=16)\n","\n","    for i in range(axs.shape[0]):\n","        for j in range(axs.shape[1]):\n","            axs[i, j].get_yaxis().set_visible(False)\n","            axs[i, j].get_xaxis().set_visible(False)\n","            axs[i ,j].set_aspect('equal')\n","    filename = os.path.join(exp_dir, ttl+'plot_mnist_iter_{}.png'.format(iter))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","\n","import random\n","def generate_random_list(start, end, length):\n","  random_list = []\n","  for _ in range(length):\n","    random_list.append(random.randint(start, end))\n","  return random_list\n","\n","def plot_mnist_init_target(x_a, x_b, x_c, x_d, x_e, x_f, x_g, x_h, x_i, x_j, iter, row_num, ttl, flag = '0_to_T', exp_dir=example_savepath):\n","\n","    column = 10\n","    nex = row_num * column\n","    fig, axs = plt.subplots(row_num, column)\n","    fig.set_size_inches( 2*column, 2*row_num )\n","\n","    random_integers = generate_random_list(0, x_a.shape[0], row_num)\n","\n","    labelsize = 20\n","    for j in range(column):\n","        for i in range(row_num):\n","          if j == 0:\n","              axs[i, j].imshow(x_a[i, :, :, 0],  cmap='Greys')\n","              if i == 0:\n","                  axs[i, j].set_title(f'0', fontsize=labelsize)\n","          if j == 1:\n","              axs[i, j].imshow(x_b[i, :, :, 0],  cmap='Greys')\n","              if i == 0:\n","                  axs[i, j].set_title(f'1', fontsize=labelsize)\n","          if j == 2:\n","              axs[i, j].imshow(x_c[i, :, :, 0],  cmap='Greys')\n","              if i == 0:\n","                  axs[i, j].set_title(f'2', fontsize=labelsize)\n","          if j == 3:\n","              axs[i, j].imshow(x_d[i, :, :, 0],  cmap='Greys')\n","              if i == 0:\n","                 axs[i, j].set_title(f'3', fontsize=labelsize)\n","          if j == 4:\n","              axs[i, j].imshow(x_e[i, :, :, 0],  cmap='Greys')\n","              if i == 0:\n","                  axs[i, j].set_title(f'4', fontsize=labelsize)\n","          if j == 5:\n","              axs[i, j].imshow(x_f[i, :, :, 0],  cmap='Greys')\n","              if i == 0:\n","                  axs[i, j].set_title(f'5', fontsize=labelsize)\n","          if j == 6:\n","              axs[i, j].imshow(x_g[i, :, :, 0],  cmap='Greys')\n","              if i == 0:\n","                  axs[i, j].set_title(f'6', fontsize=labelsize)\n","          if j == 7:\n","              axs[i, j].imshow(x_h[i, :, :, 0],  cmap='Greys')\n","              if i == 0:\n","                  axs[i, j].set_title(f'7', fontsize=labelsize)\n","          if j == 8:\n","              axs[i, j].imshow(x_i[i, :, :, 0],  cmap='Greys')\n","              if i == 0:\n","                  axs[i, j].set_title(f'8', fontsize=labelsize)\n","          if j == 9:\n","              axs[i, j].imshow(x_j[i, :, :, 0],  cmap='Greys')\n","              if i == 0:\n","                  axs[i, j].set_title(f'9', fontsize=labelsize)\n","\n","    fig.suptitle(ttl, fontsize=30, y=0.94)\n","    plt.subplots_adjust(wspace=0.05, hspace=0.05)  # tighter spacing\n","\n","    for i in range(axs.shape[0]):\n","        for j in range(axs.shape[1]):\n","            axs[i, j].get_yaxis().set_visible(False)\n","            axs[i, j].get_xaxis().set_visible(False)\n","            axs[i ,j].set_aspect('equal')\n","    filename = os.path.join(exp_dir, ttl+'_iter_{}.png'.format(iter))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","\n","\n","\n","def plot_mnist_one_row(x, iter, ttl, exp_dir=example_savepath):\n","\n","    # assume square image\n","    s = int(math.sqrt(x.shape[1]))\n","\n","    nex = 8\n","    fig, axs = plt.subplots(2, nex//2)\n","    fig.set_size_inches(18, 9)\n","\n","    for i in range(nex//2):\n","        axs[0, i].imshow(x[i,:].reshape(s,s).detach().cpu().numpy())\n","        axs[1, i].imshow(x[ nex//2 + i , : ].reshape(s,s).detach().cpu().numpy())\n","\n","    fig.suptitle(ttl, fontsize=16)\n","\n","    for i in range(axs.shape[0]):\n","        for j in range(axs.shape[1]):\n","            axs[i, j].get_yaxis().set_visible(False)\n","            axs[i, j].get_xaxis().set_visible(False)\n","            axs[i ,j].set_aspect('equal')\n","    filename = os.path.join(exp_dir, ttl+'plot_mnist_iter_{}.png'.format(iter))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"cTInIadBMis_"},"outputs":[],"source":["# @title plot MNIST latent samples (func)\n","import matplotlib.pyplot as plt\n","\n","def plot_PCA_normalized_data_in_latent_spc(dataset, list_of_digits, dim_0, dim_1, dir=model_and_data_savepath):\n","\n","    save_path = os.path.join(dir, 'plot_latent_dim10_{}'.format(dataset))\n","    save_path = os.path.join(save_path, 'PCA')\n","    mkdir_ifnotexists(save_path)\n","\n","    plt.figure(figsize=(12,12))\n","    for i in list_of_digits:\n","        if dataset == 'fmnist':\n","            encoded_x = normalize_latent_samples0(i).cpu()\n","        elif dataset == 'mnist':\n","            encoded_x = normalize_latent_samples1(i).cpu()\n","        plt.scatter(encoded_x[:1000, dim_0], encoded_x[:1000, dim_1], s=5, alpha=1, label='{}'.format(i))\n","    plt.legend()\n","    plt.title(\"plot digits({}) on latent space {}-{} PCA dimensions (normalized)\".format(list_of_digits, dim_0, dim_1))\n","    filename = os.path.join(save_path, 'latent_SVD_{}_{}'.format(dim_0, dim_1))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","\n","# Normalize entrywise, projected to ordinary directions\n","def plot_std_normalized_data_in_latent_spc(dataset, list_of_digits, dim_0, dim_1, dir=model_and_data_savepath):\n","\n","    save_path = os.path.join(dir, 'plot_latent_dim10_{}'.format(dataset))\n","    save_path = os.path.join(save_path, 'std')\n","    mkdir_ifnotexists(save_path)\n","\n","    plt.figure(figsize=(12,12))\n","    for i in list_of_digits:\n","        if dataset == 'fmnist':\n","            encoded_x = normalize_latent_samples_using_std0(i).cpu()\n","        elif dataset == 'mnist':\n","            encoded_x = normalize_latent_samples_using_std1(i).cpu()\n","        plt.scatter(encoded_x[:1000, dim_0], encoded_x[:1000, dim_1], s=5, alpha=1, label='{}'.format(i))\n","    plt.legend()\n","    plt.title(\"plot digits({}) on latent space {}-{} dimensions (normalize using std)\".format(list_of_digits, dim_0, dim_1))\n","    filename = os.path.join(save_path, 'latent_std_{}_{}'.format(dim_0, dim_1))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","\n","# plot raw latent samples\n","def plot_raw_data_in_latent_spc(dataset, list_of_digits, dim_0, dim_1, dir=model_and_data_savepath):\n","\n","    save_path = os.path.join(dir, 'plot_latent_dim10_{}'.format(dataset))\n","    save_path = os.path.join(save_path, 'raw')\n","    mkdir_ifnotexists(save_path)\n","\n","    plt.figure(figsize=(12,12))\n","    for i in list_of_digits:\n","        encoded_x = np.load(os.path.join(dir, 'encoded_image_{}_{}.npy'.format(i, dataset)))\n","        encoded_x = torch.tensor(encoded_x).cpu()\n","        plt.scatter(encoded_x[:1000, dim_0], encoded_x[:1000, dim_1], s=5, alpha=1, label='{}'.format(i))\n","    plt.legend()\n","    plt.title(\"plot digits({}) on latent space {}-{} dimensions (raw data)\".format(list_of_digits, dim_0, dim_1))\n","    filename = os.path.join(save_path, 'latent_std_{}_{}'.format(dim_0, dim_1))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"nF8NMdqOMis_"},"outputs":[],"source":["# @title plot OT map and samples (func)\n","\n","def plot_latent_samples_n_OT_map(target_data, init_pnts, transported_pnts, dim_0, dim_1, iter, normalization='PCA', flag='0_to_T', digits='05', exp_dir=example_savepath):\n","\n","    target_data = target_data.detach().cpu().numpy()\n","    init_pnts = init_pnts.detach().cpu().numpy()\n","    transported_pnts = transported_pnts.detach().cpu().numpy()\n","\n","    nn = 200\n","\n","    plt.style.use('dark_background')\n","    plt.figure(figsize=(12,12))\n","    if flag == '0_to_T':\n","        plt.scatter(init_pnts[:nn, dim_0], init_pnts[:nn, dim_1], color='mediumspringgreen', s=5, alpha=1, label='origin')\n","        plt.scatter(transported_pnts[:nn, dim_0], transported_pnts[:nn, dim_1], color='magenta', s=5, alpha=1, label='transported')\n","    else:\n","        plt.scatter(init_pnts[:nn, dim_0], init_pnts[:nn, dim_1], color='mediumspringgreen', s=5, alpha=1, label='origin')\n","        plt.scatter(transported_pnts[:nn, dim_0], transported_pnts[:nn, dim_1], color='magenta', s=5, alpha=1, label='transported')\n","    for k in range(nn):\n","        plt.plot([init_pnts[k, dim_0], transported_pnts[k, dim_0]], [init_pnts[k, dim_1], transported_pnts[k, dim_1]], c='gray', alpha=0.5, linewidth=0.8)\n","    plt.tick_params(axis='both', which='major', labelsize=20)\n","    plt.legend(fontsize=20)\n","    if normalization == 'PCA':\n","        plt.title(\"Transport map on {}-{} PCA latent plane\".format(flag, dim_0, dim_1), fontsize=40)\n","    else:\n","        plt.title(\"Transport map on {}-{} latent plane\".format(flag, dim_0, dim_1)+flag, fontsize=30)\n","    if normalization == 'PCA':\n","        filename = os.path.join(exp_dir, '[{}, {}] plot OT maps on {}-{} PCA directions iter={}.png'.format(flag, digits, dim_0, dim_1, iter))\n","    else:\n","        filename = os.path.join(exp_dir, '[{}, {}] plot OT maps on {}-{} dims (normalized using std) iter={}.png'.format(flag, digits, dim_0, dim_1, iter))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","    plt.style.use('dark_background')\n","    plt.figure(figsize=(12,12))\n","    plt.scatter(target_data[:1024, dim_0], target_data[:1024, dim_1], color='blue', s=1, alpha=1, label='target')\n","    plt.scatter(transported_pnts[:1024, dim_0], transported_pnts[:1024, dim_1], color='magenta', s=1, alpha=1, label='generate')\n","    plt.tick_params(axis='both', which='major', labelsize=20)\n","    plt.legend(fontsize=20)\n","    if normalization == 'PCA':\n","        plt.title(\"Latent samples on {}-{} PCA plane\".format(flag, digits, dim_0, dim_1), fontsize=30)\n","    else:\n","        plt.title(\"Latent samples on {}-{} plane\".format(flag, digits, dim_0, dim_1), fontsize=30)\n","    if normalization == 'PCA':\n","        filename = os.path.join(exp_dir, '[{}, {}] plot samples on {}-{} PCA directions iter={}.png'.format(flag, digits, dim_0, dim_1, iter))\n","    else:\n","        filename = os.path.join(exp_dir, '[{}, {}] plot samples on {}-{} dims (normalized using std) iter={}.png'.format(flag, digits, dim_0, dim_1, iter))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","    plt.style.use('dark_background')\n","    plt.figure(figsize=(12,12))\n","    plt.scatter(transported_pnts[:1024, dim_0], transported_pnts[:1024, dim_1], color='deepskyblue', s=1, alpha=1, label='computed')\n","    plt.tick_params(axis='both', which='major', labelsize=20)\n","    plt.legend(fontsize=20)\n","    if normalization == 'PCA':\n","        plt.title(\"Latent samples (only generated) on {}-{} PCA latent plane\".format(flag, digits, dim_0, dim_1) + flag, fontsize=30)\n","    else:\n","        plt.title(\"Latent samples (only generated) on {}-{} latent plane\".format(flag, digits, dim_0, dim_1) + flag, fontsize=30)\n","    if normalization == 'PCA':\n","        filename = os.path.join(exp_dir, '[{}, {}] plot samples (ONLY COMPUTED) on {}-{} PCA directions iter={}.png'.format(flag, digits, dim_0, dim_1, iter))\n","    else:\n","        filename = os.path.join(exp_dir, '[{}, {}] plot samples (ONLY COMPUTED) on {}-{} dims (normalized using std) iter={}.png'.format(flag, digits, dim_0, dim_1, iter))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","\n","\n","\n","def plot_3D_latent_samples_n_OT_map(target_data, init_pnts, transported_pnts, dim_0, dim_1, dim_2, iter, normalization='PCA', exp_dir=example_savepath):\n","\n","    target_data = target_data.detach().cpu().numpy()\n","    init_pnts = init_pnts.detach().cpu().numpy()\n","    transported_pnts = transported_pnts.detach().cpu().numpy()\n","\n","    plt.style.use('dark_background')\n","\n","    # 3D plot\n","    fig = plt.figure(figsize=(20, 20))\n","    ax = fig.add_subplot(projection='3d')\n","\n","    ax.scatter(target_data[:1024, dim_0], target_data[:1024, dim_1], target_data[:1024, dim_2], color='magenta', s=10, label='target')\n","    ax.scatter(transported_pnts[:1024, dim_0], transported_pnts[:1024, dim_1], transported_pnts[:1024, dim_2], color='deepskyblue', s=10, label='computed')\n","\n","    ax.set_xlabel('dim {}'.format(dim_0))\n","    ax.set_ylabel('dim {}'.format(dim_1))\n","    ax.set_zlabel('dim {}'.format(dim_2))\n","\n","    plt.legend(fontsize=20)\n","    if normalization == 'PCA':\n","        plt.title(\"plot on {}-{}-{} PCA direction\".format(dim_0, dim_1, dim_2), fontsize=40)\n","    else:\n","        plt.title(\"plot on {}-{}-{} dimsional plane\".format(dim_0, dim_1, dim_2), fontsize=40)\n","    if normalization == 'PCA':\n","        filename = os.path.join(exp_dir, 'plot 3D samples on {}-{}-{} PCA directions iter={}.png'.format(dim_0, dim_1, dim_2, iter))\n","    else:\n","        filename = os.path.join(exp_dir, 'plot 3D samples on {}-{}-{} dims (normalized using std) iter={}.png'.format(dim_0, dim_1, dim_2, iter))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","\n","    # 3D plot with OT map\n","    fig = plt.figure(figsize=(20, 20))\n","    ax = fig.add_subplot(projection='3d')\n","\n","    ax.scatter(transported_pnts[:200, dim_0], transported_pnts[:200, dim_1], transported_pnts[:200, dim_2], color='blue', s=5, label='transported')\n","    ax.scatter(init_pnts[:200, dim_0], init_pnts[:200, dim_1], init_pnts[:200, dim_2], color='mediumspringgreen', s=5, label='initial')\n","    for k in range(200):\n","        plt.plot([init_pnts[k, dim_0], transported_pnts[k, dim_0]], [init_pnts[k, dim_1], transported_pnts[k, dim_1]], [init_pnts[k, dim_2], transported_pnts[k, dim_2]], c='cyan', alpha=0.5, linewidth=0.8)\n","    plt.legend(fontsize=20)\n","    if normalization == 'PCA':\n","        plt.title(\"plot on {}-{}-{} PCA direction\".format(dim_0, dim_1, dim_2), fontsize=40)\n","    else:\n","        plt.title(\"plot on {}-{}-{} dimensional plane\".format(dim_0, dim_1, dim_2), fontsize=40)\n","    if normalization == 'PCA':\n","        filename = os.path.join(exp_dir, 'plot 3D OT maps on {}-{}-{} PCA directions iter={}.png'.format(dim_0, dim_1, dim_2, iter))\n","    else:\n","        filename = os.path.join(exp_dir, 'plot 3D OT maps on {}-{}-{} dims (normalized using std) iter={}.png'.format(dim_0, dim_1, dim_2, iter))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"TGQ0Vf3ZMis-"},"source":["\n","\n","---\n","\n","## Define the Metric/Distance functions: MMD\n","\n","---\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"E3yQOhJTMis-"},"outputs":[],"source":["# @title MMD\n","\n","# Assume the kernel function is symmetric: k(x1, x2) = k(x2, x1)\n","# MMD(p0, p1) = \\iint k(x1, x2)(p1(x1) - p0(x1))(p1(x2) - p0(x2)) dx1dx2\n","#             = \\iint k(x1, x2)p1(x1)p1(x2) dx1dx2 - 2 \\iint k(x1, x2)p0(x1)p1(x2) dx1dx2 + \\iint k(x1, x2)p0(x1)p0(x2) dx1dx2\n","# try RBF kernel: k(x1, x2) = \\exp( - |x1-x2|^2/(2\\sigma^2))\n","# size of x: N by dim x=[ x_1^t ]\n","#                       [ ..... ]\n","#                       [ x_N^t ]\n","def MMD(x, y, k_sigma=1):\n","\n","    N = x.size()[0]\n","    M = y.size()[0]\n","\n","    xsqr = torch.sum(x*x, 1).unsqueeze(1)\n","    # Xsqr_Xsqr = (|x_i|^2 + |x_j|^2)_{ij}\n","    Xsqr_Xsqr = xsqr + torch.transpose(xsqr, 0, 1)\n","    # X_dot_X = (x_i^t \\cdot x_j)_{ij}\n","    X_dot_X = torch.matmul(x, torch.transpose(x, 0, 1))\n","    # K_00 \\approx iint k(x1, x2) p0(x1)p0(x2) dx1dx2\n","    K_00 = torch.exp(-(Xsqr_Xsqr - 2 * X_dot_X)/(2 * k_sigma * k_sigma))\n","\n","    ysqr = torch.sum(y*y, 1).unsqueeze(1)\n","    # Ysqr_Ysqr = (|y_i|^2 + |y_j|^2)_{ij}\n","    Ysqr_Ysqr = ysqr + torch.transpose(ysqr, 0, 1)\n","    # Y_dot_Y = (y_i^t \\cdot y_j)_{ij}\n","    Y_dot_Y = torch.matmul(y, torch.transpose(y, 0, 1))\n","    # K_11 \\approx iint k(y1, y2) p1(y1)p0(y2) dy1dy2\n","    K_11 = torch.exp(-(Ysqr_Ysqr - 2 * Y_dot_Y)/(2 * k_sigma * k_sigma))\n","\n","    # Xsqr_Ysqr = (|x_i|^2 + |y_j|^2)_{ij}\n","    Xsqr_Ysqr = xsqr + torch.transpose(ysqr, 0, 1)\n","    # X_dot_Y = (x_i^t \\cdot y_j)_{ij}\n","    X_dot_Y = torch.matmul(x, torch.transpose(y, 0, 1))\n","    # K_01 \\approx iint k(x, y) p0(x)p1(y) dxdy\n","    K_01 = torch.exp(-(Xsqr_Ysqr - 2 * X_dot_Y) / (2 * k_sigma * k_sigma))\n","\n","    # use V-statistics, i.e., average value\n","    # cf. eq (5) https://jmlr.csail.mit.edu/papers/volume13/gretton12a/gretton12a.pdf\n","    MMDsqr = torch.sum(K_00)/(N*N) - 2 * torch.sum(K_01)/(N*M) + torch.sum(K_11)/(M*M)\n","\n","    return MMDsqr\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"ksvXGVqBMis-"},"outputs":[],"source":["# @title MMD use k(z) = -sqrt{||z||^2+0.001}\n","# k(x, y) = k(x - y)\n","\n","# Assume the kernel function is symmetric: k(x1, x2) = k(x2, x1)\n","# MMD(p0, p1) = \\iint k(x1, x2)(p1(x1) - p0(x1))(p1(x2) - p0(x2)) dx1dx2\n","#             = \\iint k(x1, x2)p1(x1)p1(x2) dx1dx2 - 2 \\iint k(x1, x2)p0(x1)p1(x2) dx1dx2 + \\iint k(x1, x2)p0(x1)p0(x2) dx1dx2\n","# size of x: N by dim x=[ x_1^t ]\n","#                       [ ..... ]\n","#                       [ x_N^t ]\n","\n","\n","# x size:N * dim; Y size:M * dim\n","# the function returns N * M tensor whose i-j entry: ||x_i - y_j||\n","def dist_XY(X, Y):\n","\n","    eps = 0.001\n","\n","    N = X.size()[0]\n","    M = Y.size()[0]\n","    x = X.T\n","    y = Y.T\n","\n","    x = x.unsqueeze(2)\n","    x_repeat = x.repeat(1, 1, M)\n","    y = y.unsqueeze(2)\n","    y_repeat = y.repeat(1, 1, N)\n","    y_repeat_t = y_repeat.transpose(1, 2)\n","    x_y_matrix = x_repeat - y_repeat_t\n","    dist_matrix_x_y = torch.sqrt(torch.sum(x_y_matrix * x_y_matrix, 0) + eps)\n","    # dist_matrix_x_y = torch.sum(x_y_matrix * x_y_matrix, 0)\n","\n","    return dist_matrix_x_y\n","\n","\n","def MMD_negnorm(x, y, k_sigma=1):\n","\n","    M = x.size()[0]\n","    N = y.size()[0]\n","\n","    dist_x_y = dist_XY(x, y)\n","    B = -torch.sum(dist_x_y) / (M * N)\n","\n","    dist_x_x = dist_XY(x, x)\n","    A = -torch.sum(dist_x_x) / (M * M)\n","\n","    dist_y_y = dist_XY(y, y)\n","    C = -torch.sum(dist_y_y) / (N * N)\n","\n","    return A/2 - B + C/2\n","\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"-BtJMrnUMis-"},"outputs":[],"source":["# @title  sliced MMD (func)\n","\n","def slicedMMD(x, y, num):\n","  sum_mmd0 = 0\n","  for d in range(x.size()[1]-1):\n","    sum_mmd0 += MMD(x[:num, [d, d+1]], y[:num, [d, d+1]])\n","\n","  return sum_mmd0\n","\n"]},{"cell_type":"markdown","metadata":{"id":"dyDUTgV2wfjG"},"source":["\n","\n","---\n","\n","\n","## set up classifier (MNIST)\n","\n","\n","---\n","\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"MOi-_JgBxu_l"},"outputs":[],"source":["def test_accuracy(classifier, list_of_imgs, target_digit):\n","  with torch.no_grad():\n","    correct = 0\n","    output = classifier(list_of_imgs)  # output  tesnor size = N (6130) \\times number of classes (10)\n","    _, predicted = torch.max(output, 1)\n","    total = list_of_imgs.size()[0]\n","    # print(total)\n","    correct += (predicted == target_digit).sum().item()\n","    # print(correct)\n","    accuracy_rate = 100 * correct / total\n","  return accuracy_rate, predicted\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"mg_24V7Ip3Q3","colab":{"base_uri":"https://localhost:8080/"},"outputId":"8e013c2c-425c-4857-9ef9-30d267ab1763","executionInfo":{"status":"ok","timestamp":1758769042261,"user_tz":240,"elapsed":3251,"user":{"displayName":"Shu Liu","userId":"06233144459630989636"}}},"outputs":[{"output_type":"execute_result","data":{"text/plain":["<All keys matched successfully>"]},"metadata":{},"execution_count":49}],"source":["model_mnist = models.resnet18()\n","\n","model_mnist.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n","model_mnist.fc =  nn.Linear(in_features=512, out_features=10, bias=True)\n","\n","save_path = os.getcwd()\n","file_name = os.path.join(save_path, 'model_classifier_28_by_28_MNIST.pt')\n","state_dict = torch.load(file_name)\n","model_mnist.load_state_dict(state_dict, strict=False)"]},{"cell_type":"markdown","metadata":{"id":"lEsefj0Ff559"},"source":["\n","\n","---\n","## Main algorithm (OTHJ)\n","---\n","\n"]},{"cell_type":"code","source":["# @title define the EMA\n","\n","\n","import copy\n","\n","class EMA:\n","    def __init__(self, model, beta=0.99):\n","        self.beta = beta\n","        self.shadow = copy.deepcopy(model).state_dict()\n","        for k in self.shadow.keys():\n","            self.shadow[k].requires_grad = False  # EMA is not trainable\n","\n","    def update(self, model):\n","        with torch.no_grad():\n","            model_state = model.state_dict()\n","            for k, v in model_state.items():\n","                if v.dtype.is_floating_point:  # only update float tensors\n","                    self.shadow[k] = self.beta * self.shadow[k] + (1 - self.beta) * v\n","\n","    def apply_shadow(self, model):\n","        model.load_state_dict(self.shadow)\n","\n"],"metadata":{"cellView":"form","id":"Lg3jcTYKyrFL"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title OT HJ implicit solver\n","import torch\n","import numpy as np\n","import torch.nn.functional\n","import math\n","import os\n","from datetime import datetime\n","from models_Resnet import gradient, ImplicitNet\n","import utils.general as utils\n","import matplotlib.pyplot as plt\n","device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n","\n","\n","normalization_type = 'std'\n","\n","effective_dim = effective_dim_using_std0\n","init_data_entire = init_data_gpu_std\n","init_data_list = init_data_list_std\n","cov_init = cov_init_data_std\n","mean_init = mean_init_data_std\n","target_data_entire = target_data_gpu_std\n","target_data_list = target_data_list_std\n","cov_target = cov_target_data_std\n","mean_target = mean_target_data_std\n","init_data_list_for_test = init_data_list_std_for_test\n","target_data_list_for_test = target_data_list_std_for_test\n","\n","def random_sampler(N, dim=effective_dim, T=1, fix_T=False, T0=0): # T0: 0 or T\n","    if fix_T:\n","        ts = T0 * torch.ones(N,1)\n","    else:\n","        ts = T * torch.rand(N,1)\n","    xs = torch.randn(N, dim)\n","    pnts = torch.cat((ts, xs), 1)\n","    return torch.tensor(pnts, dtype=torch.float32, requires_grad=True)\n","\n","## -------------------------------------------------------------------------------\n","## Configurations\n","## -------------------------------------------------------------------------------\n","gpu_id = 0\n","dim_latent = Dim\n","\n","\n","###############################################################################\n","problem_initial = \"class_OT_FashionMNIST\"\n","problem_initial = problem_initial.lower()\n","###############################################################################\n","\n","###############################################################################\n","problem = \"class_OT_MNIST\"\n","problem = problem.lower()\n","###############################################################################\n","\n","\n","regularizer_type = ['implicithjxt0t', 'mmd_negnorm_0t']\n","regularizer_coord = [1, 10000]  # [1, 800]\n","\n","with_OTloss = False\n","weight_OTloss = 0.01\n","\n","iter_0 = 0 # starting point\n","N = 2000 # collocation pnts for HJ\n","batch_size = 2000 # batch size for  sample transport\n","if target_data_list[0].shape[0] < batch_size:\n","    batch_size = target_data_list[0].shape[0]\n","\n","epochs = 70000\n","val_frequency =  2000\n","plot_detail_ot =  epochs\n","plot_detail_mnist =  10000\n","\n","num_mnist_digits_to_plt = 10\n","\n","dim = effective_dim # dimension of spatial domain\n","T = 1 # terminal time\n","\n","fixing_T_and_0_HJ_loss = False\n","\n","batch_size_OT = 4000\n","sub_batch_size = 400\n","\n","NN_dims = [128, 128, 128, 128, 128]\n","network_sol = ImplicitNet(d_in=dim+1, dims=NN_dims).to(device)\n","ema = EMA(network_sol, beta=0.99)\n","optimizer = torch.optim.Adam(params=network_sol.parameters(), lr=5 * 1E-4)\n","scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5*1E-4, gamma=0.99)\n","\n","network_sol_EMA = ImplicitNet(d_in=dim+1, dims=NN_dims).to(device)\n","\n","loss_min = 1E10\n","\n","save_model = True\n","\n","Plot_in_latent = True # False\n","Generate_images = True # False\n","\n","exp_dir = os.path.join(example_savepath, 'Classed_OT: {} to {}'.format(problem_initial, problem))\n","utils.mkdir_ifnotexists(exp_dir)\n","\n","# If need to load existing model, uncomment the following two lines:\n","# state_dict = torch.load(os.path.join(exp_dir, 'weight_sol_ep{}.pth'.format(iter_0)))\n","# network_sol.load_state_dict(state_dict[\"state_dict\"], strict=False)\n","\n","timestamp = 'Normalization_type={}, batchsize={}, NNsize={}, HJ_loss_num_colocation_pnts={}  {}'.format(normalization_type, batch_size, NN_dims, N, datetime.now().strftime('%Y_%m_%d_%H_%M_%S'))\n","exp_dir = os.path.join(exp_dir, timestamp)\n","utils.mkdir_ifnotexists(exp_dir)\n","regularizer_type = list(map(str.lower,regularizer_type))\n","assert len(regularizer_type) == len(regularizer_coord), 'match regularizer coordinates'\n","regularizer_index = {t:i for i, t in enumerate(regularizer_type)}\n","\n","if save_model:\n","    utils.mkdir_ifnotexists(os.path.join(exp_dir,'model'))\n","if torch.cuda.is_available() and gpu_id > -1:\n","    device = torch.device(gpu_id)\n","else:\n","    device = torch.device('cpu')\n","utils.set_random_seed(5884)\n","\n","MMD_loss_list_0T = []\n","MMD_loss_list_T0 = []\n","slicedMMD_loss_list = []\n","OT_dist_0T_list = []\n","OT_dist_T0_list = []\n","FID_list_0T = []\n","FID_list_T0 = []\n","Implicit_HJ_loss_list = []\n","accuracy_0T_list = []\n","accuracy_0T_EMA_list = []\n","for epoch in range(iter_0, epochs+iter_0):\n","    if epoch % 50 == 0:\n","        print(\"Iteration: {}\".format(epoch))\n","    ## -------------------------------------------------------------------------------\n","    ## Compute losses\n","    ## -------------------------------------------------------------------------------\n","    loss = torch.tensor(0.).to(device)\n","    losses = {'Train loss' : None}\n","\n","    if 'implicithjxt0t' in regularizer_type:\n","        if fixing_T_and_0_HJ_loss:\n","            pnts = random_sampler(N, fix_T=True, T0=T).to(device)\n","        else:\n","            pnts = random_sampler(N, dim, T=T).to(device)\n","        pred_sol = network_sol(pnts)\n","        grad_pred_sol = gradient(pnts,pred_sol)[:,1:]\n","        init_x = pnts[:,1:] - pnts[:,[0]]*grad_pred_sol\n","        init_xt = torch.cat((torch.zeros((N,1)).to(device), init_x), 1)\n","        fwd_loss_implicithj = ((pred_sol - 0.5*pnts[:,[0]]*torch.sum(grad_pred_sol*grad_pred_sol,dim=1,keepdim=True) - network_sol(init_xt))**2).mean()\n","\n","        if fixing_T_and_0_HJ_loss:\n","            pnts = random_sampler(N, fix_T=True, T0=0).to(device)\n","        else:\n","            pnts = random_sampler(N, dim, T=T).to(device)\n","        pred_sol = network_sol(pnts)\n","        grad_pred_sol = gradient(pnts,pred_sol)[:,1:]\n","        terminal_x = pnts[:,1:] + (T - pnts[:,[0]])*grad_pred_sol\n","        terminal_xt = torch.cat((T * torch.ones((N,1)).to(device), terminal_x), 1)\n","        bckwd_loss_implicithj = ((pred_sol + 0.5*(T - pnts[:,[0]])*torch.sum(grad_pred_sol*grad_pred_sol,dim=1,keepdim=True) - network_sol(terminal_xt))**2).mean()\n","\n","        losses['Implicit HJ loss'] = fwd_loss_implicithj.item() + bckwd_loss_implicithj.item()\n","        loss_implicithj = fwd_loss_implicithj + bckwd_loss_implicithj\n","        loss += regularizer_coord[regularizer_index['implicithjxt0t']] * loss_implicithj\n","        Implicit_HJ_loss_list.append(loss_implicithj.cpu().detach().numpy())\n","\n","    if 'mmd_negnorm_0t' in regularizer_type:\n","        for i in range(10):\n","            init_data = init_data_list[i]\n","            target_data = target_data_list[i]\n","            init_data_indices_for_0T = torch.tensor(np.random.choice(init_data.shape[0], sub_batch_size, False))\n","            init_data_pnts_for_0T = torch.tensor(init_data[init_data_indices_for_0T,:], dtype=torch.float32, requires_grad=True).to(device)\n","            init_spatialtemporal_pnts_for_0T = torch.tensor(torch.cat((torch.zeros((sub_batch_size,1)).to(device), init_data_pnts_for_0T), 1), requires_grad=True)\n","            pred_sol_for_0T = network_sol(init_spatialtemporal_pnts_for_0T)\n","            transported_pnts_to_T = init_spatialtemporal_pnts_for_0T[:, 1:] + T * gradient(init_spatialtemporal_pnts_for_0T, pred_sol_for_0T)[:,1:]\n","            data_indices_for_0T = torch.tensor(np.random.choice(target_data.shape[0], sub_batch_size, False))\n","            data_pnts_for_0T = torch.tensor(target_data[data_indices_for_0T,:],dtype=torch.float32, requires_grad=True).to(device)\n","            loss_MMD_T = MMD_negnorm(transported_pnts_to_T, data_pnts_for_0T)\n","            if epoch % 10 == 0:\n","                MMD_loss_list_0T.append(loss_MMD_T.item())\n","\n","            data_indices_for_T0 = torch.tensor(np.random.choice(target_data.shape[0], sub_batch_size, False))\n","            data_pnts_for_T0 = torch.tensor(target_data[data_indices_for_T0,:], dtype=torch.float32, requires_grad=True).to(device)\n","            spatialtemporal_pnts = torch.tensor(torch.cat((T * torch.ones((sub_batch_size,1)).to(device), data_pnts_for_T0), 1), requires_grad=True)\n","            pred_sol_for_T0 = network_sol(spatialtemporal_pnts)\n","            transported_pnts_to_0 = spatialtemporal_pnts[:,1:] - T * gradient(spatialtemporal_pnts, pred_sol_for_T0)[:,1:]\n","            init_data_indices_for_T0 = torch.tensor(np.random.choice(init_data.shape[0], sub_batch_size, False))\n","            init_data_pnts_for_T0 = torch.tensor(init_data[init_data_indices_for_T0,:],dtype=torch.float32, requires_grad=True).to(device)\n","            loss_MMD_0 = MMD_negnorm(transported_pnts_to_0, init_data_pnts_for_T0)\n","            if epoch % 10 == 0:\n","                MMD_loss_list_T0.append(loss_MMD_0.item())\n","\n","            loss_MMD = loss_MMD_0 + loss_MMD_T\n","            loss += regularizer_coord[regularizer_index['mmd_negnorm_0t']] * loss_MMD\n","\n","    ## -------------------------------------------------------------------------------\n","    ## Update parameters\n","    ## -------------------------------------------------------------------------------\n","    optimizer.zero_grad()\n","    loss.backward()\n","    optimizer.step()\n","    ema.update(network_sol)\n","\n","    # ## -------------------------------------------------------------------------------\n","    # ## Comput OT distance. T * \\int |\\nabla u(x)|^2/2 \\rho_0(x) dx\n","    # ## -------------------------------------------------------------------------------\n","    # OT_loss_0T = torch.tensor([0.0]).to(device)\n","    # for i in range(5):\n","    #     init_data = init_data_list[i]\n","    #     init_data_indices_for_0T = torch.tensor(np.random.choice(init_data.shape[0], batch_size_OT, False))\n","    #     init_data_pnts_for_0T = torch.tensor(init_data[init_data_indices_for_0T,:], dtype=torch.float32, requires_grad=True)\n","    #     init_spatialtemporal_pnts_for_0T = torch.tensor(torch.cat((torch.zeros((batch_size_OT,1)).to(device), init_data_pnts_for_0T), 1), requires_grad=True).to(device)\n","    #     pred_sol_for_0T = network_sol(init_spatialtemporal_pnts_for_0T)\n","    #     displacement_0T = gradient(init_spatialtemporal_pnts_for_0T, pred_sol_for_0T)[:,1:]\n","    #     OT_loss_0T = OT_loss_0T + T * torch.sum(displacement_0T * displacement_0T, 1).mean() / 2\n","    # OT_dist_0T_list.append(OT_loss_0T.cpu().detach())\n","    # if epoch % 100 == 0:\n","    #     print(\"OT distance from 0 to T: {}\".format(OT_loss_0T))\n","\n","    # OT_loss_T0 = torch.tensor([0.0]).to(device)\n","    # for i in range(5):\n","    #     target_data = target_data_list[i]\n","    #     data_indices_for_T0 = torch.tensor(np.random.choice(target_data.shape[0], batch_size_OT, False))\n","    #     data_pnts_for_T0 = torch.tensor(target_data[data_indices_for_T0,:], dtype=torch.float32, requires_grad=True)\n","    #     spatialtemporal_pnts = torch.tensor(torch.cat((T * torch.ones((batch_size_OT,1)).to(device), data_pnts_for_T0), 1), requires_grad=True).to(device)\n","    #     pred_sol_for_T0 = network_sol(spatialtemporal_pnts)\n","    #     displacement_T0 = gradient(spatialtemporal_pnts, pred_sol_for_T0)[:,1:]\n","    #     OT_loss_T0 = OT_loss_T0 + T * torch.sum(displacement_T0 * displacement_T0, 1).mean() / 2\n","    # OT_dist_T0_list.append(OT_loss_T0.cpu().detach())\n","    # if epoch % 100 == 0:\n","    #     print(\"OT distance from T to 0: {}\".format(OT_loss_T0))\n","\n","    ## -------------------------------------------------------------------------------\n","    ## Validation (on EMA model)\n","    ## -------------------------------------------------------------------------------\n","    # load the EMA params to model\n","    ema.apply_shadow(network_sol_EMA)\n","    if epoch % val_frequency == 0:\n","        # Check accucy on test set (OT:0 ------> T)\n","        plot_latent_samples_list = []\n","        for i in range(10):\n","            init_data = init_data_list_for_test[i]\n","            sample_size = init_data.size()[0]\n","            init_data_pnts = init_data[:sample_size, :].detach().clone().requires_grad_(True)\n","            init_spatialtemporal_pnts = torch.cat((torch.zeros((sample_size,1)).to(device), init_data_pnts), 1).detach().clone().requires_grad_(True).to(device)\n","            pred_sol = network_sol_EMA(init_spatialtemporal_pnts)\n","            transported_pnts_to_T = init_spatialtemporal_pnts[:, 1:] + T * gradient(init_spatialtemporal_pnts, pred_sol)[:,1:]\n","            if normalization_type == 'PCA':\n","                recovered_generated_target_latent_samples = recover_latent_samples_using_PCA1(transported_pnts_to_T)\n","            else:\n","                recovered_generated_target_latent_samples = recover_latent_samples_using_std1(transported_pnts_to_T)\n","            plot_latent_samples_list.append(recovered_generated_target_latent_samples)\n","\n","        plot_img_list = []\n","        for latent_sample in plot_latent_samples_list:\n","            decoded_generated_target = trained_decoder1(latent_sample.cpu().detach().numpy())\n","            plot_img_list.append(decoded_generated_target)\n","\n","        # check accuracy\n","        ave_accuracy = 0.0\n","        model_mnist.eval()\n","        for idx in range(10):\n","            reconstructed_img_vae = plot_img_list[idx]\n","            reconstructed_img_vae = 2.0 * reconstructed_img_vae - 1.0\n","            reconstructed_img_vae = np.array(reconstructed_img_vae)\n","            digit_idx = torch.from_numpy(reconstructed_img_vae)\n","            digit_idx = digit_idx.squeeze()\n","            digit_idx = digit_idx.unsqueeze(1)\n","            accuracy, predicted = test_accuracy(model_mnist, digit_idx, idx)\n","            ave_accuracy += accuracy\n","        ave_accuracy /= 10\n","        print(\"=================================================================================================\")\n","        print(\"EMA: Epoch={} Average accuracy on transporting fMNIST to MNIST [on test dataset]: {}\".format(epoch, ave_accuracy))\n","        accuracy_0T_EMA_list.append(ave_accuracy)\n","        print(\"=================================================================================================\")\n","\n","    ## -------------------------------------------------------------------------------\n","    ## Validation (on original model)\n","    ## -------------------------------------------------------------------------------\n","    if epoch % val_frequency == 0:\n","\n","        # Check accucy on test set (OT:0 ------> T)\n","        plot_latent_samples_list = []\n","        for i in range(10):\n","            init_data = init_data_list_for_test[i]\n","            sample_size = init_data.size()[0]\n","            init_data_pnts = init_data[:sample_size, :].detach().clone().requires_grad_(True)\n","            init_spatialtemporal_pnts = torch.cat((torch.zeros((sample_size,1)).to(device), init_data_pnts), 1).detach().clone().requires_grad_(True).to(device)\n","            pred_sol = network_sol(init_spatialtemporal_pnts)\n","            transported_pnts_to_T = init_spatialtemporal_pnts[:, 1:] + T * gradient(init_spatialtemporal_pnts, pred_sol)[:,1:]\n","            if normalization_type == 'PCA':\n","                recovered_generated_target_latent_samples = recover_latent_samples_using_PCA1(transported_pnts_to_T)\n","            else:\n","                recovered_generated_target_latent_samples = recover_latent_samples_using_std1(transported_pnts_to_T)\n","            plot_latent_samples_list.append(recovered_generated_target_latent_samples)\n","\n","        plot_img_list = []\n","        for latent_sample in plot_latent_samples_list:\n","            decoded_generated_target = trained_decoder1(latent_sample.cpu().detach().numpy())\n","            plot_img_list.append(decoded_generated_target)\n","\n","        # check accuracy\n","        ave_accuracy = 0.0\n","        model_mnist.eval()\n","        for idx in range(10):\n","            reconstructed_img_vae = plot_img_list[idx]\n","            reconstructed_img_vae = 2.0 * reconstructed_img_vae - 1.0\n","            reconstructed_img_vae = np.array(reconstructed_img_vae)\n","            digit_idx = torch.from_numpy(reconstructed_img_vae)\n","            digit_idx = digit_idx.squeeze()\n","            digit_idx = digit_idx.unsqueeze(1)\n","            accuracy, predicted = test_accuracy(model_mnist, digit_idx, idx)\n","            ave_accuracy += accuracy\n","        ave_accuracy /= 10\n","        print(\"=================================================================================================\")\n","        print(\"Original model: Epoch={} Average accuracy on transporting fMNIST to MNIST [on test dataset]: {}\".format(epoch, ave_accuracy))\n","        accuracy_0T_list.append(ave_accuracy)\n","        print(\"=================================================================================================\")\n","\n","        # Some plots (OT:0 ------> T)\n","        plot_latent_samples_list = []\n","        for i in range(10):\n","            init_data = init_data_list[i]\n","            sample_size = 1000\n","            init_data_pnts = init_data[:sample_size, :].detach().clone().requires_grad_(True)\n","            init_spatialtemporal_pnts = torch.cat((torch.zeros((sample_size,1)).to(device), init_data_pnts), 1).detach().clone().requires_grad_(True).to(device)\n","            pred_sol = network_sol(init_spatialtemporal_pnts)\n","            transported_pnts_to_T = init_spatialtemporal_pnts[:, 1:] + T * gradient(init_spatialtemporal_pnts, pred_sol)[:,1:]\n","            if normalization_type == 'PCA':\n","                recovered_generated_target_latent_samples = recover_latent_samples_using_PCA1(transported_pnts_to_T)\n","            else:\n","                recovered_generated_target_latent_samples = recover_latent_samples_using_std1(transported_pnts_to_T)\n","            plot_latent_samples_list.append(recovered_generated_target_latent_samples)\n","\n","            if Plot_in_latent:\n","                target_data_for_test = target_data_list_for_test[i]\n","                data_pnts = target_data_for_test\n","                # plot in first 6 directions\n","                if i < 9:\n","                    digits_string = '{}{}'.format(i, i+1)\n","                else:\n","                    digits_string = '{}{}'.format(i, 0)\n","                if (epoch+1) % plot_detail_ot == 0:\n","                    for index in range(6):\n","                        plot_latent_samples_n_OT_map(torch.tensor(data_pnts), torch.tensor(init_data_pnts), transported_pnts_to_T, index, index+1, epoch, normalization_type, '0_to_T', digits_string, exp_dir)\n","\n","        plot_img_list = []\n","        for latent_sample in plot_latent_samples_list:\n","            decoded_generated_target = trained_decoder1(latent_sample.cpu().detach().numpy())\n","            plot_img_list.append(decoded_generated_target)\n","        if Generate_images:\n","            plot_mnist_init_target(plot_img_list[0], plot_img_list[1], plot_img_list[2], plot_img_list[3], plot_img_list[4], plot_img_list[5], plot_img_list[6], plot_img_list[7], plot_img_list[8], plot_img_list[9], epoch, 10, \"[0 to T] generated {} digits (conditioned on class, iter {})\".format(rho1_dataset, iter_0), '0_to_T', exp_dir=exp_dir)\n","\n","\n","    # save accuracy list\n","    if (epoch+1) % plot_detail_mnist == 0:\n","        np.save(os.path.join(exp_dir, 'accuracy_0T_list_epoch_from_{}_to_{}.npy').format(iter_0, epoch+1), np.array(accuracy_0T_list))\n","        np.save(os.path.join(exp_dir, 'accuracy_0T_EMA_list_epoch_from_{}_to_{}.npy').format(iter_0, epoch+1), np.array(accuracy_0T_EMA_list))\n","        if save_model:\n","            torch.save({'state_dict': network_sol.state_dict(),}, os.path.join(exp_dir, f'model/weight_sol_ep{epoch+1}.pth'))\n","            torch.save({'state_dict': network_sol_EMA.state_dict(),}, os.path.join(exp_dir, f'model/weight_sol_EMA_ep{epoch+1}.pth'))\n","\n","\n","# save recorded data\n","np.save(os.path.join(exp_dir, 'accuracy_0T_list_epoch_from_{}_to_{}.npy').format(iter_0, epoch+1), np.array(accuracy_0T_list))\n","np.save(os.path.join(exp_dir, 'accuracy_0T_EMA_list_epoch_from_{}_to_{}.npy').format(iter_0, epoch+1), np.array(accuracy_0T_EMA_list))\n","\n","np.save(os.path.join(exp_dir, 'MMD_loss_list_0T_epoch_from_{}_to_{}.npy').format(iter_0, epoch+1), np.array(MMD_loss_list_0T))\n","np.save(os.path.join(exp_dir, 'MMD_loss_list_T0_epoch_from_{}_to_{}.npy').format(iter_0, epoch+1), np.array(MMD_loss_list_T0))\n","\n","np.save(os.path.join(exp_dir, 'OT_dist_0T_list_epoch_from_{}_to_{}.npy').format(iter_0, epoch+1), np.array(OT_dist_0T_list))\n","np.save(os.path.join(exp_dir, 'OT_dist_T0_list_epoch_from_{}_to_{}.npy').format(iter_0, epoch+1), np.array(OT_dist_T0_list))\n","\n","np.save(os.path.join(exp_dir, 'Implicit_HJ_loss_list_epoch_from_{}_to_{}.npy').format(iter_0, epoch+1), np.array(Implicit_HJ_loss_list))\n","\n","# plot\n","plt.figure(figsize=(10,10))\n","plt.plot(MMD_loss_list_0T, label=\"OT from 0 to T\")\n","plt.plot(MMD_loss_list_T0, label=\"OT from T to 0\")\n","plt.legend( fontsize = 12 )\n","plt.title(\"MMD - iter\", fontsize=18)\n","plt.savefig(os.path.join(exp_dir,f'MMD_loss_list'), dpi=200, bbox_inches='tight', pad_inches=0)\n","plt.close()\n","\n","# plt.figure(figsize=(10,10))\n","# plt.plot(torch.tensor(OT_dist_0T_list).cpu())\n","# plt.title(\"OT_dist_0toT - iter\", fontsize=18)\n","# plt.savefig(os.path.join(exp_dir,f'OT_dist_0T_list'), dpi=200, bbox_inches='tight', pad_inches=0)\n","# plt.close()\n","\n","# plt.figure(figsize=(10,10))\n","# plt.plot(torch.tensor(OT_dist_T0_list).cpu())\n","# plt.title(\"OT_dist_Tto0 - iter\", fontsize=18)\n","# plt.savefig(os.path.join(exp_dir,f'OT_dist_T0_list'), dpi=200, bbox_inches='tight', pad_inches=0)\n","# plt.close()\n","\n","plt.figure(figsize=(10,10))\n","plt.plot(Implicit_HJ_loss_list)\n","plt.title(\"Implicit HJ loss - iter\", fontsize=18)\n","plt.savefig(os.path.join(exp_dir,f'HJloss_list'), dpi=200, bbox_inches='tight', pad_inches=0)\n","plt.close()\n","\n","plt.figure(figsize=(10,10))\n","plt.plot(accuracy_0T_list, label='original trained NN')\n","plt.plot(accuracy_0T_EMA_list, label='EMA')\n","plt.title(\"Average_accuracy_for_OT\", fontsize=18)\n","plt.legend(fontsize=15)\n","plt.savefig(os.path.join(exp_dir,f'accuracy'), dpi=200, bbox_inches='tight', pad_inches=0)\n","plt.close()\n"],"metadata":{"id":"PhzfQuI_1pmk","cellView":"form"},"execution_count":null,"outputs":[]}],"metadata":{"accelerator":"GPU","colab":{"collapsed_sections":["L961ddlbCWcl","AyQbnwjpGaLb","FVqh5rEVMis9","PsJ3AVfSbIpf","ng6raSYzMis_","TGQ0Vf3ZMis-","dyDUTgV2wfjG"],"gpuType":"T4","machine_shape":"hm","provenance":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":0}