{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03cfbae6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "import torchvision.datasets as datasets\n",
    "from time import time\n",
    "import os\n",
    "from collections import OrderedDict\n",
    "import matplotlib.pyplot as plt\n",
    "import import_ipynb\n",
    "import seaborn as sns\n",
    "import Step_1_Representation_learning\n",
    "import torchvision.transforms as tvt\n",
    "import torchvision\n",
    "import warnings\n",
    "from tqdm import trange\n",
    "import random\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "def seed_everything(seed):\n",
    "    \"\"\"\n",
    "    Changes the seed for reproducibility. \n",
    "    \"\"\"\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "\n",
    "seed_everything(128)\n",
    "\n",
    "device = torch.device(\"cuda:2\")\n",
    "#step: load model\n",
    "image_size = 64\n",
    "batch_size = 32\n",
    "n_flow = 32\n",
    "n_block = 4\n",
    "\n",
    "dataset = torchvision.datasets.CelebA(\"../celeba/datasets/\",split='train', transform=tvt.Compose([\n",
    "                                   tvt.Resize((image_size,image_size)),\n",
    "                                   tvt.CenterCrop(image_size),\n",
    "                                  tvt.ToTensor()\n",
    "                              ]))\n",
    "testset =  torchvision.datasets.CelebA(\"../celeba/datasets/\",split='test', transform=tvt.Compose([\n",
    "                                   tvt.Resize((image_size,image_size)),\n",
    "                                   tvt.CenterCrop(image_size),\n",
    "                                   tvt.ToTensor()\n",
    "                               ]))\n",
    "\n",
    "training_data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=True)\n",
    "test_data_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, drop_last=True)\n",
    "weight_dir = r'train/checkpoint/model_495001.pt'\n",
    "net = Step_1_Representation_learning.Glow(3, n_flow, n_block)\n",
    "net.load_state_dict(torch.load(weight_dir, map_location=device))\n",
    "net.to(device)\n",
    "net.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0850243a",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_dir = 'celebA_representation'\n",
    "current_directory = os.getcwd()\n",
    "\n",
    "folder_path = os.path.join(current_directory, save_dir)\n",
    "\n",
    "if not os.path.exists(folder_path):\n",
    "    os.makedirs(folder_path)\n",
    "    print(f'Directory {folder_path} created.')\n",
    "else:\n",
    "    print(f'Directory {folder_path} already exists.')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5bcd0ffe",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "\n",
    "train_data = []\n",
    "attribute_list = []\n",
    "\n",
    "examples = iter(training_data_loader)\n",
    "n_batches = len(examples)\n",
    "n_bits = 5\n",
    "n_bins = 2.0 ** n_bits\n",
    "for i in trange(n_batches):\n",
    "    data , label = next(examples)\n",
    "    data = data.to(device)\n",
    "    data_ = data * 255\n",
    "    if n_bits < 8:\n",
    "        data__ = torch.floor(data_ / 2 ** (8 - n_bits))\n",
    "    data_f = data__ / n_bins - 0.5\n",
    "    log_p, logdet, z = net(data_f/ n_bins)\n",
    "    latnet_0 = z[0].view(-1,6144)\n",
    "    latnet_1 = z[1].view(-1,3072)\n",
    "    latnet_2 = z[2].view(-1,1536)\n",
    "    latnet_3 = z[3].view(-1,1536)\n",
    "    input_latent = [latnet_0, latnet_1, latnet_2, latnet_3]\n",
    "    latent_tensor = torch.cat(input_latent, 1) \n",
    "    train_data.append(latent_tensor.detach().cpu().numpy())\n",
    "    attribute_list.append(label.detach().cpu().numpy())\n",
    "\n",
    "data_np = np.array(train_data)\n",
    "attribute_np = np.array(attribute_list)\n",
    "print(f'Mapping training-set finished')\n",
    "\n",
    "test_data = []\n",
    "test_attribute_list = []\n",
    "test_examples = iter(test_data_loader)\n",
    "n_batches = len(test_examples)\n",
    "for i in trange(n_batches):\n",
    "    data , label = next(test_examples)\n",
    "    data = data.to(device)\n",
    "    data_ = data * 255\n",
    "    if n_bits < 8:\n",
    "        data__ = torch.floor(data_ / 2 ** (8 - n_bits))\n",
    "    data_f = data__ / n_bins - 0.5\n",
    "    log_p, logdet, z = net(data_f/ n_bins)\n",
    "    latnet_0 = z[0].view(-1,6144)\n",
    "    latnet_1 = z[1].view(-1,3072)\n",
    "    latnet_2 = z[2].view(-1,1536)\n",
    "    latnet_3 = z[3].view(-1,1536)\n",
    "    input_latent = [latnet_0, latnet_1, latnet_2, latnet_3]\n",
    "    latent_tensor = torch.cat(input_latent, 1) \n",
    "    test_data.append(latent_tensor.detach().cpu().numpy())\n",
    "    test_attribute_list.append(label.detach().cpu().numpy())\n",
    "\n",
    "\n",
    "test_data_np = np.array(test_data)\n",
    "test_attribute_np = np.array(test_attribute_list)\n",
    "print(f'Mapping test set finished')\n",
    "\n",
    "training = data_np\n",
    "att = attribute_np\n",
    "bs = training.shape[0]\n",
    "\n",
    "training = training.reshape(bs*batch_size , 12288)\n",
    "att = att.reshape(bs*batch_size , 40)\n",
    "\n",
    "\n",
    "test = test_data_np.reshape(len(test_examples)*batch_size, 12288)\n",
    "test_attribute = test_attribute_np.reshape(len(test_examples)*batch_size , 40)\n",
    "\n",
    "test_input_save_path = os.path.join('./',save_dir+'/test')\n",
    "test_attribute_path = os.path.join('./',save_dir+'/test_attribute')\n",
    "\n",
    "np.save(test_input_save_path, test)\n",
    "np.save(test_attribute_path, test_attribute)\n",
    "\n",
    "x_train, x_valid, y_train, y_valid = train_test_split(training, att, test_size=0.02, train_size=0.98)\n",
    "\n",
    "train_input_save_path = os.path.join('./',save_dir+'/train')\n",
    "train_attribute_path = os.path.join('./',save_dir+'/train_attribute')\n",
    "valid_input_save_path = os.path.join('./',save_dir+'/valid')\n",
    "valid_attribute_path = os.path.join('./',save_dir+'/valid_attribute')\n",
    "\n",
    "\n",
    "np.save(train_input_save_path, x_train)\n",
    "np.save(train_attribute_path, y_train)\n",
    "np.save(valid_input_save_path, x_valid)\n",
    "np.save(valid_attribute_path, y_valid)\n",
    "\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:DLcourse]",
   "language": "python",
   "name": "conda-env-DLcourse-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
