{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a63d0b8c-bede-441e-b854-8a527702e471",
   "metadata": {},
   "source": [
    "# Kraskov-Stogbauer-Grassberger MI estimator + images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ebc4c6b2-927a-49a5-b849-29fecadac4fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import wandb\n",
    "\n",
    "import torchvision\n",
    "import torch\n",
    "\n",
    "import torchvision.transforms.functional as F\n",
    "\n",
    "import sys\n",
    "\n",
    "sys.path.append('../python')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f46b62f-7e39-41ef-9dcf-06c881915ce7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import wandb\n",
    "\n",
    "import torchvision\n",
    "import torch\n",
    "\n",
    "import torchvision.transforms.functional as F\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from tqdm import tqdm as tqdm\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c910291-2aa1-4806-a608-783f6b4a7c68",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from utils.plots import *\n",
    "# from utils.tests import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "edaf7b4d-24c6-454f-a247-4c5905baf547",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# from utils.plots import *\n",
    "# from utils.tests import *\n",
    "\n",
    "font = {'size' : 16}\n",
    "matplotlib.rc('font', **font)\n",
    "\n",
    "from mutinfo.distributions.base import CorrelatedNormal, CorrelatedStudent, CorrelatedUniform,  SmoothedUniform, UniformlyQuantized\n",
    "from mutinfo.distributions.tools import mapped_multi_rv_frozen\n",
    "from mutinfo.distributions.images.geometric import uniform_to_rectangle, draw_rectangle\n",
    "from mutinfo.distributions.images.field import draw_field, symmetric_gaussian_field\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f879820c-cb41-41e4-a566-76be1e6c45d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "MI_grid = np.linspace(0.0, 10.0, 11)\n",
    "n_samples = 1000\n",
    "n_runs = 10\n",
    "\n",
    "image_shape = (16, 16)\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bec5d239-c84b-48cd-90cf-b0ac89b9e8af",
   "metadata": {},
   "source": [
    "### Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5db2dd3f-eb0b-4c77-9e94-7ac526f9f3b8",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "\n",
    "eps = 1\n",
    "\n",
    "image_size = 16\n",
    "\n",
    "mi_gt = 10\n",
    "\n",
    "shape = 'gaussian'\n",
    "\n",
    "# shape = 'rectangle'\n",
    "\n",
    "n_features = 256\n",
    "\n",
    "layers_per_block = 2\n",
    "\n",
    "batch_size = 128\n",
    "\n",
    "max_iter = 50000\n",
    "eval_freq = 2000\n",
    "\n",
    "loss_weight = 'sum'\n",
    "\n",
    "predict_type = 'vector_field'\n",
    "\n",
    "lr = 1e-4\n",
    "# ckpt_path = None\n",
    "ind_opt_steps = 1\n",
    "\n",
    "t_alpha = None\n",
    "t_beta = None\n",
    "\n",
    "seed = 42\n",
    "\n",
    "train_set_size = 10000\n",
    "\n",
    "val_set_size = 10000\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60c6c2ab-dd26-4abe-83ac-02b4b0ca676d",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# ckpt_path = ''\n",
    "# device = 'cuda:0'\n",
    "\n",
    "# t_alpha = 0.8\n",
    "# t_beta = 0.4\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3cb415b",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "np.random.seed(seed)\n",
    "\n",
    "torch.manual_seed(seed) \n",
    "\n",
    "device = 'cuda:0'\n",
    "\n",
    "image_shape = (image_size, image_size)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5319f89-358b-4136-bc3f-992c443dfb16",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "wandb_config = {'image size': image_size, 'mi': mi_gt,\n",
    "              'eps': eps, 'shape': shape, 'n_features': n_features,\n",
    "              'max_iter': max_iter, 'eval_freq': eval_freq, 'predict_type': predict_type, 'ind_opt_steps': ind_opt_steps,\n",
    "              'lr': lr, 't_alpha': t_alpha, 't_beta': t_beta, 'loss_weight': loss_weight, 'seed': seed,\n",
    "               'train_set_size': train_set_size, 'val_set_size': val_set_size}\n",
    "\n",
    "wandb.init(project=\"Bridge_MI_Image_finite_samples\", name=f\"{shape}_{image_size}_mi_{mi_gt}_eps_{eps}_predict_{predict_type}_lr_{lr}_loss_weight_{loss_weight}_beta_dist_{t_alpha}\", config=wandb_config)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "864ae1e6-751c-43ab-a051-542a71d26ae1",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# base_dist = CorrelatedUniform(mi_gt, 2, 2)\n",
    "\n",
    "if shape == 'gaussian':\n",
    "    \n",
    "    base_dist = CorrelatedUniform(mi_gt, 2, 2, randomize_interactions=False,\n",
    "                                  shuffle_interactions=True)\n",
    "    MI_pictures_obj = mapped_multi_rv_frozen(\n",
    "            base_dist,\n",
    "            lambda x, y: (\n",
    "                draw_field(x, symmetric_gaussian_field, image_shape),\n",
    "                draw_field(y, symmetric_gaussian_field, image_shape)\n",
    "            )\n",
    "        )\n",
    "else:\n",
    "    \n",
    "    base_dist = CorrelatedUniform(mi_gt, 4, 4, randomize_interactions=False,\n",
    "                                  shuffle_interactions=True)\n",
    "    MI_pictures_obj = mapped_multi_rv_frozen(\n",
    "            base_dist,\n",
    "            lambda x, y: (\n",
    "                draw_rectangle(uniform_to_rectangle(x, min_size=(0.2, 0.2)), image_shape),\n",
    "                draw_rectangle(uniform_to_rectangle(y, min_size=(0.2, 0.2)), image_shape)\n",
    "            )\n",
    "        )\n",
    "\n",
    "\n",
    "# MI_pictures_obj = mapped_multi_rv_frozen(\n",
    "#         base_dist,\n",
    "#         lambda x, y: (\n",
    "#             draw_field(x, symmetric_gaussian_field, image_shape),\n",
    "#             draw_field(y, symmetric_gaussian_field, image_shape)\n",
    "#         )\n",
    "#     )\n",
    "\n",
    "x, y = MI_pictures_obj.rvs(1)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0dfdcbc8-2655-4ca7-8782-f95640bb897a",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "x_data, y_data = MI_pictures_obj.rvs(train_set_size)\n",
    "\n",
    "x_data_val, y_data_val = MI_pictures_obj.rvs(val_set_size)\n",
    "\n",
    "print(x_data.shape, y_data.shape)\n",
    "\n",
    "dataset = torch.utils.data.TensorDataset(torch.Tensor(x_data), torch.Tensor(y_data))\n",
    "\n",
    "dataset_val = torch.utils.data.TensorDataset(torch.Tensor(x_data_val), torch.Tensor(y_data_val))\n",
    "\n",
    "dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "dataloader_val = torch.utils.data.DataLoader(dataset_val, batch_size=batch_size, shuffle=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e6c3328-de41-4fe7-aff9-23deada60385",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "def show(imgs):\n",
    "    if not isinstance(imgs, list):\n",
    "        imgs = [imgs]\n",
    "    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)\n",
    "    for i, img in enumerate(imgs):\n",
    "        img = img.detach()\n",
    "        img = F.to_pil_image(img)\n",
    "        axs[0, i].imshow(np.asarray(img))\n",
    "        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])\n",
    "\n",
    "\n",
    "x_batch, y_batch = MI_pictures_obj.rvs(16)\n",
    "\n",
    "grid_x = torchvision.utils.make_grid(torch.Tensor(x_batch).unsqueeze(1))\n",
    "\n",
    "show(grid_x)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9c4f1e9-99b1-4154-b9ae-76dd96af134b",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "grid_y = torchvision.utils.make_grid(torch.Tensor(y_batch).unsqueeze(1))\n",
    "\n",
    "show(grid_y)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d933dc1a-e16a-4a37-b8ac-e3cd2b8025a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def mmd(x, y):\n",
    "    \"\"\"Memory-efficient MMD implementation in JAX.\n",
    "\n",
    "    This implements the minimum-variance/biased version of the estimator described\n",
    "    in Eq.(5) of\n",
    "    https://jmlr.csail.mit.edu/papers/volume13/gretton12a/gretton12a.pdf.\n",
    "    As described in Lemma 6's proof in that paper, the unbiased estimate and the\n",
    "    minimum-variance estimate for MMD are almost identical.\n",
    "\n",
    "    Note that the first invocation of this function will be considerably slow due\n",
    "    to JAX JIT compilation.\n",
    "\n",
    "    Args:\n",
    "      x: The first set of embeddings of shape (n, embedding_dim).\n",
    "      y: The second set of embeddings of shape (n, embedding_dim).\n",
    "\n",
    "    Returns:\n",
    "      The MMD distance between x and y embedding sets.\n",
    "    \"\"\"\n",
    "    # The bandwidth parameter for the Gaussian RBF kernel. See the paper for more\n",
    "    # details.\n",
    "    _SIGMA = 10\n",
    "    # The following is used to make the metric more human readable. See the paper\n",
    "    # for more details.\n",
    "    _SCALE = 1000\n",
    "    \n",
    "    # x = torch.from_numpy(x)\n",
    "    # y = torch.from_numpy(y)\n",
    "\n",
    "    x_sqnorms = torch.diag(torch.matmul(x, x.T))\n",
    "    y_sqnorms = torch.diag(torch.matmul(y, y.T))\n",
    "\n",
    "    gamma = 1 / (2 * _SIGMA**2)\n",
    "    k_xx = torch.mean(\n",
    "        torch.exp(-gamma * (-2 * torch.matmul(x, x.T) + torch.unsqueeze(x_sqnorms, 1) + torch.unsqueeze(x_sqnorms, 0)))\n",
    "    )\n",
    "    k_xy = torch.mean(\n",
    "        torch.exp(-gamma * (-2 * torch.matmul(x, y.T) + torch.unsqueeze(x_sqnorms, 1) + torch.unsqueeze(y_sqnorms, 0)))\n",
    "    )\n",
    "    k_yy = torch.mean(\n",
    "        torch.exp(-gamma * (-2 * torch.matmul(y, y.T) + torch.unsqueeze(y_sqnorms, 1) + torch.unsqueeze(y_sqnorms, 0)))\n",
    "    )\n",
    "\n",
    "    return _SCALE * (k_xx + k_yy - 2 * k_xy)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf92be0f-43d0-467d-ba7d-88c9dfb3abdb",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from tqdm import tqdm as tqdm\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "class BridgeNoiseScheduler:\n",
    "    def __init__(self) -> None:\n",
    "        pass\n",
    "\n",
    "    def beta(self, t) -> torch.Tensor:\n",
    "        raise NotImplementedError('Abstract class')\n",
    "\n",
    "    def sigma(self, t) -> torch.Tensor:\n",
    "        raise NotImplementedError('Abstract class')\n",
    "    \n",
    "    def sigma_overlined(self, t) -> torch.Tensor:\n",
    "        raise NotImplementedError('Abstract class')\n",
    "    \n",
    "class LinearBridgeNoiseScheduler(BridgeNoiseScheduler):\n",
    "    def __init__(self, max_beta=1.5e-4) -> None:\n",
    "        super().__init__()\n",
    "        self.max_beta = 2 * max_beta\n",
    "\n",
    "    def beta(self, t) -> torch.Tensor:\n",
    "        return (0.5 - torch.abs(t - 0.5)) * self.max_beta\n",
    "\n",
    "    def _intergrate_beta(self, t_1, t_2):\n",
    "        ret_val = torch.min(t_2, torch.tensor(0.5)) ** 2 / 2 - torch.min(t_1, torch.tensor(0.5)) ** 2 / 2\n",
    "        ret_val -= (1 - torch.min(t_2, torch.tensor(1))) ** 2 / 2 - (1 - torch.min(t_1, torch.tensor(1))) ** 2 / 2\n",
    "        return ret_val * self.max_beta\n",
    "    \n",
    "    def sigma(self, t) -> torch.Tensor:\n",
    "        return torch.sqrt(self._intergrate_beta(torch.zeros_like(t), t))\n",
    "    \n",
    "    def sigma_overlined(self, t) -> torch.Tensor:\n",
    "        return torch.sqrt(self._intergrate_beta(t, torch.ones_like(t)))\n",
    "\n",
    "class BridgeMathcing(nn.Module):\n",
    "    def __init__(self, unet, eps, predict_type='vector_field', loss_weight=False):\n",
    "        super().__init__()\n",
    "\n",
    "        self.loss_weight = loss_weight\n",
    "\n",
    "        assert predict_type in ['vector_field', 'x_1', 'noise']\n",
    "        \n",
    "        self.predict_type = predict_type\n",
    "        self.vector_net = unet\n",
    "        \n",
    "        self.eps = eps\n",
    "        \n",
    "    def forward(self, x_0):\n",
    "        # solve forward ODE via Euler or torchdiffeq solver\n",
    "        x_t = x_0\n",
    "        \n",
    "        t_range = tqdm(torch.arange(0, 1, step=self.euler_dt))\n",
    "        \n",
    "        for t in t_range:\n",
    "            eps_noise = torch.randn_like(x_t, device=x_0.device)\n",
    "            x_t = x_t + self.vector_net(x_t, x_0, t.squeeze()) * self.euler_dt + torch.sqrt(self.euler_dt * self.eps) * eps_noise\n",
    "        \n",
    "        return x_t\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def sample(self, x_0, nfe=100, pbar=True):\n",
    "\n",
    "        euler_dt = 1. / nfe\n",
    "        \n",
    "        x_t = x_0\n",
    "        \n",
    "        if pbar:\n",
    "            t_range = tqdm(torch.arange(0, 1, step=euler_dt).to(x_0.device))\n",
    "        else:\n",
    "            t_range = torch.arange(0, 1, step=euler_dt).to(x_0.device)\n",
    "        \n",
    "        for t in t_range:\n",
    "            eps_noise = torch.randn_like(x_t, device=x_0.device)\n",
    "            \n",
    "            x_t = x_t + self.vector_net(x_t, x_0, t.squeeze()) * euler_dt + math.sqrt(euler_dt * self.eps) * eps_noise\n",
    "            \n",
    "        return x_t\n",
    "\n",
    "    def sample_x_t(self, x_0, x_1, t):\n",
    "\n",
    "        coef_0, coef_1 = 1 - t, t\n",
    "\n",
    "        std_t = torch.sqrt(t * (1 - t) * self.eps)\n",
    "\n",
    "        z = torch.randn_like(x_0, device=x_0.device)\n",
    "        \n",
    "        x_t = coef_1.reshape([-1, 1, 1, 1]) * x_1 + coef_0.reshape([-1, 1, 1, 1]) * x_0 + z * std_t.reshape([-1, 1, 1, 1])\n",
    "        return x_t, z\n",
    "    \n",
    "    def step(self, x_0, x_1, t):\n",
    "        t = t.reshape([-1, 1, 1, 1])\n",
    "        x_t, z = self.sample_x_t(x_0, x_1, t)\n",
    "        x_t_hat = self.vector_net(x_t, x_0, t.squeeze())\n",
    "        return self.loss(x_t_hat, x_1, x_0, x_t, t, z).mean()\n",
    "    \n",
    "    def loss(self, x_t_hat, x_1, x_0, x_t, t, z):\n",
    "\n",
    "        if self.predict_type == 'x_1':\n",
    "\n",
    "            vector_field = (x_t_hat - (x_1 - x_t) / (1 - t)).reshape([x_1.shape[0], -1])\n",
    "\n",
    "            if self.loss_weight:\n",
    "\n",
    "                with torch.no_grad():\n",
    "                    \n",
    "                    z_coef = torch.abs((vector_field**2 - (((x_1 - x_t) / (1 - t)).reshape([x_1.shape[0], -1]) * vector_field)).mean(-1))\n",
    "\n",
    "                    if self.loss_weight == 'mean':\n",
    "                        norm_const = z_coef.mean()\n",
    "                    elif self.loss_weight == 'sum':\n",
    "                        norm_const = z_coef.sum()\n",
    "                    else:\n",
    "                        raise RuntimeError('Unknown loss weight')\n",
    "\n",
    "                # print(z_coef[:32] / norm_const)\n",
    "                    \n",
    "                # print(z_coef[:32], z_coef.shape, norm_const.shape)\n",
    "                    \n",
    "                return torch.norm((x_t_hat - x_1).reshape([x_1.shape[0], -1]), dim=-1) * z_coef / norm_const\n",
    "            \n",
    "            return torch.norm((x_t_hat - x_1).reshape([x_1.shape[0], -1]), dim=-1)\n",
    "        \n",
    "        if self.predict_type == 'vector_field':\n",
    "\n",
    "            \n",
    "            return torch.norm((x_t_hat - (x_1 - x_t) / (1 - t)).reshape([x_1.shape[0], -1]), dim=-1)\n",
    "\n",
    "        if self.predict_type == 'noise':\n",
    "\n",
    "            # z = (x_t - t * x_1 - (1 - t) * x_0) / torch.sqrt(t * (1 - t) * self.eps)\n",
    "            \n",
    "            return torch.norm((x_t_hat - z).reshape([x_1.shape[0], -1]), dim=-1)\n",
    "\n",
    "import math\n",
    "\n",
    "from copy import deepcopy\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9cd74753-2c8d-4368-acf3-73eda79bf15d",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "@torch.no_grad()\n",
    "def sample_posterior(vector_net, eps, x_0, nfe=100, pbar=False, posterior='partial', predict_type='vector_field'):\n",
    "    x_t = x_0\n",
    "\n",
    "    euler_dt = 1. / nfe\n",
    "    \n",
    "    if pbar:\n",
    "        t_range = tqdm(torch.arange(0, 1, step=euler_dt).to(x_0.device))\n",
    "    else:\n",
    "        t_range = torch.arange(0, 1, step=euler_dt).to(x_0.device)\n",
    "\n",
    "    if predict_type == 'noise':\n",
    "        t_range = torch.arange(0.01, 1, step=euler_dt).to(x_0.device)\n",
    "    \n",
    "    for t in t_range:\n",
    "        # eps_noise = torch.randn_like(x_t, device=x_0.device)\n",
    "                                    \n",
    "        t_next = t + euler_dt\n",
    "\n",
    "        if predict_type == 'vector_field':\n",
    "        \n",
    "            x_1_predict = x_t + (1 - t) * vector_net(x_t, x_0, t.squeeze())\n",
    "            \n",
    "        elif predict_type == 'x_1':\n",
    "            x_1_predict = vector_net(x_t, x_0, t.squeeze())\n",
    "\n",
    "        elif predict_type == 'noise':\n",
    "            z_predict = vector_net(x_t, x_0, t.squeeze())\n",
    "\n",
    "            x_1_predict = (x_t - (1 - t) * x_0 - torch.sqrt(t * (1 - t) * eps) * z_predict) / t\n",
    "\n",
    "        if posterior == 'partial':\n",
    "    \n",
    "            std = torch.sqrt(eps * (t_next - t) * (1 - t_next) / (1 - t))\n",
    "    \n",
    "            mean_x_t = x_t * ((1 - t_next) / (1 - t)) + x_1_predict * (1 - (1 - t_next) / (1 - t))\n",
    "            \n",
    "        elif  posterior == 'full':\n",
    "            \n",
    "            std = torch.sqrt(eps * (t_next) * (1 - t_next))\n",
    "    \n",
    "            mean_x_t = x_0 * (1 - t_next) + x_1_predict * (t_next)\n",
    "            \n",
    "        x_t = mean_x_t + std * torch.randn_like(x_t)\n",
    "\n",
    "    return x_t\n",
    "\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def estimate_kl(drift_1, drift_2, eps, dataloader, nfe=100, t_eps=1e-5, posterior='partial', iters=10, predict_type='vector_field'):\n",
    "    \n",
    "    # inference \n",
    "        \n",
    "    def BB_sample(x_1, x_0, t):\n",
    "        mean = x_1 * t + x_0 * (1 - t)\n",
    "        std = torch.sqrt(eps * (1 - t) * t)\n",
    "\n",
    "        return mean + std * torch.randn_like(x_1)\n",
    "        \n",
    "    kl_value = 0\n",
    "    \n",
    "    fn = lambda x: x\n",
    "\n",
    "    n_iters = 0\n",
    "\n",
    "    for x_batch, y_batch in dataloader:\n",
    "\n",
    "        n_iters += 1\n",
    "            \n",
    "        y_gen_batch = sample_posterior(drift_1, eps, torch.Tensor(x_batch).unsqueeze(1).to(device), nfe=nfe, pbar=False, posterior=posterior, predict_type=predict_type)\n",
    "        \n",
    "        x_batch = torch.Tensor(x_batch).unsqueeze(1).to(device)\n",
    "        y_batch = torch.Tensor(y_batch).unsqueeze(1).to(device)\n",
    "        \n",
    "        t = (torch.rand(x_batch.shape[0]).to(device)) * (1 - 2 * t_eps) + t_eps\n",
    "                \n",
    "        t = t.reshape([-1, 1, 1, 1])\n",
    "        \n",
    "        # x_t = BB_sample(y_gen_batch, x_batch, t)\n",
    "        \n",
    "        x_t = BB_sample(y_batch, x_batch, t)\n",
    "\n",
    "        if predict_type == 'vector_field':\n",
    "            \n",
    "            v_1, v_2 = drift_1(x_t, x_batch, t.squeeze()), drift_2(x_t, x_batch, t.squeeze())\n",
    "            \n",
    "        elif predict_type == 'x_1':\n",
    "            \n",
    "            v_1, v_2 = (drift_1(x_t, x_batch, t.squeeze()) - x_t) / (1 - t), (drift_2(x_t, x_batch, t.squeeze()) - x_t) / (1 - t)\n",
    "\n",
    "        elif predict_type == 'noise':\n",
    "            \n",
    "            z_predict_1, z_predict_2 = drift_1(x_t, x_batch, t.squeeze()), drift_2(x_t, x_batch, t.squeeze())\n",
    "\n",
    "            x_1_predict_1 = (x_t - (1 - t) * x_batch - torch.sqrt(t * (1 - t) * eps) * z_predict_1) / t\n",
    "            \n",
    "            x_1_predict_2 = (x_t - (1 - t) * x_batch - torch.sqrt(t * (1 - t) * eps) * z_predict_2) / t\n",
    "            \n",
    "            v_1, v_2 = (x_1_predict_1 - x_t) / (1 - t), (x_1_predict_2 - x_t) / (1 - t)\n",
    "        \n",
    "        kl_value += ((fn( v_1 ) - fn( v_2 ) )**2).sum([-1, -2]).mean()\n",
    "\n",
    "    print( 1 / (2 * eps) * kl_value / n_iters)\n",
    "        \n",
    "    return 1 / (2 * eps) * kl_value / n_iters\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07e670d4-17d8-4e6b-9d54-beb06e9aa10d",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def true_KL_approx(drift_net, eps, x_0, x_1, predict_type='vector_field', t_eps=1e-3):\n",
    "    \n",
    "    def BB_sample(x_1, x_0, t):\n",
    "        mean = x_1 * t + x_0 * (1 - t)\n",
    "        std = torch.sqrt(eps * (1 - t) * t)\n",
    "\n",
    "        return mean + std * torch.randn_like(x_1)\n",
    "        \n",
    "    t = (torch.rand(x_1.shape[0]).to(device)) * (1 - 2 * t_eps) + t_eps\n",
    "\n",
    "    t = t.reshape([-1, 1, 1, 1])\n",
    "    \n",
    "    x_t = BB_sample(x_1, x_0, t)\n",
    "\n",
    "    if predict_type == 'vector_field':\n",
    "        drift = drift_net(x_t, x_0, t.squeeze())\n",
    "    elif predict_type == 'noise':\n",
    "        z_predict = drift_net(x_t, x_0, t.squeeze())\n",
    "        \n",
    "        x_1_predict = (x_t - (1 - t) * x_0 - torch.sqrt(t * (1 - t) * eps) * z_predict) / t\n",
    "        \n",
    "        drift = (x_1_predict - x_t) / (1 - t)\n",
    "        \n",
    "    elif predict_type == 'x_1':\n",
    "                \n",
    "        x_1_predict = drift_net(x_t, x_0, t.squeeze())\n",
    "        \n",
    "        drift = (x_1_predict - x_t) / (1 - t)\n",
    "    \n",
    "    first_part = 2 * ((drift)**2).sum([-1, -2]).mean()\n",
    "\n",
    "    second_part = 2 * ((x_1 - x_t) / (1 - t) * drift).sum([-1, -2]).mean()\n",
    "\n",
    "    return first_part - second_part\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1df07fb3-8df0-4b48-b318-77328d54af38",
   "metadata": {},
   "outputs": [],
   "source": [
    "# define a neural network\n",
    "\n",
    "import diffusers\n",
    "\n",
    "from diffusers import UNet2DModel\n",
    "\n",
    "class VectorNet(nn.Module):\n",
    "    def __init__(self, net, plan=False):\n",
    "        super().__init__()\n",
    "        self.net = net\n",
    "        self.plan = plan\n",
    "\n",
    "    def forward(self, x_0, x_t, t):\n",
    "        # print(x_0.shape)\n",
    "        \n",
    "        if self.plan:\n",
    "            \n",
    "            return self.net(torch.cat([x_0, x_t], dim=1), t, class_labels=torch.ones(x_0.shape[0]).to(device)).sample\n",
    "        else:\n",
    "            \n",
    "            return self.net(torch.cat([x_0, x_t], dim=1), t, class_labels=-torch.ones(x_0.shape[0]).to(device)).sample\n",
    "        \n",
    "        # if self.plan:\n",
    "            \n",
    "        #     return self.net(torch.cat([x_0, x_t], dim=-1).sample, torch.cat([t.unsqueeze(1), torch.ones(x.shape[0], 1).to(device)], dim=1) )\n",
    "        # else:\n",
    "            \n",
    "        #     return self.net(torch.cat([x_0, x_t], dim=-1).sample, torch.cat([t.unsqueeze(1), -torch.ones(x.shape[0], 1).to(device)], dim=1) )\n",
    "\n",
    "\n",
    "\n",
    "# class VectorNet(nn.Module):\n",
    "#     def __init__(self, net, plan=False):\n",
    "#         super().__init__()\n",
    "#         self.net = net\n",
    "#         self.plan = plan\n",
    "\n",
    "\n",
    "#     def forward(self, x_0, x_t, t):\n",
    "        \n",
    "#         return self.net(torch.cat([x_0, x_t], dim=1), t).sample\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cbea1e20-9eee-4c34-ae9a-bd2665e5480d",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "ema_decay = 0.999\n",
    "\n",
    "device = 'cuda'\n",
    "\n",
    "main_network = UNet2DModel(\n",
    "    sample_size=image_shape[0],  # the target image resolution\n",
    "    in_channels=2,  # the number of input channels, 3 for RGB images\n",
    "    out_channels=1,  # the number of output channels\n",
    "    layers_per_block=layers_per_block,  # how many ResNet layers to use per UNet block\n",
    "    block_out_channels=(n_features, n_features),  # More channels -> more parameters\n",
    "    down_block_types=(\n",
    "        \"DownBlock2D\",  # a regular ResNet downsampling block\n",
    "        \"AttnDownBlock2D\",  # a ResNet downsampling block with spatial self-attention\n",
    "    ),\n",
    "    up_block_types=(\n",
    "        \"AttnUpBlock2D\",  # a ResNet upsampling block with spatial self-attention\n",
    "        \"UpBlock2D\",  # a regular ResNet upsampling block\n",
    "    ),\n",
    "    class_embed_type='timestep'\n",
    ").to(device)\n",
    "\n",
    "drift_net_1 = VectorNet(main_network, plan=True)\n",
    "\n",
    "drift_net_2 = VectorNet(main_network, plan=False)\n",
    "\n",
    "from torch_ema import ExponentialMovingAverage\n",
    "\n",
    "ema_g_bm_1 = ExponentialMovingAverage(drift_net_1.parameters(), decay=ema_decay)\n",
    "ema_g_bm_2 = ema_g_bm_1\n",
    "\n",
    "\n",
    "# drift_net_1 = torch.compile(drift_net_1)\n",
    "# drift_net_2 = torch.compi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a920c90-fe92-40c4-ae01-5ecf05856a26",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# ema_decay = 0.999\n",
    "\n",
    "# device = 'cuda'\n",
    "\n",
    "# drift_net_1 = VectorNet(UNet2DModel(\n",
    "#     sample_size=image_shape[0],  # the target image resolution\n",
    "#     in_channels=2,  # the number of input channels, 3 for RGB images\n",
    "#     out_channels=1,  # the number of output channels\n",
    "#     layers_per_block=layers_per_block,  # how many ResNet layers to use per UNet block\n",
    "#     block_out_channels=(n_features, n_features),  # More channels -> more parameters\n",
    "#     down_block_types=(\n",
    "#         \"DownBlock2D\",  # a regular ResNet downsampling block\n",
    "#         \"AttnDownBlock2D\",  # a ResNet downsampling block with spatial self-attention\n",
    "#     ),\n",
    "#     up_block_types=(\n",
    "#         \"AttnUpBlock2D\",  # a ResNet upsampling block with spatial self-attention\n",
    "#         \"UpBlock2D\",  # a regular ResNet upsampling block\n",
    "#     ),\n",
    "#     class_embed_type='timestep'\n",
    "# )).to(device)\n",
    "\n",
    "# drift_net_2 = VectorNet(UNet2DModel(\n",
    "#     sample_size=image_shape[0],  # the target image resolution\n",
    "#     in_channels=2,  # the number of input channels, 3 for RGB images\n",
    "#     out_channels=1,  # the number of output channels\n",
    "#     layers_per_block=layers_per_block,  # how many ResNet layers to use per UNet block\n",
    "#     block_out_channels=(n_features, n_features),  # More channels -> more parameters\n",
    "#     down_block_types=(\n",
    "#         \"DownBlock2D\",  # a regular ResNet downsampling block\n",
    "#         \"AttnDownBlock2D\",  # a ResNet downsampling block with spatial self-attention\n",
    "#     ),\n",
    "#     up_block_types=(\n",
    "#         \"AttnUpBlock2D\",  # a ResNet upsampling block with spatial self-attention\n",
    "#         \"UpBlock2D\",  # a regular ResNet upsampling block\n",
    "#     ),\n",
    "#     class_embed_type='timestep'\n",
    "# )).to(device)\n",
    "\n",
    "# from torch_ema import ExponentialMovingAverage\n",
    "\n",
    "# # drift_net_1 = torch.compile(drift_net_1)\n",
    "# # drift_net_2 = torch.compile(drift_net_2)\n",
    "\n",
    "# # mod = MyModule()\n",
    "# # opt_mod = torch.compile(mod))\n",
    "\n",
    "# ema_g_bm_1 = ExponentialMovingAverage(drift_net_1.parameters(), decay=ema_decay)\n",
    "# ema_g_bm_2 = ExponentialMovingAverage(drift_net_2.parameters(), decay=ema_decay)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6570355a-5666-41e5-91a9-ab71e4931701",
   "metadata": {},
   "outputs": [],
   "source": [
    "# lr=1e-4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b5b14cf-a08b-4091-9de1-30c8291ecacc",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "bm_1 = BridgeMathcing(drift_net_1, eps=eps, predict_type=predict_type, loss_weight=loss_weight)\n",
    "\n",
    "opt = torch.optim.AdamW(main_network.parameters(), lr=lr)\n",
    "\n",
    "bm_2 = BridgeMathcing(drift_net_2, eps=eps, predict_type=predict_type, loss_weight=loss_weight) \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43c45ce8-05a1-40bc-888d-3a60b9fb6db0",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# bm_1 = BridgeMathcing(drift_net_1, eps=eps)\n",
    "\n",
    "# bm_2 = BridgeMathcing(drift_net_2, eps=eps) \n",
    "\n",
    "import os\n",
    "\n",
    "ckpt_path = os.path.join('ckpts', f'{shape}_{image_size}_finite_samples')\n",
    "\n",
    "if not os.path.exists(ckpt_path):\n",
    "    \n",
    "    os.mkdir(ckpt_path)\n",
    "\n",
    "ckpt_path = os.path.join(ckpt_path, f'MI_{mi_gt}')\n",
    "\n",
    "if not os.path.exists(ckpt_path):\n",
    "    \n",
    "    os.mkdir(ckpt_path)\n",
    "\n",
    "ckpt_path = os.path.join(ckpt_path, f'predict_{predict_type}_loss_weight_{loss_weight}_lr_{lr}_beta_dist_{t_alpha}')\n",
    "\n",
    "if not os.path.exists(ckpt_path):\n",
    "    \n",
    "    os.mkdir(ckpt_path)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09ef2850-8ac3-4795-8211-d2d744b66d7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "sum(p.numel() for p in drift_net_1.parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1771260-344e-46f1-9884-8115667af2da",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "n_samples = 4096\n",
    "\n",
    "x_batch, y_batch = MI_pictures_obj.rvs(n_samples)\n",
    "x_0_test_plan = torch.Tensor(x_batch).unsqueeze(1).to(device)\n",
    "x_1_test_plan = torch.Tensor(y_batch).unsqueeze(1).to(device)\n",
    "\n",
    "x_1_test_ind = x_1_test_plan[torch.randperm(x_1_test_plan.shape[0])]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99ccc720-1cca-40de-bf5b-f1196dbf22fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "if t_alpha is None or t_beta is None:\n",
    "\n",
    "    sample_fn = lambda batch_size: (torch.rand(batch_size).to(device)) * (1 - t_eps)\n",
    "\n",
    "else:\n",
    "    \n",
    "    print(f'Beta alpha {t_alpha} beta {t_beta}')\n",
    "    dist = torch.distributions.beta.Beta(t_alpha, t_beta)\n",
    "\n",
    "    sample_fn = lambda batch_size: dist.sample([batch_size]).to(device) * (1 - t_eps)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4359c885-1fae-4070-9523-745dd49fda43",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "t_eps = 1e-3\n",
    "\n",
    "for i in tqdm(range(1, max_iter + 1)):\n",
    "\n",
    "    x_batch, y_batch = next(iter(dataloader))\n",
    "    \n",
    "    x_samples = torch.Tensor(x_batch).unsqueeze(1).to(device)\n",
    "    y_samples = torch.Tensor(y_batch).unsqueeze(1).to(device)\n",
    "    \n",
    "    # t = (torch.rand(x_samples.shape[0]).to(device)) * (1 - t_eps)\n",
    "\n",
    "    t = sample_fn(x_samples.shape[0])\n",
    "    \n",
    "    # y_samples = y_samples[torch.randperm(y_samples.shape[0])]\n",
    "    \n",
    "    t = t.reshape([-1, 1, 1 ,1])\n",
    "    \n",
    "    loss_1 = bm_1.step(x_samples, y_samples, t)\n",
    "        \n",
    "    opt.zero_grad()\n",
    "    \n",
    "    loss_1.backward()\n",
    "    \n",
    "    opt.step()\n",
    "    ema_g_bm_1.update()\n",
    "    \n",
    "    y_samples_permuted = y_samples[torch.randperm(y_samples.shape[0])]\n",
    "    \n",
    "    t = sample_fn(x_samples.shape[0])\n",
    "    \n",
    "    t = t.reshape([-1, 1, 1 ,1])\n",
    "    \n",
    "    loss_2 = bm_2.step(x_samples, y_samples_permuted, t)\n",
    "    \n",
    "    opt.zero_grad()\n",
    "    \n",
    "    loss_2.backward()\n",
    "    \n",
    "    opt.step()\n",
    "    ema_g_bm_2.update()\n",
    "    \n",
    "    if wandb.run:\n",
    "        wandb.log({'Loss plan': loss_1, 'Loss ind': loss_2})\n",
    "    \n",
    "    # print(loss_1, loss_2)\n",
    "    \n",
    "    if i % eval_freq == 0:\n",
    "\n",
    "        with ema_g_bm_1.average_parameters():\n",
    "            with ema_g_bm_2.average_parameters():\n",
    "                mutual_entropy_est_ema = estimate_kl(drift_net_1, drift_net_2, eps, dataloader_val, nfe=10, t_eps=t_eps, posterior='partial', iters=20, predict_type=predict_type)\n",
    "        \n",
    "        mutual_entropy_est = estimate_kl(drift_net_1, drift_net_2, eps, dataloader_val, nfe=10, t_eps=t_eps, posterior='partial', iters=20, predict_type=predict_type)\n",
    "\n",
    "        with ema_g_bm_1.average_parameters():\n",
    "            with ema_g_bm_2.average_parameters():\n",
    "                mutual_entropy_est_ema_train = estimate_kl(drift_net_1, drift_net_2, eps, dataloader, nfe=10, t_eps=t_eps, posterior='partial', iters=20, predict_type=predict_type)\n",
    "        \n",
    "        \n",
    "        print(f'MI: EMA {mutual_entropy_est_ema} non EMA {mutual_entropy_est}')\n",
    "        \n",
    "        if wandb.run:\n",
    "            wandb.log({f'MI EMA': mutual_entropy_est_ema, 'MI non EMA': mutual_entropy_est, 'MI EMA train': mutual_entropy_est_ema_train})\n",
    "        \n",
    "        print(f'Iter {i} EMA MI {mutual_entropy_est_ema}')\n",
    "\n",
    "    if i % 10000 == 0:\n",
    "        \n",
    "        ckpt_path_paired = os.path.join(ckpt_path, f'BM_net_{n_features}_paired_{shape}_{image_size}_mi_{mi_gt}_eps_{eps}_predict_{predict_type}_{i}.pth')\n",
    "        ckpt_path_ind = os.path.join(ckpt_path, f'BM_net_{n_features}_ind_{shape}_{image_size}_mi_{mi_gt}_eps_{eps}_predict_{predict_type}_{i}.pth')\n",
    "        \n",
    "        torch.save(torch.Tensor(MI_pictures_obj._dist._dist.cov), os.path.join(ckpt_path, f'cov_MI_{mi_gt}.pth'))\n",
    "        \n",
    "        with ema_g_bm_1.average_parameters():\n",
    "            with ema_g_bm_2.average_parameters():\n",
    "                \n",
    "                torch.save(drift_net_1.state_dict(), ckpt_path_paired)\n",
    "                \n",
    "                torch.save(drift_net_2.state_dict(), ckpt_path_ind)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1033f9f2-897a-4ab2-a0b7-0846bed4648d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b85fe7b2-5d8a-45ee-9307-1a9cc0ea71f3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
