{"cells":[{"cell_type":"markdown","source":["\n","The code is composed on Google Colab.\n","\n","\n","This Python notebook is devoted to task 2: In-class shift on MNIST dataset.\n","\n","\n"],"metadata":{"id":"nQn1ZuaVGiUB"}},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"1MibEOp7kTRn","outputId":"1ba82d9e-a583-4960-b551-08a5defd1005"},"outputs":[{"output_type":"stream","name":"stdout","text":["Mounted at /content/drive\n"]}],"source":["from google.colab import drive\n","drive.mount('/content/drive')"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"ul47PBw-QQ18","outputId":"7d654eb4-5116-425d-c2fd-acfeaf1d7ff3"},"outputs":[{"output_type":"stream","name":"stdout","text":["/content/drive/MyDrive/OTHJ_FashionMNIST_MNIST\n"]}],"source":["cd/content/drive/MyDrive/OTHJ_FashionMNIST_MNIST"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"cEiUihg03b8r"},"outputs":[],"source":["# @title imports\n","import os, sys\n","import numpy as np\n","import tensorflow as tf\n","import torch\n","import torchvision\n","from torchvision import transforms, datasets\n","import torchvision.models as models\n","from torchvision.transforms import Compose, Resize, Normalize, ToTensor\n","from torch.utils.data.sampler import SubsetRandomSampler\n","import matplotlib.pyplot as plt\n","import pickle\n","import torch.nn as nn\n","import torch.nn.functional as F\n","\n","import pandas as pd\n","import matplotlib.pyplot as plt\n","from matplotlib.colors import ListedColormap\n","import seaborn as sns\n","import random\n","from PIL import Image\n","\n","from utils.general import mkdir_ifnotexists\n","\n","# Keras - Deep Learning API\n","import keras\n","from keras.datasets import mnist\n","from keras.datasets import fashion_mnist\n","from keras.layers import (\n","    Conv2D, Conv2DTranspose,\n","    Input, Flatten, Dense,\n","    Lambda, Reshape\n",")\n","from keras.models import Model\n","from keras.callbacks import (\n","    EarlyStopping, ModelCheckpoint\n",")\n","from keras import backend as K\n","from sklearn.manifold import TSNE\n","from keras.metrics import MeanSquaredError\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","\n"]},{"cell_type":"code","source":["# @title setup basic parameter\n","\n","dir = os.getcwd()\n","device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","\n","Dim = 10\n","beta_recon = 100.\n","beta_KL = 0.1\n","model_and_data_savepath = 'model_data_d{}_beta100_01'.format(Dim)\n","model_name = os.path.join(model_and_data_savepath, 'best_model_dim_{}_mnist.h5'.format(Dim))\n","example_savepath = 'example_d{}_beta100_01'.format(Dim)"],"metadata":{"id":"VD-lNU7RlpJO","cellView":"form"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"L961ddlbCWcl"},"source":["\n","\n","---\n","\n","## Load the decoder-encoder, compute and save latent data (new)\n","\n","---\n","\n"]},{"cell_type":"markdown","metadata":{"id":"9iQoJ28TCWcm"},"source":["\n","VAE (using cnn)\n","\n","\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"MFxWd7lfCWcm","cellView":"form"},"outputs":[],"source":["# @title VAE Loss\n","\n","import keras\n","from keras import layers\n","from keras import backend as K\n","\n","class VAELossLayer(layers.Layer):\n","    \"\"\"\n","    Custom Keras layer that calculates the loss (reconstruction + KL divergence)\n","    of the Variational AutoEncoder (VAE)\n","    \"\"\"\n","\n","    def __init__(self, beta=100, **kwargs):\n","        super().__init__(**kwargs)\n","        self.beta = beta\n","\n","    def compute_output_shape(self, input_shape):\n","        return input_shape[0]  # or return ()\n","\n","    def calculate_loss(self, original_input, reconstructed_output, mu, sigma):\n","        \"\"\"\n","        Calculates VAE loss, which is the sum of the reconstruction loss and KL-divergence loss\n","        \"\"\"\n","        original_input = tf.keras.layers.Flatten(data_format='channels_last')(original_input)\n","\n","        reconstructed_output = tf.keras.layers.Flatten(data_format='channels_last')(reconstructed_output)\n","        print(original_input.shape) #(32, 784)\n","        print(reconstructed_output.shape) #(32, 784)\n","        reconstruction_loss = tf.reduce_sum(tf.square(original_input - reconstructed_output), axis=[1])\n","        reconstruction_loss = self.beta * reconstruction_loss\n","        kl_loss = - tf.reduce_mean(1 + sigma - tf.square(mu) - tf.exp(sigma), axis=-1)\n","        return tf.reduce_mean(reconstruction_loss + kl_loss)\n","\n","    def call(self, inputs):\n","        \"\"\"\n","        Computes the loss and adds it to the layer's losses\n","        \"\"\"\n","        original_input, reconstructed_output, mu, sigma = inputs\n","\n","        loss = self.calculate_loss(original_input, reconstructed_output, mu, sigma)\n","        self.add_loss(loss)\n","        return original_input\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","collapsed":true,"id":"EgvxXsOACWcm"},"outputs":[],"source":["# @title Encoder\n","latent_space_dim = Dim\n","\n","input_dimensions = (28, 28, 1)\n","\n","# Input Layer\n","# Defines the shape of the input data for the encoder.\n","encoder_input_layer = Input(shape=input_dimensions, name='encoder_input_layer')\n","\n","# Convolution Layers\n","# Applies convolution operations to extract features from the input image.\n","encoder_layer = Conv2D(128, 5, padding='same', activation='relu')(encoder_input_layer)\n","encoder_layer = Conv2D(128, 3, padding='same', activation='relu')(encoder_layer)\n","encoder_layer = Conv2D(64, 3, padding='same', strides=2, activation='relu')(encoder_layer)\n","encoder_layer = Conv2D(64, 3, padding='same', activation='relu')(encoder_layer)\n","encoder_layer = Conv2D(64, 3, padding='same', activation='relu')(encoder_layer)\n","encoder_layer = Conv2D(64, 3, padding='same', activation='relu')(encoder_layer)\n","encoder_layer = Conv2D(64, 3, padding='same', activation='relu')(encoder_layer)\n","\n","enc_shape = tf.keras.backend.int_shape(encoder_layer)\n","\n","# Flattening Layer\n","# Converts the 3D output of convolution layers into a 1D tensor for dense layers.\n","encoder_layer = Flatten()(encoder_layer)\n","\n","# Dense Layer\n","# A fully connected layer that combines extracted features and performs further learning.\n","encoder_layer = Dense(64, activation='relu')(encoder_layer)\n","\n","# Output Layer for Encoder\n","# Prepares the encoded data for transition into the latent space, approximating a probability distribution.\n","mu_layer = Dense(latent_space_dim, name='latent_mu')(encoder_layer)\n","sigma_layer = Dense(latent_space_dim, name='latent_sigma')(encoder_layer)\n","\n","# Storing the output shape for use in the decoder\n","encoder_output_shape = (None, 64) # K.int_shape(encoder_layer)\n","print(f\"Output shape of encoder: {encoder_output_shape}\")\n","\n","\n","### Implementing the Reparameterization Trick\n","def sample_z(args):\n","    \"\"\"\n","    Generate a sample from the Gaussian distribution defined by args=(mu, sigma).\n","\n","    Args:\n","    mu_layer:    The mean of the Gaussian distribution.\n","    sigma_layer: The log standard deviation of the Gaussian distribution.\n","\n","    Returns:\n","    A sample from the Gaussian distribution.\n","    \"\"\"\n","    mu_layer, sigma_layer = args\n","    batch_size = tf.shape(mu_layer)[0]\n","    dim = tf.shape(mu_layer)[1]\n","\n","    # Generate a random sample from a standard normal distribution with the same shape\n","    epsilon = tf.random.normal(shape=(batch_size, dim))\n","\n","    # Scale and shift the sample by mu and sigma\n","    return mu_layer + tf.exp(sigma_layer / 2) * epsilon\n","\n","# Creating the 'z' layer using the Lambda layer to apply the reparameterization trick\n","z = Lambda(sample_z, output_shape=(latent_space_dim,), name='z')([mu_layer, sigma_layer])\n","\n","# Building the encoder model\n","encoder_model = Model(encoder_input_layer, [mu_layer, sigma_layer, z], name='encoder_model')\n","\n","# Display the model summary\n","print(encoder_model.summary())\n"]},{"cell_type":"code","execution_count":null,"metadata":{"collapsed":true,"id":"ZY3QQRABCWcm","cellView":"form"},"outputs":[],"source":["# @title Decoder\n","\n","# Decoder Input Layer\n","decoder_input_layer = Input(shape=(latent_space_dim,), name='decoder_input_layer')\n","\n","# Initial Dense Layer\n","# The number of units is derived from the last convolutional layer of the encoder\n","num_units = np.prod(enc_shape[1:]) # 14 * 14 * 64\n","decoder_dense_layer = Dense(num_units, activation='relu')(decoder_input_layer)\n","\n","# Reshaping Layer\n","# The dense layer's output is reshaped to match the last convolutional layer's output shape in the encoder\n","reshape_dims = (14, 14, 64)\n","decoder_layer = Reshape(reshape_dims)(decoder_dense_layer)\n","\n","decoder_layer = Conv2DTranspose(64, 3, padding='same', activation='relu')(decoder_layer)\n","decoder_layer = Conv2DTranspose(64, 3, padding='same', activation='relu')(decoder_layer)\n","decoder_layer = Conv2DTranspose(64, 3, padding='same', activation='relu')(decoder_layer)\n","decoder_layer = Conv2DTranspose(64, 3, padding='same', activation='relu')(decoder_layer)\n","decoder_layer = Conv2DTranspose(64, 3, strides=2, padding='same', activation='relu')(decoder_layer)\n","decoder_layer = Conv2DTranspose(128, 3, padding='same', activation='relu')(decoder_layer)\n","decoder_layer = Conv2DTranspose(128, 5, padding='same', activation='relu')(decoder_layer)\n","decoder_out_layer = Conv2DTranspose(1, 5, padding='same', activation='relu')(decoder_layer)\n","\n","# Building the Decoder Model\n","decoder_model = Model(decoder_input_layer, decoder_out_layer, name='decoder_model')\n","\n","decoder_model.summary()\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","colab":{"base_uri":"https://localhost:8080/"},"collapsed":true,"id":"0sJfV7oPCWcm","outputId":"abe0fd3c-cb13-45de-9363-c4d949f13072"},"outputs":[{"output_type":"stream","name":"stdout","text":["The type  of 'reconstructed_output' is: <class 'keras.src.backend.common.keras_tensor.KerasTensor'>\n","The shape of 'reconstructed_output' is: (None, 28, 28, 1)\n","The type of 'vae_loss_output' is: <class 'keras.src.backend.common.keras_tensor.KerasTensor'>\n","The shape of 'vae_loss_output' is: (None, 28, 28, 1)\n"]}],"source":["# @title set VAE model\n","from tensorflow.keras.models import load_model\n","\n","\n","# Reconstructed Output from Decoder\n","reconstructed_output = decoder_model(z)\n","\n","reconstructed_model = Model(encoder_input_layer, reconstructed_output, name='reconstructed_model')\n","\n","print(f\"The type  of 'reconstructed_output' is: {type(reconstructed_output)}\")\n","print(f\"The shape of 'reconstructed_output' is: {reconstructed_output.shape}\")\n","\n","\n","# Adding the loss computation layer to the model\n","vae_loss_output = VAELossLayer()([encoder_input_layer, reconstructed_output, mu_layer, sigma_layer])\n","\n","print(f\"The type of 'vae_loss_output' is: {type(vae_loss_output)}\")\n","print(f\"The shape of 'vae_loss_output' is: {vae_loss_output.shape}\")\n","\n","# loaded_model = keras.saving.load_model(\"best_model_dim_15.h5\")\n","final_vae_model = Model(encoder_input_layer, vae_loss_output, name='vae_model')\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"collapsed":true,"id":"9BbmQa7eCWcn","outputId":"4ca5815f-77b7-406f-cd57-e210e39278ff","cellView":"form"},"outputs":[{"output_type":"stream","name":"stderr","text":["WARNING:absl:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n"]}],"source":["# @title load trained encoder and decoder\n","\n","from tensorflow.keras.models import load_model\n","\n","trained_model = load_model(model_name, custom_objects={\"sample_z\": sample_z, \"VAELossLayer\": VAELossLayer})\n","\n","encoder_input = trained_model.input\n","\n","# Encoder_mu\n","# Assume 'encoder_input' is the first input and 'z_mean' is an intermediate layer\n","encoder_input = trained_model.input\n","mu_layer = trained_model.get_layer('latent_mu')\n","mu_output = mu_layer.output\n","# Recreate the encoder model\n","encoder_mu = Model(inputs=encoder_input, outputs=mu_output)\n","\n","# Encoder_sigma\n","# Assume 'encoder_input' is the first input and 'z_mean' is an intermediate layer\n","encoder_input = trained_model.input\n","sigma_layer = trained_model.get_layer('latent_sigma')\n","sigma_output = sigma_layer.output\n","# Recreate the encoder model\n","encoder_sigma = Model(inputs=encoder_input, outputs=sigma_output)\n","\n","# Decoder\n","trained_decoder = trained_model.get_layer('decoder_model')\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"collapsed":true,"id":"8qxGz_C5zkJb","cellView":"form"},"outputs":[],"source":["# @title Load MNIST dataset & encode\n","\n","\n","# load the fashion mnist data\n","(x_train, y_train), (x_test, y_test) = mnist.load_data()\n","concatenated_x = tf.concat([x_train, x_test], axis=0)\n","concatenated_y = tf.concat([y_train, y_test], axis=0)\n","\n","list_of_images = []\n","\n","for i in range(10):\n","  list_of_images.append(concatenated_x[concatenated_y == i]) # (7000, 28, 28)\n","\n","for i in range(10):\n","    data = list_of_images[i].numpy()/255.\n","    data = tf.reshape(data, (list_of_images[i].shape[0], 28, 28, 1))\n","    encoded_mu_all = encoder_mu.predict(data)\n","    encoded_sigma_all = encoder_sigma.predict(data)\n","    epsilon_all = tf.random.normal(shape=(data.shape[0], latent_space_dim))\n","    encoded_images = encoded_mu_all + tf.exp(encoded_sigma_all / 2) * epsilon_all\n","    numpy_encoded_images = encoded_images.numpy()\n","    save_path = os.path.join(model_and_data_savepath, 'encoded_image_{}_{}.npy'.format(i, 'mnist'))\n","    with open(save_path, 'wb') as f:\n","        np.save(f, numpy_encoded_images)\n","\n","\n","list_of_images_for_test = []\n","for i in range(10):\n","    list_of_images_for_test.append(x_test[y_test == i]) # (7000, 28, 28)\n","\n","for i in range(10):\n","    data = list_of_images_for_test[i]/255.\n","    data = tf.reshape(data, (list_of_images_for_test[i].shape[0], 28, 28, 1))\n","    encoded1_mu_all = encoder_mu.predict(data)\n","    encoded1_sigma_all = encoder_sigma.predict(data)\n","    epsilon1_all = tf.random.normal(shape=(data.shape[0], latent_space_dim))\n","    encoded1_images = encoded1_mu_all + tf.exp(encoded1_sigma_all / 2) * epsilon1_all\n","    numpy_encoded1_images = encoded1_images.numpy()\n","    save_path = os.path.join(model_and_data_savepath, 'encoded_image_{}_for_test_mnist.npy'.format(i))\n","    with open(save_path, 'wb') as f:\n","        np.save(f, numpy_encoded1_images)\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"QxDHnElk4HV6","cellView":"form"},"outputs":[],"source":["# @title Load all digits\n","\n","\n","device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n","\n","tensor_encoded_list=[]\n","for i in range(10):\n","    encoded_data_path = os.path.join(model_and_data_savepath, 'encoded_image_{}_{}.npy'.format(i, 'mnist'))\n","    encoded_x = np.load(encoded_data_path)\n","    tensor_encoded_list.append(torch.tensor(encoded_x).to(device))\n"]},{"cell_type":"markdown","metadata":{"id":"dyDUTgV2wfjG"},"source":["\n","\n","---\n","## check accuracy on 28 $\\times$ 28 images\n","\n","\n","---\n","\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"h-820URmwlWe"},"outputs":[],"source":["import tensorflow as tf\n","# Keras - Deep Learning API\n","import keras                                # High-level neural networks API\n","from keras.datasets import mnist            # MNIST dataset of hand-written digits\n","from keras.layers import (                  # Neural network layers\n","    Conv2D, Conv2DTranspose,\n","    Input, Flatten, Dense,\n","    Lambda, Reshape\n",")\n","from keras.models import Model              # Model definition and training\n","from keras.callbacks import (               # Training callbacks\n","    EarlyStopping, ModelCheckpoint\n",")\n","from keras import backend as K\n","from sklearn.manifold import TSNE\n","\n","import sys\n","sys.path.append('/Volumes/D/GitHub-Portfolio/DeepLearning-MNIST-VAE/src/')\n","\n","# Metrics\n","from keras.metrics import MeanSquaredError\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","\n","# load the fashion mnist data\n","(x_train, y_train), (x_test, y_test) = mnist.load_data()\n","assert x_train.shape == (60000, 28, 28)\n","assert x_test.shape == (10000, 28, 28)\n","assert y_train.shape == (60000,)\n","assert y_test.shape == (10000,)\n","\n","concatenated_x = tf.concat([x_train, x_test], axis=0)\n","concatenated_y = tf.concat([y_train, y_test], axis=0)\n","\n","list_of_images = []\n","for i in range(10):\n","  list_of_images.append(concatenated_x[concatenated_y == i]) # (7000, 28, 28)"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"mg_24V7Ip3Q3","outputId":"f4dcd64e-2c39-4545-d12a-8266fed9c654"},"outputs":[{"output_type":"execute_result","data":{"text/plain":["<All keys matched successfully>"]},"metadata":{},"execution_count":33}],"source":["model = models.resnet18()\n","\n","model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n","model.fc =  nn.Linear(in_features=512, out_features=10, bias=True)\n","\n","save_path = os.getcwd()\n","file_name = os.path.join(save_path, 'model_classifier_28_by_28_MNIST.pt')\n","state_dict = torch.load(file_name)\n","model.load_state_dict(state_dict, strict=False)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"LLAqzPtQwnzQ"},"outputs":[],"source":["def imshow(img):\n","    img = img / 2 + 0.5\n","    npimg = img.numpy()\n","    plt.imshow(np.transpose(npimg, (1, 2, 0)))\n","    plt.show()"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"FkwLPqxywwXv"},"outputs":[],"source":["def test_accuracy_fashion_MNIST(classifier, list_of_imgs, target_digit):\n","  with torch.no_grad():\n","    correct = 0\n","    output = classifier(list_of_imgs)\n","    _, predicted = torch.max(output, 1)\n","    total = list_of_imgs.size()[0]\n","    correct += (predicted == target_digit).sum().item()\n","    accuracy_rate = 100 * correct / total\n","  return accuracy_rate, predicted"]},{"cell_type":"markdown","metadata":{"id":"_WmQYYP-gcpj"},"source":["\n","\n","---\n","## preparing normalized latent samples\n","\n","\n","---\n","\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"collapsed":true,"id":"Jj3NHr-NVv43","outputId":"1a016a8c-a53a-4a3f-edaa-1015ec2890fa","cellView":"form"},"outputs":[{"output_type":"stream","name":"stdout","text":["effective dimension = 10\n","effective dimension using std = 10\n"]},{"output_type":"stream","name":"stderr","text":["/tmp/ipython-input-337636696.py:25: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n","  U_reduced = torch.tensor(U[:, idx]).to(device)\n","/tmp/ipython-input-337636696.py:26: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n","  V_reduced = torch.tensor(V[idx, :]).to(device)\n"]}],"source":["# @title compute mean value, SVD of the covariance matrix for the latent data\n","import pickle\n","\n","dim_latent = Dim\n","\n","encoded_x_tot = tensor_encoded_list[0]\n","for j in range(1, 10):\n","    encoded_x_tot = torch.cat((encoded_x_tot, tensor_encoded_list[j]), 0).to(device)\n","\n","encoded_x_tot_2 = tensor_encoded_list[5]\n","for j in range(6, 10):\n","    encoded_x_tot_2 = torch.cat((encoded_x_tot_2, tensor_encoded_list[j]), 0).to(device)\n","\n","encoded_x_tot_1 = tensor_encoded_list[0]\n","for j in range(0, 4):\n","    encoded_x_tot_1 = torch.cat((encoded_x_tot_1, tensor_encoded_list[j]), 0).to(device)\n","\n","\n","mean_encoded_x_tot = torch.mean(encoded_x_tot, dim=0, keepdims=True).to(device)\n","cov_encoded_x_tot = torch.matmul((encoded_x_tot - mean_encoded_x_tot).T, (encoded_x_tot - mean_encoded_x_tot)) / encoded_x_tot.size(0)\n","U, s, V = np.linalg.svd(cov_encoded_x_tot.cpu().detach().numpy(), full_matrices=True)\n","U = torch.tensor(U, dtype=torch.float32)\n","V = torch.tensor(V, dtype=torch.float32)\n","idx = s > 1e-4\n","U_reduced = torch.tensor(U[:, idx]).to(device)\n","V_reduced = torch.tensor(V[idx, :]).to(device)\n","s_reduced = torch.tensor(s[idx]).to(device)\n","\n","sqr = encoded_x_tot ** 2\n","std = torch.sqrt(torch.mean(sqr, dim=0, keepdims=True) - mean_encoded_x_tot ** 2)  # Trace of Covariance matrix\n","idx_std = std[0].cpu().numpy() > 0\n","\n","# with open(os.path.join(model_and_data_savepath, 'SVD_U.pkl'), 'wb') as file:\n","#     pickle.dump(U_reduced.cpu().detach().numpy(), file)\n","# with open(os.path.join(model_and_data_savepath, 'SVD_V.pkl'), 'wb') as file:\n","#     pickle.dump(V_reduced.cpu().detach().numpy(), file)\n","# with open(os.path.join(model_and_data_savepath, 'SVD_s.pkl'), 'wb') as file:\n","#     pickle.dump(s_reduced.cpu().detach().numpy(), file)\n","# with open(os.path.join(model_and_data_savepath, 'mean_encoded_x_all.pkl'), 'wb') as file:\n","#     pickle.dump(mean_encoded_x_tot.cpu().detach().numpy(), file)\n","\n","# with open(os.path.join(model_and_data_savepath, 'std.pkl'), 'wb') as file:\n","#     pickle.dump(std.cpu().detach().numpy(), file)\n","\n","effective_dim= sum(idx) # 15\n","print(\"effective dimension = {}\".format(effective_dim))\n","\n","effective_dim_using_std = sum(idx_std) # 15\n","print(\"effective dimension using std = {}\".format(effective_dim_using_std))\n","\n"]},{"cell_type":"code","source":["plt.plot(np.log(s)/np.log(10))\n","plt.savefig(model_and_data_savepath + '/singular_value.pdf')\n","plt.close()\n","print(s)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"rRrsqrgh8mhP","outputId":"5eb419c2-e988-4d74-826e-cf6e02c60de8"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["[10.85222    9.951967   8.215828   6.4507594  5.200374   3.8695521\n","  3.1955023  2.44244    2.1457496  1.3165425]\n"]}]},{"cell_type":"code","execution_count":null,"metadata":{"id":"AVillvqKs-Xg","cellView":"form"},"outputs":[],"source":["# @title overview of latent samples (func)\n","import numpy as np\n","import pickle\n","\n","\n","def normalize_latent_samples(digit, savepath=model_and_data_savepath, dataset='mnist'):\n","\n","    if digit == \"first_list\":\n","        encoded_x = encoded_x_tot_1.to(device)\n","    elif digit == \"second_list\":\n","        encoded_x = encoded_x_tot_2.to(device)\n","    else:\n","        encoded_x = np.load(os.path.join(savepath, 'encoded_image_{}_{}.npy').format(digit, dataset))\n","        encoded_x = torch.tensor(encoded_x).to(device)\n","    # normalize:\n","    encoded_x = torch.matmul((encoded_x - mean_encoded_x_tot), U_reduced)/torch.sqrt(s_reduced)\n","\n","    return encoded_x\n","\n","\n","norm_eps = 1e-6\n","def normalize_latent_samples_using_std(digit, savepath=model_and_data_savepath, dataset='mnist', load_test=False):\n","\n","    if digit == \"first_list\":\n","        encoded_x = encoded_x_tot_1.to(device)\n","    elif digit == \"second_list\":\n","        encoded_x = encoded_x_tot_2.to(device)\n","    elif load_test:\n","        encoded_x = np.load(os.path.join(savepath, 'encoded_image_{}_for_test_{}.npy').format(digit, dataset))\n","        encoded_x = torch.tensor(encoded_x).to(device)\n","    else:\n","        encoded_x = np.load(os.path.join(savepath, 'encoded_image_{}_{}.npy').format(digit, dataset))\n","        encoded_x = torch.tensor(encoded_x).to(device)\n","    # normalized:\n","    normalized_encoded_x = (encoded_x[:, idx_std] - mean_encoded_x_tot[:, idx_std])/(std[0, idx_std]) # use Trace of Cov of data to do normalization\n","\n","    return normalized_encoded_x\n","\n","\n","def recover_latent_samples_using_PCA(samples):  # samples:  N * d (d = reduced dimension 46)\n","\n","    recovered_samples = torch.matmul(samples * torch.sqrt(s_reduced), U_reduced.T) + mean_encoded_x_tot\n","\n","    return recovered_samples\n","\n","\n","def recover_latent_samples_using_std(samples):  # samples:  N * d (d = reduced dimension 46)\n","\n","    recovered_samples = torch.zeros(samples.size()[0], dim_latent).to(device)\n","    i = 0\n","    for d in range(dim_latent):\n","      if idx_std[d]:\n","        recovered_samples[:, d] = samples[:, i] * std[0, d] + mean_encoded_x_tot[0, d]\n","        i+=1\n","      else:\n","        recovered_samples[:, d] = mean_encoded_x_tot[0, d] * torch.ones(samples.size()[0]).to(mean_encoded_x_tot.device)\n","\n","    return recovered_samples\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"collapsed":true,"id":"Y80G7m-h07Pa","cellView":"form"},"outputs":[],"source":["# @title prepare data using SVD normalization\n","\n","data_all_gpu = normalize_latent_samples(\"0\")\n","for i in range(1, 10):\n","    data_i = normalize_latent_samples(\"{}\".format(i))\n","    data_i_gpu = data_i.to(device)\n","    data_all_gpu = torch.cat((data_all_gpu, data_i_gpu), 0)\n","\n","init_data_gpu = data_all_gpu\n","mean_init_data = torch.mean(init_data_gpu, dim=0).to(device)\n","cov_init_data = torch.matmul((init_data_gpu - mean_init_data).T, (init_data_gpu - mean_init_data)) / init_data_gpu.size()[0]\n","cov_init_data = cov_init_data.to(device)\n","\n","target_data_gpu = data_all_gpu\n","mean_target_data = torch.mean(target_data_gpu, dim=0).to(device)\n","cov_target_data = torch.matmul((target_data_gpu - mean_target_data).T, (target_data_gpu - mean_target_data)) / target_data_gpu.size()[0]\n","cov_target_data = cov_target_data.to(device)\n","\n","\n","target_data_9_svd = normalize_latent_samples('9')[:6310, :].to(device)\n","target_data_8_svd = normalize_latent_samples('8')[:6310, :].to(device)\n","target_data_7_svd = normalize_latent_samples('7')[:6310, :].to(device)\n","target_data_6_svd = normalize_latent_samples('6')[:6310, :].to(device)\n","target_data_5_svd = normalize_latent_samples('5')[:6310, :].to(device)\n","target_data_4_svd = normalize_latent_samples('4')[:6310, :].to(device)\n","target_data_3_svd = normalize_latent_samples('3')[:6310, :].to(device)\n","target_data_2_svd = normalize_latent_samples('2')[:6310, :].to(device)\n","target_data_1_svd = normalize_latent_samples('1')[:6310, :].to(device)\n","target_data_0_svd = normalize_latent_samples('0')[:6310, :].to(device)\n","target_data_list_svd = [target_data_0_svd, target_data_1_svd, target_data_2_svd, target_data_3_svd, target_data_4_svd, target_data_5_svd, target_data_6_svd, target_data_7_svd, target_data_8_svd, target_data_9_svd]\n","\n","init_data_9_svd = normalize_latent_samples('9')[:6310, :].to(device)\n","init_data_8_svd = normalize_latent_samples('8')[:6310, :].to(device)\n","init_data_7_svd = normalize_latent_samples('7')[:6310, :].to(device)\n","init_data_6_svd = normalize_latent_samples('6')[:6310, :].to(device)\n","init_data_5_svd = normalize_latent_samples('5')[:6310, :].to(device)\n","init_data_4_svd = normalize_latent_samples('4')[:6310, :].to(device)\n","init_data_3_svd = normalize_latent_samples('3')[:6310, :].to(device)\n","init_data_2_svd = normalize_latent_samples('2')[:6310, :].to(device)\n","init_data_1_svd = normalize_latent_samples('1')[:6310, :].to(device)\n","init_data_0_svd = normalize_latent_samples('0')[:6310, :].to(device)\n","init_data_list_svd = [init_data_0_svd, init_data_1_svd, init_data_2_svd, init_data_3_svd, init_data_4_svd, init_data_5_svd, init_data_6_svd, init_data_7_svd, init_data_8_svd, init_data_9_svd]\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"Og88dGcbEFbk","cellView":"form"},"outputs":[],"source":["# @title prepare data with normalization (using std, i.e., the trace of covariance to normalize)\n","data_all_gpu_std = normalize_latent_samples_using_std(\"0\")\n","for i in range(1, 10):\n","    data_i_std = normalize_latent_samples_using_std(\"{}\".format(i))\n","    data_i_gpu_std = data_i_std.to(device)\n","    data_all_gpu_std = torch.cat((data_all_gpu_std, data_i_gpu_std), 0)\n","\n","init_data_gpu_std = data_all_gpu_std\n","mean_init_data_std = torch.mean(init_data_gpu_std, dim=0).to(device)\n","cov_init_data_std = torch.matmul((init_data_gpu_std - mean_init_data_std).T, (init_data_gpu_std - mean_init_data_std)) / init_data_gpu_std.size()[0]\n","cov_init_data_std = cov_init_data_std.to(device)\n","\n","target_data_gpu_std = data_all_gpu_std\n","mean_target_data_std = torch.mean(target_data_gpu_std, dim=0).to(device)\n","cov_target_data_std = torch.matmul((target_data_gpu_std - mean_target_data_std).T, (target_data_gpu_std - mean_target_data_std)) / target_data_gpu_std.size()[0]\n","cov_target_data_std = cov_target_data_std.to(device)\n","\n","target_data_9 = normalize_latent_samples_using_std('9')[:6310, :].to(device)\n","target_data_8 = normalize_latent_samples_using_std('8')[:6310, :].to(device)\n","target_data_7 = normalize_latent_samples_using_std('7')[:6310, :].to(device)\n","target_data_6 = normalize_latent_samples_using_std('6')[:6310, :].to(device)\n","target_data_5 = normalize_latent_samples_using_std('5')[:6310, :].to(device)\n","target_data_4 = normalize_latent_samples_using_std('4')[:6310, :].to(device)\n","target_data_3 = normalize_latent_samples_using_std('3')[:6310, :].to(device)\n","target_data_2 = normalize_latent_samples_using_std('2')[:6310, :].to(device)\n","target_data_1 = normalize_latent_samples_using_std('1')[:6310, :].to(device)\n","target_data_0 = normalize_latent_samples_using_std('0')[:6310, :].to(device)\n","target_data_list_std = [target_data_0, target_data_1, target_data_2, target_data_3, target_data_4, target_data_5, target_data_6, target_data_7, target_data_8, target_data_9]\n","\n","init_data_9 = normalize_latent_samples_using_std('9')[:6310, :].to(device)\n","init_data_8 = normalize_latent_samples_using_std('8')[:6310, :].to(device)\n","init_data_7 = normalize_latent_samples_using_std('7')[:6310, :].to(device)\n","init_data_6 = normalize_latent_samples_using_std('6')[:6310, :].to(device)\n","init_data_5 = normalize_latent_samples_using_std('5')[:6310, :].to(device)\n","init_data_4 = normalize_latent_samples_using_std('4')[:6310, :].to(device)\n","init_data_3 = normalize_latent_samples_using_std('3')[:6310, :].to(device)\n","init_data_2 = normalize_latent_samples_using_std('2')[:6310, :].to(device)\n","init_data_1 = normalize_latent_samples_using_std('1')[:6310, :].to(device)\n","init_data_0 = normalize_latent_samples_using_std('0')[:6310, :].to(device)\n","init_data_list_std = [init_data_0, init_data_1, init_data_2, init_data_3, init_data_4, init_data_5, init_data_6, init_data_7, init_data_8, init_data_9]\n","\n","data_all_gpu_std_for_test = normalize_latent_samples_using_std(\"0\", load_test=True)\n","for i in range(1, 10):\n","    data_i_std_for_test = normalize_latent_samples_using_std(\"{}\".format(i), load_test=True)\n","    data_i_gpu_std_for_test = data_i_std_for_test.to(device)\n","    data_all_gpu_std_for_test = torch.cat((data_all_gpu_std_for_test, data_i_gpu_std_for_test), 0)\n","\n","init_data_9_for_test = normalize_latent_samples_using_std('9', load_test=True).to(device)\n","init_data_8_for_test = normalize_latent_samples_using_std('8', load_test=True).to(device)\n","init_data_7_for_test = normalize_latent_samples_using_std('7', load_test=True).to(device)\n","init_data_6_for_test = normalize_latent_samples_using_std('6', load_test=True).to(device)\n","init_data_5_for_test = normalize_latent_samples_using_std('5', load_test=True).to(device)\n","init_data_4_for_test = normalize_latent_samples_using_std('4', load_test=True).to(device)\n","init_data_3_for_test = normalize_latent_samples_using_std('3', load_test=True).to(device)\n","init_data_2_for_test = normalize_latent_samples_using_std('2', load_test=True).to(device)\n","init_data_1_for_test = normalize_latent_samples_using_std('1', load_test=True).to(device)\n","init_data_0_for_test = normalize_latent_samples_using_std('0', load_test=True).to(device)\n","init_data_list_std_for_test = [init_data_0_for_test, init_data_1_for_test, init_data_2_for_test, init_data_3_for_test, init_data_4_for_test, init_data_5_for_test, init_data_6_for_test, init_data_7_for_test, init_data_8_for_test, init_data_9_for_test]\n","\n","data_all_gpu_std_for_test = normalize_latent_samples_using_std(\"0\", load_test=True)\n","for i in range(1, 10):\n","    data_i_std_for_test = normalize_latent_samples_using_std(\"{}\".format(i), load_test=True)\n","    data_i_gpu_std_for_test = data_i_std_for_test.to(device)\n","    data_all_gpu_std_for_test = torch.cat((data_all_gpu_std_for_test, data_i_gpu_std_for_test), 0)\n","\n","target_data_9_for_test = normalize_latent_samples_using_std('9', load_test=True).to(device)\n","target_data_8_for_test = normalize_latent_samples_using_std('8', load_test=True).to(device)\n","target_data_7_for_test = normalize_latent_samples_using_std('7', load_test=True).to(device)\n","target_data_6_for_test = normalize_latent_samples_using_std('6', load_test=True).to(device)\n","target_data_5_for_test = normalize_latent_samples_using_std('5', load_test=True).to(device)\n","target_data_4_for_test = normalize_latent_samples_using_std('4', load_test=True).to(device)\n","target_data_3_for_test = normalize_latent_samples_using_std('3', load_test=True).to(device)\n","target_data_2_for_test = normalize_latent_samples_using_std('2', load_test=True).to(device)\n","target_data_1_for_test = normalize_latent_samples_using_std('1', load_test=True).to(device)\n","target_data_0_for_test = normalize_latent_samples_using_std('0', load_test=True).to(device)\n","target_data_list_std_for_test = [target_data_0_for_test, target_data_1_for_test, target_data_2_for_test, target_data_3_for_test, target_data_4_for_test, target_data_5_for_test, target_data_6_for_test, target_data_7_for_test, target_data_8_for_test, target_data_9_for_test]\n","\n"]},{"cell_type":"markdown","metadata":{"id":"tKNgKlzzsnQU"},"source":["\n","\n","---\n","\n","## Define the Metric/Distance functions: MMD\n","\n","---\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"oa9NUc3DXmFq"},"outputs":[],"source":["# @title MMD\n","\n","# Assume the kernel function is symmetric: k(x1, x2) = k(x2, x1)\n","# MMD(p0, p1) = \\iint k(x1, x2)(p1(x1) - p0(x1))(p1(x2) - p0(x2)) dx1dx2\n","#             = \\iint k(x1, x2)p1(x1)p1(x2) dx1dx2 - 2 \\iint k(x1, x2)p0(x1)p1(x2) dx1dx2 + \\iint k(x1, x2)p0(x1)p0(x2) dx1dx2\n","# try RBF kernel: k(x1, x2) = \\exp( - |x1-x2|^2/(2\\sigma^2))\n","# size of x: N by dim x=[ x_1^t ]\n","#                       [ ..... ]\n","#                       [ x_N^t ]\n","def MMD(x, y, k_sigma=1):\n","\n","    N = x.size()[0]\n","    M = y.size()[0]\n","\n","    xsqr = torch.sum(x*x, 1).unsqueeze(1)\n","    # Xsqr_Xsqr = (|x_i|^2 + |x_j|^2)_{ij}\n","    Xsqr_Xsqr = xsqr + torch.transpose(xsqr, 0, 1)\n","    # X_dot_X = (x_i^t \\cdot x_j)_{ij}\n","    X_dot_X = torch.matmul(x, torch.transpose(x, 0, 1))\n","    # K_00 \\approx iint k(x1, x2) p0(x1)p0(x2) dx1dx2\n","    K_00 = torch.exp(-(Xsqr_Xsqr - 2 * X_dot_X)/(2 * k_sigma * k_sigma))\n","\n","    ysqr = torch.sum(y*y, 1).unsqueeze(1)\n","    # Ysqr_Ysqr = (|y_i|^2 + |y_j|^2)_{ij}\n","    Ysqr_Ysqr = ysqr + torch.transpose(ysqr, 0, 1)\n","    # Y_dot_Y = (y_i^t \\cdot y_j)_{ij}\n","    Y_dot_Y = torch.matmul(y, torch.transpose(y, 0, 1))\n","    # K_11 \\approx iint k(y1, y2) p1(y1)p0(y2) dy1dy2\n","    K_11 = torch.exp(-(Ysqr_Ysqr - 2 * Y_dot_Y)/(2 * k_sigma * k_sigma))\n","\n","    # Xsqr_Ysqr = (|x_i|^2 + |y_j|^2)_{ij}\n","    Xsqr_Ysqr = xsqr + torch.transpose(ysqr, 0, 1)\n","    # X_dot_Y = (x_i^t \\cdot y_j)_{ij}\n","    X_dot_Y = torch.matmul(x, torch.transpose(y, 0, 1))\n","    # K_01 \\approx iint k(x, y) p0(x)p1(y) dxdy\n","    K_01 = torch.exp(-(Xsqr_Ysqr - 2 * X_dot_Y) / (2 * k_sigma * k_sigma))\n","\n","    # use V-statistics, i.e., average value\n","    # cf. eq (5) https://jmlr.csail.mit.edu/papers/volume13/gretton12a/gretton12a.pdf\n","    MMDsqr = torch.sum(K_00)/(N*N) - 2 * torch.sum(K_01)/(N*M) + torch.sum(K_11)/(M*M)\n","\n","    return MMDsqr\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"Q_dBnBxobgCp"},"outputs":[],"source":["# @title MMD use k(z) = -sqrt{||z||^2+0.001}\n","# k(x, y) = k(x - y)\n","\n","# Assume the kernel function is symmetric: k(x1, x2) = k(x2, x1)\n","# MMD(p0, p1) = \\iint k(x1, x2)(p1(x1) - p0(x1))(p1(x2) - p0(x2)) dx1dx2\n","#             = \\iint k(x1, x2)p1(x1)p1(x2) dx1dx2 - 2 \\iint k(x1, x2)p0(x1)p1(x2) dx1dx2 + \\iint k(x1, x2)p0(x1)p0(x2) dx1dx2\n","# size of x: N by dim x=[ x_1^t ]\n","#                       [ ..... ]\n","#                       [ x_N^t ]\n","\n","\n","# x size:N * dim; Y size:M * dim\n","# the function returns N * M tensor whose i-j entry: ||x_i - y_j||\n","def dist_XY(X, Y):\n","\n","    eps = 0.001\n","\n","    N = X.size()[0]\n","    M = Y.size()[0]\n","    x = X.T\n","    y = Y.T\n","\n","    x = x.unsqueeze(2)\n","    x_repeat = x.repeat(1, 1, M)\n","    y = y.unsqueeze(2)\n","    y_repeat = y.repeat(1, 1, N)\n","    y_repeat_t = y_repeat.transpose(1, 2)\n","    x_y_matrix = x_repeat - y_repeat_t\n","    dist_matrix_x_y = torch.sqrt(torch.sum(x_y_matrix * x_y_matrix, 0) + eps)\n","    # dist_matrix_x_y = torch.sum(x_y_matrix * x_y_matrix, 0)\n","\n","    return dist_matrix_x_y\n","\n","\n","def MMD_negnorm(x, y, k_sigma=1):\n","\n","    M = x.size()[0]\n","    N = y.size()[0]\n","\n","    dist_x_y = dist_XY(x, y)\n","    B = -torch.sum(dist_x_y) / (M * N)\n","\n","    dist_x_x = dist_XY(x, x)\n","    A = -torch.sum(dist_x_x) / (M * M)\n","\n","    dist_y_y = dist_XY(y, y)\n","    C = -torch.sum(dist_y_y) / (N * N)\n","\n","    return A/2 - B + C/2\n","\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"x19mrwfUSOfM"},"outputs":[],"source":["# @title  sliced MMD (func)\n","\n","def slicedMMD(x, y, num):\n","  sum_mmd0 = 0\n","  for d in range(x.size()[1]-1):\n","    sum_mmd0 += MMD(x[:num, [d, d+1]], y[:num, [d, d+1]])\n","\n","  return sum_mmd0\n","\n"]},{"cell_type":"markdown","metadata":{"id":"IBGEhneOs0MU"},"source":["\n","\n","---\n","\n","\n","## Plotting function\n","\n","\n","---\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"3WtiwdHsUKYi"},"outputs":[],"source":["# @title plot samples on latent space (func)\n","\n","\n","def plot_on_latent_space(encoded_x_numpy, transported_samples, dim_0, dim_1, scale, iter, exp_dir=dir):\n","\n","    plt.figure(figsize=(12,12))\n","    plt.scatter( encoded_x_numpy[:1024, dim_0], encoded_x_numpy[:1024, dim_1], color='magenta', s=10, alpha=1)\n","    plt.scatter( transported_samples[:1024, dim_0], transported_samples[:1024, dim_1], color='blue', s=1, alpha=1)\n","    plt.title(\"plot on latent space on {}-{} coordinate plane\".format(dim_0, dim_1))\n","    filename = os.path.join(exp_dir, 'plot on latent space on {}-{} coordinate plane rescale={}iter={}.png'.format(dim_0, dim_1, scale, iter))\n","    # filename = 'plot on latent space on {}-{} coordinate plane rescale={}iter={}.png'.format(dim_0, dim_1, scale, iter)\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n"]},{"cell_type":"code","source":["# @title plot MNIST figure (func)\n","\n","import math\n","\n","def plot_mnist(images, output, n_plots, scale, iter, ttl, exp_dir=example_savepath):\n","\n","\n","    images = images.cpu().detach().numpy()[:n_plots, :]\n","    output = output[:n_plots, :]\n","\n","\n","    plot_num = n_plots\n","\n","    output = output.view(plot_num, 28, 28)\n","    output = output.cpu().detach().numpy()\n","\n","\n","    # plot the first ten input images and then reconstructed images\n","    fig, axes = plt.subplots(nrows=2, ncols=plot_num, sharex=True, sharey=True, figsize=(25,4))\n","    # input images on top row, reconstructions on bottom\n","    for images, row in zip([images, output], axes):\n","        for img, ax in zip(images, row):\n","            ax.imshow(np.squeeze(img), cmap='gray')\n","            ax.get_xaxis().set_visible(False)\n","            ax.get_yaxis().set_visible(False)\n","    plt.title(ttl)\n","    # filename = ttl+'plot_mnist_iter_{}.png'.format(iter)\n","    filename = os.path.join(exp_dir, ttl+'plot_mnist_iter_{}.png'.format(iter))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","\n","def plot_mnist_tab(rownum, colnum, x, iter, ttl, exp_dir=example_savepath):\n","    nex = rownum * colnum\n","    fig, axs = plt.subplots(rownum, nex//rownum)\n","    fig.set_size_inches(4*rownum, 4*colnum)\n","\n","    for i in range(rownum):\n","      for j in range(colnum):\n","        # axs[i, j].imshow(x[i*colnum+j,:, :, 0], cmap='gray')\n","        axs[i, j].imshow(x[i*colnum+j,:, :, 0])\n","\n","    fig.suptitle(ttl, fontsize=16)\n","\n","    for i in range(axs.shape[0]):\n","        for j in range(axs.shape[1]):\n","            axs[i, j].get_yaxis().set_visible(False)\n","            axs[i, j].get_xaxis().set_visible(False)\n","            axs[i ,j].set_aspect('equal')\n","    filename = os.path.join(exp_dir, ttl+'plot_mnist_iter_{}.png'.format(iter))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","\n","import random\n","def generate_random_list(start, end, length):\n","  random_list = []\n","  for _ in range(length):\n","    random_list.append(random.randint(start, end))\n","  return random_list\n","\n","def plot_mnist_init_target(x_a, x_b, x_c, x_d, x_e, x_f, x_g, x_h, x_i, x_j, iter, row_num, ttl, flag = '0_to_T', exp_dir=example_savepath):\n","\n","    column = 10\n","    # row_num = 20\n","\n","    nex = row_num * column\n","    fig, axs = plt.subplots(row_num, column)\n","    fig.set_size_inches( 2*column, 2*row_num )\n","\n","    random_integers = generate_random_list(0, x_a.shape[0], row_num)\n","\n","    for j in range(column):\n","        for i in range(row_num):\n","          if j == 0:\n","              axs[i, j].imshow(x_a[random_integers[i], :, :, 0])\n","              if flag == '0_to_T':\n","                  axs[i, j].set_title(f'1')\n","              else:\n","                  axs[i, j].set_title(f'9')\n","          if j == 1:\n","              axs[i, j].imshow(x_b[random_integers[i], :, :, 0])\n","              if flag == '0_to_T':\n","                  axs[i, j].set_title(f'2')\n","              else:\n","                  axs[i, j].set_title(f'0')\n","          if j == 2:\n","              axs[i, j].imshow(x_c[random_integers[i], :, :, 0])\n","              if flag == '0_to_T':\n","                  axs[i, j].set_title(f'3')\n","              else:\n","                  axs[i, j].set_title(f'1')\n","          if j == 3:\n","              axs[i, j].imshow(x_d[random_integers[i], :, :, 0])\n","              if flag == '0_to_T':\n","                  axs[i, j].set_title(f'4')\n","              else:\n","                  axs[i, j].set_title(f'2')\n","          if j == 4:\n","              axs[i, j].imshow(x_e[random_integers[i], :, :, 0])\n","              if flag == '0_to_T':\n","                  axs[i, j].set_title(f'5')\n","              else:\n","                  axs[i, j].set_title(f'3')\n","          if j == 5:\n","              axs[i, j].imshow(x_f[random_integers[i], :, :, 0])\n","              if flag == '0_to_T':\n","                  axs[i, j].set_title(f'6')\n","              else:\n","                  axs[i, j].set_title(f'4')\n","          if j == 6:\n","              axs[i, j].imshow(x_g[random_integers[i], :, :, 0])\n","              if flag == '0_to_T':\n","                  axs[i, j].set_title(f'7')\n","              else:\n","                  axs[i, j].set_title(f'5')\n","          if j == 7:\n","              axs[i, j].imshow(x_h[random_integers[i], :, :, 0])\n","              if flag == '0_to_T':\n","                  axs[i, j].set_title(f'8')\n","              else:\n","                  axs[i, j].set_title(f'6')\n","          if j == 8:\n","              axs[i, j].imshow(x_i[random_integers[i], :, :, 0])\n","              if flag == '0_to_T':\n","                  axs[i, j].set_title(f'9')\n","              else:\n","                  axs[i, j].set_title(f'7')\n","          if j == 9:\n","              axs[i, j].imshow(x_j[random_integers[i], :, :, 0])\n","              if flag == '0_to_T':\n","                  axs[i, j].set_title(f'0')\n","              else:\n","                  axs[i, j].set_title(f'8')\n","\n","    fig.suptitle(ttl, fontsize=16)\n","\n","    for i in range(axs.shape[0]):\n","        for j in range(axs.shape[1]):\n","            axs[i, j].get_yaxis().set_visible(False)\n","            axs[i, j].get_xaxis().set_visible(False)\n","            axs[i ,j].set_aspect('equal')\n","    filename = os.path.join(exp_dir, ttl+'plot_mnist_iter_{}.png'.format(iter))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","\n","\n","\n","def plot_mnist_one_row(x, iter, ttl, exp_dir=example_savepath):\n","\n","    # assume square image\n","    s = int(math.sqrt(x.shape[1]))\n","\n","    nex = 8\n","    fig, axs = plt.subplots(2, nex//2)\n","    fig.set_size_inches(18, 9)\n","\n","    for i in range(nex//2):\n","        # axs[0, i].imshow(x[i,:].reshape(s,s).detach().cpu().numpy(), cmap='gray')\n","        axs[0, i].imshow(x[i,:].reshape(s,s).detach().cpu().numpy())\n","        # axs[1, i].imshow(x[ nex//2 + i , : ].reshape(s,s).detach().cpu().numpy(), cmap='gray')\n","        axs[1, i].imshow(x[ nex//2 + i , : ].reshape(s,s).detach().cpu().numpy())\n","\n","    fig.suptitle(ttl, fontsize=16)\n","\n","    for i in range(axs.shape[0]):\n","        for j in range(axs.shape[1]):\n","            axs[i, j].get_yaxis().set_visible(False)\n","            axs[i, j].get_xaxis().set_visible(False)\n","            axs[i ,j].set_aspect('equal')\n","    filename = os.path.join(exp_dir, ttl+'plot_mnist_iter_{}.png'.format(iter))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","\n"],"metadata":{"cellView":"form","id":"vBtNJC1yMVRK"},"execution_count":null,"outputs":[]},{"cell_type":"code","execution_count":null,"metadata":{"id":"q8pssibfep3x","cellView":"form"},"outputs":[],"source":["# @title plot MNIST latent samples (func)\n","import matplotlib.pyplot as plt\n","\n","\n","\n","def overview_of_latent_samples(digit='All', flag=1, dir=model_and_data_savepath):\n","\n","    if digit != \"All\":\n","        # with open('latent_MNIST_{}_trained_on_all_digits.pkl'.format(digit), 'rb') as file:\n","        #     loaded_data = pickle.load(file)\n","        encoded_x = np.load(os.path.join(dir, 'encoded_image_{}.npy'.format(digit)))\n","        encoded_x = torch.tensor(encoded_x).cpu()\n","        mean_encoded_x = torch.mean(encoded_x, dim=0).cpu()\n","\n","    else:\n","        if flag == 1:\n","            loaded_data = encoded_x_tot_1\n","        else:\n","            loaded_data = encoded_x_tot_2\n","        encoded_x = torch.tensor(loaded_data)\n","        mean_encoded_x = torch.mean(encoded_x, dim=0)\n","\n","    cov_encoded_x = torch.matmul((encoded_x - mean_encoded_x).T, (encoded_x - mean_encoded_x)) / encoded_x.size(0)\n","    cov_encoded_x = cov_encoded_x.cpu()\n","    sqrt_cov_matrix = sqrtm(cov_encoded_x.detach().numpy())\n","    sqrt_cov_matrix = torch.tensor(sqrt_cov_matrix, dtype=torch.float32)\n","    condition_number = np.linalg.cond(cov_encoded_x.detach().numpy())\n","    condition_number_sqrt_cov  =  np.linalg.cond(sqrt_cov_matrix)\n","    singularv_cov = np.linalg.svdvals(cov_encoded_x.detach().numpy())\n","    U, _, V = np.linalg.svd(cov_encoded_x.detach().numpy(), full_matrices=True)\n","    U = torch.tensor(U, dtype=torch.float32)\n","    V = torch.tensor(V, dtype=torch.float32)\n","\n","    return encoded_x, mean_encoded_x, U, singularv_cov, V\n","\n","\n","# Using SVD to normalize, projected to PCA directions\n","# flag = 1: 0,1,2,3,4\n","# flag = 2: 5,6,7,8,9\n","def plot_in_latent_spc_along_PCA_directions_with_common_normalization(list_of_digits, dim_0, dim_1,  save_dir, dir=model_and_data_savepath):\n","    mean_encoded_x_all = mean_encoded_x_tot.cpu()\n","    U_all = U_reduced.cpu()\n","    s_all = s_reduced.cpu()\n","    V_all = V_reduced.cpu()\n","\n","\n","    mean_encoded_x_all = mean_encoded_x_all.cpu()\n","    U_all = U_all.cpu()\n","\n","    plt.figure(figsize=(12,12))\n","    for i in list_of_digits:\n","        # encoded_x, mean_encoded_x, U, singularv_cov, V = overview_of_latent_samples(i, flag)\n","\n","        encoded_x = np.load(os.path.join(dir, 'encoded_image_{}.npy'.format(i)))\n","        encoded_x = torch.tensor(encoded_x).cpu()\n","\n","        centered_pca_encoded_x = torch.matmul((encoded_x - mean_encoded_x_all), U_all)/torch.sqrt(s_all)\n","        plt.scatter(centered_pca_encoded_x[:1000, dim_0], centered_pca_encoded_x[:1000, dim_1],  s=5, alpha=1, label='{}'.format(i))\n","\n","    plt.legend()\n","    plt.title(\"plot digits({}) on latent space {}-{} PCA dimensions (using SVD on whole MNIST dataset)\".format(list_of_digits, dim_0, dim_1))\n","    filename = os.path.join(save_dir,\"plot digits({}) on latent space {}-{} PCA dimensions\".format(list_of_digits, dim_0, dim_1))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","\n","# Using standard deviation to normalize\n","# flag = 1: 0,1,2,3,4\n","# flag = 2: 5,6,7,8,9\n","def plot_in_latent_spc_std_with_common_normalization(list_of_digits, dim_0, dim_1, save_dir, dir=model_and_data_savepath):\n","    mean_encoded_x = mean_encoded_x_tot.cpu()\n","    std_all = std.cpu()\n","\n","    plt.figure(figsize=(12,12))\n","    for i in list_of_digits:\n","        encoded_x = np.load(os.path.join(dir, 'encoded_image_{}.npy'.format(i ) ) )\n","        encoded_x = torch.tensor(encoded_x).cpu()\n","        normalized_encoded_x = (encoded_x[:, idx_std] - mean_encoded_x[:, idx_std])/(std_all[0, idx_std]) # use Trace of Cov of data to do normalization\n","        plt.scatter(normalized_encoded_x[:1000, dim_0], normalized_encoded_x[:1000, dim_1],  s=5, alpha=1, label='{}'.format(i))\n","\n","    plt.legend()\n","    plt.title(\"plot digits({}) on latent space {}-{} dimensions (using SVD on whole MNIST dataset\".format(list_of_digits, dim_0, dim_1))\n","    filename = os.path.join(save_dir,\"plot digits({}) on std latent space {}-{} dimensions\".format(list_of_digits, dim_0, dim_1))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","\n","def plot_normalized_MNIST_data_in_latent_spc_along_PCA_directions(list_of_digits, dim_0, dim_1, save_dir, dir=model_and_data_savepath):\n","    plt.figure(figsize=(12,12))\n","    for i in list_of_digits:\n","        encoded_x = normalize_latent_samples(i).cpu()\n","        plt.scatter(encoded_x[:1000, dim_0], encoded_x[:1000, dim_1], s=5, alpha=1, label='{}'.format(i))\n","    plt.legend()\n","    plt.title(\"plot digits({}) on latent space {}-{} PCA dimensions (normalized)\".format(list_of_digits, dim_0, dim_1))\n","    filename = os.path.join(save_dir,\"[PCA] plot digits({}) on latent space {}-{} PCA dimensions\".format(list_of_digits, dim_0, dim_1))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","\n","# Normalize entrywise, projected to ordinary directions\n","def plot_std_normalized_MNIST_data_in_latent_spc_ordinary_coord(list_of_digits, dim_0, dim_1, save_dir, dir=model_and_data_savepath):\n","    plt.figure(figsize=(12,12))\n","    for i in list_of_digits:\n","        encoded_x = normalize_latent_samples_using_std(i).cpu()\n","        plt.scatter(encoded_x[:1000, dim_0], encoded_x[:1000, dim_1], s=5, alpha=1, label='{}'.format(i))\n","    plt.legend()\n","    plt.title(\"plot digits({}) on latent space {}-{} dimensions (normalize using std)\".format(list_of_digits, dim_0, dim_1))\n","    filename = os.path.join(save_dir,\"[std] plot digits({}) on latent space {}-{} dimensions\".format(list_of_digits, dim_0, dim_1))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","\n","# plot raw latent samples\n","def plot_raw_MNIST_data_in_latent_spc_ordinary_coord(list_of_digits, dim_0, dim_1, save_dir, dir=model_and_data_savepath):\n","    plt.figure(figsize=(12,12))\n","    for i in list_of_digits:\n","        encoded_x = np.load(os.path.join(dir, 'encoded_image_{}.npy'.format(i )))\n","        encoded_x = torch.tensor(encoded_x).cpu()\n","        plt.scatter(encoded_x[:1000, dim_0], encoded_x[:1000, dim_1], s=5, alpha=1, label='{}'.format(i))\n","    plt.legend()\n","    plt.title(\"plot digits({}) on latent space {}-{} dimensions (raw data)\".format(list_of_digits, dim_0, dim_1))\n","    filename = os.path.join(save_dir,\"[RAW] plot digits({}) on latent space {}-{} dimensions\".format(list_of_digits, dim_0, dim_1))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","\n","def plot_in_latent_spc_along_PCA_directions_with_common_normalization_cmp_groups(list_of_digits_1, list_of_digits_2, dim_0, dim_1, save_dir, dir=model_and_data_savepath):\n","    mean_encoded_x_all = mean_encoded_x_tot.cpu()\n","    U_all = U_reduced.cpu()\n","    s_all = s_reduced.cpu()\n","    V_all = V_reduced.cpu()\n","    mean_encoded_x_all = mean_encoded_x_all.cpu()\n","    U_all = U_all.cpu()\n","\n","    plt.figure(figsize=(12,12))\n","    l = 0\n","    for i in list_of_digits_1:\n","        encoded_x = np.load(os.path.join(dir, 'encoded_image_{}.npy'.format(i)))\n","        encoded_x = torch.tensor(encoded_x).cpu()\n","        centered_pca_encoded_x = torch.matmul((encoded_x - mean_encoded_x_all), U_all)/torch.sqrt(s_all)\n","        if l == 0:\n","            plt.scatter(centered_pca_encoded_x[:1000, dim_0], centered_pca_encoded_x[:1000, dim_1], color=\"blue\", s=5, alpha=1, label='{}'.format(list_of_digits_1))\n","        else:\n","            plt.scatter(centered_pca_encoded_x[:1000, dim_0], centered_pca_encoded_x[:1000, dim_1], color=\"blue\", s=5, alpha=1)\n","        l=l+1\n","    l = 0\n","    for i in list_of_digits_2:\n","        encoded_x = np.load(os.path.join(dir, 'encoded_image_{}.npy'.format(i)))\n","        encoded_x = torch.tensor(encoded_x).cpu()\n","        centered_pca_encoded_x = torch.matmul((encoded_x - mean_encoded_x_all), U_all)/torch.sqrt(s_all)\n","        if l == 0:\n","            plt.scatter(centered_pca_encoded_x[:1000, dim_0], centered_pca_encoded_x[:1000, dim_1], color=\"green\", s=5, alpha=1, label='{}'.format(list_of_digits_1))\n","        else:\n","            plt.scatter(centered_pca_encoded_x[:1000, dim_0], centered_pca_encoded_x[:1000, dim_1], color=\"green\", s=5, alpha=1)\n","        l=l+1\n","    plt.legend()\n","    plt.title(\"plot digits({}) and digits({}) on latent space {}-{} PCA dimensions (using SVD on whole MNIST dataset)\".format(list_of_digits_1, list_of_digits_2, dim_0, dim_1))\n","    filename = os.path.join(save_dir,\"plot digits({}) and ({}) on latent space {}-{} PCA dimensions\".format(list_of_digits_1, list_of_digits_2, dim_0, dim_1))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","\n","# Using standard deviation to normalize\n","# flag = 1: 0,1,2,3,4\n","# flag = 2: 5,6,7,8,9\n","def plot_in_latent_spc_std_with_common_normalization_cmp_groups(list_of_digits_1, list_of_digits_2, dim_0, dim_1, save_dir, dir=model_and_data_savepath):\n","    mean_encoded_x = mean_encoded_x_tot.cpu()\n","    std_all = std.cpu()\n","\n","    plt.figure(figsize=(12,12))\n","    l = 0\n","    for i in list_of_digits_1:\n","        encoded_x = np.load(os.path.join(dir, 'encoded_image_{}.npy'.format(i )))\n","        encoded_x = torch.tensor(encoded_x).cpu()\n","        normalized_encoded_x = (encoded_x[:, idx_std] - mean_encoded_x[:, idx_std])/(std_all[0, idx_std]) # use Trace of Cov of data to do normalization\n","        if l == 0:\n","            plt.scatter(normalized_encoded_x[:1000, dim_0], normalized_encoded_x[:1000, dim_1], color='b', s=5, alpha=1, label='{}'.format(list_of_digits_1))\n","        else:\n","            plt.scatter(normalized_encoded_x[:1000, dim_0], normalized_encoded_x[:1000, dim_1], color='b', s=5, alpha=1)\n","        l=l+1\n","\n","    for i in list_of_digits_2:\n","        encoded_x = np.load(os.path.join(dir, 'encoded_image_{}.npy'.format(i ) ))\n","        encoded_x = torch.tensor(encoded_x).cpu()\n","        normalized_encoded_x = (encoded_x[:, idx_std] - mean_encoded_x[:, idx_std])/(std_all[0, idx_std]) # use Trace of Cov of data to do normalization\n","        if l == 0:\n","            plt.scatter(normalized_encoded_x[:1000, dim_0], normalized_encoded_x[:1000, dim_1], color='green', s=5, alpha=1, label='{}'.format(list_of_digits_2))\n","        else:\n","            plt.scatter(normalized_encoded_x[:1000, dim_0], normalized_encoded_x[:1000, dim_1], color='green', s=5, alpha=1)\n","        l=l+1\n","\n","    plt.legend()\n","    plt.title(\"plot digits({}) and digits ({}) on latent space {}-{} dimensions (using SVD on whole MNIST dataset\".format(list_of_digits_1, list_of_digits_2, dim_0, dim_1))\n","    filename = os.path.join(save_dir,\"plot digits({}) and ({}) on std latent space {}-{} dimensions\".format(list_of_digits_1, list_of_digits_2, dim_0, dim_1))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"9i9qoulVQVEu","cellView":"form"},"outputs":[],"source":["# @title plot OT map and samples (func)\n","\n","def plot_latent_samples_n_OT_map(target_data, init_pnts, transported_pnts, dim_0, dim_1, iter, normalization='PCA', flag='0_to_T', digits='05', exp_dir=example_savepath):\n","\n","    target_data = target_data.detach().cpu().numpy()\n","    init_pnts = init_pnts.detach().cpu().numpy()\n","    transported_pnts = transported_pnts.detach().cpu().numpy()\n","\n","    plt.style.use('dark_background')\n","    plt.figure(figsize=(12,12))\n","    if flag == '0_to_T':\n","        plt.scatter(transported_pnts[:300, dim_0], transported_pnts[:300, dim_1], color='blue', s=5, alpha=1, label='transported to T')\n","        plt.scatter(init_pnts[:300, dim_0], init_pnts[:300, dim_1], color='mediumspringgreen', s=5, alpha=1, label='0')\n","    else:\n","        plt.scatter(transported_pnts[:300, dim_0], transported_pnts[:300, dim_1], color='blue', s=5, alpha=1, label='transported to 0')\n","        plt.scatter(init_pnts[:300, dim_0], init_pnts[:300, dim_1], color='mediumspringgreen', s=5, alpha=1, label='T')\n","    # plt.quiver(init_pnts[:100, dim_0], init_pnts[:100, dim_1], transported_pnts[:100, dim_0], transported_pnts[:100, dim_1], color='cyan', width=0.002, headwidth=0.004)\n","    for k in range(300):\n","        # arrow_direction = transported_pnts[k, :] - init_pnts[k, :]\n","        # plt.quiver(init_pnts[k, dim_0], init_pnts[k, dim_1], arrow_direction[dim_0], arrow_direction[dim_1], color='cyan', scale=10, scale_units='xy')\n","        # plt.annotate(\"\", xy=(init_pnts[k, dim_0], init_pnts[k, dim_1]), xytext=(arrow_direction[dim_0], arrow_direction[dim_1]),\n","                # arrowprops=dict(arrowstyle=\"-|>\", color=\"cyan\", lw=2,\n","                #                 connectionstyle=\"arc3\"))\n","        plt.plot([init_pnts[k, dim_0], transported_pnts[k, dim_0]], [init_pnts[k, dim_1], transported_pnts[k, dim_1]], c='cyan', alpha=0.5, linewidth=0.8)\n","    plt.legend(fontsize=20)\n","    if normalization == 'PCA':\n","        plt.title(\"[{}] Plot on {}-{} PCA direction\".format(flag, dim_0, dim_1), fontsize=40)\n","    else:\n","        plt.title(\"[{}] Plot on {}-{} dimensional plane\".format(flag, dim_0, dim_1)+flag, fontsize=40)\n","    if normalization == 'PCA':\n","        filename = os.path.join(exp_dir, '[{}, {}] plot OT maps on {}-{} PCA directions iter={}.png'.format(flag, digits, dim_0, dim_1, iter))\n","    else:\n","        filename = os.path.join(exp_dir, '[{}, {}] plot OT maps on {}-{} dims (normalized using std) iter={}.png'.format(flag, digits, dim_0, dim_1, iter))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","    plt.style.use('dark_background')\n","    plt.figure(figsize=(12,12))\n","    plt.scatter(target_data[:1024, dim_0], target_data[:1024, dim_1], color='magenta', s=1, alpha=1, label='target')\n","    plt.scatter(transported_pnts[:1024, dim_0], transported_pnts[:1024, dim_1], color='blue', s=1, alpha=1, label='computed')\n","    # plt.scatter(init_pnts[:1024, dim_0], init_pnts[:1024, dim_1], color='mediumspringgreen', s=1, alpha=1)\n","    plt.legend(fontsize=20)\n","    if normalization == 'PCA':\n","        plt.title(\"[{}, {}] Plot on {}-{} PCA direction\".format(flag, digits, dim_0, dim_1), fontsize=40)\n","    else:\n","        plt.title(\"[{}, {}] Plot on {}-{} dimensional plane\".format(flag, digits, dim_0, dim_1), fontsize=40)\n","    if normalization == 'PCA':\n","        filename = os.path.join(exp_dir, '[{}, {}] plot samples on {}-{} PCA directions iter={}.png'.format(flag, digits, dim_0, dim_1, iter))\n","    else:\n","        filename = os.path.join(exp_dir, '[{}, {}] plot samples on {}-{} dims (normalized using std) iter={}.png'.format(flag, digits, dim_0, dim_1, iter))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","    plt.style.use('dark_background')\n","    plt.figure(figsize=(12,12))\n","    # plt.scatter(target_data[:1024, dim_0], target_data[:1024, dim_1], color='magenta', s=1, alpha=1, label='target')\n","    plt.scatter(transported_pnts[:1024, dim_0], transported_pnts[:1024, dim_1], color='deepskyblue', s=1, alpha=1, label='computed')\n","    # plt.scatter(init_pnts[:1024, dim_0], init_pnts[:1024, dim_1], color='mediumspringgreen', s=1, alpha=1)\n","    plt.legend(fontsize=20)\n","    if normalization == 'PCA':\n","        plt.title(\"[{}, {}] plot on {}-{} PCA direction\".format(flag, digits, dim_0, dim_1) + flag, fontsize=40)\n","    else:\n","        plt.title(\"[{}, {}] plot on {}-{} dimsional plane\".format(flag, digits, dim_0, dim_1) + flag, fontsize=40)\n","    if normalization == 'PCA':\n","        filename = os.path.join(exp_dir, '[{}, {}] plot samples (ONLY COMPUTED) on {}-{} PCA directions iter={}.png'.format(flag, digits, dim_0, dim_1, iter))\n","    else:\n","        filename = os.path.join(exp_dir, '[{}, {}] plot samples (ONLY COMPUTED) on {}-{} dims (normalized using std) iter={}.png'.format(flag, digits, dim_0, dim_1, iter))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","\n","\n","\n","def plot_3D_latent_samples_n_OT_map(target_data, init_pnts, transported_pnts, dim_0, dim_1, dim_2, iter, normalization='PCA', exp_dir=example_savepath):\n","\n","    target_data = target_data.detach().cpu().numpy()\n","    init_pnts = init_pnts.detach().cpu().numpy()\n","    transported_pnts = transported_pnts.detach().cpu().numpy()\n","\n","    plt.style.use('dark_background')\n","\n","    # 3D plot\n","    fig = plt.figure(figsize=(20, 20))\n","    ax = fig.add_subplot(projection='3d')\n","\n","    ax.scatter(target_data[:1024, dim_0], target_data[:1024, dim_1], target_data[:1024, dim_2], color='magenta', s=10, label='target')\n","    ax.scatter(transported_pnts[:1024, dim_0], transported_pnts[:1024, dim_1], transported_pnts[:1024, dim_2], color='deepskyblue', s=10, label='computed')\n","\n","    ax.set_xlabel('dim {}'.format(dim_0))\n","    ax.set_ylabel('dim {}'.format(dim_1))\n","    ax.set_zlabel('dim {}'.format(dim_2))\n","\n","    plt.legend(fontsize=20)\n","    if normalization == 'PCA':\n","        plt.title(\"plot on {}-{}-{} PCA direction\".format(dim_0, dim_1, dim_2), fontsize=40)\n","    else:\n","        plt.title(\"plot on {}-{}-{} dimsional plane\".format(dim_0, dim_1, dim_2), fontsize=40)\n","    if normalization == 'PCA':\n","        filename = os.path.join(exp_dir, 'plot 3D samples on {}-{}-{} PCA directions iter={}.png'.format(dim_0, dim_1, dim_2, iter))\n","    else:\n","        filename = os.path.join(exp_dir, 'plot 3D samples on {}-{}-{} dims (normalized using std) iter={}.png'.format(dim_0, dim_1, dim_2, iter))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","\n","    # 3D plot with OT map\n","    fig = plt.figure(figsize=(20, 20))\n","    ax = fig.add_subplot(projection='3d')\n","\n","    ax.scatter(transported_pnts[:200, dim_0], transported_pnts[:200, dim_1], transported_pnts[:200, dim_2], color='blue', s=5, label='transported')\n","    ax.scatter(init_pnts[:200, dim_0], init_pnts[:200, dim_1], init_pnts[:200, dim_2], color='mediumspringgreen', s=5, label='initial')\n","    for k in range(200):\n","        plt.plot([init_pnts[k, dim_0], transported_pnts[k, dim_0]], [init_pnts[k, dim_1], transported_pnts[k, dim_1]], [init_pnts[k, dim_2], transported_pnts[k, dim_2]], c='cyan', alpha=0.5, linewidth=0.8)\n","    plt.legend(fontsize=20)\n","    if normalization == 'PCA':\n","        plt.title(\"plot on {}-{}-{} PCA direction\".format(dim_0, dim_1, dim_2), fontsize=40)\n","    else:\n","        plt.title(\"plot on {}-{}-{} dimensional plane\".format(dim_0, dim_1, dim_2), fontsize=40)\n","    if normalization == 'PCA':\n","        filename = os.path.join(exp_dir, 'plot 3D OT maps on {}-{}-{} PCA directions iter={}.png'.format(dim_0, dim_1, dim_2, iter))\n","    else:\n","        filename = os.path.join(exp_dir, 'plot 3D OT maps on {}-{}-{} dims (normalized using std) iter={}.png'.format(dim_0, dim_1, dim_2, iter))\n","    plt.savefig(filename, bbox_inches='tight', pad_inches=0.2, dpi=500)\n","    plt.close()\n","\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"SWq4QpZHrv66"},"source":["\n","\n","---\n","## Main algorithm (OTHJ)\n","---\n","\n"]},{"cell_type":"code","source":["# @title OT HJ implicit solver\n","import torch\n","import numpy as np\n","import torch.nn.functional\n","import math\n","import os\n","from datetime import datetime\n","from models_Resnet import gradient, ImplicitNet\n","import utils.general as utils\n","import matplotlib.pyplot as plt\n","device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n","\n","\n","normalization_type = 'std'\n","\n","if normalization_type == 'PCA':\n","    effective_dim = effective_dim\n","    init_data_entire = init_data_gpu\n","    init_data_list = init_data_list_svd\n","    cov_init = cov_init_data\n","    mean_init = mean_init_data\n","    target_data_entire = target_data_gpu\n","    target_data_list = target_data_list_svd\n","    cov_target = cov_target_data\n","    mean_target = mean_target_data\n","else:\n","    effective_dim = effective_dim_using_std\n","    init_data_entire = init_data_gpu_std\n","    init_data_list = init_data_list_std\n","    cov_init = cov_init_data_std\n","    mean_init = mean_init_data_std\n","    target_data_entire = target_data_gpu_std\n","    target_data_list = target_data_list_std\n","    cov_target = cov_target_data_std\n","    mean_target = mean_target_data_std\n","\n","def random_sampler(N, dim=effective_dim, T=1, fix_T=False, T0=0):\n","    if fix_T:\n","        ts = T0 * torch.ones(N,1)\n","    else:\n","        ts = T * torch.rand(N,1)\n","    xs = torch.randn(N, dim)\n","    pnts = torch.cat((ts, xs), 1)\n","    return torch.tensor(pnts, dtype=torch.float32, requires_grad=True)\n","\n","\n","## -------------------------------------------------------------------------------\n","## Configurations\n","## -------------------------------------------------------------------------------\n","gpu_id = 0\n","\n","dim_latent = Dim\n","\n","\n","iter_0 = 0\n","\n","regularizer_type = ['implicithjxt0t', 'mmd_negnorm_0t']\n","regularizer_coord = [1, 500]\n","\n","with_OTloss = False\n","weight_OTloss = 0.01\n","\n","N = 2000 # collocation pnts for HJ\n","batch_size = 2000 # batch size for  sample transport\n","if target_data_list[0].shape[0] < batch_size:\n","    batch_size = target_data_list[0].shape[0]\n","\n","epochs = 100000\n","val_frequency  = 2000\n","plot_detail_ot = epochs\n","plot_detail_mnist = 10000\n","\n","num_mnist_digits_to_plt = 10\n","\n","dim = effective_dim # dimension of spatial domain\n","T = 1 # terminal time\n","\n","fixing_T_and_0_HJ_loss = True\n","\n","batch_size_OT = 4000\n","sub_batch_size = 400\n","\n","NN_dims = [128, 128, 128, 128]\n","network_sol = ImplicitNet(d_in=dim+1, dims=NN_dims).to(device)\n","optimizer = torch.optim.Adam(params=network_sol.parameters(), lr=1E-4)\n","\n","loss_min = 1E10\n","\n","save_model = True\n","\n","exp_dir = os.path.join(example_savepath, 'Classed_OT_shift')\n","utils.mkdir_ifnotexists(exp_dir)\n","\n","# If need to load existing model, uncomment the following two lines:\n","# state_dict = torch.load(os.path.join(exp_dir, 'weight_sol_ep{}.pth'.format(iter_0)))\n","# network_sol.load_state_dict(state_dict[\"state_dict\"], strict=False)\n","\n","timestamp = 'Normalization_type={}, batchsize={}, NNsize={}, HJ_loss_num_colocation_pnts={}  {}'.format(normalization_type, batch_size, NN_dims, N, datetime.now().strftime('%Y_%m_%d_%H_%M_%S'))\n","exp_dir = os.path.join(exp_dir, timestamp)\n","utils.mkdir_ifnotexists(exp_dir)\n","regularizer_type = list(map(str.lower,regularizer_type))\n","assert len(regularizer_type) == len(regularizer_coord), 'match regularizer coordinates'\n","regularizer_index = {t:i for i, t in enumerate(regularizer_type)}\n","\n","if save_model:\n","    utils.mkdir_ifnotexists(os.path.join(exp_dir,'model'))\n","if torch.cuda.is_available() and gpu_id > -1:\n","    device = torch.device(gpu_id)\n","else:\n","    device = torch.device('cpu')\n","utils.set_random_seed(5884)\n","\n","\n","MMD_loss_list_0T = []\n","MMD_loss_list_T0 = []\n","# OT_dist_0T_list = []\n","# OT_dist_T0_list = []\n","Implicit_HJ_loss_list = []\n","accuracy_0T_list = []\n","accuracy_T0_list = []\n","for epoch in range(iter_0, epochs+iter_0):\n","\n","    if epoch % 100 == 0:\n","        print(\"Iteration: {}\".format(epoch))\n","\n","    ## -------------------------------------------------------------------------------\n","    ## Compute losses\n","    ## -------------------------------------------------------------------------------\n","    loss = torch.tensor(0.).to(device)\n","    losses = {'Train loss' : None}\n","\n","    if 'implicithjxt0t' in regularizer_type:\n","        if fixing_T_and_0_HJ_loss:\n","            pnts = random_sampler(N, fix_T=True, T0=T).to(device)\n","        else:\n","            pnts = random_sampler(N, dim, T=T).to(device)\n","        pred_sol = network_sol(pnts)\n","        grad_pred_sol = gradient(pnts,pred_sol)[:,1:]\n","        init_x = pnts[:,1:] - pnts[:,[0]]*grad_pred_sol\n","        init_xt = torch.cat((torch.zeros((N,1)).to(device), init_x), 1)\n","        fwd_loss_implicithj = ((pred_sol - 0.5*pnts[:,[0]]*torch.sum(grad_pred_sol*grad_pred_sol,dim=1,keepdim=True) - network_sol(init_xt))**2).mean()\n","\n","        if fixing_T_and_0_HJ_loss:\n","            pnts = random_sampler(N, fix_T=True, T0=0).to(device)\n","        else:\n","            pnts = random_sampler(N, dim, T=T).to(device)\n","        pred_sol = network_sol(pnts)\n","        grad_pred_sol = gradient(pnts,pred_sol)[:,1:]\n","        terminal_x = pnts[:,1:] + (T - pnts[:,[0]])*grad_pred_sol\n","        terminal_xt = torch.cat((T * torch.ones((N,1)).to(device), terminal_x), 1)\n","        bckwd_loss_implicithj = ((pred_sol + 0.5*(T - pnts[:,[0]])*torch.sum(grad_pred_sol*grad_pred_sol,dim=1,keepdim=True) - network_sol(terminal_xt))**2).mean()\n","\n","        losses['Implicit HJ loss'] = fwd_loss_implicithj.item() + bckwd_loss_implicithj.item()\n","        loss_implicithj = fwd_loss_implicithj + bckwd_loss_implicithj\n","        loss += regularizer_coord[regularizer_index['implicithjxt0t']] * loss_implicithj\n","        Implicit_HJ_loss_list.append(loss_implicithj.cpu().detach().numpy())\n","\n","    if 'mmd_negnorm_0t' in regularizer_type:\n","        for i in range(10):\n","            init_data = init_data_list[i]\n","            if i < 9:\n","                target_data = target_data_list[i+1]\n","            else:\n","                target_data = target_data_list[0]\n","            init_data_indices_for_0T = torch.tensor(np.random.choice(init_data.shape[0], sub_batch_size, False))\n","            init_data_pnts_for_0T = torch.tensor(init_data[init_data_indices_for_0T,:], dtype=torch.float32, requires_grad=True).to(device)\n","            init_spatialtemporal_pnts_for_0T = torch.tensor(torch.cat((torch.zeros((sub_batch_size,1)).to(device), init_data_pnts_for_0T), 1), requires_grad=True)\n","            pred_sol_for_0T = network_sol(init_spatialtemporal_pnts_for_0T)\n","            transported_pnts_to_T = init_spatialtemporal_pnts_for_0T[:, 1:] + T * gradient(init_spatialtemporal_pnts_for_0T, pred_sol_for_0T)[:,1:]\n","            data_indices_for_0T = torch.tensor(np.random.choice(target_data.shape[0], sub_batch_size, False))\n","            data_pnts_for_0T = torch.tensor(target_data[data_indices_for_0T,:],dtype=torch.float32, requires_grad=True).to(device)\n","            loss_MMD_T = MMD_negnorm(transported_pnts_to_T, data_pnts_for_0T)\n","            if epoch % 10 == 0:\n","                MMD_loss_list_0T.append(loss_MMD_T.item())\n","            data_indices_for_T0 = torch.tensor(np.random.choice(target_data.shape[0], sub_batch_size, False))\n","            data_pnts_for_T0 = torch.tensor(target_data[data_indices_for_T0,:], dtype=torch.float32, requires_grad=True).to(device)\n","            spatialtemporal_pnts = torch.tensor(torch.cat((T * torch.ones((sub_batch_size,1)).to(device), data_pnts_for_T0), 1), requires_grad=True)\n","            pred_sol_for_T0 = network_sol(spatialtemporal_pnts)\n","            transported_pnts_to_0 = spatialtemporal_pnts[:,1:] - T * gradient(spatialtemporal_pnts, pred_sol_for_T0)[:,1:]\n","            init_data_indices_for_T0 = torch.tensor(np.random.choice(init_data.shape[0], sub_batch_size, False))\n","            init_data_pnts_for_T0 = torch.tensor(init_data[init_data_indices_for_T0,:],dtype=torch.float32, requires_grad=True).to(device)\n","            loss_MMD_0 = MMD_negnorm(transported_pnts_to_0, init_data_pnts_for_T0)\n","            if epoch % 10 == 0:\n","                MMD_loss_list_T0.append(loss_MMD_0.item())\n","            loss_MMD = loss_MMD_0 + loss_MMD_T\n","            loss += regularizer_coord[regularizer_index['mmd_negnorm_0t']] * loss_MMD\n","\n","    # if with_OTloss:\n","    #     # OT loss functional\n","    #     displacement_0T = (transported_pnts_to_T - init_data_pnts_for_0T)/T\n","    #     OT_loss_0T = torch.sum(displacement_0T * displacement_0T, 1).mean() / 2\n","    #     # OT_dist_0T_list.append(OT_loss_0T.cpu().detach().numpy())\n","    #     displacement_T0 = (transported_pnts_to_0 - init_data_pnts_for_T0)/T\n","    #     OT_loss_T0 = torch.sum(displacement_T0 * displacement_T0, 1).mean() / 2\n","    #     # OT_dist_T0_list.append(OT_loss_T0.cpu().detach().numpy())\n","    #     loss += weight_OTloss * (OT_loss_0T + OT_loss_T0)\n","\n","    ## -------------------------------------------------------------------------------\n","    ## Update parameters\n","    ## -------------------------------------------------------------------------------\n","    optimizer.zero_grad()\n","    loss.backward()\n","    optimizer.step()\n","\n","    ## -------------------------------------------------------------------------------\n","    ## Comput OT distance. T * \\int |\\nabla u(x)|^2/2 \\rho_0(x) dx\n","    ## -------------------------------------------------------------------------------\n","    # OT_loss_0T = torch.tensor([0.0]).to(device)\n","    # for i in range(5):\n","    #     init_data = init_data_list[i]\n","    #     init_data_indices_for_0T = torch.tensor(np.random.choice(init_data.shape[0], batch_size_OT, False))\n","    #     init_data_pnts_for_0T = torch.tensor(init_data[init_data_indices_for_0T,:], dtype=torch.float32, requires_grad=True)\n","    #     init_spatialtemporal_pnts_for_0T = torch.tensor(torch.cat((torch.zeros((batch_size_OT,1)).to(device), init_data_pnts_for_0T), 1), requires_grad=True).to(device)\n","    #     pred_sol_for_0T = network_sol(init_spatialtemporal_pnts_for_0T)\n","    #     displacement_0T = gradient(init_spatialtemporal_pnts_for_0T, pred_sol_for_0T)[:,1:]\n","    #     OT_loss_0T = OT_loss_0T + T * torch.sum(displacement_0T * displacement_0T, 1).mean() / 2\n","    # OT_dist_0T_list.append(OT_loss_0T.cpu().detach())\n","    # print(\"OT distance from 0 to T: {}\".format(OT_loss_0T))\n","\n","    # OT_loss_T0 = torch.tensor([0.0]).to(device)\n","    # for i in range(5):\n","    #     target_data = target_data_list[i]\n","    #     data_indices_for_T0 = torch.tensor(np.random.choice(target_data.shape[0], batch_size_OT, False))\n","    #     data_pnts_for_T0 = torch.tensor(target_data[data_indices_for_T0,:], dtype=torch.float32, requires_grad=True)\n","    #     spatialtemporal_pnts = torch.tensor(torch.cat((T * torch.ones((batch_size_OT,1)).to(device), data_pnts_for_T0), 1), requires_grad=True).to(device)\n","    #     pred_sol_for_T0 = network_sol(spatialtemporal_pnts)\n","    #     displacement_T0 = gradient(spatialtemporal_pnts, pred_sol_for_T0)[:,1:]\n","    #     OT_loss_T0 = OT_loss_T0 + T * torch.sum(displacement_T0 * displacement_T0, 1).mean() / 2\n","    # OT_dist_T0_list.append(OT_loss_T0.cpu().detach())\n","    # print(\"OT distance from T to 0: {}\".format(OT_loss_T0))\n","\n","    ## -------------------------------------------------------------------------------\n","    ## Validation\n","    ## -------------------------------------------------------------------------------\n","    if (epoch+1) % val_frequency == 0:\n","      # accuracy & plots 0 ------> T\n","      plot_latent_samples_list = []\n","      for i in range(10):\n","          init_data = init_data_list_std_for_test[i]\n","          sample_size = init_data.size()[0]\n","          init_data_indices = torch.tensor(np.random.choice(init_data.shape[0], sample_size, False))\n","          init_data_pnts = torch.tensor(init_data[init_data_indices,:], dtype=torch.float32, requires_grad=True)\n","          init_spatialtemporal_pnts = torch.tensor(torch.cat((torch.zeros((sample_size,1)).to(device), init_data_pnts), 1), requires_grad=True).to(device)\n","          pred_sol = network_sol(init_spatialtemporal_pnts)\n","          transported_pnts_to_T = init_spatialtemporal_pnts[:, 1:] + T * gradient(init_spatialtemporal_pnts, pred_sol)[:,1:]\n","          if normalization_type == 'PCA':\n","              recovered_generated_target_latent_samples = recover_latent_samples_using_PCA(transported_pnts_to_T)\n","          else:\n","              recovered_generated_target_latent_samples = recover_latent_samples_using_std(transported_pnts_to_T)\n","          plot_latent_samples_list.append(recovered_generated_target_latent_samples)\n","\n","          if i < 9:\n","              target_data = target_data_list[i+1]\n","          else:\n","              target_data = target_data_list[0]\n","          data_indices = torch.tensor(np.random.choice(target_data.shape[0], sample_size, False))\n","          data_pnts = target_data[data_indices,:]\n","          #plot in first 6 PCA directions\n","          if i < 9:\n","              digits_string = '{}{}'.format(i, i+1)\n","          else:\n","              digits_string = '{}{}'.format(i, 0)\n","          if (epoch+1) % plot_detail_ot == 0:\n","              for index in range(6):\n","                  plot_latent_samples_n_OT_map(torch.tensor(data_pnts), torch.tensor(init_data_pnts), transported_pnts_to_T, index, index+1, epoch, normalization_type, '0_to_T', digits_string, exp_dir)\n","\n","      plot_img_list = []\n","      for latent_sample in plot_latent_samples_list:\n","          decoded_generated_target = trained_decoder(latent_sample.cpu().detach().numpy())\n","          plot_img_list.append(decoded_generated_target)\n","      if (epoch+1) % plot_detail_mnist == 0:\n","          plot_mnist_init_target(plot_img_list[0], plot_img_list[1], plot_img_list[2], plot_img_list[3], plot_img_list[4], plot_img_list[5], plot_img_list[6], plot_img_list[7], plot_img_list[8], plot_img_list[9], epoch, 10, \"[0 to T] generated MNIST digits (conditioned on class)\", '0_to_T', exp_dir=exp_dir)\n","\n","      # check accuracy\n","      ave_accuracy = 0.0\n","      model.eval()\n","      for idx in range(10):\n","          reconstructed_img_vae = plot_img_list[idx]\n","          reconstructed_img_vae = 2.0 * reconstructed_img_vae - 1.0\n","          reconstructed_img_vae = np.array(reconstructed_img_vae)\n","          digit_idx = torch.from_numpy(reconstructed_img_vae)\n","          digit_idx = digit_idx.squeeze()\n","          digit_idx = digit_idx.unsqueeze(1)\n","          if idx < 9:\n","             accuracy, predicted = test_accuracy_fashion_MNIST(model, digit_idx, idx+1)\n","          else:\n","             accuracy, predicted = test_accuracy_fashion_MNIST(model, digit_idx, 0)\n","          print(accuracy)\n","          ave_accuracy += accuracy\n","      ave_accuracy /= 10\n","      print(\"=================================================================================================\")\n","      print(\"[iter {}] average accuracy on transporting 0,1,2,3,4,5,6,7,8,9 to 1,2,3,4,5,6,7,8,9,0: {}\".format(epoch, ave_accuracy))\n","      accuracy_0T_list.append(ave_accuracy)\n","      print(\"=================================================================================================\")\n","\n","      # accuracy & plots T ------> 0\n","      plot_latent_samples_list = []\n","      for i in range(10):\n","          target_data = target_data_list_std_for_test[i]\n","          sample_size = target_data.size()[0]\n","          target_data_indices = torch.tensor(np.random.choice(target_data.shape[0], sample_size, False))\n","          target_data_pnts = torch.tensor(target_data[target_data_indices,:], dtype=torch.float32, requires_grad=True)\n","          target_spatialtemporal_pnts = torch.tensor(torch.cat((T*torch.ones((sample_size,1)).to(device), target_data_pnts), 1), requires_grad=True).to(device)\n","          pred_sol = network_sol(target_spatialtemporal_pnts)\n","          transported_pnts_to_0 = target_spatialtemporal_pnts[:, 1:] - T * gradient(target_spatialtemporal_pnts, pred_sol)[:,1:]\n","          if normalization_type == 'PCA':\n","              recovered_generated_init_latent_samples = recover_latent_samples_using_PCA(transported_pnts_to_0)\n","          else:\n","              recovered_generated_init_latent_samples = recover_latent_samples_using_std(transported_pnts_to_0)\n","          plot_latent_samples_list.append(recovered_generated_init_latent_samples)\n","\n","          if i > 0:\n","              data_pnts = init_data_list[i-1]\n","          else:\n","              data_pnts = init_data_list[9]\n","          data_indices = torch.tensor(np.random.choice(data_pnts.shape[0], sample_size, False))\n","          data_pnts = data_pnts[data_indices,:]\n","          #plot in first 6 PCA directions\n","          if i > 0:\n","              digits_string = '{}{}'.format(i, i-1)\n","          else:\n","              digits_string = '{}{}'.format(i, 9)\n","          if (epoch+1) % plot_detail_ot == 0:\n","              for index in range(6):\n","                  plot_latent_samples_n_OT_map(torch.tensor(data_pnts), torch.tensor(target_data_pnts), transported_pnts_to_0, index, index+1, epoch, normalization_type, 'T_to_0', digits_string, exp_dir)\n","\n","      plot_img_list = []\n","      for latent_sample in plot_latent_samples_list:\n","          decoded_generated_init = trained_decoder(latent_sample.cpu().detach().numpy())\n","          plot_img_list.append(decoded_generated_init)\n","      if (epoch+1) % plot_detail_mnist == 0:\n","         plot_mnist_init_target(plot_img_list[0], plot_img_list[1], plot_img_list[2], plot_img_list[3], plot_img_list[4], plot_img_list[5], plot_img_list[6], plot_img_list[7], plot_img_list[8], plot_img_list[9], epoch, 10, \"[T to 0] generated MNIST digits (conditioned on class)\", 'T_to_0',  exp_dir=exp_dir)\n","\n","      # check accuracy\n","      ave_accuracy = 0.0\n","      model.eval()\n","      for idx in range(10):\n","          reconstructed_img_vae = plot_img_list[idx]\n","          reconstructed_img_vae = 2.0 * reconstructed_img_vae - 1.0\n","          reconstructed_img_vae = np.array(reconstructed_img_vae)\n","          digit_idx = torch.from_numpy(reconstructed_img_vae)\n","          digit_idx = digit_idx.squeeze()\n","          digit_idx = digit_idx.unsqueeze(1)\n","          if idx > 0:\n","              accuracy, predicted = test_accuracy_fashion_MNIST(model, digit_idx, idx-1)\n","          else:\n","              accuracy, predicted = test_accuracy_fashion_MNIST(model, digit_idx, 9)\n","          print(accuracy)\n","          ave_accuracy += accuracy\n","      ave_accuracy /= 10\n","      print(\"=================================================================================================\")\n","      print(\"[iter {}] average accuracy on transporting 1,2,3,4,5,6,7,8,9,0 to 0,1,2,3,4,5,6,7,8,9: {}\".format(epoch, ave_accuracy))\n","      accuracy_T0_list.append(ave_accuracy)\n","      print(\"=================================================================================================\")\n","\n","      if save_model:\n","          torch.save({'state_dict': network_sol.state_dict(),}, os.path.join(exp_dir, f'model/weight_sol_ep{epoch}.pth'))\n","\n","\n","plt.figure(figsize=(10,10))\n","plt.plot(MMD_loss_list_0T, label=\"OT from 0 to T\")\n","plt.plot(MMD_loss_list_T0, label=\"OT from T to 0\")\n","plt.legend( fontsize = 12 )\n","plt.title(\"MMD - iter\", fontsize=18)\n","plt.savefig(os.path.join(exp_dir,f'MMD_loss_list'), dpi=200, bbox_inches='tight', pad_inches=0)\n","plt.close()\n","\n","# plt.figure(figsize=(10,10))\n","# plt.plot(torch.tensor(OT_dist_0T_list).cpu())\n","# plt.title(\"OT_dist_0toT - iter\", fontsize=18)\n","# plt.savefig(os.path.join(exp_dir,f'OT_dist_0T_list'), dpi=200, bbox_inches='tight', pad_inches=0)\n","# plt.close()\n","\n","# plt.figure(figsize=(10,10))\n","# plt.plot(torch.tensor(OT_dist_T0_list).cpu())\n","# plt.title(\"OT_dist_Tto0 - iter\", fontsize=18)\n","# plt.savefig(os.path.join(exp_dir,f'OT_dist_T0_list'), dpi=200, bbox_inches='tight', pad_inches=0)\n","# plt.close()\n","\n","plt.figure(figsize=(10,10))\n","plt.plot(Implicit_HJ_loss_list)\n","plt.title(\"Implicit HJ loss - iter\", fontsize=18)\n","plt.savefig(os.path.join(exp_dir,f'HJloss_list'), dpi=200, bbox_inches='tight', pad_inches=0)\n","plt.close()\n","\n","plt.figure(figsize=(10,10))\n","plt.plot(accuracy_0T_list, label='Accuracy @ T')\n","plt.plot(accuracy_T0_list, label='Accuracy @ 0')\n","plt.title(\"Average_accuracy_for_OT\", fontsize=18)\n","plt.legend(fontsize=15)\n","plt.savefig(os.path.join(exp_dir,f'accuracy'), dpi=200, bbox_inches='tight', pad_inches=0)\n","plt.close()\n","\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"qKGKT_HfNOrO","outputId":"4dd76941-1911-442e-c8d8-2b2516add1e8","collapsed":true,"cellView":"form"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Iteration: 0\n"]},{"output_type":"stream","name":"stderr","text":["/tmp/ipython-input-4129383417.py:45: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n","  return torch.tensor(pnts, dtype=torch.float32, requires_grad=True)\n","/tmp/ipython-input-4129383417.py:167: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n","  init_data_pnts_for_0T = torch.tensor(init_data[init_data_indices_for_0T,:], dtype=torch.float32, requires_grad=True).to(device)\n","/tmp/ipython-input-4129383417.py:168: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n","  init_spatialtemporal_pnts_for_0T = torch.tensor(torch.cat((torch.zeros((sub_batch_size,1)).to(device), init_data_pnts_for_0T), 1), requires_grad=True)\n","/tmp/ipython-input-4129383417.py:172: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n","  data_pnts_for_0T = torch.tensor(target_data[data_indices_for_0T,:],dtype=torch.float32, requires_grad=True).to(device)\n","/tmp/ipython-input-4129383417.py:177: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n","  data_pnts_for_T0 = torch.tensor(target_data[data_indices_for_T0,:], dtype=torch.float32, requires_grad=True).to(device)\n","/tmp/ipython-input-4129383417.py:178: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n","  spatialtemporal_pnts = torch.tensor(torch.cat((T * torch.ones((sub_batch_size,1)).to(device), data_pnts_for_T0), 1), requires_grad=True)\n","/tmp/ipython-input-4129383417.py:182: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n","  init_data_pnts_for_T0 = torch.tensor(init_data[init_data_indices_for_T0,:],dtype=torch.float32, requires_grad=True).to(device)\n"]}]}],"metadata":{"accelerator":"GPU","colab":{"collapsed_sections":["L961ddlbCWcl","dyDUTgV2wfjG","_WmQYYP-gcpj","tKNgKlzzsnQU","IBGEhneOs0MU"],"gpuType":"A100","provenance":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":0}