{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "bd501698",
   "metadata": {},
   "source": [
    "# Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "568cf54d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys\n",
    "sys.path.append(\"..\")\n",
    "sys.path.append(\"../ALAE\")\n",
    "import random\n",
    "import string\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "# from src.light_sb import LightSB\n",
    "from ofmsrc.alae_distributions import LoaderSampler, TensorSampler\n",
    "# import deeplake\n",
    "from tqdm import tqdm\n",
    "\n",
    "import wandb\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "from alae_ffhq_inference import load_model, encode, decode"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "15f19259",
   "metadata": {},
   "outputs": [],
   "source": [
    "def seed_all(seed=123):\n",
    "    OUTPUT_SEED = seed\n",
    "    torch.manual_seed(OUTPUT_SEED)\n",
    "    np.random.seed(OUTPUT_SEED)\n",
    "    random.seed(OUTPUT_SEED)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "03da5cc2",
   "metadata": {},
   "source": [
    "## Config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "eec9e89d",
   "metadata": {},
   "outputs": [],
   "source": [
    "DIM = 512\n",
    "assert DIM > 1\n",
    "\n",
    "INPUT_DATA = \"ADULT\" # MAN, WOMAN, ADULT, CHILDREN\n",
    "TARGET_DATA = \"CHILDREN\" # MAN, WOMAN, ADULT, CHILDREN\n",
    "\n",
    "OUTPUT_SEED = 0xBADBEEF\n",
    "BATCH_SIZE = 128\n",
    "D_LR = 1e-3 # 1e-3 for eps 0.1\n",
    "INV_TOLERANCE = 1e-2\n",
    "\n",
    "MAX_STEPS = 10002\n",
    "CONTINUE = -1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f1ea4108",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_all(seed=OUTPUT_SEED)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "43d8a552",
   "metadata": {},
   "outputs": [],
   "source": [
    "CODE = 'MB128'\n",
    "EXP_NAME = f'ofm_ALAE_{INPUT_DATA}_TO_{TARGET_DATA}_{CODE}'\n",
    "OUTPUT_PATH = '../checkpoints/{}'.format(EXP_NAME)\n",
    "\n",
    "config = dict(\n",
    "    DIM=DIM,\n",
    "    D_LR=D_LR,\n",
    "    BATCH_SIZE=BATCH_SIZE,\n",
    "    INV_TOLERANCE=INV_TOLERANCE,\n",
    "    CODE=CODE\n",
    ")\n",
    "\n",
    "if not os.path.exists(OUTPUT_PATH):\n",
    "    os.makedirs(OUTPUT_PATH)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b88351b8",
   "metadata": {},
   "source": [
    "# Data loading"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f6f0658",
   "metadata": {},
   "outputs": [],
   "source": [
    "import gdown\n",
    "import os\n",
    "\n",
    "if not os.path.isdir('../data'):\n",
    "    os.makedirs('../data')\n",
    "\n",
    "urls = {\n",
    "    \"../data/age.npy\": \"https://drive.google.com/uc?id=1Vi6NzxCsS23GBNq48E-97Z9UuIuNaxPJ\",\n",
    "    \"../data/gender.npy\": \"https://drive.google.com/uc?id=1SEdsmQGL3mOok1CPTBEfc_O1750fGRtf\",\n",
    "    \"../data/latents.npy\": \"https://drive.google.com/uc?id=1ENhiTRsHtSjIjoRu1xYprcpNd8M9aVu8\",\n",
    "    \"../data/test_images.npy\": \"https://drive.google.com/uc?id=1SjBWWlPjq-dxX4kxzW-Zn3iUR3po8Z0i\",\n",
    "}\n",
    "\n",
    "for name, url in urls.items():\n",
    "    gdown.download(url, os.path.join(f\"{name}\"), quiet=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8202210a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# To download data use\n",
    "\n",
    "train_size = 60000\n",
    "test_size = 10000\n",
    "\n",
    "latents = np.load(\"../data/latents.npy\")\n",
    "gender = np.load(\"../data/gender.npy\")\n",
    "age = np.load(\"../data/age.npy\")\n",
    "test_inp_images = np.load(\"../data/test_images.npy\")\n",
    "\n",
    "train_latents, test_latents = latents[:train_size], latents[train_size:]\n",
    "train_gender, test_gender = gender[:train_size], gender[train_size:]\n",
    "train_age, test_age = age[:train_size], age[train_size:]\n",
    "\n",
    "if INPUT_DATA == \"MAN\":\n",
    "    x_inds_train = np.arange(train_size)[(train_gender == \"male\").reshape(-1)]\n",
    "    x_inds_test = np.arange(test_size)[(test_gender == \"male\").reshape(-1)]\n",
    "elif INPUT_DATA == \"WOMAN\":\n",
    "    x_inds_train = np.arange(train_size)[(train_gender == \"female\").reshape(-1)]\n",
    "    x_inds_test = np.arange(test_size)[(test_gender == \"female\").reshape(-1)]\n",
    "elif INPUT_DATA == \"ADULT\":\n",
    "    x_inds_train = np.arange(train_size)[\n",
    "        (train_age >= 18).reshape(-1)*(train_age != -1).reshape(-1)\n",
    "    ]\n",
    "    x_inds_test = np.arange(test_size)[\n",
    "        (test_age >= 18).reshape(-1)*(test_age != -1).reshape(-1)\n",
    "    ]\n",
    "elif INPUT_DATA == \"CHILDREN\":\n",
    "    x_inds_train = np.arange(train_size)[\n",
    "        (train_age < 18).reshape(-1)*(train_age != -1).reshape(-1)\n",
    "    ]\n",
    "    x_inds_test = np.arange(test_size)[\n",
    "        (test_age < 18).reshape(-1)*(test_age != -1).reshape(-1)\n",
    "    ]\n",
    "x_data_train = train_latents[x_inds_train]\n",
    "x_data_test = test_latents[x_inds_test]\n",
    "\n",
    "if TARGET_DATA == \"MAN\":\n",
    "    y_inds_train = np.arange(train_size)[(train_gender == \"male\").reshape(-1)]\n",
    "    y_inds_test = np.arange(test_size)[(test_gender == \"male\").reshape(-1)]\n",
    "elif TARGET_DATA == \"WOMAN\":\n",
    "    y_inds_train = np.arange(train_size)[(train_gender == \"female\").reshape(-1)]\n",
    "    y_inds_test = np.arange(test_size)[(test_gender == \"female\").reshape(-1)]\n",
    "elif TARGET_DATA == \"ADULT\":\n",
    "    y_inds_train = np.arange(train_size)[\n",
    "        (train_age >= 18).reshape(-1)*(train_age != -1).reshape(-1)\n",
    "    ]\n",
    "    y_inds_test = np.arange(test_size)[\n",
    "        (test_age >= 18).reshape(-1)*(test_age != -1).reshape(-1)\n",
    "    ]\n",
    "elif TARGET_DATA == \"CHILDREN\":\n",
    "    y_inds_train = np.arange(train_size)[\n",
    "        (train_age < 18).reshape(-1)*(train_age != -1).reshape(-1)\n",
    "    ]\n",
    "    y_inds_test = np.arange(test_size)[\n",
    "        (test_age < 18).reshape(-1)*(test_age != -1).reshape(-1)\n",
    "    ]\n",
    "y_data_train = train_latents[y_inds_train]\n",
    "y_data_test = test_latents[y_inds_test]\n",
    "\n",
    "X_train = torch.tensor(x_data_train)\n",
    "Y_train = torch.tensor(y_data_train)\n",
    "\n",
    "X_test = torch.tensor(x_data_test)\n",
    "Y_test = torch.tensor(y_data_test)\n",
    "\n",
    "X_sampler = TensorSampler(X_train, device=\"cpu\")\n",
    "Y_sampler = TensorSampler(Y_train, device=\"cpu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "96f8593e",
   "metadata": {},
   "source": [
    "# Model initialisation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d27c7528",
   "metadata": {},
   "source": [
    "## OFM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "f202183a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from ofmsrc.icnn import (\n",
    "    ICNNCPF,\n",
    "    ICNN2CPF,\n",
    "    ICNN3CPF,\n",
    "    LseICNNCPF,\n",
    "    ResICNN2CPF,\n",
    "    DenseICNN2CPF,\n",
    "    ICNN2CPFnoSPact\n",
    ")\n",
    "from ofmsrc.icnn import (\n",
    "    LinActnormICNN,\n",
    "    DenseICNN\n",
    ")\n",
    "\n",
    "from ofmsrc.model_tools import (\n",
    "    id_pretrain_model,\n",
    "    ofm_forward,\n",
    "    ofm_inverse,\n",
    "    ofm_loss\n",
    ")\n",
    "from ofmsrc.model_tools import (\n",
    "    TorchlbfgsInvParams,\n",
    "    TorchoptimInvParams,\n",
    "    BruteforceIHVParams\n",
    ")\n",
    "\n",
    "from ofmsrc.alae_distributions import NormalSampler\n",
    "\n",
    "from ofmsrc.tools import EMA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "9fae15b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "class SamplerProxy:\n",
    "    \n",
    "    def __init__(self, sampler, device='cuda'):\n",
    "        self.sampler = sampler\n",
    "        self.device = device\n",
    "    \n",
    "    def sample(self, tpl):\n",
    "        return self.sampler.sample(tpl[0]).to(self.device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e4cdd1e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_all(seed=OUTPUT_SEED)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "7f91b30d",
   "metadata": {},
   "outputs": [],
   "source": [
    "dimh, num_hidl = 256, 3\n",
    "pretrain_sampler = NormalSampler(\n",
    "    np.array([0.,] * DIM),\n",
    "    cov=np.eye(DIM)*4.\n",
    ")\n",
    "D = DenseICNN(DIM, [1024, 1024]).cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "7dc8f9ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "D_opt = torch.optim.Adam(D.parameters(), lr=D_LR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "7bd675ea",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.001\n"
     ]
    }
   ],
   "source": [
    "print(D_LR)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "201efb1e",
   "metadata": {},
   "source": [
    "## ALAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "c023f8ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "from ofmsrc.discrete_ot import OTPlanSampler\n",
    "\n",
    "def get_discrete_ot_plan_sample_fn(sampler_x, sampler_y, device='cuda'):\n",
    "    \n",
    "    ot_plan_sampler = OTPlanSampler('exact')\n",
    "    \n",
    "    def ret_fn(batch_size):\n",
    "        \n",
    "        x_samples = sampler_x.sample(batch_size).to(device)\n",
    "        y_samples = sampler_y.sample(batch_size).to(device)\n",
    "        \n",
    "        return ot_plan_sampler.sample_plan(x_samples, y_samples)\n",
    "    \n",
    "    return ret_fn\n",
    "\n",
    "sampling_fn = get_discrete_ot_plan_sample_fn(X_sampler, Y_sampler)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "cbf9dc89",
   "metadata": {},
   "outputs": [],
   "source": [
    "# To download the required model run, run training_artifacts/download_all.py in the ALAE folder.\n",
    "\n",
    "# model = load_model(\"../ALAE/configs/ffhq.yaml\", training_artifacts_dir=\"../ALAE/training_artifacts/ffhq/\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "f62840d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "EMA_BETAS = [0.999, 0.99]\n",
    "ema = EMA(0, betas=EMA_BETAS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "cab2da87",
   "metadata": {},
   "outputs": [],
   "source": [
    "USE_WANDB = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5ecd05c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "if USE_WANDB:\n",
    "    wandb.init(name=EXP_NAME, project='ofm', config=config)\n",
    "    \n",
    "with tqdm(range(CONTINUE + 1, MAX_STEPS)) as tbar:\n",
    "    for step in tbar:\n",
    "        D_opt.zero_grad()\n",
    "\n",
    "        current_metrics = dict()\n",
    "        X, Y = sampling_fn(BATCH_SIZE)\n",
    "        X = X.cuda(); Y = Y.cuda()\n",
    "        t = (torch.rand(BATCH_SIZE) + 1e-8).cuda()\n",
    "\n",
    "        loss, true_loss = ofm_loss(D, X, Y, t, \n",
    "                TorchlbfgsInvParams(lbfgs_params=dict(tolerance_grad = INV_TOLERANCE), max_iter=10),\n",
    "                BruteforceIHVParams(),\n",
    "                tol_inverse_border = 0.5,\n",
    "                stats=current_metrics)\n",
    "\n",
    "        D_opt.zero_grad()\n",
    "        loss.backward()\n",
    "        D_opt.step(); D.convexify(); ema(D)\n",
    "        current_metrics['loss'] = loss.item(); current_metrics['true_loss'] = true_loss.item()\n",
    "\n",
    "        if USE_WANDB:\n",
    "            wandb.log(current_metrics, step=step)\n",
    "\n",
    "        if step % 100 == 1:\n",
    "            torch.save(D.state_dict(), os.path.join(OUTPUT_PATH, 'D.pth'))\n",
    "            torch.save(D_opt.state_dict(), os.path.join(OUTPUT_PATH, f'D_opt.pt'))\n",
    "            for beta in EMA_BETAS:\n",
    "                beta_model_path = os.path.join(OUTPUT_PATH, 'D_{}.pth'.format(beta))\n",
    "                torch.save(ema.get_model(beta).state_dict(), beta_model_path)\n",
    "        tbar.set_postfix(loss=current_metrics['loss'], tloss = current_metrics['true_loss'])\n",
    "#         tbar.set_postfix(loss=current_metrics['loss'], tloss = current_metrics['true_loss'], ood=current_metrics['ood_ratio'])\n",
    "if USE_WANDB:\n",
    "    wandb.finish()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b77f58b",
   "metadata": {},
   "source": [
    "# Results plotting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30991af8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# To download the required model run, run training_artifacts/download_all.py in the ALAE folder.\n",
    "\n",
    "alae_model = load_model(\"../ALAE/configs/ffhq.yaml\", training_artifacts_dir=\"../ALAE/training_artifacts/ffhq/\")\n",
    "torch.manual_seed(OUTPUT_SEED); np.random.seed(OUTPUT_SEED)\n",
    "\n",
    "inds_to_map = np.random.choice(np.arange((x_inds_test < 300).sum()), size=10, replace=False)\n",
    "number_of_samples = 3\n",
    "\n",
    "mapped_all = []\n",
    "latent_to_map = torch.tensor(test_latents[x_inds_test[inds_to_map]])\n",
    "\n",
    "inp_images = test_inp_images[x_inds_test[inds_to_map]]\n",
    "\n",
    "with torch.no_grad():\n",
    "    for k in range(number_of_samples):\n",
    "        mapped = D.push_nograd(latent_to_map.cuda()).cpu()\n",
    "        mapped_all.append(mapped)\n",
    "    \n",
    "mapped = torch.stack(mapped_all, dim=1)\n",
    "\n",
    "decoded_all = []\n",
    "with torch.no_grad():\n",
    "    for k in range(number_of_samples):\n",
    "        decoded_img = decode(alae_model, mapped[:, k].cpu())\n",
    "        decoded_img = ((decoded_img * 0.5 + 0.5) * 255).type(torch.long).clamp(0, 255).cpu().type(torch.uint8).permute(0, 2, 3, 1).numpy()\n",
    "        decoded_all.append(decoded_img)\n",
    "        \n",
    "decoded_all = np.stack(decoded_all, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f9eff6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "n_pictures = 2\n",
    "\n",
    "fig, axes = plt.subplots(n_pictures, number_of_samples+1, figsize=(number_of_samples+1, n_pictures), dpi=200)\n",
    "\n",
    "for i, ind in enumerate(range(n_pictures)):\n",
    "    ax = axes[i]\n",
    "    ax[0].imshow(inp_images[ind])\n",
    "    for k in range(number_of_samples):\n",
    "        ax[k+1].imshow(decoded_all[ind, k])\n",
    "        \n",
    "        ax[k+1].get_xaxis().set_visible(False)\n",
    "        ax[k+1].set_yticks([])\n",
    "        \n",
    "    ax[0].get_xaxis().set_visible(False)\n",
    "    ax[0].set_yticks([])\n",
    "\n",
    "fig.tight_layout(pad=0.05)\n",
    "fig.savefig('ofm_transfer_mb128.png')\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
