{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9c8ede4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "import torchvision.datasets as datasets\n",
    "from time import time\n",
    "import os\n",
    "from collections import OrderedDict\n",
    "import matplotlib.pyplot as plt\n",
    "import import_ipynb\n",
    "import seaborn as sns\n",
    "import Step_1_Representation_learning \n",
    "import torchvision.transforms as tvt\n",
    "import torchvision\n",
    "import warnings\n",
    "from tqdm import trange\n",
    "import random\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import math\n",
    "import torch.nn.functional as F\n",
    "import torchvision.models as models\n",
    "from torch.utils.data import DataLoader, Subset, Dataset\n",
    "from torchvision.utils import make_grid\n",
    "from PIL import Image\n",
    "from tqdm import tqdm\n",
    "\n",
    "\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "def seed_everything(seed):\n",
    "    \"\"\"\n",
    "    Changes the seed for reproducibility. \n",
    "    \"\"\"\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "\n",
    "seed_everything(128)\n",
    "\n",
    "device = torch.device(\"cuda:0\")\n",
    "#step: load model\n",
    "image_size = 64\n",
    "batch_size = 12\n",
    "n_flow = 32\n",
    "n_block = 4\n",
    "\n",
    "root_dir =  '../UTKFace/'\n",
    "csv_file = r'metadata_spu.csv'\n",
    "data_frame = pd.read_csv(csv_file)\n",
    "\n",
    "\n",
    "class CustomDataset(Dataset):\n",
    "    def __init__(self, csv_file, root_dir,split, transform):\n",
    "        self.data_frame = pd.read_csv(csv_file)\n",
    "        self.data_frame = self.data_frame[self.data_frame['split'] == split].reset_index(drop=True)\n",
    "        self.root_dir = root_dir\n",
    "        self.transform = transform\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data_frame)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        img_name = os.path.join(self.root_dir, self.data_frame.iloc[idx, 0])\n",
    "        image = Image.open(img_name)\n",
    "        target = self.data_frame['gender'][idx]\n",
    "        target = torch.tensor(target) # assuming target is a class index\n",
    "        sensitive = self.data_frame['race'][idx]\n",
    "        sensitive = torch.tensor(sensitive)\n",
    "        if self.transform:\n",
    "            image = self.transform(image)\n",
    "\n",
    "        return image, target, sensitive\n",
    "\n",
    "\n",
    "transform = tvt.Compose([\n",
    "    tvt.Resize((image_size,image_size)),\n",
    "    tvt.ToTensor()])\n",
    "\n",
    "training_set = CustomDataset(csv_file=csv_file, root_dir=root_dir, split=0, transform=transform)\n",
    "training_data_loader = DataLoader(training_set, batch_size=batch_size, shuffle=False)\n",
    "\n",
    "valid_set = CustomDataset(csv_file=csv_file, root_dir=root_dir, split=1, transform=transform)\n",
    "valid_data_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=False)\n",
    "\n",
    "test_set = CustomDataset(csv_file=csv_file, root_dir=root_dir, split=2, transform=transform)\n",
    "test_data_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)\n",
    "\n",
    "\n",
    "weight_dir = r'UTK_GLOW/checkpoint/model_033001.pt'\n",
    "net = Step_1_Representation_learning.Glow(3, n_flow, n_block)\n",
    "net.load_state_dict(torch.load(weight_dir, map_location=device))\n",
    "net.to(device)\n",
    "net.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "732b6911",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_dir = 'UTK_representation'\n",
    "current_directory = os.getcwd()\n",
    "\n",
    "folder_path = os.path.join(current_directory, save_dir)\n",
    "\n",
    "    # Check if the directory exists\n",
    "if not os.path.exists(folder_path):\n",
    "    # If the directory does not exist, create it\n",
    "    os.makedirs(folder_path)\n",
    "    print(f'Directory {folder_path} created.')\n",
    "else:\n",
    "    print(f'Directory {folder_path} already exists.')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49f3a186",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "\n",
    "train_data = []\n",
    "train_target = []\n",
    "\n",
    "examples = iter(training_data_loader)\n",
    "n_batches = len(examples)\n",
    "n_bits = 5\n",
    "n_bins = 2.0 ** n_bits\n",
    "for i in trange(n_batches):\n",
    "    data , label, _ = next(examples)\n",
    "    data = data.to(device)\n",
    "    data_ = data * 255\n",
    "    if n_bits < 8:\n",
    "        data__ = torch.floor(data_ / 2 ** (8 - n_bits))\n",
    "    data_f = data__ / n_bins - 0.5\n",
    "    log_p, logdet, z = net(data_f/ n_bins)\n",
    "    latnet_0 = z[0].view(-1,6144)\n",
    "    latnet_1 = z[1].view(-1,3072)\n",
    "    latnet_2 = z[2].view(-1,1536)\n",
    "    latnet_3 = z[3].view(-1,1536)\n",
    "    input_latent = [latnet_0, latnet_1, latnet_2, latnet_3]\n",
    "    latent_tensor = torch.cat(input_latent, 1) \n",
    "    train_data.extend(latent_tensor.detach().cpu().numpy())\n",
    "    train_target.extend(label.detach().cpu().numpy())\n",
    "\n",
    "data_np = np.array(train_data)\n",
    "target_np = np.array(train_target)\n",
    "print(f'Mapping training set finished')\n",
    "###\n",
    "test_data = []\n",
    "test_label_list = []\n",
    "test_sensitive_list = []\n",
    "test_examples = iter(test_data_loader)\n",
    "n_batches = len(test_examples)\n",
    "for i in trange(n_batches):\n",
    "    data , label, sensitive = next(test_examples)\n",
    "    data = data.to(device)\n",
    "    data_ = data * 255\n",
    "    if n_bits < 8:\n",
    "        data__ = torch.floor(data_ / 2 ** (8 - n_bits))\n",
    "    data_f = data__ / n_bins - 0.5\n",
    "    log_p, logdet, z = net(data_f/ n_bins)\n",
    "    latnet_0 = z[0].view(-1,6144)\n",
    "    latnet_1 = z[1].view(-1,3072)\n",
    "    latnet_2 = z[2].view(-1,1536)\n",
    "    latnet_3 = z[3].view(-1,1536)\n",
    "    input_latent = [latnet_0, latnet_1, latnet_2, latnet_3]\n",
    "    latent_tensor = torch.cat(input_latent, 1) \n",
    "    test_data.extend(latent_tensor.detach().cpu().numpy())\n",
    "    test_label_list.extend(label.detach().cpu().numpy())\n",
    "    test_sensitive_list.extend(sensitive.detach().cpu().numpy())\n",
    "\n",
    "\n",
    "test_data_np = np.array(test_data)\n",
    "test_label_np = np.array(test_label_list)\n",
    "test_sensitive_np = np.array(test_sensitive_list)\n",
    "print(f'Mapping test set finished')\n",
    "\n",
    "valid_data = []\n",
    "valid_label_list = []\n",
    "valid_sensitive_list = []\n",
    "valid_examples = iter(valid_data_loader)\n",
    "n_batches = len(valid_examples)\n",
    "for i in trange(n_batches):\n",
    "    data , label, sensitive = next(valid_examples)\n",
    "    data = data.to(device)\n",
    "    data_ = data * 255\n",
    "    if n_bits < 8:\n",
    "        data__ = torch.floor(data_ / 2 ** (8 - n_bits))\n",
    "    data_f = data__ / n_bins - 0.5\n",
    "    #log_p, logdet, z = net(data_f + torch.rand_like(data_f) / n_bins)\n",
    "    log_p, logdet, z = net(data_f/ n_bins)\n",
    "    latnet_0 = z[0].view(-1,6144)\n",
    "    latnet_1 = z[1].view(-1,3072)\n",
    "    latnet_2 = z[2].view(-1,1536)\n",
    "    latnet_3 = z[3].view(-1,1536)\n",
    "    input_latent = [latnet_0, latnet_1, latnet_2, latnet_3]\n",
    "    latent_tensor = torch.cat(input_latent, 1) \n",
    "    valid_data.extend(latent_tensor.detach().cpu().numpy())\n",
    "    valid_label_list.extend(label.detach().cpu().numpy())\n",
    "    valid_sensitive_list.extend(sensitive.detach().cpu().numpy())\n",
    "\n",
    "\n",
    "valid_data_np = np.array(valid_data)\n",
    "valid_label_np = np.array(valid_label_list)\n",
    "valid_sensitive_np = np.array(valid_sensitive_list)\n",
    "print(f'Mapping valid set finished')\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "training = data_np\n",
    "bs = training.shape[0]\n",
    "\n",
    "training = training\n",
    "label = target_np\n",
    "\n",
    "\n",
    "test = test_data_np\n",
    "test_label = test_label_np\n",
    "test_sensitive = test_sensitive_np\n",
    "\n",
    "valid = valid_data_np\n",
    "valid_label = valid_label_np\n",
    "valid_sensitive = valid_sensitive_np\n",
    "\n",
    "# revision\n",
    "test_input_save_path = os.path.join('./',save_dir+'/test')\n",
    "test_label_path = os.path.join('./',save_dir+'/test_label')\n",
    "test_sensitive_path = os.path.join('./',save_dir+'/test_sensitive')\n",
    "\n",
    "np.save(test_input_save_path, test)\n",
    "np.save(test_label_path, test_label)\n",
    "np.save(test_sensitive_path, test_sensitive)\n",
    "\n",
    "\n",
    "\n",
    "train_input_save_path = os.path.join('./',save_dir+'/train')\n",
    "train_label_path = os.path.join('./',save_dir+'/train_label')\n",
    "np.save(train_input_save_path, training)\n",
    "np.save(train_label_path, label)\n",
    "\n",
    "\n",
    "\n",
    "valid_input_save_path = os.path.join('./',save_dir+'/valid')\n",
    "valid_label_path = os.path.join('./',save_dir+'/valid_label')\n",
    "valid_sensitive_path = os.path.join('./',save_dir+'/valid_sensitive')\n",
    "\n",
    "\n",
    "\n",
    "np.save(valid_input_save_path, valid)\n",
    "np.save(valid_label_path, valid_label)\n",
    "np.save(valid_sensitive_path, valid_sensitive)\n",
    "\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:DLcourse]",
   "language": "python",
   "name": "conda-env-DLcourse-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
