{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c306bca0",
   "metadata": {},
   "source": [
    "# IMPORT STATEMENTS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c12d3f35",
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import tensorflow_datasets as tfds\n",
    "from tensorflow_examples.models.pix2pix import pix2pix\n",
    "\n",
    "import os\n",
    "import time\n",
    "import matplotlib.pyplot as plt\n",
    "from IPython.display import clear_output\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from tensorflow.keras.layers import Activation, Dense, Input\n",
    "from tensorflow.keras.layers import Conv2D, Flatten\n",
    "from tensorflow.keras.layers import Conv2DTranspose\n",
    "from tensorflow.keras.layers import LeakyReLU\n",
    "from tensorflow.keras.layers import concatenate\n",
    "from tensorflow.keras.optimizers import RMSprop\n",
    "from tensorflow.keras.models import Model\n",
    "from tensorflow.keras.models import load_model\n",
    "from tensorflow_addons.layers import InstanceNormalization\n",
    "from tensorflow.python.ops.numpy_ops import np_config\n",
    "import tensorflow_addons as tfa\n",
    "from __future__ import print_function, division\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import matplotlib.pyplot as plt\n",
    "import random\n",
    "from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply, GaussianNoise, Layer, LeakyReLU\n",
    "from tensorflow.keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D, AveragePooling2D\n",
    "from tensorflow.keras.layers import MaxPooling2D, concatenate, Concatenate, Lambda, Add\n",
    "from tensorflow.keras.layers import UpSampling2D, Conv2D, Conv2DTranspose\n",
    "from tensorflow.keras.models import Sequential, Model\n",
    "from tensorflow.keras.optimizers import Adam\n",
    "from tensorflow.keras import losses\n",
    "from tensorflow.keras.utils import to_categorical\n",
    "from tensorflow.keras.losses import BinaryCrossentropy\n",
    "from tensorflow.keras.initializers import TruncatedNormal, Constant, Zeros, Initializer\n",
    "from tensorflow.keras.activations import softplus\n",
    "\n",
    "import pandas as pd\n",
    "from csv import writer\n",
    "from functools import partial\n",
    "from tensorflow.keras.models import load_model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58430833",
   "metadata": {},
   "source": [
    "# PREPARE DATASET"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "249f579d",
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = list(np.load(\"./index_train.npy\"))\n",
    "idx_test = list(np.load(\"./index_test.npy\"))\n",
    "\n",
    "# provide correct path here\n",
    "real_images = np.load('../../celeba/data/celebA_64x64.npy').reshape((-1,64,64,3))/255.\n",
    "X_train = real_images[idx]\n",
    "X_test = real_images[idx_test]\n",
    "\n",
    "attr_df = pd.read_csv('../../celeba/data/attributes_list_full.csv').iloc[:,1:]\n",
    "attr_df = attr_df.replace(to_replace = -1, value = 0)\n",
    "attr_train = attr_df.values[idx]\n",
    "attr_test = attr_df.values[idx_test]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c19a7c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "blond_male_train_idx = np.where(np.logical_and(attr_train[:,9]==1, attr_train[:,20]==1))[0]\n",
    "blond_male_test_idx = np.where(np.logical_and(attr_test[:,9]==1, attr_test[:,20]==1))[0]\n",
    "\n",
    "X_blond_male_train = X_train[blond_male_train_idx]\n",
    "X_blond_male_test = X_test[blond_male_test_idx]\n",
    "\n",
    "y_blond_male_train = attr_train[blond_male_train_idx]\n",
    "y_blond_male_test = attr_test[blond_male_test_idx]\n",
    "\n",
    "nonblond_male_train_idx = np.where(np.logical_and(attr_train[:,9]==0, attr_train[:,20]==1))[0]\n",
    "nonblond_male_test_idx = np.where(np.logical_and(attr_test[:,9]==0, attr_test[:,20]==1))[0]\n",
    "\n",
    "X_nonblond_male_train = X_train[nonblond_male_train_idx]\n",
    "X_nonblond_male_test = X_test[nonblond_male_test_idx]\n",
    "\n",
    "y_nonblond_male_train = attr_train[nonblond_male_train_idx]\n",
    "y_nonblond_male_test = attr_test[nonblond_male_test_idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e28256ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "BUFFER_SIZE = 1000\n",
    "BATCH_SIZE = 16\n",
    "IMG_WIDTH = 64\n",
    "IMG_HEIGHT = 64\n",
    "img_shape = (64,64,3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ef7b7fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_blond_male = next(iter(X_blond_male_train))\n",
    "sample_nonblond_male = next(iter(X_nonblond_male_train))\n",
    "\n",
    "plt.subplot(121)\n",
    "plt.title('Blond Male')\n",
    "plt.imshow(sample_blond_male)\n",
    "\n",
    "plt.subplot(122)\n",
    "plt.title('Non-Blond Male')\n",
    "plt.imshow(sample_nonblond_male)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3b637650",
   "metadata": {},
   "source": [
    "# CYCLEGAN SETUP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb1147e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras.layers import Embedding\n",
    "from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply\n",
    "\n",
    "def encoder_layer(inputs,\n",
    "                  filters=16,\n",
    "                  kernel_size=3,\n",
    "                  strides=2,\n",
    "                  activation='relu',\n",
    "                  instance_norm=True):\n",
    "    \"\"\"Builds a generic encoder layer made of Conv2D-IN-LeakyReLU\n",
    "    IN is optional, LeakyReLU may be replaced by ReLU\n",
    "    \"\"\"\n",
    "\n",
    "    conv = Conv2D(filters=filters,\n",
    "                  kernel_size=kernel_size,\n",
    "                  strides=strides,\n",
    "                  padding='same')\n",
    "\n",
    "    x = inputs\n",
    "    if instance_norm:\n",
    "        x = InstanceNormalization()(x)\n",
    "    if activation == 'relu':\n",
    "        x = Activation('relu')(x)\n",
    "    else:\n",
    "        x = LeakyReLU(alpha=0.2)(x)\n",
    "    x = conv(x)\n",
    "    return x\n",
    "\n",
    "\n",
    "def decoder_layer(inputs,\n",
    "                  paired_inputs,\n",
    "                  filters=16,\n",
    "                  kernel_size=3,\n",
    "                  strides=2,\n",
    "                  activation='relu',\n",
    "                  instance_norm=True):\n",
    "    \"\"\"Builds a generic decoder layer made of Conv2D-IN-LeakyReLU\n",
    "    IN is optional, LeakyReLU may be replaced by ReLU\n",
    "    Arguments: (partial)\n",
    "    inputs (tensor): the decoder layer input\n",
    "    paired_inputs (tensor): the encoder layer output \n",
    "          provided by U-Net skip connection &\n",
    "          concatenated to inputs.\n",
    "    \"\"\"\n",
    "\n",
    "    conv = Conv2DTranspose(filters=filters,\n",
    "                           kernel_size=kernel_size,\n",
    "                           strides=strides,\n",
    "                           padding='valid')\n",
    "    x = inputs\n",
    "    print(x)\n",
    "    if instance_norm:\n",
    "        x = InstanceNormalization()(x)\n",
    "    if activation == 'relu':\n",
    "        x = Activation('relu')(x)\n",
    "    else:\n",
    "        x = LeakyReLU(alpha=0.2)(x)\n",
    "    x = conv(x)\n",
    "    x = concatenate([x, paired_inputs])\n",
    "    return x\n",
    "\n",
    "\n",
    "def build_generator(input_shape,\n",
    "                    output_shape=None,\n",
    "                    kernel_size=3,\n",
    "                    name=None):\n",
    "    \"\"\"The generator is a U-Network made of a 4-layer encoder\n",
    "    and a 4-layer decoder. Layer n-i is connected to layer i.\n",
    "    Arguments:\n",
    "    input_shape (tuple): input shape\n",
    "    output_shape (tuple): output shape\n",
    "    kernel_size (int): kernel size of encoder & decoder layers\n",
    "    name (string): name assigned to generator model\n",
    "    Returns:\n",
    "    generator (Model):\n",
    "    \"\"\"\n",
    "\n",
    "    inputs = Input(shape=input_shape)\n",
    "    label = Input(shape=(1,))\n",
    "    channels = int(output_shape[-1])\n",
    "    label_embedding = Flatten()(Embedding(2, np.prod(input_shape))(label))\n",
    "\n",
    "    flat_img = Flatten()(inputs) \n",
    "    model_input = multiply([flat_img, label_embedding])\n",
    "    d0 = Reshape(input_shape)(model_input)\n",
    "        \n",
    "    e1 = encoder_layer(d0,\n",
    "                       32,\n",
    "                       kernel_size=kernel_size,\n",
    "                       activation='leaky_relu',\n",
    "                       strides=1)\n",
    "    e2 = encoder_layer(e1,\n",
    "                       64,\n",
    "                       activation='leaky_relu',\n",
    "                       kernel_size=kernel_size)\n",
    "    e3 = encoder_layer(e2,\n",
    "                       128,\n",
    "                       activation='leaky_relu',\n",
    "                       kernel_size=kernel_size)\n",
    "    e4 = encoder_layer(e3,\n",
    "                       256,\n",
    "                       activation='leaky_relu',\n",
    "                       kernel_size=kernel_size)\n",
    "    e5 = encoder_layer(e4,\n",
    "                       512,\n",
    "                       activation='leaky_relu',\n",
    "                       kernel_size=kernel_size)\n",
    "\n",
    "    d1 = decoder_layer(e5,\n",
    "                       e4,\n",
    "                       256,\n",
    "                       kernel_size=5, strides = 1)\n",
    "    d2 = decoder_layer(d1,\n",
    "                       e3,\n",
    "                       128,\n",
    "                       kernel_size=9, strides = 1)\n",
    "    d3 = decoder_layer(d2,\n",
    "                       e2,\n",
    "                       64,\n",
    "                       kernel_size=2, strides = 2)\n",
    "    d4 = decoder_layer(d3,\n",
    "                       e1,\n",
    "                       32,\n",
    "                       kernel_size=2, strides = 2)\n",
    "#     d5 = decoder_layer(d4,\n",
    "#                        e1,\n",
    "#                        32,\n",
    "#                        kernel_size=2, strides = 1)\n",
    "    \n",
    "    outputs = Conv2DTranspose(channels,\n",
    "                              kernel_size=2,\n",
    "                              strides=2,\n",
    "                              activation='sigmoid',\n",
    "                              padding='same', dtype = 'float64')(d3)\n",
    "\n",
    "    generator = Model([inputs, label], outputs, name=name)\n",
    "\n",
    "    return generator\n",
    "\n",
    "def build_discriminator(input_shape,output_shape=1,\n",
    "                        kernel_size=3,\n",
    "                        patchgan=True,\n",
    "                        name=None):\n",
    "    \"\"\"The discriminator is a 4-layer encoder that outputs either\n",
    "    a 1-dim or a n x n-dim patch of probability that input is real \n",
    "    Arguments:\n",
    "    input_shape (tuple): input shape\n",
    "    kernel_size (int): kernel size of decoder layers\n",
    "    patchgan (bool): whether the output is a patch \n",
    "        or just a 1-dim\n",
    "    name (string): name assigned to discriminator model\n",
    "    Returns:\n",
    "    discriminator (Model):\n",
    "    \"\"\"\n",
    "\n",
    "    inputs = Input(shape=input_shape)\n",
    "    x = encoder_layer(inputs,\n",
    "                      32,\n",
    "                      kernel_size=kernel_size,\n",
    "                      activation='leaky_relu',\n",
    "                      instance_norm=False)\n",
    "    x = encoder_layer(x,\n",
    "                      64,\n",
    "                      kernel_size=kernel_size,\n",
    "                      activation='leaky_relu',\n",
    "                      instance_norm=False)\n",
    "    x = encoder_layer(x,\n",
    "                      128,\n",
    "                      kernel_size=kernel_size,\n",
    "                      activation='leaky_relu',\n",
    "                      instance_norm=False)\n",
    "    x = encoder_layer(x,\n",
    "                      256,\n",
    "                      kernel_size=kernel_size,\n",
    "                      strides=1,\n",
    "                      activation='leaky_relu',\n",
    "                      instance_norm=False)\n",
    "    x = encoder_layer(x,\n",
    "                      512,\n",
    "                      kernel_size=kernel_size,\n",
    "                      strides=1,\n",
    "                      activation='leaky_relu',\n",
    "                      instance_norm=False)\n",
    "\n",
    "    # if patchgan=True use nxn-dim output of probability\n",
    "    # else use 1-dim output of probability\n",
    "    if patchgan:\n",
    "        x = LeakyReLU(alpha=0.2)(x)\n",
    "        outputs = Conv2D(output_shape,\n",
    "                         kernel_size=kernel_size,\n",
    "                         strides=2,\n",
    "                         padding='same', dtype = 'float64')(x)\n",
    "    else:\n",
    "        x = Flatten()(x)\n",
    "        x = Dense(output_shape)(x)\n",
    "        outputs = Activation('linear', dtype = 'float64')(x)\n",
    "\n",
    "\n",
    "    discriminator = Model(inputs, outputs, name=name)\n",
    "\n",
    "    return discriminator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67f2da34",
   "metadata": {},
   "outputs": [],
   "source": [
    "generator_g = build_generator(img_shape, img_shape, 5, name = \"gen_g\")\n",
    "generator_f = build_generator(img_shape, img_shape, 5, name = \"gen_f\")\n",
    "discriminator_x = build_discriminator(img_shape, 1, 5, True, name = \"disc_x\")\n",
    "discriminator_y = build_discriminator(img_shape, 1, 5, True, name = \"disc_y\")\n",
    "\n",
    "generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)\n",
    "generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)\n",
    "\n",
    "discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)\n",
    "discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41da154f",
   "metadata": {},
   "outputs": [],
   "source": [
    "discriminator_hair = build_discriminator(img_shape, 1, 5, False, name='disc_hair')\n",
    "discriminator_gender = build_discriminator(img_shape, 1, 5, False, name='disc_gender')\n",
    "\n",
    "discriminator_hair_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)\n",
    "discriminator_gender_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09eb6e26",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train discs\n",
    "np_config.enable_numpy_behavior()\n",
    "batch_size = 256\n",
    "\n",
    "train_dataset = tf.data.Dataset.from_tensor_slices((X_train, attr_train))\n",
    "train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)\n",
    "epochs = 1\n",
    "for epoch in range(epochs):\n",
    "    print(\"\\nStart of epoch %d\" % (epoch,))\n",
    "    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):\n",
    "        with tf.GradientTape() as tape:\n",
    "            output1 = discriminator_hair(x_batch_train, training=True)\n",
    "            contrastive_loss1 = tfa.losses.ContrastiveLoss()(y_batch_train[:,9], output1)\n",
    "        grads1 = tape.gradient(contrastive_loss1, discriminator_hair.trainable_weights)\n",
    "        discriminator_hair_optimizer.apply_gradients(zip(grads1, discriminator_hair.trainable_weights))\n",
    "        if step % 20 == 0:\n",
    "            print(\"Training loss (for one batch) at step %d: %.4f\"% (step, float(contrastive_loss1)))\n",
    "            print(\"Seen so far: %s samples\" % ((step + 1) * batch_size))\n",
    "            \n",
    "    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):\n",
    "        with tf.GradientTape() as tape:\n",
    "            output2 = discriminator_gender(x_batch_train, training=True)\n",
    "            contrastive_loss2 = tfa.losses.ContrastiveLoss()(y_batch_train[:,20], output2)\n",
    "        grads2 = tape.gradient(contrastive_loss2, discriminator_gender.trainable_weights)\n",
    "        discriminator_gender_optimizer.apply_gradients(zip(grads2, discriminator_gender.trainable_weights))\n",
    "        if step % 20 == 0:\n",
    "            print(\"Training loss (for one batch) at step %d: %.4f\"% (step, float(contrastive_loss2)))\n",
    "            print(\"Seen so far: %s samples\" % ((step + 1) * batch_size))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7570868",
   "metadata": {},
   "outputs": [],
   "source": [
    "LAMBDA = 10\n",
    "loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)\n",
    "def discriminator_loss(real, generated):\n",
    "    \n",
    "    real_loss = loss_obj(tf.ones_like(real), real)\n",
    "    generated_loss = loss_obj(tf.zeros_like(generated), generated)\n",
    "    \n",
    "    total_disc_loss = real_loss + generated_loss\n",
    "    return total_disc_loss * 0.5\n",
    "\n",
    "def generator_loss(generated):\n",
    "    return loss_obj(tf.ones_like(generated), generated)\n",
    "\n",
    "def calc_cycle_loss(real_image, cycled_image):\n",
    "    loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))\n",
    "    return LAMBDA * loss1\n",
    "def identity_loss(real_image, same_image):\n",
    "    loss = tf.reduce_mean(tf.abs(real_image - same_image))\n",
    "    return LAMBDA * 0.5 * loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d37c5ec4",
   "metadata": {},
   "outputs": [],
   "source": [
    "to_nonblond = generator_g([sample_blond_male.reshape((1,64,64,3)),[0]])\n",
    "to_blond = generator_f([sample_nonblond_male.reshape((1,64,64,3)),[1]])\n",
    "plt.figure(figsize=(8, 8))\n",
    "contrast = 8\n",
    "\n",
    "imgs = [sample_blond_male, to_nonblond, sample_nonblond_male, to_blond]\n",
    "title = ['Blond', 'To Non Blond', 'Non Blond', 'To Blond']\n",
    "\n",
    "for i in range(len(imgs)):\n",
    "    plt.subplot(2, 2, i+1)\n",
    "    plt.title(title[i])\n",
    "    if i % 2 == 0:\n",
    "        plt.imshow(imgs[i])\n",
    "    else:\n",
    "        plt.imshow(imgs[i][0])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "410f983b",
   "metadata": {},
   "outputs": [],
   "source": [
    "EPOCHS = 10\n",
    "def generate_images(model, test_input):\n",
    "    prediction = model([test_input, [1]])\n",
    "    #return prediction[0]\n",
    "    plt.figure(figsize=(12, 12))\n",
    "    display_list = [test_input[0], prediction[0]]\n",
    "    title = ['Input Image', 'Predicted Image']\n",
    "\n",
    "    for i in range(2):\n",
    "        plt.subplot(1, 2, i+1)\n",
    "        plt.title(title[i])\n",
    "        # getting the pixel values between [0, 1] to plot it.\n",
    "        plt.imshow(display_list[i])\n",
    "        plt.axis('off')\n",
    "    plt.show()\n",
    "    return prediction[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8d3cc3fb",
   "metadata": {},
   "source": [
    "## TRAINING LOOP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b59c86bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def train_step(real_x, real_y, label_x, label_y):\n",
    "    # persistent is set to True because the tape is used more than\n",
    "    # once to calculate the gradients.\n",
    "    with tf.GradientTape(persistent=True) as tape:\n",
    "        # Generator G translates X -> Y\n",
    "        # Generator F translates Y -> X.\n",
    "    \n",
    "        fake_y = generator_g([real_x,[0]], training=True)\n",
    "        cycled_x = generator_f([fake_y,[1]], training=True)\n",
    "\n",
    "        fake_x = generator_f([real_y, [1]], training=True)\n",
    "        cycled_y = generator_g([fake_x, [0]], training=True)\n",
    "\n",
    "        # same_x and same_y are used for identity loss.\n",
    "        same_x = generator_f([real_x,[1]], training=True)\n",
    "        same_y = generator_g([real_y, [0]], training=True)\n",
    "\n",
    "        disc_real_x = discriminator_x(real_x, training=True)\n",
    "        disc_real_y = discriminator_y(real_y, training=True)\n",
    "\n",
    "        disc_fake_x = discriminator_x(fake_x, training=True)\n",
    "        disc_fake_y = discriminator_y(fake_y, training=True)\n",
    "        \n",
    "        disc_gender1 = discriminator_gender(real_x, training=False)\n",
    "        disc_gender2 = discriminator_gender(fake_y, training=False)\n",
    "        disc_gender3 = discriminator_gender(real_y, training=False)\n",
    "        disc_gender4 = discriminator_gender(fake_x, training=False)\n",
    "        \n",
    "        disc_hair1 = discriminator_hair(real_x, training=False)\n",
    "        disc_hair2 = discriminator_hair(fake_y, training=False)\n",
    "        disc_hair3 = discriminator_hair(real_y, training=False)\n",
    "        disc_hair4 = discriminator_hair(fake_x, training=False)\n",
    "\n",
    "        # calculate the loss\n",
    "        gen_g_loss = generator_loss(disc_fake_y)\n",
    "        gen_f_loss = generator_loss(disc_fake_x)\n",
    "        \n",
    "        gender_loss1 = tfa.losses.ContrastiveLoss()([label_x[:,20], label_x[:,20]], [disc_gender1[0], disc_gender2[0]])\n",
    "        gender_loss2 = tfa.losses.ContrastiveLoss()([label_y[:,20], label_y[:,20]], [disc_gender3[0], disc_gender4[0]])\n",
    "\n",
    "        hair_loss1 = -tfa.losses.ContrastiveLoss()([label_x[:,9], 1-label_y[:,9]], [disc_hair1[0], disc_hair2[0]])\n",
    "        hair_loss2 = -tfa.losses.ContrastiveLoss()([label_y[:,9], 1-label_x[:,9]], [disc_hair3[0], disc_hair4[0]])\n",
    "\n",
    "    \n",
    "        total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)\n",
    "    \n",
    "        lam=0.0005\n",
    "        # Total generator loss = adversarial loss + cycle loss\n",
    "        total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)+lam*(gender_loss1+hair_loss1)\n",
    "        total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)+lam*(gender_loss2+hair_loss2)\n",
    "\n",
    "        disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)\n",
    "        disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)\n",
    "    \n",
    "    # Calculate the gradients for generator and discriminator\n",
    "    generator_g_gradients = tape.gradient(total_gen_g_loss, \n",
    "                                        generator_g.trainable_variables)\n",
    "    generator_f_gradients = tape.gradient(total_gen_f_loss, \n",
    "                                        generator_f.trainable_variables)\n",
    "    \n",
    "    discriminator_x_gradients = tape.gradient(disc_x_loss, \n",
    "                                            discriminator_x.trainable_variables)\n",
    "    discriminator_y_gradients = tape.gradient(disc_y_loss, \n",
    "                                            discriminator_y.trainable_variables)\n",
    "  \n",
    "    # Apply the gradients to the optimizer\n",
    "    generator_g_optimizer.apply_gradients(zip(generator_g_gradients, \n",
    "                                            generator_g.trainable_variables))\n",
    "\n",
    "    generator_f_optimizer.apply_gradients(zip(generator_f_gradients, \n",
    "                                            generator_f.trainable_variables))\n",
    "    \n",
    "    discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,\n",
    "                                                discriminator_x.trainable_variables))\n",
    "  \n",
    "    discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,\n",
    "                                                discriminator_y.trainable_variables))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b4a034e",
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE=1\n",
    "x_blond_train = tf.data.Dataset.from_tensor_slices(X_blond_male_train).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)\n",
    "x_nonblond_train = tf.data.Dataset.from_tensor_slices(X_nonblond_male_train).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)\n",
    "y_blond_train = tf.data.Dataset.from_tensor_slices(y_blond_male_train).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)\n",
    "y_blond_test = tf.data.Dataset.from_tensor_slices(y_nonblond_male_train).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c158c4b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for epoch in range(EPOCHS):\n",
    "    \n",
    "    start = time.time()\n",
    "\n",
    "    n = 0\n",
    "    for image_x, image_y, label_x, label_y in tf.data.Dataset.zip((x_blond_train, x_nonblond_train, y_blond_train, y_blond_test)):\n",
    "        train_step(image_x, image_y, label_x, label_y)\n",
    "        if n % 10 == 0:\n",
    "            print ('.', end='')\n",
    "        n += 1\n",
    "\n",
    "    clear_output(wait=True)\n",
    "    # Using a consistent image (sample_horse) so that the progress of the model\n",
    "    # is clearly visible.\n",
    "    generate_images(generator_f, sample_nonblond_male.reshape((1,64,64,3)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "892fa87f",
   "metadata": {},
   "outputs": [],
   "source": [
    "cfs = []\n",
    "for i in range(10000):#range(X_nonblond_male_train.shape[0]):\n",
    "    cfs.append(generate_images(generator_f, X_nonblond_male_train[i].reshape((1,64,64,3))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e484982",
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_classifier(latent_dim_img, img_shape = (64,64,3), kernel_initializer = TruncatedNormal(mean = 0.0, stddev = 0.01)):\n",
    "\n",
    "    model = Sequential()\n",
    "\n",
    "    model.add(Conv2D(64,(2,2), strides=(1,1), input_shape = img_shape, kernel_initializer=kernel_initializer, use_bias = False))\n",
    "    model.add(BatchNormalization(momentum = 0.9, epsilon = 1e-5))\n",
    "    model.add(LeakyReLU(alpha = 0.02))\n",
    "    model.add(Dropout(0.2))\n",
    "    model.add(Conv2D(128, (7,7), strides = (2,2), kernel_initializer=kernel_initializer, use_bias = False))\n",
    "    model.add(BatchNormalization(momentum = 0.9, epsilon = 1e-5))\n",
    "    model.add(LeakyReLU(alpha = 0.02))\n",
    "    model.add(Dropout(0.2))\n",
    "    model.add(Conv2D(256, (5,5), strides = (2,2), kernel_initializer=kernel_initializer, use_bias = False))\n",
    "    model.add(BatchNormalization(momentum = 0.9, epsilon = 1e-5))\n",
    "    model.add(LeakyReLU(alpha = 0.02))\n",
    "    model.add(Dropout(0.2))\n",
    "    model.add(Conv2D(256, (7,7), strides = (2,2), kernel_initializer=kernel_initializer, use_bias = False))\n",
    "    model.add(BatchNormalization(momentum = 0.9, epsilon = 1e-5))\n",
    "    model.add(LeakyReLU(alpha = 0.02))\n",
    "    model.add(Dropout(0.5))\n",
    "    model.add(Conv2D(512, (4,4), strides = (1,1), kernel_initializer=kernel_initializer, use_bias = False))\n",
    "    model.add(BatchNormalization(momentum = 0.9, epsilon = 1e-5))\n",
    "    model.add(LeakyReLU(alpha = 0.02))\n",
    "    model.add(Dropout(0.5))\n",
    "    model.add(Conv2D(latent_dim_img, (1,1), strides = (1,1), kernel_initializer=kernel_initializer, use_bias = False))\n",
    "    model.add(Flatten())\n",
    "    model.add(Dropout(0.5))\n",
    "    model.add(Dense(128, activation = 'relu'))\n",
    "    model.add(Dropout(0.5))\n",
    "    model.add(Dense(128, activation = 'relu'))\n",
    "    model.add(Dropout(0.5))\n",
    "    model.add(Dense(1, activation = None))\n",
    "    model.add(Activation('sigmoid'))\n",
    "    \n",
    "    model.compile(loss=losses.binary_crossentropy, optimizer=Adam(1e-4, 0.5),metrics=['accuracy'])\n",
    "\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d60e125c",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(1):\n",
    "    conf_model = build_classifier(256)\n",
    "    conf_model.fit(np.vstack([X_train, cfs]), np.hstack([attr_train[:,9].astype(float), np.ones_like(attr_train[:,9][:10000]).astype(float)]),\n",
    "                   epochs = 20, batch_size = 256, shuffle = True)\n",
    "    \n",
    "    print(conf_model.evaluate(X_test,attr_test[:,9].astype(float)))\n",
    "    print(conf_model.evaluate(X_blond_male_test, attr_test[blond_male_test_idx,9].astype(float)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8043871",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(conf_model.evaluate(X_blond_male_test, attr_test[blond_male_test_idx,9].astype(float)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84c83262",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(np.mean([81, 80.54, 77.80, 78.90]))\n",
    "print(np.std([81, 80.54, 77.80, 78.90]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "305321e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# idx = np.random.randint(0,X_nonblond_male_train.shape[0], 10)\n",
    "cfs = []\n",
    "for i in range(10000):#range(X_nonblond_male_train.shape[0]):\n",
    "    cfs.append(generate_images(generator_f, X_nonblond_male_train[i].reshape((1,64,64,3))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5755fea",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
