{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "66da8413",
   "metadata": {},
   "source": [
    "# Ave celeba in data space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ea8d838",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torch.distributions as TD\n",
    "\n",
    "\n",
    "import os\n",
    "from tqdm import tqdm\n",
    "import wandb\n",
    "import itertools\n",
    "import sys \n",
    "sys.path.append(\"../../\")\n",
    "from src.utils import Config, Distrib2Sampler , normalize_out_to_0_1, make_f_pot, freeze, unfreeze, plot_barycenter_map_in_data_space_more\n",
    "from src.resnet2 import weights_init_D, ResNet_D\n",
    "from src.distributions import DatasetSampler\n",
    "from src.cost import cond_score,  cost_l2_grad_y\n",
    "from src.eot import sample_langevin_batch\n",
    "from src.dgm_utils.statsmanager import StatsManager, StatsManagerDrawScheduler\n",
    "from src.eot_utils import computePotGrad, evaluating\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "from typing import Callable, Tuple, Union"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "19a39191",
   "metadata": {},
   "source": [
    "## 1. Parameters of papermill"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f39a2583",
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "GPU_DEVICES = [0]\n",
    "EPS = 0.01\n",
    "LR = 1e-4\n",
    "ENERGY_ITRS = 500\n",
    "BATCH_SIZE = 64\n",
    "ALPHAS=[0.25, 0.5, 0.25]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3bcb27f4",
   "metadata": {},
   "source": [
    "## 2. Config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de84f442",
   "metadata": {},
   "outputs": [],
   "source": [
    "CONFIG = Config()\n",
    "\n",
    "CONFIG.DATASET = 'ave_celeba'\n",
    "CONFIG.DATASET_PATH  = '../../data/ave_celeba_green_v2/' \n",
    "CONFIG.BATCH_SIZE = BATCH_SIZE\n",
    "CONFIG.CLASSES=[0,1,2]\n",
    "CONFIG.IMG_SIZE=64\n",
    "CONFIG.NC=3\n",
    "\n",
    "CONFIG.ALPHAS_BARYCENTER = ALPHAS\n",
    "CONFIG.MAX_STEPS = 1000\n",
    "CONFIG.K = len(CONFIG.ALPHAS_BARYCENTER)\n",
    "CONFIG.HREG = EPS\n",
    "\n",
    "CONFIG.LR = LR\n",
    "CONFIG.CLIP_GRADS_NORM = False\n",
    "CONFIG.BETAS = (0.2, 0.99)\n",
    "\n",
    "CONFIG.LANGEVIN_THRESH = None\n",
    "CONFIG.LANGEVIN_SAMPLING_NOISE = 0.1\n",
    "CONFIG.ENERGY_SAMPLING_ITERATIONS = ENERGY_ITRS\n",
    "CONFIG.LANGEVIN_DECAY = 1.0\n",
    "CONFIG.LANGEVIN_SCORE_COEFFICIENT = 1.0\n",
    "CONFIG.LANGEVIN_COST_COEFFICIENT = 1.0\n",
    " \n",
    "CONFIG.BASIC_NOISE_VAR = 2.0\n",
    "CONFIG.DEVICE =  f\"cuda:{GPU_DEVICES[0]}\"\n",
    "CONFIG.DEVICES_IDS = GPU_DEVICES\n",
    "\n",
    "CONFIG.NUM_TEST_RUNS = 4\n",
    "CONFIG.FLAG_F_G_LATENT=False"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f87814bc",
   "metadata": {},
   "source": [
    "## 3. Data Sampler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24cf80f6",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "transform = torchvision.transforms.Compose([\n",
    "    torchvision.transforms.Resize(CONFIG.IMG_SIZE),\n",
    "    torchvision.transforms.ToTensor(),\n",
    "    torchvision.transforms.Lambda(lambda x: 2 * x - 1),\n",
    "    torchvision.transforms.Lambda(lambda x:  (x+1)/2),\n",
    "    torchvision.transforms.Lambda(lambda x: torch.clip(x,0,1))\n",
    "])\n",
    "\n",
    "data_samplers=[]\n",
    "for k in tqdm(range(CONFIG.K)):\n",
    "    dataset = torchvision.datasets.ImageFolder(os.path.join(CONFIG.DATASET_PATH,f\"ave_celeba_{k}/\"), transform=transform)\n",
    "    data_samplers.append(DatasetSampler(dataset, flag_label=True, batch_size=256 ,num_workers=40))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e5aaaf69",
   "metadata": {},
   "source": [
    "## 4. Networks "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fffe6170",
   "metadata": {},
   "outputs": [],
   "source": [
    "f1 = []\n",
    "for idx,alpha in enumerate(CONFIG.ALPHAS_BARYCENTER):\n",
    "    \n",
    " \n",
    "    f1.append( ResNet_D(size=CONFIG.IMG_SIZE,\n",
    "              nc=CONFIG.NC, nfilter=64, nfilter_max=512, res_ratio=0.1).to(CONFIG.DEVICE))\n",
    "    weights_init_D(f1[idx])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b5de90a",
   "metadata": {},
   "outputs": [],
   "source": [
    "f1_opt = torch.optim.Adam(itertools.chain(f1[0].parameters(),f1[1].parameters(),f1[2].parameters()),\n",
    "                              CONFIG.LR, betas=CONFIG.BETAS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0f058b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "f_pots = [make_f_pot(i,f1,CONFIG) for i in range(len(CONFIG.CLASSES))]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6e97376",
   "metadata": {},
   "source": [
    "## 5. Preliminaries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f0bc5c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample_langevin_mu_f(\n",
    "        f: Callable[[torch.Tensor], torch.Tensor], \n",
    "        x: torch.Tensor, \n",
    "        y_init: torch.Tensor, \n",
    "        config: Config\n",
    "    ) -> torch.Tensor:\n",
    "    \n",
    "    def score(y, ret_stats=False):\n",
    "        return cond_score(f, cost_l2_grad_y , y, x, config, flag_grayscale=False,\n",
    "        flag_f_G_latent=CONFIG.FLAG_F_G_LATENT, latent2data_gen=None, ret_stats=ret_stats)\n",
    "    \n",
    "    y, r_t, cost_r_t, score_r_t, noise_norm = sample_langevin_batch(\n",
    "        score, \n",
    "        y_init,\n",
    "        n_steps=config.ENERGY_SAMPLING_ITERATIONS, \n",
    "        decay=config.LANGEVIN_DECAY, \n",
    "        thresh=config.LANGEVIN_THRESH, \n",
    "        noise=config.LANGEVIN_SAMPLING_NOISE, \n",
    "        data_projector=lambda x: x, \n",
    "        compute_stats=True)\n",
    "    \n",
    "    return y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81ec35f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "name_exp = f\"EPS_{EPS}_LR_{LR}_NS_{ENERGY_ITRS}_BATCH_SIZE_{BATCH_SIZE}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e06ca54c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "wandb.init(project=\"Ave_celeba_in_data_space\" ,\n",
    "           name=name_exp,\n",
    "           config=CONFIG)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43acd596",
   "metadata": {},
   "outputs": [],
   "source": [
    "init_noise_sampler = Distrib2Sampler(TD.Normal(\n",
    "    torch.zeros(CONFIG.NC,CONFIG.IMG_SIZE, CONFIG.IMG_SIZE).to('cpu'), \n",
    "    torch.ones(CONFIG.NC, CONFIG.IMG_SIZE, CONFIG.IMG_SIZE).to('cpu') * CONFIG.BASIC_NOISE_VAR))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35719fff",
   "metadata": {},
   "outputs": [],
   "source": [
    "SMDS = StatsManagerDrawScheduler(StatsManager('loss'), 1, 1, (5, 4), epoch_freq=20)\n",
    "\n",
    "for step in tqdm(range(CONFIG.MAX_STEPS)):\n",
    "    \n",
    "    \n",
    "    for idx in range(len(CONFIG.CLASSES)):\n",
    "        unfreeze(f1[idx])\n",
    "        \n",
    "    X1 = data_samplers[0].sample(CONFIG.BATCH_SIZE).to(CONFIG.DEVICE)\n",
    "    X2 = data_samplers[1].sample(CONFIG.BATCH_SIZE).to(CONFIG.DEVICE)\n",
    "    X3 = data_samplers[2].sample(CONFIG.BATCH_SIZE).to(CONFIG.DEVICE)\n",
    "    \n",
    "    Y1_init = init_noise_sampler.sample(CONFIG.BATCH_SIZE).to(CONFIG.DEVICE)\n",
    "    Y2_init = init_noise_sampler.sample(CONFIG.BATCH_SIZE).to(CONFIG.DEVICE)\n",
    "    Y3_init = init_noise_sampler.sample(CONFIG.BATCH_SIZE).to(CONFIG.DEVICE)\n",
    "    \n",
    "    \n",
    "    for idx in range(len(CONFIG.CLASSES)):\n",
    "        freeze(f1[idx])  \n",
    "    with torch.no_grad():\n",
    "        Y1 = sample_langevin_mu_f(lambda x:  f_pots[0](x), X1, Y1_init, CONFIG)\n",
    "        Y3 = sample_langevin_mu_f(lambda x:  f_pots[2](x), X3, Y3_init, CONFIG)\n",
    "        Y2 = sample_langevin_mu_f(lambda x:  f_pots[1](x), X2, Y2_init, CONFIG)\n",
    "         \n",
    "    for idx in range(len(CONFIG.CLASSES)):\n",
    "        unfreeze(f1[idx]) \n",
    "    \n",
    "    if CONFIG.FLAG_F_G_LATENT:\n",
    "        loss = CONFIG.ALPHAS_BARYCENTER[0]*f_pots[0]( normalize_out_to_0_1(G(Y1,c=None)) ).mean() + CONFIG.ALPHAS_BARYCENTER[1]*f_pots[1](normalize_out_to_0_1(G(Y2,c=None)) ).mean() +\\\n",
    "    + CONFIG.ALPHAS_BARYCENTER[2]*f_pots[2](normalize_out_to_0_1(G(Y3,c=None))  ).mean()\n",
    "    else:\n",
    "        loss = CONFIG.ALPHAS_BARYCENTER[0]*f_pots[0]( Y1 ).mean() + CONFIG.ALPHAS_BARYCENTER[1]*f_pots[1](Y2).mean() +\\\n",
    "    + CONFIG.ALPHAS_BARYCENTER[2]*f_pots[2]( Y3  ).mean()\n",
    "        \n",
    "    \n",
    "    f1_opt.zero_grad(); loss.backward(); f1_opt.step()\n",
    "    SMDS.SM.upd('loss', loss.item())\n",
    "    SMDS.epoch()\n",
    "    wandb.log({\"loss train\":loss.item()},step=step)\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        \n",
    "        if step % 10 == 0:\n",
    "            cost_0 = 0.5 * torch.flatten(  Y1  - X1,\n",
    "                                   start_dim=1).pow(2).sum(dim=1, keepdim=True)\n",
    "            cost_1 = 0.5 * torch.flatten(  Y2  - X2,\n",
    "                                   start_dim=1).pow(2).sum(dim=1, keepdim=True)\n",
    "            cost_2 = 0.5 * torch.flatten(  Y3  - X3,\n",
    "                                   start_dim=1).pow(2).sum(dim=1, keepdim=True)\n",
    "            \n",
    "            wandb.log({\"Transport by cost 0\": cost_0.mean().item()},step=step)\n",
    "            wandb.log({\"Transport by cost 1\": cost_1.mean().item()},step=step)\n",
    "            wandb.log({\"Transport by cost 2\": cost_2.mean().item()},step=step)\n",
    "        \n",
    "        if step % 50 == 0:\n",
    "            N_ESTIMATE_POINTS = 8\n",
    "            X1 = data_samplers[0].sample(N_ESTIMATE_POINTS).to(CONFIG.DEVICE)\n",
    "            X2 = data_samplers[1].sample(N_ESTIMATE_POINTS).to(CONFIG.DEVICE)\n",
    "            X3 = data_samplers[2].sample(N_ESTIMATE_POINTS).to(CONFIG.DEVICE)\n",
    "            \n",
    "            for idx in range(len(CONFIG.CLASSES)):\n",
    "                    freeze(f1[idx]) \n",
    "                    \n",
    "            map_1 = []\n",
    "            map_2 = []\n",
    "            map_3 = []\n",
    "            for run in range(CONFIG.NUM_TEST_RUNS):\n",
    "                Y1_init  = init_noise_sampler.sample(N_ESTIMATE_POINTS).to(CONFIG.DEVICE) \n",
    "                Y2_init =  init_noise_sampler.sample(N_ESTIMATE_POINTS).to(CONFIG.DEVICE)\n",
    "                Y3_init =  init_noise_sampler.sample(N_ESTIMATE_POINTS).to(CONFIG.DEVICE)\n",
    "                Y1 = sample_langevin_mu_f(lambda x:  f_pots[0](x), X1, Y1_init, CONFIG)\n",
    "                Y3 = sample_langevin_mu_f(lambda x:  f_pots[2](x), X3, Y3_init, CONFIG)\n",
    "                Y2 = sample_langevin_mu_f(lambda x:  f_pots[1](x), X2, Y2_init, CONFIG)\n",
    "\n",
    "                map_1.append( Y1 );map_2.append( Y2 );map_3.append( Y3 )\n",
    "            \n",
    "                    \n",
    "            plot_barycenter_map_in_data_space_more(X1,X2,X3,map_1,map_2,map_3,step=step,n_estimate_points=8)\n",
    "            \n",
    "        \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6a30b44",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "celltoolbar": "Tags",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
