{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b9ebe1d4",
   "metadata": {},
   "source": [
    "# Continuous entropic barycenter estimation of MNIST 01 in data space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af3d971b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torch.distributions as TD\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline \n",
    "\n",
    "from tqdm import tqdm\n",
    "import wandb\n",
    "import os\n",
    "import sys \n",
    "sys.path.append(\"..\")\n",
    "from src.utils import Config, Distrib2Sampler, plot_barycenter_map_in_data_space\n",
    "from src.eot_utils import computePotGrad, evaluating\n",
    "from src.eot import sample_langevin_batch\n",
    "from src.dgm_utils.statsmanager import StatsManager, StatsManagerDrawScheduler\n",
    "from src.cost import cond_score, cost_l2_grad_y\n",
    "from src.distributions import DatasetSampler\n",
    "from src.resnet2 import  ResNet_D, weights_init_D\n",
    "\n",
    "from typing import Callable, Tuple, Union"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "439f3a20",
   "metadata": {},
   "source": [
    "## 1. Parameters for Papermill"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d6c0811",
   "metadata": {},
   "outputs": [],
   "source": [
    "GPU_DEVICES = [0]\n",
    "EPS = 0.01\n",
    "LR = 1e-4\n",
    "ENERGY_ITRS = 10\n",
    "BATCH_SIZE = 128"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "56790093",
   "metadata": {},
   "source": [
    "## 2. Create Config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e988271e",
   "metadata": {},
   "outputs": [],
   "source": [
    "CONFIG = Config()\n",
    "\n",
    "CONFIG.DATASET = 'mnist01'\n",
    "CONFIG.DATASET_PATH  = './data/train_MNIST' \n",
    "CONFIG.BATCH_SIZE = BATCH_SIZE\n",
    "CONFIG.CLASSES=[0,1]\n",
    "CONFIG.IMG_SIZE=32\n",
    "CONFIG.NC=1\n",
    "\n",
    "CONFIG.ALPHAS_BARYCENTER = [.5, .5]\n",
    "CONFIG.MAX_STEPS = 1000\n",
    "CONFIG.K = len(CONFIG.ALPHAS_BARYCENTER)\n",
    "CONFIG.HREG = EPS\n",
    "\n",
    "CONFIG.LR = LR\n",
    "CONFIG.CLIP_GRADS_NORM = False\n",
    "CONFIG.BETAS = (0.2, 0.99)\n",
    "\n",
    "CONFIG.LANGEVIN_THRESH = None\n",
    "CONFIG.LANGEVIN_SAMPLING_NOISE = 0.1\n",
    "CONFIG.ENERGY_SAMPLING_ITERATIONS = ENERGY_ITRS\n",
    "CONFIG.LANGEVIN_DECAY = 1.0\n",
    "CONFIG.LANGEVIN_SCORE_COEFFICIENT = 1.0\n",
    "CONFIG.LANGEVIN_COST_COEFFICIENT = 1.0\n",
    " \n",
    "CONFIG.BASIC_NOISE_VAR = 2.0\n",
    "CONFIG.DEVICE =  f\"cuda:{GPU_DEVICES[0]}\"\n",
    "CONFIG.DEVICES_IDS = GPU_DEVICES\n",
    "\n",
    "CONFIG.NUM_TEST_RUNS = 4"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db228842",
   "metadata": {},
   "source": [
    "## 3. Data samplers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62d3e7cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = torchvision.transforms.Compose([\n",
    "    torchvision.transforms.Resize((CONFIG.IMG_SIZE, CONFIG.IMG_SIZE)),\n",
    "    torchvision.transforms.ToTensor(),\n",
    "    torchvision.transforms.Lambda(lambda x: 2 * x - 1)\n",
    "])\n",
    "\n",
    "data_samplers = []\n",
    " \n",
    "\n",
    "for k in range(CONFIG.K):\n",
    "    dataset = torchvision.datasets.MNIST(root=CONFIG.DATASET_PATH, download=True, \n",
    "                                         transform=transform)\n",
    "    idx = [t == CONFIG.CLASSES[k] for t in dataset.targets]\n",
    "    dataset.targets, dataset.data = np.array(dataset.targets)[idx], torch.tensor(dataset.data)[idx] \n",
    "    data_samplers.append(DatasetSampler(dataset,flag_label=True))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0915b95f",
   "metadata": {},
   "source": [
    "## 4. Potentials"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d0984c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "f1 = ResNet_D(size=CONFIG.IMG_SIZE,\n",
    "              nc=CONFIG.NC, nfilter=64, nfilter_max=512, res_ratio=0.1).to(CONFIG.DEVICE)\n",
    "# f2 = - f1\n",
    "weights_init_D(f1)\n",
    "#f1 = DataParallelAttrAccess(f1,device_ids =CONFIG.DEVICES_IDS)\n",
    "f1_opt = torch.optim.Adam(f1.parameters(), CONFIG.LR, betas=CONFIG.BETAS)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d286785",
   "metadata": {},
   "source": [
    "## 5. Preliminaries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec7bb172",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample_langevin_mu_f(\n",
    "        f: Callable[[torch.Tensor], torch.Tensor], \n",
    "        x: torch.Tensor, \n",
    "        y_init: torch.Tensor, \n",
    "        config: Config\n",
    "    ) -> torch.Tensor:\n",
    "    \n",
    "    def score(y, ret_stats=False):\n",
    "        return cond_score(f, cost_l2_grad_y, y, x, config, ret_stats=ret_stats)\n",
    "    \n",
    "    y, r_t, cost_r_t, score_r_t, noise_norm = sample_langevin_batch(\n",
    "        score, \n",
    "        y_init,\n",
    "        n_steps=config.ENERGY_SAMPLING_ITERATIONS, \n",
    "        decay=config.LANGEVIN_DECAY, \n",
    "        thresh=config.LANGEVIN_THRESH, \n",
    "        noise=config.LANGEVIN_SAMPLING_NOISE, \n",
    "        data_projector=lambda x: x, \n",
    "        compute_stats=True)\n",
    "    \n",
    "    return y"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f9b723cb",
   "metadata": {},
   "source": [
    "## 6. Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf6bfa9c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "wandb.init(project=\"MNIST_01_barycenter_in_data_space\" ,\n",
    "           name=f\"EPS_{CONFIG.HREG}_LR_{CONFIG.LR}_BS_{CONFIG.BATCH_SIZE}_NS_{CONFIG.ENERGY_SAMPLING_ITERATIONS}\" ,\n",
    "           config=CONFIG)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86015b65",
   "metadata": {},
   "outputs": [],
   "source": [
    "init_noise_sampler = Distrib2Sampler(TD.Normal(\n",
    "    torch.zeros(CONFIG.NC,CONFIG.IMG_SIZE,CONFIG.IMG_SIZE).to('cpu'), \n",
    "    torch.ones(CONFIG.NC,CONFIG.IMG_SIZE,CONFIG.IMG_SIZE).to('cpu') * CONFIG.BASIC_NOISE_VAR))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8db386c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "SMDS = StatsManagerDrawScheduler(StatsManager('loss'), 1, 1, (5, 4), epoch_freq=20)\n",
    "\n",
    "for step in tqdm(range(1)):\n",
    "     \n",
    "    f1.train(True)\n",
    "    X1 = data_samplers[0].sample(CONFIG.BATCH_SIZE).to(CONFIG.DEVICE)\n",
    "    X2 = data_samplers[1].sample(CONFIG.BATCH_SIZE).to(CONFIG.DEVICE)\n",
    "    \n",
    "    Y1_init = init_noise_sampler.sample(CONFIG.BATCH_SIZE).to(CONFIG.DEVICE)\n",
    "    Y2_init = init_noise_sampler.sample(CONFIG.BATCH_SIZE).to(CONFIG.DEVICE)\n",
    "    \n",
    "    with evaluating(f1):\n",
    "        Y1 = sample_langevin_mu_f(lambda x: f1(x), X1, Y1_init, CONFIG)\n",
    "        Y2 = sample_langevin_mu_f(lambda x: -f1(x), X2, Y2_init, CONFIG)\n",
    "        \n",
    "    loss = CONFIG.ALPHAS_BARYCENTER[0]*f1(Y1).mean() - CONFIG.ALPHAS_BARYCENTER[1]*f1(Y2).mean()\n",
    "    f1_opt.zero_grad(); loss.backward(); f1_opt.step()\n",
    "    SMDS.SM.upd('loss', loss.item())\n",
    "    SMDS.epoch()\n",
    "    wandb.log({\"loss train\":loss.item()},step=step)\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        \n",
    "        if step % 50 == 0:\n",
    "            N_ESTIMATE_POINTS = 8\n",
    "            X1 = data_samplers[0].sample(N_ESTIMATE_POINTS).to(CONFIG.DEVICE)\n",
    "            X2 = data_samplers[1].sample(N_ESTIMATE_POINTS).to(CONFIG.DEVICE)\n",
    "            with evaluating(f1):\n",
    "                map_1 = []\n",
    "                map_2 = []\n",
    "                for run in range(CONFIG.NUM_TEST_RUNS):\n",
    "                    Y1_init  = init_noise_sampler.sample(N_ESTIMATE_POINTS).to(CONFIG.DEVICE) \n",
    "                    Y2_init =  init_noise_sampler.sample(N_ESTIMATE_POINTS).to(CONFIG.DEVICE)\n",
    "                    Y1 = sample_langevin_mu_f(lambda x: f1(x), X1, Y1_init, CONFIG)\n",
    "                    Y2 = sample_langevin_mu_f(lambda x: -f1(x), X2, Y2_init, CONFIG)\n",
    "                    map_1.append(Y1);map_2.append(Y2)\n",
    "                    \n",
    "            plot_barycenter_map_in_data_space(X1,X2,map_1,map_2, step=step,n_estimate_points=8)   \n",
    "\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2228ef53",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
