{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "bde8af78",
   "metadata": {},
   "source": [
    "# Continuous entropic barycenter estimation of MNIST 01 in latent space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3ccfeeb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torch.distributions as TD\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline \n",
    "\n",
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "import dnnlib \n",
    "import legacy\n",
    "\n",
    "from tqdm import tqdm\n",
    "import wandb\n",
    "import os\n",
    "import sys \n",
    "sys.path.append(\"../..\")\n",
    "from src.utils import Config, Distrib2Sampler, normalize_out_to_0_1, plot_barycenter_map_in_data_space\n",
    "from src.eot_utils import computePotGrad, evaluating\n",
    "from src.eot import sample_langevin_batch\n",
    "from src.dgm_utils.statsmanager import StatsManager, StatsManagerDrawScheduler\n",
    "from src.cost import cond_score, cost_grad_image_latent\n",
    "from src.distributions import DatasetSampler\n",
    "from src.resnet2 import  ResNet_D, weights_init_D\n",
    " \n",
    "\n",
    "from typing import Callable, Tuple, Union"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f5945e30",
   "metadata": {},
   "source": [
    "## 1. Parameters for papermill"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7fcdc1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "GPU_DEVICES = [0]\n",
    "EPS = 0.01\n",
    "LR = 1e-4\n",
    "ENERGY_ITRS = 250\n",
    "BATCH_SIZE = 128"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7e85cda",
   "metadata": {},
   "source": [
    "## 2. Create Config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b60db97",
   "metadata": {},
   "outputs": [],
   "source": [
    "CONFIG = Config()\n",
    "\n",
    "CONFIG.DATASET = 'mnist01'\n",
    "CONFIG.DATASET_PATH  = '../../data/train_MNIST' \n",
    "CONFIG.BATCH_SIZE = BATCH_SIZE\n",
    "CONFIG.CLASSES=[0,1]\n",
    "CONFIG.IMG_SIZE=32\n",
    "CONFIG.NC=1\n",
    "\n",
    "CONFIG.ALPHAS_BARYCENTER = [.5, .5]\n",
    "CONFIG.MAX_STEPS = 1000\n",
    "CONFIG.K = len(CONFIG.ALPHAS_BARYCENTER)\n",
    "CONFIG.HREG = EPS\n",
    "\n",
    "CONFIG.LR = LR\n",
    "CONFIG.CLIP_GRADS_NORM = False\n",
    "CONFIG.BETAS = (0.2, 0.99)\n",
    "\n",
    "CONFIG.LANGEVIN_THRESH = None\n",
    "CONFIG.LANGEVIN_SAMPLING_NOISE = 0.1\n",
    "CONFIG.ENERGY_SAMPLING_ITERATIONS = ENERGY_ITRS\n",
    "CONFIG.LANGEVIN_DECAY = 1.0\n",
    "CONFIG.LANGEVIN_SCORE_COEFFICIENT = 1.0\n",
    "CONFIG.LANGEVIN_COST_COEFFICIENT = 1.0\n",
    " \n",
    "CONFIG.BASIC_NOISE_VAR = 2.0\n",
    "CONFIG.DEVICE =  f\"cuda:{GPU_DEVICES[0]}\"\n",
    "CONFIG.DEVICES_IDS = GPU_DEVICES\n",
    "CONFIG.LATENT_SIZE=512\n",
    "CONFIG.NUM_TEST_RUNS = 4\n",
    "CONFIG.FLAG_GRAYSCALE=False\n",
    "CONFIG.FLAG_F_G_LATENT = True\n",
    "CONFIG.GENERATOR_PATH = \"../../SG2_ckpt/mnist/mnist_01.pkl\" # path to generator"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "999db7f0",
   "metadata": {},
   "source": [
    "## 4. Data samplers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9215b6c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = torchvision.transforms.Compose([\n",
    "    torchvision.transforms.Resize((CONFIG.IMG_SIZE, CONFIG.IMG_SIZE)),\n",
    "    torchvision.transforms.ToTensor(),\n",
    "    torchvision.transforms.Lambda(lambda x: torch.clip(x,0,1)),\n",
    "     #torchvision.transforms.Lambda(lambda x: (x + 1)/2),\n",
    "    #torchvision.transforms.Lambda(lambda x: torch.clip(x,0,1))\n",
    "])\n",
    "\n",
    "data_samplers = []\n",
    " \n",
    "\n",
    "for k in range(CONFIG.K):\n",
    "    dataset = torchvision.datasets.MNIST(root=CONFIG.DATASET_PATH, download=True, \n",
    "                                         transform=transform)\n",
    "    idx = [t == CONFIG.CLASSES[k] for t in dataset.targets]\n",
    "    dataset.targets, dataset.data = np.array(dataset.targets)[idx], torch.tensor(dataset.data)[idx] \n",
    "    data_samplers.append(DatasetSampler(dataset,flag_label=True))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b7150e1a",
   "metadata": {},
   "source": [
    "## 5. Networks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1335db1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "f1 = ResNet_D(size=CONFIG.IMG_SIZE,\n",
    "              nc=CONFIG.NC, nfilter=64, nfilter_max=512, res_ratio=0.1).to(CONFIG.DEVICE)\n",
    "# f2 = - f1\n",
    "weights_init_D(f1)\n",
    "#f1 = DataParallelAttrAccess(f1,device_ids =CONFIG.DEVICES_IDS)\n",
    "\n",
    "f1_opt = torch.optim.Adam(f1.parameters(), CONFIG.LR, betas=CONFIG.BETAS)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f3c072c7",
   "metadata": {},
   "source": [
    "## 6. Style-GAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b863c03e",
   "metadata": {},
   "outputs": [],
   "source": [
    "with dnnlib.util.open_url(CONFIG.GENERATOR_PATH) as f:\n",
    "    G =  legacy.load_network_pkl(f)['G_ema'].to(CONFIG.DEVICE)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "237e1696",
   "metadata": {},
   "source": [
    "## 7. Preliminaries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c77ee68",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample_langevin_mu_f(\n",
    "        f: Callable[[torch.Tensor], torch.Tensor], \n",
    "        cost_grad_fn,\n",
    "        x: torch.Tensor, \n",
    "        y_init: torch.Tensor, \n",
    "        config: Config,\n",
    "        latent2data_gen\n",
    "    ) -> torch.Tensor:\n",
    "    \n",
    "    def score(y, ret_stats=False):\n",
    "        return cond_score(f, cost_grad_fn, y, x, config,\n",
    "                          flag_grayscale=CONFIG.FLAG_GRAYSCALE, \n",
    "                          flag_f_G_latent=CONFIG.FLAG_F_G_LATENT,\n",
    "                          latent2data_gen=latent2data_gen,ret_stats=ret_stats)\n",
    "    \n",
    "    y, r_t, cost_r_t, score_r_t, noise_norm = sample_langevin_batch(\n",
    "        score, \n",
    "        y_init,\n",
    "        n_steps=config.ENERGY_SAMPLING_ITERATIONS, \n",
    "        decay=config.LANGEVIN_DECAY, \n",
    "        thresh=config.LANGEVIN_THRESH, \n",
    "        noise=config.LANGEVIN_SAMPLING_NOISE, \n",
    "        data_projector=lambda x: x, \n",
    "        compute_stats=True)\n",
    "    \n",
    "    return y"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ce645396",
   "metadata": {},
   "source": [
    "## 8. Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cda69fbc",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "wandb.init(project=\"MNIST_01_barycenter_in_latent_space\" ,\n",
    "           name=f\"FLAG_F_G_{CONFIG.FLAG_F_G_LATENT}_EPS_{CONFIG.HREG}_LR_{CONFIG.LR}_BS_{CONFIG.BATCH_SIZE}_NS_{CONFIG.ENERGY_SAMPLING_ITERATIONS}\" ,\n",
    "           config=CONFIG)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38f55a78",
   "metadata": {},
   "outputs": [],
   "source": [
    "init_noise_sampler = Distrib2Sampler(TD.Normal(\n",
    "    torch.zeros( CONFIG.LATENT_SIZE).to('cpu'), \n",
    "    torch.ones(CONFIG.LATENT_SIZE).to('cpu') * CONFIG.BASIC_NOISE_VAR))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e76def9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d676972",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "SMDS = StatsManagerDrawScheduler(StatsManager('loss'), 1, 1, (5, 4), epoch_freq=20)\n",
    "\n",
    "for step in tqdm(range(CONFIG.MAX_STEPS)):\n",
    "     \n",
    "    f1.train(True)\n",
    "    X1 = data_samplers[0].sample(CONFIG.BATCH_SIZE).to(CONFIG.DEVICE)\n",
    "    Y1_init = init_noise_sampler.sample(CONFIG.BATCH_SIZE).to(CONFIG.DEVICE)\n",
    " \n",
    "    X2 = data_samplers[1].sample(CONFIG.BATCH_SIZE).to(CONFIG.DEVICE)\n",
    "    Y2_init = init_noise_sampler.sample(CONFIG.BATCH_SIZE).to(CONFIG.DEVICE)\n",
    "    \n",
    "    with evaluating(f1):\n",
    "        Y1 = sample_langevin_mu_f(lambda x: f1(x), cost_grad_image_latent, X1, Y1_init, CONFIG,  latent2data_gen=G)\n",
    "        Y2 = sample_langevin_mu_f(lambda x: -f1(x),cost_grad_image_latent, X2, Y2_init, CONFIG,  latent2data_gen=G)\n",
    "        \n",
    "    loss = CONFIG.ALPHAS_BARYCENTER[0]*f1(normalize_out_to_0_1(G(Y1,c=None))).mean() - CONFIG.ALPHAS_BARYCENTER[1]*f1(normalize_out_to_0_1(G(Y2,c=None))).mean()\n",
    "    f1_opt.zero_grad(); loss.backward(); f1_opt.step()\n",
    "  \n",
    "    SMDS.SM.upd('loss', loss.item())\n",
    "    SMDS.epoch()\n",
    "    wandb.log({\"loss train\":loss.item()},step=step)\n",
    "    \n",
    "    \n",
    "    with torch.no_grad():\n",
    "        \n",
    "        if step % 50 == 0:\n",
    "            N_ESTIMATE_POINTS = 8\n",
    "            X1 = data_samplers[0].sample(N_ESTIMATE_POINTS).to(CONFIG.DEVICE)\n",
    "            X2 = data_samplers[1].sample(N_ESTIMATE_POINTS).to(CONFIG.DEVICE)\n",
    "            with evaluating(f1):\n",
    "                map_1 = []\n",
    "                map_2 = []\n",
    "                for run in range(CONFIG.NUM_TEST_RUNS):\n",
    "                    Y1_init  = init_noise_sampler.sample(N_ESTIMATE_POINTS).to(CONFIG.DEVICE) \n",
    "                    Y2_init =  init_noise_sampler.sample(N_ESTIMATE_POINTS).to(CONFIG.DEVICE)\n",
    "                    Y1 = sample_langevin_mu_f(lambda x: f1(x),cost_grad_image_latent, X1, Y1_init, CONFIG, latent2data_gen=G)\n",
    "                    Y2 = sample_langevin_mu_f(lambda x: -f1(x),cost_grad_image_latent, X2, Y2_init, CONFIG, latent2data_gen=G)\n",
    "                    map_1.append(normalize_out_to_0_1(G(Y1,c=None)));map_2.append(normalize_out_to_0_1(G(Y2,c=None)))\n",
    "                    \n",
    "            plot_barycenter_map_in_data_space(X1,X2,map_1,map_2, step=step,n_estimate_points=8)\n",
    "            \n",
    "          "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c9e1a17",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
