{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "# Set before any protobuf-related imports\n",
    "os.environ[\"PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION\"] = \"python\"\n",
    "\n",
    "import jax\n",
    "jax.config.update('jax_platforms', 'cpu')\n",
    "import bmi\n",
    "import numpy as np\n",
    "\n",
    "import torch\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "\n",
    "eps = 0.01\n",
    "\n",
    "ema_decay = 0.999\n",
    "\n",
    "n_epochs = 200\n",
    "\n",
    "# predict_type = 'x_1'\n",
    "\n",
    "predict_type = 'vector_field'\n",
    "\n",
    "seed = 42\n",
    "\n",
    "lr = 3e-4\n",
    "\n",
    "loss_weight = None\n",
    "\n",
    "dim = 160\n",
    "\n",
    "mi_gt = 80\n",
    "\n",
    "task_name = 'gaussian_hc'\n",
    "\n",
    "batch_size = 128\n",
    "\n",
    "max_steps = 50000\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from mutinfo.distributions.base import CorrelatedNormal, CorrelatedStudent, CorrelatedUniform,  SmoothedUniform, UniformlyQuantized\n",
    "\n",
    "if task_name == 'gaussian':\n",
    "    task = CorrelatedNormal(mi_gt, dim, dim, randomize_interactions=False,\n",
    "                                      shuffle_interactions=True)\n",
    "\n",
    "if task_name == 'gaussian_hc':\n",
    "    task = CorrelatedNormal(mi_gt, dim, dim, randomize_interactions=False,\n",
    "                                      shuffle_interactions=True)\n",
    "    \n",
    "    \n",
    "if task_name == 'corr_uniform':\n",
    "    task = CorrelatedUniform(mi_gt, dim, dim, randomize_interactions=False,\n",
    "                                  shuffle_interactions=True)\n",
    "\n",
    "if task_name == 'smoothed_uniform':\n",
    "    task = SmoothedUniform(mi_gt, dim, randomize_interactions=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "x_train, y_train = task.rvs(100000)\n",
    "x_test, y_test = task.rvs(10000)\n",
    "\n",
    "if task_name == 'gaussian_cubic':\n",
    "    x_train, y_train = x_train ** 3, y_train ** 3\n",
    "    x_test, y_test = x_test ** 3, y_test ** 3\n",
    "    \n",
    "if task_name == 'gaussian_hc':\n",
    "    x_train, y_train = x_train * np.sqrt(np.absolute(x_train)), y_train * np.sqrt(np.absolute(y_train))\n",
    "    x_test, y_test = x_test * np.sqrt(np.absolute(x_test)), y_test * np.sqrt(np.absolute(y_test))\n",
    "\n",
    "x_train = torch.Tensor(x_train)\n",
    "y_train = torch.Tensor(y_train)\n",
    "x_test = torch.Tensor(x_test)\n",
    "y_test = torch.Tensor(y_test)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "\n",
    "train_dataset = TensorDataset(x_train, y_train)\n",
    "test_dataset  = TensorDataset(x_test,  y_test)\n",
    "\n",
    "# 2. Create DataLoader for training\n",
    "train_l = DataLoader(\n",
    "    dataset=train_dataset,\n",
    "    batch_size=batch_size,      # adjust to your GPU/CPU memory\n",
    "    shuffle=True,        # shuffle for stochastic gradient descent\n",
    "    num_workers=4,       # adjust based on CPU cores\n",
    "    pin_memory=True,     # often speeds up host->GPU transfers\n",
    "    drop_last=True       # drop last smaller batch for consistent batch sizes\n",
    ")\n",
    "\n",
    "# 3. Create DataLoader for evaluation\n",
    "test_l = DataLoader(\n",
    "    dataset=test_dataset,\n",
    "    batch_size=batch_size,      # larger batch size often fine for evaluation\n",
    "    shuffle=False,       # no need to shuffle at test time\n",
    "    num_workers=2,\n",
    "    pin_memory=True,\n",
    "    drop_last=False\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "if dim <= 5:\n",
    "    time_embed = 64\n",
    "\n",
    "    n_filters = 64\n",
    "elif dim <=25:\n",
    "\n",
    "    time_embed = 128\n",
    "\n",
    "    n_filters = 128\n",
    "else:\n",
    "    \n",
    "    time_embed = 256\n",
    "\n",
    "    n_filters = 256\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "default_seed = 42\n",
    "\n",
    "def seed_basic(seed=default_seed):\n",
    "    import os\n",
    "    import random\n",
    "\n",
    "    random.seed(seed)\n",
    "    os.environ['PYTHONHASHSEED'] = str(seed)\n",
    "\n",
    "\n",
    "def seed_numpy(seed=default_seed):\n",
    "    import numpy\n",
    "\n",
    "    numpy.random.seed(seed)\n",
    "\n",
    "\n",
    "# tensorflow random seed \n",
    "def seed_tensorflow(seed=default_seed):\n",
    "    import tensorflow\n",
    "\n",
    "    tensorflow.random.set_seed(seed)\n",
    "\n",
    "\n",
    "# torch random seed\n",
    "def seed_torch(seed=default_seed):\n",
    "    import torch\n",
    "\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "\n",
    "\n",
    "def seed_everything(seed=default_seed,\n",
    "                    to_be_seeded: list[str] = [\"basic\", \"numpy\", \"tensorflow\", \"torch\"]):\n",
    "    for name in to_be_seeded:\n",
    "        try:\n",
    "            globals()[\"seed_\" + name](seed)\n",
    "        except Exception as exception:\n",
    "            print(exception)\n",
    "\n",
    "seed_everything(seed=seed)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import wandb\n",
    "\n",
    "from comet_ml import Experiment\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "\n",
    "wandb_config = {'dim': dim, 'eps': eps, 'task': task_name, 'ema_decay': ema_decay,\n",
    "                'n_epochs': n_epochs, 'gt': mi_gt, 'seed': seed,\n",
    "                  'n_filters': n_filters, 'lr': lr}\n",
    "\n",
    "experiment = Experiment(\n",
    "     project_name=\"Bridge_MI_Image_finite_samples\",        # ← same project name\n",
    " )\n",
    "experiment.set_name(task_name)\n",
    "experiment.log_parameters(wandb_config)\n",
    "\n",
    "\n",
    "experiment.log_metrics({'GT MI': mi_gt})\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import os\n",
    "\n",
    "\n",
    "if not os.path.exists(f'log_high_mi_{task_name}'):\n",
    "    \n",
    "    os.mkdir(f'log_high_mi_{task_name}')\n",
    "    \n",
    "log_path = os.path.join(f'log_high_mi_{task_name}', f'n_feat_{n_filters}_predict_{predict_type}')\n",
    "\n",
    "if not os.path.exists(log_path):\n",
    "    \n",
    "    os.mkdir(log_path)\n",
    "    \n",
    "log_path = os.path.join(log_path,  f'dim_{dim}_mi_gt_{mi_gt}')\n",
    "\n",
    "if not os.path.exists(log_path):\n",
    "    \n",
    "    os.mkdir(log_path)\n",
    "\n",
    "log_path = os.path.join(log_path, f'predict_{predict_type}_loss_weight_{loss_weight}_lr_{lr}_eps_{eps}_ema_{ema_decay}_n_epochs_{n_epochs}_n_filters_{n_filters}')\n",
    "\n",
    "if not os.path.exists(log_path):\n",
    "    \n",
    "    os.mkdir(log_path)\n",
    "\n",
    "log_path = os.path.join(log_path, f'seed_{seed}.txt')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import torch\n",
    "from torch import nn\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "from functools import partial\n",
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "\n",
    "# helpers functions\n",
    "\n",
    "\n",
    "def exists(x):\n",
    "    return x is not None\n",
    "\n",
    "\n",
    "def default(val, d):\n",
    "    if exists(val):\n",
    "        return val\n",
    "    return d() if callable(d) else d\n",
    "\n",
    "\n",
    "def identity(t, *args, **kwargs):\n",
    "    return t\n",
    "\n",
    "\n",
    "def num_to_groups(num, divisor):\n",
    "    groups = num // divisor\n",
    "    remainder = num % divisor\n",
    "    arr = [divisor] * groups\n",
    "    if remainder > 0:\n",
    "        arr.append(remainder)\n",
    "    return arr\n",
    "\n",
    "\n",
    "# small helper modules\n",
    "\n",
    "class Residual(nn.Module):\n",
    "    def __init__(self, fn):\n",
    "        super().__init__()\n",
    "        self.fn = fn\n",
    "\n",
    "    def forward(self, x, *args, **kwargs):\n",
    "        return self.fn(x, *args, **kwargs) + x\n",
    "\n",
    "\n",
    "def Upsample(dim, dim_out=None):\n",
    "    return nn.Sequential(\n",
    "        # nn.Upsample(scale_factor = 2, mode = 'nearest'),\n",
    "        nn.Linear(dim, default(dim_out, dim))\n",
    "    )\n",
    "\n",
    "\n",
    "def Downsample(dim, dim_out=None):\n",
    "    return nn.Linear(dim, default(dim_out, dim))\n",
    "\n",
    "\n",
    "class Residual(nn.Module):\n",
    "    def __init__(self, fn):\n",
    "        super().__init__()\n",
    "        self.fn = fn\n",
    "\n",
    "    def forward(self, x, *args, **kwargs):\n",
    "        return self.fn(x, *args, **kwargs) + x\n",
    "\n",
    "\n",
    "class Block(nn.Module):\n",
    "    def __init__(self, dim, dim_out, groups=8, shift_scale=True):\n",
    "        super().__init__()\n",
    "        # self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1)\n",
    "        self.proj = nn.Linear(dim, dim_out)\n",
    "        self.act = nn.SiLU()\n",
    "        # self.act = nn.Relu()\n",
    "        self.norm = nn.GroupNorm(groups, dim)\n",
    "        # self.norm = nn.BatchNorm1d( dim)\n",
    "        self.shift_scale = shift_scale\n",
    "\n",
    "    def forward(self, x, t=None):\n",
    "        x = self.norm(x)\n",
    "        x = self.act(x)\n",
    "        x = self.proj(x)\n",
    "\n",
    "        if exists(t):\n",
    "            if self.shift_scale:\n",
    "                scale, shift = t\n",
    "                x = x * (scale.squeeze() + 1) + shift.squeeze()\n",
    "            else:\n",
    "                x = x + t\n",
    "\n",
    "        return x\n",
    "\n",
    "\n",
    "class ResnetBlock(nn.Module):\n",
    "    def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=32, shift_scale=False):\n",
    "        super().__init__()\n",
    "        self.shift_scale = shift_scale\n",
    "        self.mlp = nn.Sequential(\n",
    "            nn.SiLU(),\n",
    "            # nn.Linear(time_emb_dim, dim_out * 2)\n",
    "            nn.Linear(time_emb_dim, dim_out*2 if shift_scale else dim_out)\n",
    "        ) if exists(time_emb_dim) else None\n",
    "\n",
    "        self.block1 = Block(dim, dim_out, groups=groups,\n",
    "                            shift_scale=shift_scale)\n",
    "        self.block2 = Block(dim_out, dim_out, groups=groups,\n",
    "                            shift_scale=shift_scale)\n",
    "        # self.res_conv = nn.Conv1d(dim, dim_out, 1) if dim != dim_out else nn.Identity()\n",
    "        self.lin_layer = nn.Linear(\n",
    "            dim, dim_out) if dim != dim_out else nn.Identity()\n",
    "\n",
    "    def forward(self, x, time_emb=None):\n",
    "\n",
    "        scale_shift = None\n",
    "        if exists(self.mlp) and exists(time_emb):\n",
    "\n",
    "            time_emb = self.mlp(time_emb)\n",
    "            scale_shift = time_emb\n",
    "\n",
    "        h = self.block1(x, t=scale_shift)\n",
    "\n",
    "        h = self.block2(h)\n",
    "\n",
    "        return h + self.lin_layer(x)\n",
    "\n",
    "class UnetMLP_simple(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        dim,\n",
    "        init_dim=128,\n",
    "        dim_mults=(1, 1),\n",
    "        resnet_block_groups=8,\n",
    "        time_dim=128,\n",
    "        nb_var=1,\n",
    "    ):\n",
    "        super().__init__()\n",
    "\n",
    "        # determine dimensions\n",
    "        self.nb_var = nb_var\n",
    "        init_dim = default(init_dim, dim)\n",
    "        if init_dim == None:\n",
    "            init_dim = dim * dim_mults[0]\n",
    "\n",
    "        dim_in = dim\n",
    "        dims = [init_dim, *map(lambda m: init_dim * m, dim_mults)]\n",
    "        in_out = list(zip(dims[:-1], dims[1:]))\n",
    "\n",
    "        block_klass = partial(ResnetBlock, groups=resnet_block_groups)\n",
    "\n",
    "        self.init_lin = nn.Linear(dim * 2, init_dim)\n",
    "\n",
    "        self.time_mlp = nn.Sequential(\n",
    "            nn.Linear(nb_var, time_dim),\n",
    "            nn.GELU(),\n",
    "            nn.Linear(time_dim, time_dim)\n",
    "        )\n",
    "\n",
    "        # layers\n",
    "\n",
    "        self.downs = nn.ModuleList([])\n",
    "        self.ups = nn.ModuleList([])\n",
    "\n",
    "        num_resolutions = len(in_out)\n",
    "\n",
    "        for ind, (dim_in, dim_out) in enumerate(in_out):\n",
    "            is_last = ind >= (num_resolutions - 1)\n",
    "\n",
    "            module = nn.ModuleList([block_klass(dim_in, dim_in, time_emb_dim=time_dim),\n",
    "                                    #        block_klass(dim_in, dim_in, time_emb_dim = time_dim)\n",
    "                                    ])\n",
    "\n",
    "            # module.append( Downsample(dim_in, dim_out) if not is_last else nn.Linear(dim_in, dim_out))\n",
    "            self.downs.append(module)\n",
    "\n",
    "        mid_dim = dims[-1]\n",
    "        joint_dim = mid_dim\n",
    "       # joint_dim = 24\n",
    "        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)\n",
    "\n",
    "        # self.mid_block2 = block_klass(joint_dim, mid_dim, time_emb_dim = time_dim)\n",
    "\n",
    "        for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):\n",
    "            is_last = ind == (len(in_out) - 1)\n",
    "            module = nn.ModuleList([block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),\n",
    "                                    #       block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim)\n",
    "                                    ])\n",
    "            self.ups.append(module)\n",
    "\n",
    "\n",
    "        self.out_dim = dim_in\n",
    "\n",
    "        self.final_res_block = block_klass(\n",
    "            init_dim * 2, init_dim, time_emb_dim=time_dim)\n",
    "\n",
    "        self.proj = nn.Linear(init_dim, dim)\n",
    "\n",
    "        self.proj.weight.data.fill_(0.0)\n",
    "        self.proj.bias.data.fill_(0.0)\n",
    "\n",
    "        self.final_lin = nn.Sequential(\n",
    "            nn.GroupNorm(resnet_block_groups, init_dim),\n",
    "            nn.SiLU(),\n",
    "            self.proj\n",
    "        )\n",
    "\n",
    "    def forward(self, x, t=None, std=None):\n",
    "        t = t.reshape(t.size(0), self.nb_var)\n",
    "\n",
    "        x = self.init_lin(x.float())\n",
    "        \n",
    "        r = x.clone()\n",
    "\n",
    "        t = self.time_mlp(t).squeeze()\n",
    "\n",
    "        h = []\n",
    "\n",
    "        for blocks in self.downs:\n",
    "\n",
    "            block1 = blocks[0]\n",
    "\n",
    "            x = block1(x, t)\n",
    "\n",
    "            h.append(x)\n",
    "            \n",
    "        for blocks in self.ups:\n",
    "\n",
    "            block1 = blocks[0]\n",
    "            x = torch.cat((x, h.pop()), dim=1)\n",
    "            x = block1(x, t)\n",
    "        \n",
    "\n",
    "        x = torch.cat((x, r), dim=1)\n",
    "\n",
    "        x = self.final_res_block(x, t)\n",
    "\n",
    "        if std != None:\n",
    "            return self.final_lin(x) / std\n",
    "        else:\n",
    "            return self.final_lin(x)\n",
    "\n",
    "\n",
    "var_list_0 = {\"x\" + str(i): dim for i in range(2)}\n",
    "\n",
    "var_list = list(var_list_0.keys())\n",
    "sizes = list(var_list_0.values())\n",
    "hidden_dim = 128\n",
    "\n",
    "score = UnetMLP_simple(dim=np.sum(sizes), init_dim=128, dim_mults=[],\n",
    "                            time_dim=hidden_dim, nb_var=len(var_list_0.keys()))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "device = 'cuda'\n",
    "\n",
    "class CondMLPTimeMINDE(nn.Module):\n",
    "    \n",
    "    def __init__(self, minde_mlp):\n",
    "        super().__init__()\n",
    "        self.net = minde_mlp\n",
    "        \n",
    "    def forward(self, x, x_0, time):\n",
    "        \n",
    "        return self.net(torch.cat([x, x_0], dim=-1), time)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "from torch import nn\n",
    "\n",
    "class BridgeMathcing(nn.Module):\n",
    "    def __init__(self, unet, eps, predict_type='vector_field', loss_weight=False):\n",
    "        super().__init__()\n",
    "\n",
    "        self.loss_weight = loss_weight\n",
    "\n",
    "        assert predict_type in ['vector_field', 'x_1', 'noise']\n",
    "        \n",
    "        self.predict_type = predict_type\n",
    "        self.vector_net = unet\n",
    "        \n",
    "        self.eps = eps\n",
    "        \n",
    "    def forward(self, x_0):\n",
    "        # solve forward ODE via Euler or torchdiffeq solver\n",
    "        x_t = x_0\n",
    "        \n",
    "        t_range = tqdm(torch.arange(0, 1, step=self.euler_dt))\n",
    "        \n",
    "        for t in t_range:\n",
    "            eps_noise = torch.randn_like(x_t, device=x_0.device)\n",
    "            x_t = x_t + self.vector_net(x_t, x_0, t.squeeze()) * self.euler_dt + torch.sqrt(self.euler_dt * self.eps) * eps_noise\n",
    "        \n",
    "        return x_t\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def sample(self, x_0, nfe=100, pbar=True):\n",
    "\n",
    "        euler_dt = 1. / nfe\n",
    "        \n",
    "        x_t = x_0\n",
    "        \n",
    "        if pbar:\n",
    "            t_range = tqdm(torch.arange(0, 1, step=euler_dt).to(x_0.device))\n",
    "        else:\n",
    "            t_range = torch.arange(0, 1, step=euler_dt).to(x_0.device)\n",
    "        \n",
    "        for t in t_range:\n",
    "            eps_noise = torch.randn_like(x_t, device=x_0.device)\n",
    "            \n",
    "            x_t = x_t + self.vector_net(x_t, x_0, t.squeeze()) * euler_dt + math.sqrt(euler_dt * self.eps) * eps_noise\n",
    "            \n",
    "        return x_t\n",
    "\n",
    "    def sample_x_t(self, x_0, x_1, t):\n",
    "\n",
    "        coef_0, coef_1 = 1 - t, t\n",
    "\n",
    "        std_t = torch.sqrt(t * (1 - t) * self.eps)\n",
    "\n",
    "        z = torch.randn_like(x_0, device=x_0.device)\n",
    "        \n",
    "        x_t = coef_1.reshape([-1, 1]) * x_1 + coef_0.reshape([-1, 1]) * x_0 + z * std_t.reshape([-1, 1])\n",
    "        return x_t, z\n",
    "    \n",
    "    def step(self, x_0, x_1, t):\n",
    "        t = t.reshape([-1, 1])\n",
    "        x_t, z = self.sample_x_t(x_0, x_1, t)\n",
    "        x_t_hat = self.vector_net(x_t, x_0, t.squeeze())\n",
    "        return self.loss(x_t_hat, x_1, x_0, x_t, t, z).mean()\n",
    "    \n",
    "    def loss(self, x_t_hat, x_1, x_0, x_t, t, z):\n",
    "\n",
    "        if self.predict_type == 'x_1':\n",
    "\n",
    "            vector_field = (x_t_hat - (x_1 - x_t) / (1 - t)).reshape([x_1.shape[0], -1])\n",
    "\n",
    "            if self.loss_weight:\n",
    "\n",
    "                with torch.no_grad():\n",
    "                    \n",
    "                    z_coef = torch.abs((vector_field**2 - (((x_1 - x_t) / (1 - t)).reshape([x_1.shape[0], -1]) * vector_field)).mean(-1))\n",
    "\n",
    "                    if self.loss_weight == 'mean':\n",
    "                        norm_const = z_coef.mean()\n",
    "                    elif self.loss_weight == 'sum':\n",
    "                        norm_const = z_coef.sum()\n",
    "                    else:\n",
    "                        raise RuntimeError('Unknown loss weight')\n",
    "                    \n",
    "                return torch.norm((x_t_hat - x_1).reshape([x_1.shape[0], -1]), dim=-1) * z_coef / norm_const\n",
    "            \n",
    "            return torch.norm((x_t_hat - x_1).reshape([x_1.shape[0], -1]), dim=-1)\n",
    "        \n",
    "        if self.predict_type == 'vector_field':\n",
    "\n",
    "            \n",
    "\n",
    "            # print(x_t_hat.shape, x_1.shape, vector_field.shape, ((x_1 - x_t) / (1 - t)).shape)\n",
    "            \n",
    "            if self.loss_weight:\n",
    "                \n",
    "                vector_field = (x_t_hat - (x_1 - x_t) / (1 - t)).reshape([x_1.shape[0], -1])\n",
    "\n",
    "                with torch.no_grad():\n",
    "                    \n",
    "                    z_coef = torch.abs((vector_field**2 - (((x_1 - x_t) / (1 - t)).reshape([x_1.shape[0], -1]) * vector_field)).mean(-1))\n",
    "\n",
    "                    if self.loss_weight == 'mean':\n",
    "                        norm_const = z_coef.mean()\n",
    "                    elif self.loss_weight == 'sum':\n",
    "                        norm_const = z_coef.sum()\n",
    "                    else:\n",
    "                        raise RuntimeError('Unknown loss weight')\n",
    "                    \n",
    "                return torch.norm((x_t_hat - ((x_1 - x_t) / (1 - t))).reshape([x_1.shape[0], -1]), dim=-1) * z_coef / norm_const\n",
    "\n",
    "            # print(x_t_hat.shape, vector_field.shape, ((x_1 - x_t) / (1 - t)).shape)\n",
    "            \n",
    "            return torch.norm((x_t_hat - ((x_1 - x_t) / (1 - t))).reshape([x_1.shape[0], -1]), dim=-1)\n",
    "\n",
    "        if self.predict_type == 'noise':\n",
    "            \n",
    "            return torch.norm((x_t_hat - z).reshape([x_1.shape[0], -1]), dim=-1)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "@torch.no_grad()\n",
    "def estimate_kl(drift_1, drift_2, eps, test_loader, t_eps=1e-3, posterior='partial', predict_type='vector_field', n_repeat=10):\n",
    "    \n",
    "    # inference\n",
    "    \n",
    "    def BB_sample(x_1, x_0, t):\n",
    "        mean = x_1 * t + x_0 * (1 - t)\n",
    "        std = torch.sqrt(eps * (1 - t) * t)\n",
    "\n",
    "        return mean + std * torch.randn_like(x_1)\n",
    "    \n",
    "    kl_value = 0\n",
    "    \n",
    "    fn = lambda x: x\n",
    "\n",
    "    n_elems = 0\n",
    "\n",
    "    for k in range(n_repeat):\n",
    "        \n",
    "        for item in iter(test_loader):\n",
    "            \n",
    "            samples_x = item[0].to(device)\n",
    "            samples_y = item[1].to(device)\n",
    "            \n",
    "            n_elems += samples_x.shape[0]\n",
    "    \n",
    "            x_batch = samples_x\n",
    "            y_batch = samples_y\n",
    "            \n",
    "            x_batch, y_batch = x_batch.to(device), y_batch.to(device)\n",
    "            \n",
    "            # print(x_batch.shape, y_batch.shape)\n",
    "            \n",
    "            t = (torch.rand(x_batch.shape[0]).to(device)) * (1 - t_eps)\n",
    "            \n",
    "            t = t.reshape([-1, 1])\n",
    "                    \n",
    "            x_t = BB_sample(y_batch, x_batch, t)\n",
    "    \n",
    "            if predict_type == 'vector_field':\n",
    "                \n",
    "                v_1, v_2 = drift_1(x_t, x_batch, t.squeeze()), drift_2(x_t, x_batch, t.squeeze())\n",
    "                kl_value += ((fn( v_1 ) - fn( v_2 ) )**2).sum([-1]).sum()\n",
    "                \n",
    "            elif predict_type == 'x_1':\n",
    "                \n",
    "                v_1, v_2 = drift_1(x_t, x_batch, t.squeeze()), drift_2(x_t, x_batch, t.squeeze())\n",
    "                \n",
    "                kl_value += ( ( (fn( v_1 ) - fn( v_2 ) ) / (1 - t) )**2).sum([-1]).sum()\n",
    "                \n",
    "                # kl_value += ( ( (fn( v_1 ) - fn( v_2 ) ) / (1 - t) )**2).sum([-1, -2]).sum()\n",
    "    \n",
    "            elif predict_type == 'noise':\n",
    "                \n",
    "                z_predict_1, z_predict_2 = drift_1(x_t, x_batch, t.squeeze()), drift_2(x_t, x_batch, t.squeeze())\n",
    "    \n",
    "                x_1_predict_1 = (x_t - (1 - t) * x_batch - torch.sqrt(t * (1 - t) * eps) * z_predict_1) / t\n",
    "                \n",
    "                x_1_predict_2 = (x_t - (1 - t) * x_batch - torch.sqrt(t * (1 - t) * eps) * z_predict_2) / t\n",
    "                \n",
    "                v_1, v_2 = (x_1_predict_1), (x_1_predict_2)\n",
    "                \n",
    "                kl_value += ( ( (fn( v_1 ) - fn( v_2 ) ) / (1 - t) )**2).sum([-1]).sum()\n",
    "            \n",
    "        # print( 1 / (2 * eps) * kl_value / n_elems )\n",
    "        \n",
    "    return 1 / (2 * eps) * kl_value / n_elems\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import torch.nn as nn\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "class MLP(nn.Module):\n",
    "    def __init__(self, input_dim, layer_widths=[100,100,2], activate_final = False, activation_fn=F.tanh):\n",
    "        super(MLP, self).__init__()\n",
    "        layers = []\n",
    "        prev_width = input_dim\n",
    "        for layer_width in layer_widths:\n",
    "            layers.append(torch.nn.Linear(prev_width, layer_width))\n",
    "            prev_width = layer_width\n",
    "        self.input_dim = input_dim\n",
    "        self.layer_widths = layer_widths\n",
    "        self.layers = nn.ModuleList(layers)\n",
    "        self.activate_final = activate_final\n",
    "        self.activation_fn = activation_fn\n",
    "        \n",
    "    def forward(self, x):\n",
    "        for i, layer in enumerate(self.layers[:-1]):\n",
    "            x = self.activation_fn(layer(x))\n",
    "        x = self.layers[-1](x)\n",
    "        if self.activate_final:\n",
    "            x = self.activation_fn(x)\n",
    "        return x\n",
    "\n",
    "import math\n",
    "class SinusoidalPosEmb(nn.Module):\n",
    "\n",
    "    def __init__(self, dim):\n",
    "        super().__init__()\n",
    "        self.dim = dim\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.squeeze()\n",
    "        half_dim = self.dim // 2\n",
    "        emb = math.log(10000) / (half_dim - 1)\n",
    "        emb = torch.exp(torch.arange(half_dim, device=x.device) * -emb)\n",
    "        emb = x.unsqueeze(-1) * emb.unsqueeze(0)\n",
    "\n",
    "        return torch.cat((emb.sin(), emb.cos()), dim=-1)\n",
    "    \n",
    "class MLPTime(nn.Module):\n",
    "    \n",
    "    def __init__(self, in_channels, out_channels, hidden_channels, time_embed_dim, num_hidden_blocks):\n",
    "        super().__init__()\n",
    "        self.to_time_embed = SinusoidalPosEmb(time_embed_dim)\n",
    "        self.in_layer = nn.Sequential(\n",
    "            nn.Linear(in_channels + time_embed_dim, hidden_channels),\n",
    "            nn.SiLU(),\n",
    "        )\n",
    "        \n",
    "        self.hidden_blocks = nn.ModuleList([\n",
    "            nn.Sequential(\n",
    "                nn.Linear(hidden_channels, hidden_channels),\n",
    "                nn.SiLU(),\n",
    "                nn.Linear(hidden_channels, hidden_channels),\n",
    "                nn.SiLU(),\n",
    "            ) for _ in range(num_hidden_blocks)\n",
    "        ])\n",
    "\n",
    "        self.out_layer = nn.Linear(hidden_channels, out_channels)\n",
    "        \n",
    "    def forward(self, x, time):\n",
    "        time_embed = self.to_time_embed(time)\n",
    "        x = self.in_layer(torch.cat([x, time_embed], dim=-1))\n",
    "        for hidden_block in self.hidden_blocks:\n",
    "            x = hidden_block(x) + x\n",
    "        x = self.out_layer(x)\n",
    "\n",
    "        return x\n",
    "\n",
    "\n",
    "class CondMLPTime(nn.Module):\n",
    "    \n",
    "    def __init__(self, mlp_time):\n",
    "        super().__init__()\n",
    "        self.net = mlp_time\n",
    "        \n",
    "    def forward(self, x, x_0, time):\n",
    "        \n",
    "        return self.net(torch.cat([x, x_0], dim=-1), time)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "\n",
    "class CondMLPTimeMINDE(nn.Module):\n",
    "    \n",
    "    def __init__(self, minde_mlp, plan=True):\n",
    "        super().__init__()\n",
    "        self.net = minde_mlp\n",
    "        self.plan = plan\n",
    "        \n",
    "    def forward(self, x, x_0, time):\n",
    "        # print(torch.ones(x.shape[0]))\n",
    "        if self.plan:\n",
    "            \n",
    "            return self.net(torch.cat([x, x_0], dim=-1), torch.cat([time.unsqueeze(1), torch.ones(x.shape[0], 1).to(device)], dim=1) )\n",
    "        else:\n",
    "            \n",
    "            return self.net(torch.cat([x, x_0], dim=-1), torch.cat([time.unsqueeze(1), -torch.ones(x.shape[0], 1).to(device)], dim=1) )\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "device = 'cuda'\n",
    "\n",
    "# drift_net_1 = CondMLPTime(MLPTime(dim * 2, dim, n_filters, 32, 2)).to(device)\n",
    "\n",
    "# drift_net_2 = CondMLPTime(MLPTime(dim * 2, dim, n_filters, 32, 2)).to(device)\n",
    "\n",
    "var_list_0 = {\"x\" + str(i): dim for i in range(2)}\n",
    "\n",
    "var_list = list(var_list_0.keys())\n",
    "sizes = list(var_list_0.values())\n",
    "\n",
    "net_1 = UnetMLP_simple(dim=dim, init_dim=n_filters, dim_mults=[],\n",
    "                            time_dim=time_embed, nb_var=2)\n",
    "\n",
    "drift_net_1 = CondMLPTimeMINDE(net_1, plan=True).to(device)\n",
    "    \n",
    "drift_net_2 = CondMLPTimeMINDE(net_1, plan=False).to(device)\n",
    "\n",
    "\n",
    "from torch_ema import ExponentialMovingAverage\n",
    "\n",
    "ema_g_bm_1 = ExponentialMovingAverage(drift_net_1.parameters(), decay=ema_decay)\n",
    "ema_g_bm_2 = ema_g_bm_1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "bm_1 = BridgeMathcing(drift_net_1, eps=eps, predict_type=predict_type, loss_weight=loss_weight)\n",
    "\n",
    "opt = torch.optim.Adam(net_1.parameters(), lr=lr)\n",
    "\n",
    "bm_2 = BridgeMathcing(drift_net_2, eps=eps, predict_type=predict_type, loss_weight=loss_weight) \n",
    "\n",
    "# opt = torch.optim.Adam(drift_net_2.parameters(), lr=lr)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import numpy as np\n",
    "\n",
    "sum([np.prod(p.size()) for p in drift_net_1.parameters()])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "t_alpha = 1\n",
    "t_beta = 1\n",
    "\n",
    "t_eps = 1e-3\n",
    "\n",
    "if t_alpha is None or t_beta is None:\n",
    "\n",
    "    sample_fn = lambda batch_size: (torch.rand(batch_size).to(device)) * (1 - t_eps)\n",
    "\n",
    "else:\n",
    "    \n",
    "    print(f'Beta alpha {t_alpha} beta {t_beta}')\n",
    "    dist = torch.distributions.beta.Beta(t_alpha, t_beta)\n",
    "\n",
    "    sample_fn = lambda batch_size: dist.sample([batch_size]).to(device) * (1 - t_eps)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "highest_mi = 0\n",
    "since_best_mi = 0\n",
    "\n",
    "def early_stopping(mi):\n",
    "\n",
    "    global highest_mi\n",
    "    global since_best_mi\n",
    "    \n",
    "    if mi > highest_mi:\n",
    "        since_best_mi = 0\n",
    "        highest_mi = mi\n",
    "        print('Update best MI')\n",
    "    else:\n",
    "        since_best_mi += 1\n",
    "\n",
    "    print(f'since_best_mi {since_best_mi}')\n",
    "\n",
    "    if since_best_mi > 1:\n",
    "        return True\n",
    "    else:\n",
    "        return False\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "n_steps = 0\n",
    "\n",
    "x_start = []\n",
    "\n",
    "for item in iter(train_l):\n",
    "    \n",
    "    x_start.append(item[0].to(device))\n",
    "\n",
    "x_start = torch.cat(x_start, dim=0)\n",
    "\n",
    "import time\n",
    "start = time.time()\n",
    "\n",
    "\n",
    "for j in range(n_epochs):\n",
    "\n",
    "    \n",
    "    end = time.time()\n",
    "    print(f\"Elapsed time for n steps {n_steps}: {end - start:.4f} seconds\")\n",
    "\n",
    "    if task_name == 'gaussian_dre':\n",
    "        if n_steps > max_steps:\n",
    "            break\n",
    "            \n",
    "    for i, item in enumerate(train_l):\n",
    "        \n",
    "        if task_name == 'gaussian_dre':\n",
    "            x_y = sample_gaussian(batch_size, cov_matrix)\n",
    "            x_samples = torch.Tensor(x_y[:, 1::2]).to(device)\n",
    "            y_samples = torch.Tensor(x_y[:, 0::2]).to(device)\n",
    "        else:\n",
    "            x_samples = item[0].to(device)\n",
    "            y_samples = item[1].to(device)\n",
    "        \n",
    "        t = sample_fn(x_samples.shape[0])\n",
    "        \n",
    "        # t = (torch.rand(x_samples.shape[0]).to(device)) * (1 - t_eps)\n",
    "    \n",
    "        loss_1 = bm_1.step(x_samples, y_samples, t)\n",
    "        \n",
    "        opt.zero_grad()\n",
    "    \n",
    "        loss_1.backward()\n",
    "        \n",
    "        # torch.nn.utils.clip_grad_norm_(drift_net_1.parameters(), 1.)\n",
    "    \n",
    "        opt.step()\n",
    "        \n",
    "        ema_g_bm_1.update()\n",
    "        \n",
    "        y_samples_permuted = y_samples[torch.randperm(y_samples.shape[0])]\n",
    "    \n",
    "        t = sample_fn(x_samples.shape[0])\n",
    "        \n",
    "        \n",
    "        loss_2 = bm_2.step(x_samples, y_samples_permuted, t)\n",
    "        \n",
    "        opt.zero_grad()\n",
    "    \n",
    "        loss_2.backward()\n",
    "        \n",
    "    \n",
    "        opt.step()\n",
    "        ema_g_bm_2.update()\n",
    "\n",
    "        n_steps += 1\n",
    "        \n",
    "        experiment.log_metrics({'Loss plan': loss_1, 'Loss ind': loss_2})\n",
    "    \n",
    "    with ema_g_bm_1.average_parameters():\n",
    "        with ema_g_bm_2.average_parameters():\n",
    "            mutual_entropy_est_ema = estimate_kl(drift_net_1, drift_net_2, eps, test_l, t_eps=t_eps, posterior='partial', predict_type=predict_type, n_repeat=10)\n",
    "    \n",
    "        \n",
    "    \n",
    "    with ema_g_bm_1.average_parameters():\n",
    "        with ema_g_bm_2.average_parameters():\n",
    "            mutual_entropy_est_ema_train = estimate_kl(drift_net_1, drift_net_2, eps, train_l, t_eps=t_eps, posterior='partial', predict_type=predict_type, n_repeat=1)\n",
    "    \n",
    "    mutual_entropy_est = estimate_kl(drift_net_1, drift_net_2, eps, test_l, t_eps=t_eps, posterior='partial', predict_type=predict_type, n_repeat=10)\n",
    "    \n",
    "    print(f'Epoch {j} MI: {mutual_entropy_est} EMA MI {mutual_entropy_est_ema} EMA MI train {mutual_entropy_est_ema_train}')\n",
    "    \n",
    "    experiment.log_metrics({f'MI EMA': mutual_entropy_est_ema, 'MI non EMA': mutual_entropy_est, 'MI EMA train': mutual_entropy_est_ema_train})\n",
    "    \n",
    "    # wandb.log({'MI': mutual_entropy_est, 'MI_est': mutual_entropy_est_ema, 'MI EMA train': mutual_entropy_est_ema_train})\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base_env",
   "language": "python",
   "name": "base_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
