{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "1018839b",
   "metadata": {},
   "source": [
    "# Ave celeba in latent space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b2bf134",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torch.distributions as TD\n",
    "\n",
    "import sys \n",
    "sys.path.append(\"../../\")\n",
    "import dnnlib\n",
    "import legacy\n",
    "\n",
    "import os\n",
    "from tqdm import tqdm\n",
    "import wandb\n",
    "import itertools\n",
    "import sys \n",
    "sys.path.append(\"../../\")\n",
    "from src.utils import Config, Distrib2Sampler , normalize_out_to_0_1, make_f_pot, freeze, unfreeze, plot_barycenter_map_in_data_space_more\n",
    "from src.resnet2 import weights_init_D, ResNet_D\n",
    "from src.distributions import DatasetSampler\n",
    "from src.cost import cond_score, cost_grad_image_latent\n",
    "from src.eot import sample_langevin_batch\n",
    "from src.dgm_utils.statsmanager import StatsManager, StatsManagerDrawScheduler\n",
    "from src.eot_utils import computePotGrad, evaluating\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "from typing import Callable, Tuple, Union"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "47254871",
   "metadata": {},
   "source": [
    "## 1. Parameters for papermill"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68ec4aa7",
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "GPU_DEVICES = [0]\n",
    "EPS = 0.0001\n",
    "LR = 1e-4\n",
    "ENERGY_ITRS = 1000\n",
    "BATCH_SIZE = 64\n",
    "ALPHAS=[0.25, 0.5, 0.25]\n",
    "FLAG_F_G_LATENT=False"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "49c5da7f",
   "metadata": {},
   "source": [
    "## 2. Config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d33f000c",
   "metadata": {},
   "outputs": [],
   "source": [
    "CONFIG = Config()\n",
    "\n",
    "CONFIG.DATASET = 'ave_celeba'\n",
    "CONFIG.DATASET_PATH  = './data/ave_celeba_green_v2/' \n",
    "CONFIG.BATCH_SIZE = BATCH_SIZE\n",
    "CONFIG.CLASSES=[0,1,2]\n",
    "CONFIG.IMG_SIZE=64\n",
    "CONFIG.NC=3\n",
    "\n",
    "CONFIG.ALPHAS_BARYCENTER = ALPHAS\n",
    "CONFIG.MAX_STEPS = 1000\n",
    "CONFIG.K = len(CONFIG.ALPHAS_BARYCENTER)\n",
    "CONFIG.HREG = EPS\n",
    "\n",
    "CONFIG.LR = LR\n",
    "CONFIG.CLIP_GRADS_NORM = False\n",
    "CONFIG.BETAS = (0.2, 0.99)\n",
    "\n",
    "CONFIG.LANGEVIN_THRESH = None\n",
    "CONFIG.LANGEVIN_SAMPLING_NOISE = 0.1\n",
    "CONFIG.ENERGY_SAMPLING_ITERATIONS = ENERGY_ITRS\n",
    "CONFIG.LANGEVIN_DECAY = 1.0\n",
    "CONFIG.LANGEVIN_SCORE_COEFFICIENT = 1.0\n",
    "CONFIG.LANGEVIN_COST_COEFFICIENT = 1.0\n",
    " \n",
    "CONFIG.BASIC_NOISE_VAR = 2.0\n",
    "CONFIG.DEVICE =  f\"cuda:{GPU_DEVICES[0]}\"\n",
    "CONFIG.DEVICES_IDS = GPU_DEVICES\n",
    "\n",
    "CONFIG.NUM_TEST_RUNS = 4\n",
    "CONFIG.FLAG_F_G_LATENT = FLAG_F_G_LATENT  \n",
    "CONFIG.LATENT_SIZE =512\n",
    "CONFIG.GENERATOR_PATH =  \" ./training-runs/00011-aligned_celeba-stylegan2/network-snapshot-008800.pkl\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "948a1297",
   "metadata": {},
   "source": [
    "## 3. Data sampler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af40ffc3",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "transform = torchvision.transforms.Compose([\n",
    "    torchvision.transforms.Resize(CONFIG.IMG_SIZE),\n",
    "    torchvision.transforms.ToTensor(),\n",
    "    torchvision.transforms.Lambda(lambda x: torch.clip(x,0,1))\n",
    "])\n",
    "\n",
    "data_samplers=[]\n",
    "for k in tqdm(range(CONFIG.K)):\n",
    "    dataset = torchvision.datasets.ImageFolder(os.path.join(CONFIG.DATASET_PATH,f\"ave_celeba_{k}/\"), transform=transform)\n",
    "    data_samplers.append(DatasetSampler(dataset, flag_label=True, batch_size=256 ,num_workers=40))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1eafaac8",
   "metadata": {},
   "source": [
    "## 4. Networks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efbad018",
   "metadata": {},
   "outputs": [],
   "source": [
    "f1 = []\n",
    "for idx,alpha in enumerate(CONFIG.ALPHAS_BARYCENTER):\n",
    "    \n",
    "    if CONFIG.FLAG_F_G_LATENT:\n",
    "        f1.append( ResNet_D(size=CONFIG.IMG_SIZE,\n",
    "                  nc=CONFIG.NC, nfilter=64, nfilter_max=512, res_ratio=0.1).to(CONFIG.DEVICE))\n",
    "        weights_init_D(f1[idx])\n",
    "    else:\n",
    "        model = []\n",
    "        hiddens = [CONFIG.LATENT_SIZE, CONFIG.LATENT_SIZE//2, CONFIG.LATENT_SIZE//4, CONFIG.LATENT_SIZE//8,1]\n",
    "        for ins,out in zip(hiddens[:-1],hiddens[1:]):\n",
    "            model.append(torch.nn.Linear(ins,out,bias=True))\n",
    "            model.append(torch.nn.ReLU())\n",
    "        model.pop()\n",
    "        f1.append(torch.nn.Sequential(*model).to(CONFIG.DEVICE))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83e70876",
   "metadata": {},
   "outputs": [],
   "source": [
    "f1_opt = torch.optim.Adam(itertools.chain(f1[0].parameters(),f1[1].parameters(),f1[2].parameters()),\n",
    "                              CONFIG.LR, betas=CONFIG.BETAS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4efdd68f",
   "metadata": {},
   "outputs": [],
   "source": [
    "f_pots = [make_f_pot(i,f1,CONFIG) for i in range(len(CONFIG.CLASSES))]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "79dfc630",
   "metadata": {},
   "source": [
    "## 5. StyleGAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6397eb33",
   "metadata": {},
   "outputs": [],
   "source": [
    "with dnnlib.util.open_url(CONFIG.GENERATOR_PATH) as f:\n",
    "    G =  legacy.load_network_pkl(f)['G_ema'].to(CONFIG.DEVICE)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "464b9f62",
   "metadata": {},
   "source": [
    "## 6. Preleminaries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd0c7b09",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample_langevin_mu_f(\n",
    "        f: Callable[[torch.Tensor], torch.Tensor], \n",
    "        x: torch.Tensor, \n",
    "        y_init: torch.Tensor, \n",
    "        config: Config\n",
    "    ) -> torch.Tensor:\n",
    "    \n",
    "    def score(y, ret_stats=False):\n",
    "        return cond_score(f, cost_grad_image_latent , y, x, config, flag_grayscale=False,\n",
    "        flag_f_G_latent=CONFIG.FLAG_F_G_LATENT, latent2data_gen=G, ret_stats=ret_stats)\n",
    "    \n",
    "    y, r_t, cost_r_t, score_r_t, noise_norm = sample_langevin_batch(\n",
    "        score, \n",
    "        y_init,\n",
    "        n_steps=config.ENERGY_SAMPLING_ITERATIONS, \n",
    "        decay=config.LANGEVIN_DECAY, \n",
    "        thresh=config.LANGEVIN_THRESH, \n",
    "        noise=config.LANGEVIN_SAMPLING_NOISE, \n",
    "        data_projector=lambda x: x, \n",
    "        compute_stats=True)\n",
    "    \n",
    "    return y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d1f5056",
   "metadata": {},
   "outputs": [],
   "source": [
    "name_exp = f\"mean_FLAG_F_G_{FLAG_F_G_LATENT}_EPS_{EPS}_LR_{LR}_BS_{BATCH_SIZE}_NS_{ENERGY_ITRS}_ALPHAS_{[0.25, 0.5, 0.25]}\"\n",
    " "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1952cf1",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "wandb.init(project=\"Ave_celeba_in_latent_space\" ,\n",
    "           name=name_exp,\n",
    "           config=CONFIG)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cecf10b",
   "metadata": {},
   "outputs": [],
   "source": [
    "init_noise_sampler = Distrib2Sampler(TD.Normal(\n",
    "    torch.zeros( CONFIG.LATENT_SIZE).to('cpu'), \n",
    "    torch.ones(CONFIG.LATENT_SIZE).to('cpu') * CONFIG.BASIC_NOISE_VAR))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5a2dda2",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "SMDS = StatsManagerDrawScheduler(StatsManager('loss'), 1, 1, (5, 4), epoch_freq=20)\n",
    "\n",
    "for step in tqdm(range(CONFIG.MAX_STEPS)):\n",
    "    \n",
    "    \n",
    "    for idx in range(len(CONFIG.CLASSES)):\n",
    "        unfreeze(f1[idx])\n",
    "        \n",
    "    X1 = data_samplers[0].sample(CONFIG.BATCH_SIZE).to(CONFIG.DEVICE)\n",
    "    X2 = data_samplers[1].sample(CONFIG.BATCH_SIZE).to(CONFIG.DEVICE)\n",
    "    X3 = data_samplers[2].sample(CONFIG.BATCH_SIZE).to(CONFIG.DEVICE)\n",
    "    \n",
    "    Y1_init = init_noise_sampler.sample(CONFIG.BATCH_SIZE).to(CONFIG.DEVICE)\n",
    "    Y2_init = init_noise_sampler.sample(CONFIG.BATCH_SIZE).to(CONFIG.DEVICE)\n",
    "    Y3_init = init_noise_sampler.sample(CONFIG.BATCH_SIZE).to(CONFIG.DEVICE)\n",
    "    \n",
    "    \n",
    "    for idx in range(len(CONFIG.CLASSES)):\n",
    "        freeze(f1[idx])  \n",
    "    with torch.no_grad():\n",
    "        Y1 = sample_langevin_mu_f(lambda x:  f_pots[0](x), X1, Y1_init, CONFIG)\n",
    "        Y3 = sample_langevin_mu_f(lambda x:  f_pots[2](x), X3, Y3_init, CONFIG)\n",
    "        Y2 = sample_langevin_mu_f(lambda x:  f_pots[1](x), X2, Y2_init, CONFIG)\n",
    "         \n",
    "    for idx in range(len(CONFIG.CLASSES)):\n",
    "        unfreeze(f1[idx]) \n",
    "    \n",
    "    if CONFIG.FLAG_F_G_LATENT:\n",
    "        loss = CONFIG.ALPHAS_BARYCENTER[0]*f_pots[0]( normalize_out_to_0_1(G(Y1,c=None)) ).mean() + CONFIG.ALPHAS_BARYCENTER[1]*f_pots[1](normalize_out_to_0_1(G(Y2,c=None)) ).mean() +\\\n",
    "    + CONFIG.ALPHAS_BARYCENTER[2]*f_pots[2](normalize_out_to_0_1(G(Y3,c=None))  ).mean()\n",
    "    else:\n",
    "        loss = CONFIG.ALPHAS_BARYCENTER[0]*f_pots[0]( Y1 ).mean() + CONFIG.ALPHAS_BARYCENTER[1]*f_pots[1](Y2).mean() +\\\n",
    "    + CONFIG.ALPHAS_BARYCENTER[2]*f_pots[2]( Y3  ).mean()\n",
    "        \n",
    "    \n",
    "    f1_opt.zero_grad(); loss.backward(); f1_opt.step()\n",
    "    SMDS.SM.upd('loss', loss.item())\n",
    "    SMDS.epoch()\n",
    "    wandb.log({\"loss train\":loss.item()},step=step)\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        \n",
    "        if step % 10 == 0:\n",
    "            cost_0 = 0.5 * torch.flatten( normalize_out_to_0_1(G(Y1, c=None)) - X1,\n",
    "                                   start_dim=1).pow(2).sum(dim=1, keepdim=True)\n",
    "            cost_1 = 0.5 * torch.flatten( normalize_out_to_0_1(G(Y2, c=None)) - X2,\n",
    "                                   start_dim=1).pow(2).sum(dim=1, keepdim=True)\n",
    "            cost_2 = 0.5 * torch.flatten( normalize_out_to_0_1(G(Y3, c=None)) - X3,\n",
    "                                   start_dim=1).pow(2).sum(dim=1, keepdim=True)\n",
    "            \n",
    "            wandb.log({\"Transport by cost 0\": cost_0.mean().item()},step=step)\n",
    "            wandb.log({\"Transport by cost 1\": cost_1.mean().item()},step=step)\n",
    "            wandb.log({\"Transport by cost 2\": cost_2.mean().item()},step=step)\n",
    "        \n",
    "        if step % 50 == 0:\n",
    "            N_ESTIMATE_POINTS = 8\n",
    "            X1 = data_samplers[0].sample(N_ESTIMATE_POINTS).to(CONFIG.DEVICE)\n",
    "            X2 = data_samplers[1].sample(N_ESTIMATE_POINTS).to(CONFIG.DEVICE)\n",
    "            X3 = data_samplers[2].sample(N_ESTIMATE_POINTS).to(CONFIG.DEVICE)\n",
    "            \n",
    "            for idx in range(len(CONFIG.CLASSES)):\n",
    "                    freeze(f1[idx]) \n",
    "                    \n",
    "            map_1 = []\n",
    "            map_2 = []\n",
    "            map_3 = []\n",
    "            for run in range(CONFIG.NUM_TEST_RUNS):\n",
    "                Y1_init  = init_noise_sampler.sample(N_ESTIMATE_POINTS).to(CONFIG.DEVICE) \n",
    "                Y2_init =  init_noise_sampler.sample(N_ESTIMATE_POINTS).to(CONFIG.DEVICE)\n",
    "                Y3_init =  init_noise_sampler.sample(N_ESTIMATE_POINTS).to(CONFIG.DEVICE)\n",
    "                Y1 = sample_langevin_mu_f(lambda x:  f_pots[0](x), X1, Y1_init, CONFIG)\n",
    "                Y3 = sample_langevin_mu_f(lambda x:  f_pots[2](x), X3, Y3_init, CONFIG)\n",
    "                Y2 = sample_langevin_mu_f(lambda x:  f_pots[1](x), X2, Y2_init, CONFIG)\n",
    "\n",
    "                map_1.append(normalize_out_to_0_1(G(Y1,c=None)));map_2.append(normalize_out_to_0_1(G(Y2,c=None)));map_3.append(normalize_out_to_0_1(G(Y3,c=None)))\n",
    "            \n",
    "                    \n",
    "            plot_barycenter_map_in_data_space_more(X1,X2,X3,map_1,map_2,map_3,step=step,n_estimate_points=8)\n",
    "            \n",
    "        if step % 50 == 0:\n",
    "            \n",
    "            for idx in range(len(CONFIG.CLASSES)):\n",
    "                    freeze(f1[idx]) \n",
    "                    \n",
    "            for k in range(CONFIG.K):\n",
    "                torch.save(f1[k].cpu().state_dict(), os.path.join(\"./\",\n",
    "           \"Averaging-with-Energy-A-Generic-Algorithm-for-Continuous-Entropic-Barycenter-Estimation/\",\n",
    "           \"ckpts/\",\n",
    "           f\"Ave_celeba_barycenter_in_latent_space/_net_{k}_{name_exp}.pth\")\n",
    "          )\n",
    "            \n",
    "            for k in range(CONFIG.K):\n",
    "                f1[k].to(CONFIG.DEVICE)\n",
    "              \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d5dea85",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "celltoolbar": "Tags",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
