{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "992a452f-308f-4b07-b516-3697265727fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "# from lmi import lmi\n",
    "import torch\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import rcParams\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "import time\n",
    "import os\n",
    "\n",
    "os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"]=\"false\"\n",
    "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]=\".10\"\n",
    "os.environ[\"XLA_PYTHON_CLIENT_ALLOCATOR\"]=\"platform\"\n",
    "\n",
    "from bmi.estimators import MINEEstimator as MINE\n",
    "from bmi.estimators import InfoNCEEstimator as InfoNCE\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b50b07b1-30ba-42aa-893a-9becefa1334c",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "\n",
    "device = 'cuda'\n",
    "\n",
    "n_filters = 256\n",
    "\n",
    "time_embed = 256\n",
    "ema_decay = 0.999\n",
    "eps = 1\n",
    "predict_type = 'vector_field'\n",
    "loss_weight = False\n",
    "lr = 3e-4\n",
    "\n",
    "shuffle_coef = 0.1\n",
    "\n",
    "n_epochs = 500\n",
    "\n",
    "batch_size = 64\n",
    "\n",
    "wd_reg = 0.001\n",
    "\n",
    "seed = 42\n",
    "\n",
    "dropout_prob = 0.2\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "54ba715e",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "torch.manual_seed(seed)\n",
    "np.random.seed(seed)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33a0119d-f725-491d-9de4-7dd5480b0716",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "human = pd.read_csv('ProtT5_embeddings/Human_preprocessed.csv', index_col=0)\n",
    "human = human.sample(frac=1)\n",
    "print(\"%d sequences in Human proteome\" % len(human))\n",
    "\n",
    "\n",
    "athali = pd.read_csv('ProtT5_embeddings/Athaliana_preprocessed.csv', index_col=0)\n",
    "athali = athali.sample(frac=1)\n",
    "print(\"%d sequences in A. thaliana proteome\" % len(athali))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da477716-bf4c-49c3-bf0b-a24e4fda7170",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "X_data, Y_data = [np.zeros(1024)], [np.zeros(1024)]\n",
    "labels = []\n",
    "\n",
    "specs = [human.to_numpy(), athali.iloc[:-1].to_numpy()] # drop last E Coli for even parity lol\n",
    "\n",
    "s_names = ['E. Coli', 'A. Thaliana']\n",
    "\n",
    "N_pairs = min([len(x) for x in specs])//2\n",
    "\n",
    "for i in range(2):\n",
    "    X_data = np.vstack((X_data, specs[i][:N_pairs]))\n",
    "    Y_data = np.vstack((Y_data, specs[i][N_pairs:2*N_pairs]))\n",
    "    labels += [s_names[i]]*(N_pairs)\n",
    "\n",
    "X_data = np.array([x.flatten() for x in X_data[1:]])\n",
    "Y_data = np.array([x.flatten() for x in Y_data[1:]])\n",
    "labels = np.array(labels)\n",
    "\n",
    "assert len(X_data) == len(labels)\n",
    "\n",
    "def generate_dataset(percent_shuffle, N_samples=2 * 10**4):\n",
    "    \"\"\"\n",
    "    shuffle pairs in a subset of the data\n",
    "    \"\"\"\n",
    "    \n",
    "    inds = np.arange(len(X_data))\n",
    "    np.random.shuffle(inds)\n",
    "    \n",
    "    Xs = X_data[inds[:N_samples]].copy()\n",
    "    Ys = Y_data[inds[:N_samples]].copy()\n",
    "    Lx = labels[inds[:N_samples]].copy()\n",
    "    Ly = labels[inds[:N_samples]].copy()\n",
    "    \n",
    "    rows_to_shuffle = int(percent_shuffle*len(Xs))\n",
    "    \n",
    "    # get identical shuffles\n",
    "    # thanks to https://stackoverflow.com/questions/4601373/\n",
    "    # better-way-to-shuffle-two-numpy-arrays-in-unison\n",
    "    rng_state = np.random.get_state()\n",
    "    np.random.shuffle(Xs[:rows_to_shuffle])\n",
    "    np.random.set_state(rng_state)\n",
    "    np.random.shuffle(Lx[:rows_to_shuffle])\n",
    "    \n",
    "    return Xs, Ys, Lx, Ly\n",
    "\n",
    "Xs, Ys, Lx, Ly = generate_dataset(0.5)\n",
    "print(\"Sanity check (percent identical labels, should be ~0.75): %f\"\n",
    "      %(sum(Lx == Ly)/len(Lx)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9e75e5a-fba7-42ff-bc68-81d4f816f76d",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "import torch\n",
    "from torch import nn\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "from functools import partial\n",
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "\n",
    "# helpers functions\n",
    "\n",
    "\n",
    "def exists(x):\n",
    "    return x is not None\n",
    "\n",
    "\n",
    "def default(val, d):\n",
    "    if exists(val):\n",
    "        return val\n",
    "    return d() if callable(d) else d\n",
    "\n",
    "\n",
    "def identity(t, *args, **kwargs):\n",
    "    return t\n",
    "\n",
    "\n",
    "def num_to_groups(num, divisor):\n",
    "    groups = num // divisor\n",
    "    remainder = num % divisor\n",
    "    arr = [divisor] * groups\n",
    "    if remainder > 0:\n",
    "        arr.append(remainder)\n",
    "    return arr\n",
    "\n",
    "\n",
    "# small helper modules\n",
    "\n",
    "class Residual(nn.Module):\n",
    "    def __init__(self, fn):\n",
    "        super().__init__()\n",
    "        self.fn = fn\n",
    "\n",
    "    def forward(self, x, *args, **kwargs):\n",
    "        return self.fn(x, *args, **kwargs) + x\n",
    "\n",
    "\n",
    "def Upsample(dim, dim_out=None):\n",
    "    return nn.Sequential(\n",
    "        # nn.Upsample(scale_factor = 2, mode = 'nearest'),\n",
    "        nn.Linear(dim, default(dim_out, dim))\n",
    "    )\n",
    "\n",
    "\n",
    "def Downsample(dim, dim_out=None):\n",
    "    return nn.Linear(dim, default(dim_out, dim))\n",
    "\n",
    "\n",
    "class Residual(nn.Module):\n",
    "    def __init__(self, fn):\n",
    "        super().__init__()\n",
    "        self.fn = fn\n",
    "\n",
    "    def forward(self, x, *args, **kwargs):\n",
    "        return self.fn(x, *args, **kwargs) + x\n",
    "\n",
    "\n",
    "class Block(nn.Module):\n",
    "    def __init__(self, dim, dim_out, groups=8, shift_scale=True, dropout_prob=0.1):\n",
    "        super().__init__()\n",
    "        # self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1)\n",
    "        self.proj = nn.Linear(dim, dim_out)\n",
    "        self.act = nn.SiLU()\n",
    "        # self.act = nn.Relu()\n",
    "        self.norm = nn.GroupNorm(groups, dim)\n",
    "        self.dropout = nn.Dropout(dropout_prob)\n",
    "        # self.norm = nn.BatchNorm1d( dim)\n",
    "        self.shift_scale = shift_scale\n",
    "\n",
    "    def forward(self, x, t=None):\n",
    "        x = self.norm(x)\n",
    "        x = self.act(x)\n",
    "        x = self.dropout(x)\n",
    "        x = self.proj(x)\n",
    "\n",
    "        if exists(t):\n",
    "            if self.shift_scale:\n",
    "                scale, shift = t\n",
    "                x = x * (scale.squeeze() + 1) + shift.squeeze()\n",
    "            else:\n",
    "                x = x + t\n",
    "\n",
    "        return x\n",
    "\n",
    "\n",
    "class ResnetBlock(nn.Module):\n",
    "    def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=32, shift_scale=False, dropout_prob=0.1):\n",
    "        super().__init__()\n",
    "        self.shift_scale = shift_scale\n",
    "        self.mlp = nn.Sequential(\n",
    "            nn.SiLU(),\n",
    "            # nn.Linear(time_emb_dim, dim_out * 2)\n",
    "            nn.Linear(time_emb_dim, dim_out*2 if shift_scale else dim_out)\n",
    "        ) if exists(time_emb_dim) else None\n",
    "\n",
    "        self.block1 = Block(dim, dim_out, groups=groups,\n",
    "                            shift_scale=shift_scale, dropout_prob=dropout_prob)\n",
    "        self.block2 = Block(dim_out, dim_out, groups=groups,\n",
    "                            shift_scale=shift_scale, dropout_prob=dropout_prob)\n",
    "        # self.res_conv = nn.Conv1d(dim, dim_out, 1) if dim != dim_out else nn.Identity()\n",
    "        self.lin_layer = nn.Linear(\n",
    "            dim, dim_out) if dim != dim_out else nn.Identity()\n",
    "\n",
    "    def forward(self, x, time_emb=None):\n",
    "\n",
    "        scale_shift = None\n",
    "        if exists(self.mlp) and exists(time_emb):\n",
    "\n",
    "            time_emb = self.mlp(time_emb)\n",
    "            scale_shift = time_emb\n",
    "\n",
    "        h = self.block1(x, t=scale_shift)\n",
    "\n",
    "        h = self.block2(h)\n",
    "\n",
    "        return h + self.lin_layer(x)\n",
    "\n",
    "\n",
    "class UnetMLP_simple(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        dim,\n",
    "        init_dim=128,\n",
    "        dim_mults=(1, 1),\n",
    "        resnet_block_groups=8,\n",
    "        time_dim=128,\n",
    "        nb_var=1,\n",
    "        dropout_prob=0.1\n",
    "    ):\n",
    "        super().__init__()\n",
    "\n",
    "        # determine dimensions\n",
    "        self.nb_var = nb_var\n",
    "        init_dim = default(init_dim, dim)\n",
    "        if init_dim == None:\n",
    "            init_dim = dim * dim_mults[0]\n",
    "\n",
    "        dim_in = dim\n",
    "        dims = [init_dim, *map(lambda m: init_dim * m, dim_mults)]\n",
    "        in_out = list(zip(dims[:-1], dims[1:]))\n",
    "\n",
    "        block_klass = partial(ResnetBlock, groups=resnet_block_groups, dropout_prob=dropout_prob)\n",
    "\n",
    "        self.init_lin = nn.Linear(dim * 2, init_dim)\n",
    "\n",
    "        self.time_mlp = nn.Sequential(\n",
    "            nn.Linear(nb_var, time_dim),\n",
    "            nn.GELU(),\n",
    "            nn.Linear(time_dim, time_dim)\n",
    "        )\n",
    "\n",
    "        # layers\n",
    "\n",
    "        self.downs = nn.ModuleList([])\n",
    "        self.ups = nn.ModuleList([])\n",
    "\n",
    "        num_resolutions = len(in_out)\n",
    "\n",
    "        for ind, (dim_in, dim_out) in enumerate(in_out):\n",
    "            is_last = ind >= (num_resolutions - 1)\n",
    "\n",
    "            module = nn.ModuleList([block_klass(dim_in, dim_in, time_emb_dim=time_dim),\n",
    "                                    #        block_klass(dim_in, dim_in, time_emb_dim = time_dim)\n",
    "                                    ])\n",
    "\n",
    "            # module.append( Downsample(dim_in, dim_out) if not is_last else nn.Linear(dim_in, dim_out))\n",
    "            self.downs.append(module)\n",
    "\n",
    "        mid_dim = dims[-1]\n",
    "        joint_dim = mid_dim\n",
    "       # joint_dim = 24\n",
    "        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)\n",
    "\n",
    "        # self.mid_block2 = block_klass(joint_dim, mid_dim, time_emb_dim = time_dim)\n",
    "\n",
    "        for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):\n",
    "            is_last = ind == (len(in_out) - 1)\n",
    "            module = nn.ModuleList([block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),\n",
    "                                    #       block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim)\n",
    "                                    ])\n",
    "            self.ups.append(module)\n",
    "\n",
    "\n",
    "        self.out_dim = dim_in\n",
    "\n",
    "        self.final_res_block = block_klass(\n",
    "            init_dim * 2, init_dim, time_emb_dim=time_dim)\n",
    "\n",
    "        self.proj = nn.Linear(init_dim, dim)\n",
    "\n",
    "        self.proj.weight.data.fill_(0.0)\n",
    "        self.proj.bias.data.fill_(0.0)\n",
    "\n",
    "        self.final_lin = nn.Sequential(\n",
    "            nn.GroupNorm(resnet_block_groups, init_dim),\n",
    "            nn.SiLU(),\n",
    "            self.proj\n",
    "        )\n",
    "\n",
    "    def forward(self, x, t=None, std=None):\n",
    "        t = t.reshape(t.size(0), self.nb_var)\n",
    "\n",
    "        x = self.init_lin(x.float())\n",
    "        \n",
    "        r = x.clone()\n",
    "\n",
    "        t = self.time_mlp(t).squeeze()\n",
    "\n",
    "        h = []\n",
    "\n",
    "        for blocks in self.downs:\n",
    "\n",
    "            block1 = blocks[0]\n",
    "\n",
    "            x = block1(x, t)\n",
    "\n",
    "            h.append(x)\n",
    "            \n",
    "        for blocks in self.ups:\n",
    "\n",
    "            block1 = blocks[0]\n",
    "            x = torch.cat((x, h.pop()), dim=1)\n",
    "            x = block1(x, t)\n",
    "        \n",
    "\n",
    "        x = torch.cat((x, r), dim=1)\n",
    "\n",
    "        x = self.final_res_block(x, t)\n",
    "\n",
    "        if std != None:\n",
    "            return self.final_lin(x) / std\n",
    "        else:\n",
    "            return self.final_lin(x)\n",
    "\n",
    "dim = Xs.shape[1] \n",
    "\n",
    "var_list_0 = {\"x\" + str(i): dim for i in range(2)}\n",
    "\n",
    "var_list = list(var_list_0.keys())\n",
    "sizes = list(var_list_0.values())\n",
    "hidden_dim = 128\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58dfa056-ef4b-41e7-9f19-3b1d5dbec16c",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "device = 'cuda'\n",
    "\n",
    "class CondMLPTimeMINDE(nn.Module):\n",
    "    \n",
    "    def __init__(self, minde_mlp, plan=True):\n",
    "        super().__init__()\n",
    "        self.net = minde_mlp\n",
    "        self.plan = plan\n",
    "        \n",
    "    def forward(self, x, x_0, time):\n",
    "        # print(torch.ones(x.shape[0]))\n",
    "        if self.plan:\n",
    "            \n",
    "            return self.net(torch.cat([x, x_0], dim=-1), torch.cat([time.unsqueeze(1), torch.ones(x.shape[0], 1).to(device)], dim=1) )\n",
    "        else:\n",
    "            \n",
    "            return self.net(torch.cat([x, x_0], dim=-1), torch.cat([time.unsqueeze(1), -torch.ones(x.shape[0], 1).to(device)], dim=1) )\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "862a06c4-42c5-4f2b-b384-541fc2f520a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "from torch import nn\n",
    "\n",
    "class BridgeMathcing(nn.Module):\n",
    "    def __init__(self, unet, eps, predict_type='vector_field', loss_weight=False):\n",
    "        super().__init__()\n",
    "\n",
    "        self.loss_weight = loss_weight\n",
    "\n",
    "        assert predict_type in ['vector_field', 'x_1', 'noise']\n",
    "        \n",
    "        self.predict_type = predict_type\n",
    "        self.vector_net = unet\n",
    "        \n",
    "        self.eps = eps\n",
    "        \n",
    "    def forward(self, x_0):\n",
    "        # solve forward ODE via Euler or torchdiffeq solver\n",
    "        x_t = x_0\n",
    "        \n",
    "        t_range = tqdm(torch.arange(0, 1, step=self.euler_dt))\n",
    "        \n",
    "        for t in t_range:\n",
    "            eps_noise = torch.randn_like(x_t, device=x_0.device)\n",
    "            x_t = x_t + self.vector_net(x_t, x_0, t.squeeze()) * self.euler_dt + torch.sqrt(self.euler_dt * self.eps) * eps_noise\n",
    "        \n",
    "        return x_t\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def sample(self, x_0, nfe=100, pbar=True):\n",
    "\n",
    "        euler_dt = 1. / nfe\n",
    "        \n",
    "        x_t = x_0\n",
    "        \n",
    "        if pbar:\n",
    "            t_range = tqdm(torch.arange(0, 1, step=euler_dt).to(x_0.device))\n",
    "        else:\n",
    "            t_range = torch.arange(0, 1, step=euler_dt).to(x_0.device)\n",
    "        \n",
    "        for t in t_range:\n",
    "            eps_noise = torch.randn_like(x_t, device=x_0.device)\n",
    "            \n",
    "            x_t = x_t + self.vector_net(x_t, x_0, t.squeeze()) * euler_dt + math.sqrt(euler_dt * self.eps) * eps_noise\n",
    "            \n",
    "        return x_t\n",
    "\n",
    "    def sample_x_t(self, x_0, x_1, t):\n",
    "\n",
    "        coef_0, coef_1 = 1 - t, t\n",
    "\n",
    "        std_t = torch.sqrt(t * (1 - t) * self.eps)\n",
    "\n",
    "        z = torch.randn_like(x_0, device=x_0.device)\n",
    "        \n",
    "        x_t = coef_1.reshape([-1, 1]) * x_1 + coef_0.reshape([-1, 1]) * x_0 + z * std_t.reshape([-1, 1])\n",
    "        return x_t, z\n",
    "    \n",
    "    def step(self, x_0, x_1, t):\n",
    "        t = t.reshape([-1, 1])\n",
    "        x_t, z = self.sample_x_t(x_0, x_1, t)\n",
    "        x_t_hat = self.vector_net(x_t, x_0, t.squeeze())\n",
    "        return self.loss(x_t_hat, x_1, x_0, x_t, t, z).mean()\n",
    "    \n",
    "    def loss(self, x_t_hat, x_1, x_0, x_t, t, z):\n",
    "\n",
    "        if self.predict_type == 'x_1':\n",
    "\n",
    "            vector_field = (x_t_hat - (x_1 - x_t) / (1 - t)).reshape([x_1.shape[0], -1])\n",
    "\n",
    "            if self.loss_weight:\n",
    "\n",
    "                with torch.no_grad():\n",
    "                    \n",
    "                    z_coef = torch.abs((vector_field**2 - (((x_1 - x_t) / (1 - t)).reshape([x_1.shape[0], -1]) * vector_field)).mean(-1))\n",
    "\n",
    "                    if self.loss_weight == 'mean':\n",
    "                        norm_const = z_coef.mean()\n",
    "                    elif self.loss_weight == 'sum':\n",
    "                        norm_const = z_coef.sum()\n",
    "                    else:\n",
    "                        raise RuntimeError('Unknown loss weight')\n",
    "                    \n",
    "                return torch.norm((x_t_hat - x_1).reshape([x_1.shape[0], -1]), dim=-1) * z_coef / norm_const\n",
    "            \n",
    "            return torch.norm((x_t_hat - x_1).reshape([x_1.shape[0], -1]), dim=-1)\n",
    "        \n",
    "        if self.predict_type == 'vector_field':\n",
    "\n",
    "            \n",
    "\n",
    "            # print(x_t_hat.shape, x_1.shape, vector_field.shape, ((x_1 - x_t) / (1 - t)).shape)\n",
    "            \n",
    "            if self.loss_weight:\n",
    "                \n",
    "                vector_field = (x_t_hat - (x_1 - x_t) / (1 - t)).reshape([x_1.shape[0], -1])\n",
    "\n",
    "                with torch.no_grad():\n",
    "                    \n",
    "                    z_coef = torch.abs((vector_field**2 - (((x_1 - x_t) / (1 - t)).reshape([x_1.shape[0], -1]) * vector_field)).mean(-1))\n",
    "\n",
    "                    if self.loss_weight == 'mean':\n",
    "                        norm_const = z_coef.mean()\n",
    "                    elif self.loss_weight == 'sum':\n",
    "                        norm_const = z_coef.sum()\n",
    "                    else:\n",
    "                        raise RuntimeError('Unknown loss weight')\n",
    "                    \n",
    "                return torch.norm((x_t_hat - ((x_1 - x_t) / (1 - t))).reshape([x_1.shape[0], -1]), dim=-1) * z_coef / norm_const\n",
    "\n",
    "            # print(x_t_hat.shape, vector_field.shape, ((x_1 - x_t) / (1 - t)).shape)\n",
    "            \n",
    "            return torch.norm((x_t_hat - ((x_1 - x_t) / (1 - t))).reshape([x_1.shape[0], -1]), dim=-1)\n",
    "\n",
    "        if self.predict_type == 'noise':\n",
    "            \n",
    "            return torch.norm((x_t_hat - z).reshape([x_1.shape[0], -1]), dim=-1)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "8171d3a7-da98-4907-980c-1ab7f236b445",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "@torch.no_grad()\n",
    "def estimate_kl(drift_1, drift_2, eps, test_loader, t_eps=1e-3, posterior='partial', predict_type='vector_field', n_repeat=10):\n",
    "    \n",
    "    # inference\n",
    "    \n",
    "    def BB_sample(x_1, x_0, t):\n",
    "        mean = x_1 * t + x_0 * (1 - t)\n",
    "        std = torch.sqrt(eps * (1 - t) * t)\n",
    "\n",
    "        return mean + std * torch.randn_like(x_1)\n",
    "    \n",
    "    kl_value = 0\n",
    "    \n",
    "    fn = lambda x: x\n",
    "\n",
    "    n_elems = 0\n",
    "\n",
    "    for k in range(n_repeat):\n",
    "        \n",
    "        for item in iter(test_loader):\n",
    "            \n",
    "            samples_x = item[0].to(device)\n",
    "            samples_y = item[1].to(device)\n",
    "            \n",
    "            n_elems += samples_x.shape[0]\n",
    "    \n",
    "            x_batch = samples_x\n",
    "            y_batch = samples_y\n",
    "            \n",
    "            x_batch, y_batch = x_batch.to(device), y_batch.to(device)\n",
    "            \n",
    "            # print(x_batch.shape, y_batch.shape)\n",
    "            \n",
    "            t = (torch.rand(x_batch.shape[0]).to(device)) * (1 - t_eps)\n",
    "            \n",
    "            t = t.reshape([-1, 1])\n",
    "                    \n",
    "            x_t = BB_sample(y_batch, x_batch, t)\n",
    "    \n",
    "            if predict_type == 'vector_field':\n",
    "                \n",
    "                v_1, v_2 = drift_1(x_t, x_batch, t.squeeze()), drift_2(x_t, x_batch, t.squeeze())\n",
    "                kl_value += ((fn( v_1 ) - fn( v_2 ) )**2).sum([-1]).sum()\n",
    "                \n",
    "            elif predict_type == 'x_1':\n",
    "                \n",
    "                v_1, v_2 = drift_1(x_t, x_batch, t.squeeze()), drift_2(x_t, x_batch, t.squeeze())\n",
    "                \n",
    "                kl_value += ( ( (fn( v_1 ) - fn( v_2 ) ) / (1 - t) )**2).sum([-1]).sum()\n",
    "                \n",
    "                # kl_value += ( ( (fn( v_1 ) - fn( v_2 ) ) / (1 - t) )**2).sum([-1, -2]).sum()\n",
    "    \n",
    "            elif predict_type == 'noise':\n",
    "                \n",
    "                z_predict_1, z_predict_2 = drift_1(x_t, x_batch, t.squeeze()), drift_2(x_t, x_batch, t.squeeze())\n",
    "    \n",
    "                x_1_predict_1 = (x_t - (1 - t) * x_batch - torch.sqrt(t * (1 - t) * eps) * z_predict_1) / t\n",
    "                \n",
    "                x_1_predict_2 = (x_t - (1 - t) * x_batch - torch.sqrt(t * (1 - t) * eps) * z_predict_2) / t\n",
    "                \n",
    "                v_1, v_2 = (x_1_predict_1), (x_1_predict_2)\n",
    "                \n",
    "                kl_value += ( ( (fn( v_1 ) - fn( v_2 ) ) / (1 - t) )**2).sum([-1]).sum()\n",
    "            \n",
    "            \n",
    "\n",
    "        # print( 1 / (2 * eps) * kl_value / n_elems )\n",
    "        \n",
    "    return 1 / (2 * eps) * kl_value / n_elems\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ccf2157e-9670-40be-b8ad-7c0b2892d682",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "var_list_0 = {\"x\" + str(i): dim for i in range(2)}\n",
    "\n",
    "var_list = list(var_list_0.keys())\n",
    "sizes = list(var_list_0.values())\n",
    "\n",
    "net_1 = UnetMLP_simple(dim=dim, init_dim=n_filters, dim_mults=[],\n",
    "                            time_dim=time_embed, nb_var=2, dropout_prob=dropout_prob)\n",
    "\n",
    "drift_net_1 = CondMLPTimeMINDE(net_1, plan=True).to(device)\n",
    "\n",
    "drift_net_2 = CondMLPTimeMINDE(net_1, plan=False).to(device)\n",
    "\n",
    "\n",
    "from torch_ema import ExponentialMovingAverage\n",
    "\n",
    "ema_g_bm_1 = ExponentialMovingAverage(drift_net_1.parameters(), decay=ema_decay)\n",
    "ema_g_bm_2 = ema_g_bm_1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c15d5e8-b161-4726-bbb3-74457f06abcd",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "bm_1 = BridgeMathcing(drift_net_1, eps=eps, predict_type=predict_type, loss_weight=loss_weight)\n",
    "\n",
    "# opt = torch.optim.AdamW(net_1.parameters(), lr=lr, 0.1)\n",
    "opt = torch.optim.Adam(net_1.parameters(), lr=lr, weight_decay=wd_reg)\n",
    "\n",
    "bm_2 = BridgeMathcing(drift_net_2, eps=eps, predict_type=predict_type, loss_weight=loss_weight) \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "26816f6f-69f3-453e-b824-3bc893229a65",
   "metadata": {},
   "outputs": [],
   "source": [
    "t_eps = 1e-3\n",
    "sample_fn = lambda batch_size: (torch.rand(batch_size).to(device)) * (1 - t_eps)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "d437938a-0cef-436f-8e9a-94e8041f642e",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "Xs, Ys, Lx, Ly = generate_dataset(shuffle_coef)\n",
    "\n",
    "Xs = np.nan_to_num((Xs - Xs.mean(axis=0)) / Xs.std(axis=0))\n",
    "Ys = np.nan_to_num((Ys - Ys.mean(axis=0)) / Ys.std(axis=0))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "999e3777-16ac-478b-bdf0-d9cd939d226e",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "\n",
    "train_dataset = TensorDataset(torch.Tensor(Xs[:int(Xs.shape[0] * 0.9)]), torch.Tensor(Ys[:int(Xs.shape[0] * 0.9)]))\n",
    "\n",
    "val_dataset = TensorDataset(torch.Tensor(Xs[int(Xs.shape[0] * 0.9):]), torch.Tensor(Ys[int(Xs.shape[0] * 0.9):]))\n",
    "\n",
    "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
    "val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "574c3d0b-4a74-4335-963f-1440da4ef818",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# DISCRETE ESTIMATORS\n",
    "def entropyd(sx, base=2):\n",
    "    \"\"\"Discrete entropy estimator\n",
    "    sx is a list of samples\n",
    "    \"\"\"\n",
    "    unique, count = np.unique(sx, return_counts=True, axis=0)\n",
    "    # Convert to float as otherwise integer division results in all 0 for proba.\n",
    "    proba = count.astype(float) / len(sx)\n",
    "    # Avoid 0 division; remove probabilities == 0.0 (removing them does not change the entropy estimate as 0 * log(1/0) = 0.\n",
    "    proba = proba[proba > 0.0]\n",
    "    return np.sum(proba * np.log(1.0 / proba)) / np.log(base)\n",
    "\n",
    "def centropyd(x, y, base=2):\n",
    "    \"\"\"The classic K-L k-nearest neighbor continuous entropy estimator for the\n",
    "    entropy of X conditioned on Y.\n",
    "    \"\"\"\n",
    "    xy = np.c_[x, y]\n",
    "    return entropyd(xy, base) - entropyd(y, base)\n",
    "\n",
    "\n",
    "def midd(x, y, base=2):\n",
    "    \"\"\"Discrete mutual information estimator\n",
    "    Given a list of samples which can be any hashable object\n",
    "    \"\"\"\n",
    "    assert len(x) == len(y), \"Arrays should have same length\"\n",
    "    return entropyd(x, base) - centropyd(x, y, base)\n",
    "\n",
    "\n",
    "mi_gt = midd(Lx, Ly) / np.log(2)\n",
    "\n",
    "'GT MI: ', mi_gt\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d6efaa7-32b2-4fe3-a6e7-6b5f6c5c7336",
   "metadata": {},
   "outputs": [],
   "source": [
    "import wandb\n",
    "\n",
    "# right one\n",
    "mi_gt_nats = mi_gt * np.log(2)**2\n",
    "\n",
    "wandb_config = {'mi_gt': mi_gt_nats, 'n_filters': n_filters, 'time_embed': time_embed,\n",
    "                'ema_decay': ema_decay, 'eps': eps, 'predict_type': predict_type, \n",
    "               'lr': lr, 'shuffle_coef': shuffle_coef, 'n_epochs': n_epochs,\n",
    "                'wd_reg': wd_reg, 'seed': seed, 'dropout_prob': dropout_prob}\n",
    "\n",
    "wandb.init(project='ProtTrans5', name=f'MI_{mi_gt_nats}', config=wandb_config)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "7e3f8983-f6d3-417d-8ddc-d284c5ef356d",
   "metadata": {},
   "outputs": [],
   "source": [
    "class RunningMeanLast5:\n",
    "    def __init__(self):\n",
    "        self.values = []\n",
    "        self.mean = 0.0\n",
    "\n",
    "    def add_value(self, value):\n",
    "        self.values.append(value)\n",
    "        if len(self.values) > 5:\n",
    "            self.values.pop(0)  # Remove the oldest value if more than 5 values are present\n",
    "        self.mean = sum(self.values) / len(self.values)\n",
    "\n",
    "    def get_mean(self):\n",
    "        return self.mean\n",
    "\n",
    "    def clear(self):\n",
    "        self.values = []\n",
    "        self.mean = 0.0\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8c5a736-cee0-4b13-b9d7-1e84343bc0fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "rm = RunningMeanLast5()\n",
    "\n",
    "import wandb\n",
    "net_1.train()\n",
    "\n",
    "# n_epochs = 100\n",
    "\n",
    "x_start = []\n",
    "\n",
    "for j in range(n_epochs):\n",
    "\n",
    "    for i, item in enumerate(train_loader):\n",
    "        \n",
    "        x_samples = item[0].to(device)\n",
    "        y_samples = item[1].to(device)\n",
    "        \n",
    "        t = sample_fn(x_samples.shape[0])\n",
    "            \n",
    "        loss_1 = bm_1.step(x_samples, y_samples, t)\n",
    "        \n",
    "        opt.zero_grad()\n",
    "    \n",
    "        loss_1.backward()\n",
    "            \n",
    "        opt.step()\n",
    "        ema_g_bm_1.update()\n",
    "        \n",
    "        y_samples_permuted = y_samples[torch.randperm(y_samples.shape[0])]\n",
    "    \n",
    "        t = sample_fn(x_samples.shape[0])\n",
    "                \n",
    "        loss_2 = bm_2.step(x_samples, y_samples_permuted, t)\n",
    "    \n",
    "        opt.zero_grad()\n",
    "    \n",
    "        loss_2.backward()\n",
    "            \n",
    "        opt.step()\n",
    "        ema_g_bm_2.update()\n",
    "            \n",
    "    net_1.eval()\n",
    "    with ema_g_bm_1.average_parameters():\n",
    "        with ema_g_bm_2.average_parameters():\n",
    "            mutual_entropy_est_ema = estimate_kl(drift_net_1, drift_net_2, eps, val_loader, t_eps=t_eps, posterior='partial', predict_type=predict_type, n_repeat=10)\n",
    "            mutual_entropy_est_ema_train = estimate_kl(drift_net_1, drift_net_2, eps, train_loader, t_eps=t_eps, posterior='partial', predict_type=predict_type, n_repeat=10)\n",
    "    net_1.train()\n",
    "\n",
    "    rm.add_value(mutual_entropy_est_ema)\n",
    "    \n",
    "    mutual_entropy_est = estimate_kl(drift_net_1, drift_net_2, eps, val_loader, t_eps=t_eps, posterior='partial', predict_type=predict_type, n_repeat=10)\n",
    "    \n",
    "    print(f'Epoch {j} MI: {mutual_entropy_est} EMA MI {mutual_entropy_est_ema}')\n",
    "    \n",
    "    wandb.log({'MI': mutual_entropy_est, 'MI_est': mutual_entropy_est_ema, 'MI EMA train': mutual_entropy_est_ema_train})\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43497a58-6d72-46f2-98cd-4b4798df78e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import os\n",
    "\n",
    "final_mi = rm.get_mean()\n",
    "\n",
    "import os.path as osp\n",
    "\n",
    "dir_path = f'logs/infobridge/eps_{eps}_seed_{seed}'\n",
    "\n",
    "if not os.path.exists(dir_path):\n",
    "    os.makedirs(dir_path)\n",
    "\n",
    "log_path = osp.join(dir_path, f'MI_{mi_gt}.txt')\n",
    "\n",
    "with open(log_path, 'w+') as f:\n",
    "    \n",
    "    f.write(f\"Shuffle Rate: {shuffle_coef}\\n\")\n",
    "    f.write(f\"DMI: {mi_gt}\\n\")\n",
    "    f.write(f\"Infobridge MI: {final_mi}\\n\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b13e74cd-c1a6-40fc-8e03-007116bbd751",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base_env",
   "language": "python",
   "name": "base_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
