{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ed641e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys\n",
    "sys.path.append(\"..\")\n",
    "sys.path.append(\"../ALAE\")\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "import ot\n",
    "\n",
    "from src.distributions import LoaderSampler, TensorSampler\n",
    "from src.ulight_ot import ULightOT\n",
    "from tqdm import tqdm\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import pandas as pd\n",
    "from sklearn.metrics import accuracy_score\n",
    "\n",
    "import wandb\n",
    "import matplotlib\n",
    "from matplotlib import pyplot as plt\n",
    "from torch.optim.lr_scheduler import MultiStepLR\n",
    "from IPython.display import clear_output\n",
    "\n",
    "from alae_ffhq_inference import load_model, encode, decode\n",
    "from src.plotters import fig2img\n",
    "from src.fid_score import calculate_frechet_distance\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "217c36a6",
   "metadata": {},
   "source": [
    "## Config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dea8df7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "DIM = 512\n",
    "assert DIM > 1\n",
    "\n",
    "INPUT_DATA = \"ADULT\" # MAN, WOMAN, ADULT, YOUNG\n",
    "TARGET_DATA = \"YOUNG\" # MAN, WOMAN, ADULT, YOUNG\n",
    "DIVERGENCE = 'KL'#\"Xi2\"\n",
    "\n",
    "OUTPUT_SEED = 42\n",
    "BATCH_SIZE = 128\n",
    "EPSILON = 0.1\n",
    "D_LR = 1\n",
    "D_GRADIENT_MAX_NORM = 1e5 #float(\"inf\")\n",
    "N_POTENTIALS = 50\n",
    "SAMPLING_BATCH_SIZE = 128\n",
    "INIT_BY_SAMPLES = True\n",
    "IS_DIAGONAL = True\n",
    "\n",
    "PLOT_EVERY = 500\n",
    "MAX_STEPS = 10000\n",
    "CONTINUE = -1\n",
    "\n",
    "CONTINUE = -1\n",
    "\n",
    "DEVICE = 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b6dca7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# To download data use\n",
    "train_size = 60000\n",
    "test_size = 10000\n",
    "\n",
    "latents = np.load(\"../data/latents.npy\")\n",
    "gender = np.load(\"../data/gender.npy\")\n",
    "age = np.load(\"../data/age.npy\")\n",
    "test_inp_images = np.load(\"../data/test_images.npy\")\n",
    "\n",
    "train_latents, test_latents = latents[:train_size], latents[train_size:]\n",
    "train_gender, test_gender = gender[:train_size], gender[train_size:]\n",
    "train_age, test_age = age[:train_size], age[train_size:]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2456f16a",
   "metadata": {},
   "source": [
    "## Define Classifiers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9dad72a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "class BinaryClassifier(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(BinaryClassifier, self).__init__()\n",
    "        self.layer1 = nn.Sequential(\n",
    "            nn.Linear(512, 256),\n",
    "            nn.BatchNorm1d(256),\n",
    "            nn.Dropout(0.5),\n",
    "            nn.ReLU()\n",
    "        )\n",
    "        self.layer2 = nn.Sequential(\n",
    "            nn.Linear(256, 128),\n",
    "            nn.BatchNorm1d(128),\n",
    "            nn.Dropout(0.5),\n",
    "            nn.ReLU()\n",
    "        )\n",
    "        self.layer3 = nn.Linear(128, 1)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        out = self.layer1(x)\n",
    "        out = self.layer2(out)\n",
    "        out = torch.sigmoid(self.layer3(out))\n",
    "        return out"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dfe687b5",
   "metadata": {},
   "source": [
    "## ALAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83d30f20",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = load_model(\"../ALAE/configs/ffhq.yaml\", training_artifacts_dir=\"../ALAE/training_artifacts/ffhq/\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4335b865",
   "metadata": {},
   "source": [
    "## Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7ae58cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "def decode_and_plot(T, latent_to_map, inp_images):\n",
    "    mapped_all = []\n",
    "    with torch.no_grad():\n",
    "        for k in range(number_of_samples):\n",
    "            mapped = T(latent_to_map)\n",
    "            mapped_all.append(mapped)\n",
    "\n",
    "    mapped = torch.stack(mapped_all, dim=1)\n",
    "\n",
    "    decoded_all = []\n",
    "    with torch.no_grad():\n",
    "        for k in range(number_of_samples):\n",
    "            decoded_img = decode(model, mapped[:, k].detach().cpu())\n",
    "            decoded_img = ((decoded_img * 0.5 + 0.5) * 255).type(torch.long).clamp(0, 255).cpu().type(torch.uint8).permute(0, 2, 3, 1).numpy()\n",
    "            decoded_all.append(decoded_img)\n",
    "\n",
    "    decoded_all = np.stack(decoded_all, axis=1)\n",
    "    \n",
    "    fig, axes = plt.subplots(number_of_samples+1, latent_to_map.shape[0], figsize=(latent_to_map.shape[0], number_of_samples+1), dpi=200)\n",
    "\n",
    "    for i, ind in enumerate(range(latent_to_map.shape[0])):\n",
    "    #     ax = axes[i]\n",
    "        axes[0, i].imshow(inp_images[ind])\n",
    "        for k in range(number_of_samples):\n",
    "            axes[k+1, i].imshow(decoded_all[ind, k])\n",
    "\n",
    "            axes[k+1, i].get_xaxis().set_visible(False)\n",
    "            axes[k+1, i].set_yticks([])\n",
    "\n",
    "        axes[0, i].get_xaxis().set_visible(False)\n",
    "        axes[0, i].set_yticks([])\n",
    "\n",
    "    fig.tight_layout(pad=0.05)\n",
    "    \n",
    "    return fig, axes"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1834d09a",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "589ac6d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "if INPUT_DATA == \"MAN\":\n",
    "    x_inds_train = np.arange(train_size)[(train_gender == \"male\").reshape(-1)]\n",
    "    x_inds_test = np.arange(test_size)[(test_gender == \"male\").reshape(-1)]\n",
    "    mlp_classifier = BinaryClassifier()\n",
    "    mlp_classifier\n",
    "    mlp_classifier.load_state_dict(torch.load('./checkpoints/young_old_classifier.pth', map_location=DEVICE))\n",
    "elif INPUT_DATA == \"WOMAN\":\n",
    "    x_inds_train = np.arange(train_size)[(train_gender == \"female\").reshape(-1)]\n",
    "    x_inds_test = np.arange(test_size)[(test_gender == \"female\").reshape(-1)]\n",
    "elif INPUT_DATA == \"ADULT\":\n",
    "    x_inds_train = np.arange(train_size)[\n",
    "        (train_age > 44).reshape(-1)*(train_age != -1).reshape(-1)\n",
    "    ]\n",
    "    x_inds_test = np.arange(test_size)[\n",
    "        (test_age > 44).reshape(-1)*(test_age != -1).reshape(-1)\n",
    "    ]\n",
    "elif INPUT_DATA == \"YOUNG\":\n",
    "    x_inds_train = np.arange(train_size)[\n",
    "        ((train_age > 16) & (train_age <= 44)).reshape(-1)*(train_age != -1).reshape(-1)\n",
    "    ]\n",
    "    x_inds_test = np.arange(test_size)[\n",
    "        ((test_age > 16) & (test_age <= 44)).reshape(-1)*(test_age != -1).reshape(-1)\n",
    "    ]\n",
    "    mlp_classifier = BinaryClassifier()\n",
    "    mlp_classifier\n",
    "    mlp_classifier.load_state_dict(torch.load('./checkpoints/male_female_classifier.pth', map_location=DEVICE))\n",
    "\n",
    "if TARGET_DATA == \"MAN\":\n",
    "    y_inds_train = np.arange(train_size)[(train_gender == \"male\").reshape(-1)]\n",
    "    y_inds_test = np.arange(test_size)[(test_gender == \"male\").reshape(-1)]\n",
    "    mlp_classifier = BinaryClassifier()\n",
    "    mlp_classifier\n",
    "    mlp_classifier.load_state_dict(torch.load('./checkpoints/young_old_classifier.pth', map_location=DEVICE))\n",
    "elif TARGET_DATA == \"WOMAN\":\n",
    "    y_inds_train = np.arange(train_size)[(train_gender == \"female\").reshape(-1)]\n",
    "    y_inds_test = np.arange(test_size)[(test_gender == \"female\").reshape(-1)]\n",
    "elif TARGET_DATA == \"ADULT\":\n",
    "    y_inds_train = np.arange(train_size)[\n",
    "        (train_age > 44).reshape(-1)*(train_age != -1).reshape(-1)\n",
    "    ]\n",
    "    y_inds_test = np.arange(test_size)[\n",
    "        (test_age > 44).reshape(-1)*(test_age != -1).reshape(-1)\n",
    "    ]\n",
    "elif TARGET_DATA == \"ADULT-MAN\":\n",
    "    male_train = np.arange(train_size)[(train_gender == \"male\")]\n",
    "    male_test = np.arange(test_size)[(test_gender == \"male\")]\n",
    "    \n",
    "    y_inds_train = male_train[(train_age[male_train].reshape(-1) > 44)]#*(train_age != -1)\n",
    "    y_inds_test = male_test[(test_age[male_test].reshape(-1) > 44)]#*(test_age != -1).reshape(-1)\n",
    "elif TARGET_DATA == \"YOUNG\":\n",
    "    y_inds_train = np.arange(train_size)[\n",
    "        ((train_age > 16) & (train_age <= 44)).reshape(-1)*(train_age != -1).reshape(-1)\n",
    "    ]\n",
    "    y_inds_test = np.arange(test_size)[\n",
    "        ((test_age > 16) & (test_age <= 44)).reshape(-1)*(test_age != -1).reshape(-1)\n",
    "    ]\n",
    "    mlp_classifier = BinaryClassifier()\n",
    "    mlp_classifier\n",
    "    mlp_classifier.load_state_dict(torch.load('./checkpoints/male_female_classifier.pth', map_location=DEVICE))\n",
    "    \n",
    "x_data_train = train_latents[x_inds_train]\n",
    "x_data_test = test_latents[x_inds_test]\n",
    "x_data_test_gender = test_gender[x_inds_test]\n",
    "x_data_test_age = test_age[x_inds_test]\n",
    "\n",
    "inds_to_map = np.random.choice(np.arange((x_inds_test < 300).sum()), size=10, replace=False)\n",
    "number_of_samples = 1\n",
    "mapped_all = []\n",
    "latent_to_map = torch.tensor(test_latents[x_inds_test[inds_to_map]])\n",
    "inp_images = test_inp_images[x_inds_test[inds_to_map]]\n",
    "    \n",
    "y_data_train = train_latents[y_inds_train]\n",
    "y_data_test = test_latents[y_inds_test]\n",
    "\n",
    "X_train = torch.tensor(x_data_train)\n",
    "Y_train = torch.tensor(y_data_train)\n",
    "\n",
    "X_test = torch.tensor(x_data_test)\n",
    "Y_test = torch.tensor(y_data_test)\n",
    "\n",
    "X_sampler = TensorSampler(X_train, device=\"cpu\")\n",
    "Y_sampler = TensorSampler(Y_train, device=\"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1efd7932",
   "metadata": {},
   "outputs": [],
   "source": [
    "def decode_and_plot(T, latent_to_map, inp_images):\n",
    "    mapped_all = []\n",
    "    with torch.no_grad():\n",
    "        for k in range(number_of_samples):\n",
    "            mapped = T(latent_to_map)\n",
    "            mapped_all.append(mapped)\n",
    "\n",
    "    mapped = torch.stack(mapped_all, dim=1)\n",
    "\n",
    "    decoded_all = []\n",
    "    with torch.no_grad():\n",
    "        for k in range(number_of_samples):\n",
    "            decoded_img = decode(model, mapped[:, k].detach().cpu())\n",
    "            decoded_img = ((decoded_img * 0.5 + 0.5) * 255).type(torch.long).clamp(0, 255).cpu().type(torch.uint8).permute(0, 2, 3, 1).numpy()\n",
    "            decoded_all.append(decoded_img)\n",
    "\n",
    "    decoded_all = np.stack(decoded_all, axis=1)\n",
    "    \n",
    "    fig, axes = plt.subplots(number_of_samples+1, latent_to_map.shape[0], figsize=(latent_to_map.shape[0], number_of_samples+1), dpi=200)\n",
    "\n",
    "    for i, ind in enumerate(range(latent_to_map.shape[0])):\n",
    "        axes[0, i].imshow(inp_images[ind])\n",
    "        for k in range(number_of_samples):\n",
    "            axes[k+1, i].imshow(decoded_all[ind, k])\n",
    "\n",
    "            axes[k+1, i].get_xaxis().set_visible(False)\n",
    "            axes[k+1, i].set_yticks([])\n",
    "\n",
    "        axes[0, i].get_xaxis().set_visible(False)\n",
    "        axes[0, i].set_yticks([])\n",
    "\n",
    "    fig.tight_layout(pad=0.05)\n",
    "    \n",
    "    return fig, axes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4dfa705d",
   "metadata": {},
   "outputs": [],
   "source": [
    "#compute real data stats:\n",
    "real_data = Y_sampler.sample(10000)\n",
    "real_data = real_data.cpu().data.numpy().reshape(real_data.size(0), -1)\n",
    "mu_data, sigma_data = np.mean(real_data, axis=0), np.cov(real_data, rowvar=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b24ece2b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "arr = []\n",
    "best_acc = np.inf\n",
    "taus = [100]\n",
    "EPSILONS = [0.1]\n",
    "\n",
    "df_acc = pd.DataFrame(index=EPSILONS, columns=taus)\n",
    "df_fd = pd.DataFrame(index=EPSILONS, columns=taus)\n",
    "arrs = []\n",
    "for Tau in taus:\n",
    "    for EPSILON in EPSILONS:\n",
    "        acc_list = []\n",
    "        fd_list = []\n",
    "        arr =[]\n",
    "        for _ in range(3):\n",
    "            D = ULightOT(dim=DIM, n_potentials=N_POTENTIALS, epsilon=EPSILON,\n",
    "            sampling_batch_size=SAMPLING_BATCH_SIZE, is_diagonal=IS_DIAGONAL)\n",
    "\n",
    "            log_m = torch.zeros(1, requires_grad=True)\n",
    "\n",
    "            if INIT_BY_SAMPLES:\n",
    "                D.init_r_by_samples(Y_sampler.sample(N_POTENTIALS))\n",
    "\n",
    "            D_opt = torch.optim.Adam(D.parameters(), lr=D_LR)\n",
    "            m_opt = torch.optim.Adam([log_m], lr=1e-3)\n",
    "            D_sch = MultiStepLR(D_opt, milestones=[500, 1000, 5000, 10000])\n",
    "            \n",
    "            for step in tqdm(range(CONTINUE + 1, 3000)):\n",
    "                # training cycle\n",
    "                D_opt.zero_grad(); m_opt.zero_grad();\n",
    "                if step < 1000:\n",
    "                    tau = 1000\n",
    "                else:\n",
    "                    tau = Tau\n",
    "\n",
    "                X, Y = X_sampler.sample(BATCH_SIZE), Y_sampler.sample(BATCH_SIZE)\n",
    "\n",
    "                log_V = D.get_potential(Y)\n",
    "                psi = EPSILON * log_V + torch.norm(Y, p=2, dim=-1)**2/2 \n",
    "                if DIVERGENCE == 'UKL':\n",
    "                    f_psi = tau * (torch.exp(-psi/tau) - 1)\n",
    "                elif DIVERGENCE == 'Xi2':\n",
    "                    psi = -(F.relu(-psi + 2*tau) - (1+(-psi>-2*tau))*tau)\n",
    "                    f_psi = 0.25 * psi**2/tau - psi\n",
    "\n",
    "                log_C = D.get_C(X)\n",
    "                log_U = D.get_marginal(X)\n",
    "                phi = EPSILON * (log_U + log_m - log_C) + torch.norm(X, p=2, dim=-1)**2/2\n",
    "                \n",
    "                if DIVERGENCE == 'UKL':\n",
    "                    f_phi = tau * (torch.exp(-phi/tau) - 1)\n",
    "                elif DIVERGENCE == 'Xi2':\n",
    "                    phi = -(F.relu(-phi + 2*tau) - (1+(-phi>-2*tau))*tau)\n",
    "                    f_phi = 0.25 * phi**2/tau - phi\n",
    "\n",
    "                D_loss = EPSILON * torch.exp(log_m) + f_phi.mean() + f_psi.mean()\n",
    "                arr.append(D_loss.item())\n",
    "                D_loss.backward()\n",
    "                D_gradient_norm = torch.nn.utils.clip_grad_norm_(D.parameters(), max_norm=D_GRADIENT_MAX_NORM)\n",
    "                D_opt.step(); m_opt.step();\n",
    "                D_sch.step()\n",
    "                \n",
    "                if (step - 1) % 1000 == 0:\n",
    "                    clear_output(wait=True)\n",
    "                    # The code for plotting the results\n",
    "                    fig, axes = decode_and_plot(D, latent_to_map, inp_images)\n",
    "                    plt.show()\n",
    "            D_test = D(X_test)\n",
    "            mlp_classifier.eval()\n",
    "            pred_labels = mlp_classifier(D_test)\n",
    "            pred_labels = torch.round(pred_labels.squeeze())\n",
    "\n",
    "            pred_labels_np = pred_labels.data\n",
    "            \n",
    "            if INPUT_DATA == 'ADULT' or INPUT_DATA == 'YOUNG':\n",
    "                actual_labels_np = np.where(x_data_test_gender == 'male', 1, 0)\n",
    "            else:\n",
    "                actual_labels_np = (x_data_test_age.reshape(-1) > 44)*1\n",
    "            accuracy = accuracy_score(pred_labels_np, actual_labels_np)\n",
    "            print('Accuracy: ', accuracy)\n",
    "\n",
    "            D_test = D_test.cpu().data.numpy().reshape(D_test.size(0), -1)\n",
    "            mu, sigma = np.mean(D_test, axis=0), np.cov(D_test, rowvar=False)\n",
    "            FD_T = calculate_frechet_distance(mu, sigma, mu_data, sigma_data)\n",
    "            print('FD: ', FD_T)\n",
    "            \n",
    "            acc_list.append(accuracy*100)\n",
    "            fd_list.append(FD_T)\n",
    "            \n",
    "        df_acc.loc[EPSILON, tau] = f'{np.mean(acc_list):.2f} ± {np.std(acc_list):.2f}'\n",
    "        df_fd.loc[EPSILON, tau] = f'{np.mean(fd_list):.2f} ± {np.std(fd_list):.2f}'\n",
    "                    \n",
    "        print('Results for tau: ', tau, ' And EPSILON: ', EPSILON) "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
