{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import jax\n",
    "jax.config.update('jax_platforms', 'cpu')\n",
    "import bmi\n",
    "from src.libs.minde import MINDE\n",
    "from src.scripts.helper import get_data_loader, get_default_config\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# jax.__version__\n",
    "# import jaxlib\n",
    "# jaxlib.__version__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip install jaxlib==0.4.30 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# !pip install jax[cuda]==0.4.30 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "bmi.benchmark.BENCHMARK_TASKS\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "\n",
    "name_task = \"half_cube-multinormal-sparse-25-25-2-2.0\"\n",
    "# name_task=\"swissroll_x-normal_cdf-1v1-normal-0.75\"\n",
    "\n",
    "eps = 1\n",
    "\n",
    "ema_decay = 0.999\n",
    "\n",
    "n_epochs = 50\n",
    "\n",
    "# predict_type = 'x_1'\n",
    "\n",
    "predict_type = 'vector_field'\n",
    "\n",
    "seed = 42\n",
    "\n",
    "lr = 3e-4\n",
    "\n",
    "loss_weight = None\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bmi.benchmark.BENCHMARK_TASKS\n",
    "# name_task = \"1v1-normal-0.75\"\n",
    "# name_task=\"student-identity-5-5-2\"\n",
    "# name_task  = \"spiral-multinormal-sparse-25-25-2-2.0\"\n",
    "# name_task=\"multinormal-dense-5-5-0.5\"\n",
    "task = bmi.benchmark.BENCHMARK_TASKS[name_task]\n",
    "task.mutual_information"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = get_default_config()\n",
    "args.type =\"c\"\n",
    "args.test_epoch=2\n",
    "args.max_epochs=100\n",
    "args.warmup_epochs = 0\n",
    "args.bs = 512\n",
    "args.lr = 1e-4\n",
    "args.arch = \"mlp\"\n",
    "args.importance_sampling = True\n",
    "args.use_ema = True\n",
    "args.seed = seed\n",
    "\n",
    "train_l,test_l = get_data_loader(args,task)\n",
    "\n",
    "dim = next(iter(train_l))['x'].shape[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "if dim <= 5:\n",
    "    time_embed = 64\n",
    "\n",
    "    n_filters = 64\n",
    "elif dim <=25:\n",
    "\n",
    "    time_embed = 128\n",
    "\n",
    "    n_filters = 128\n",
    "else:\n",
    "    \n",
    "    time_embed = 256\n",
    "\n",
    "    n_filters = 256\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "default_seed = 42\n",
    "\n",
    "def seed_basic(seed=default_seed):\n",
    "    import os\n",
    "    import random\n",
    "\n",
    "    random.seed(seed)\n",
    "    os.environ['PYTHONHASHSEED'] = str(seed)\n",
    "\n",
    "\n",
    "def seed_numpy(seed=default_seed):\n",
    "    import numpy\n",
    "\n",
    "    numpy.random.seed(seed)\n",
    "\n",
    "\n",
    "# tensorflow random seed \n",
    "def seed_tensorflow(seed=default_seed):\n",
    "    import tensorflow\n",
    "\n",
    "    tensorflow.random.set_seed(seed)\n",
    "\n",
    "\n",
    "# torch random seed\n",
    "def seed_torch(seed=default_seed):\n",
    "    import torch\n",
    "\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "\n",
    "\n",
    "def seed_everything(seed=default_seed,\n",
    "                    to_be_seeded: list[str] = [\"basic\", \"numpy\", \"tensorflow\", \"torch\"]):\n",
    "    for name in to_be_seeded:\n",
    "        try:\n",
    "            globals()[\"seed_\" + name](seed)\n",
    "        except Exception as exception:\n",
    "            print(exception)\n",
    "\n",
    "seed_everything(seed=seed)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\t\n",
    "# train_l,test_l = get_data_loader(args,task)\n",
    "# model = MINDE(args,var_list={\"x\":task.dim_x,\"y\":task.dim_y}, gt = task.mutual_information)\n",
    "# model.fit(train_l,test_l)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "import wandb\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "wandb_config = {'dim': dim, 'eps': eps, 'task': name_task, 'ema_decay': ema_decay,\n",
    "                'n_epochs': n_epochs, 'gt': task.mutual_information, 'seed': args.seed,\n",
    "                  'n_filters': n_filters, 'lr': lr}\n",
    "\n",
    "wandb.init(project=\"Bridge_MI\", name=f\"{name_task}\", config=wandb_config)\n",
    "\n",
    "wandb.log({'GT MI': task.mutual_information})\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "task.dim_x, task.dim_y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import os\n",
    "\n",
    "log_path = os.path.join('log_bmi', f'n_feat_{n_filters}_predict_{predict_type}')\n",
    "\n",
    "if not os.path.exists(log_path):\n",
    "    \n",
    "    os.mkdir(log_path)\n",
    "    \n",
    "log_path = os.path.join(log_path,  f'{name_task}')\n",
    "\n",
    "if not os.path.exists(log_path):\n",
    "    \n",
    "    os.mkdir(log_path)\n",
    "\n",
    "\n",
    "log_path = os.path.join(log_path, f'predict_{predict_type}_loss_weight_{loss_weight}_lr_{lr}_eps_{eps}_ema_{ema_decay}_n_epochs_{n_epochs}_n_filters_{n_filters}')\n",
    "\n",
    "if not os.path.exists(log_path):\n",
    "    \n",
    "    os.mkdir(log_path)\n",
    "\n",
    "log_path = os.path.join(log_path, f'seed_{seed}.txt')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import torch\n",
    "from torch import nn\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "from functools import partial\n",
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "\n",
    "# helpers functions\n",
    "\n",
    "\n",
    "def exists(x):\n",
    "    return x is not None\n",
    "\n",
    "\n",
    "def default(val, d):\n",
    "    if exists(val):\n",
    "        return val\n",
    "    return d() if callable(d) else d\n",
    "\n",
    "\n",
    "def identity(t, *args, **kwargs):\n",
    "    return t\n",
    "\n",
    "\n",
    "def num_to_groups(num, divisor):\n",
    "    groups = num // divisor\n",
    "    remainder = num % divisor\n",
    "    arr = [divisor] * groups\n",
    "    if remainder > 0:\n",
    "        arr.append(remainder)\n",
    "    return arr\n",
    "\n",
    "\n",
    "# small helper modules\n",
    "\n",
    "class Residual(nn.Module):\n",
    "    def __init__(self, fn):\n",
    "        super().__init__()\n",
    "        self.fn = fn\n",
    "\n",
    "    def forward(self, x, *args, **kwargs):\n",
    "        return self.fn(x, *args, **kwargs) + x\n",
    "\n",
    "\n",
    "def Upsample(dim, dim_out=None):\n",
    "    return nn.Sequential(\n",
    "        # nn.Upsample(scale_factor = 2, mode = 'nearest'),\n",
    "        nn.Linear(dim, default(dim_out, dim))\n",
    "    )\n",
    "\n",
    "\n",
    "def Downsample(dim, dim_out=None):\n",
    "    return nn.Linear(dim, default(dim_out, dim))\n",
    "\n",
    "\n",
    "class Residual(nn.Module):\n",
    "    def __init__(self, fn):\n",
    "        super().__init__()\n",
    "        self.fn = fn\n",
    "\n",
    "    def forward(self, x, *args, **kwargs):\n",
    "        return self.fn(x, *args, **kwargs) + x\n",
    "\n",
    "\n",
    "class Block(nn.Module):\n",
    "    def __init__(self, dim, dim_out, groups=8, shift_scale=True):\n",
    "        super().__init__()\n",
    "        # self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1)\n",
    "        self.proj = nn.Linear(dim, dim_out)\n",
    "        self.act = nn.SiLU()\n",
    "        # self.act = nn.Relu()\n",
    "        self.norm = nn.GroupNorm(groups, dim)\n",
    "        # self.norm = nn.BatchNorm1d( dim)\n",
    "        self.shift_scale = shift_scale\n",
    "\n",
    "    def forward(self, x, t=None):\n",
    "        x = self.norm(x)\n",
    "        x = self.act(x)\n",
    "        x = self.proj(x)\n",
    "\n",
    "        if exists(t):\n",
    "            if self.shift_scale:\n",
    "                scale, shift = t\n",
    "                x = x * (scale.squeeze() + 1) + shift.squeeze()\n",
    "            else:\n",
    "                x = x + t\n",
    "\n",
    "        return x\n",
    "\n",
    "\n",
    "class ResnetBlock(nn.Module):\n",
    "    def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=32, shift_scale=False):\n",
    "        super().__init__()\n",
    "        self.shift_scale = shift_scale\n",
    "        self.mlp = nn.Sequential(\n",
    "            nn.SiLU(),\n",
    "            # nn.Linear(time_emb_dim, dim_out * 2)\n",
    "            nn.Linear(time_emb_dim, dim_out*2 if shift_scale else dim_out)\n",
    "        ) if exists(time_emb_dim) else None\n",
    "\n",
    "        self.block1 = Block(dim, dim_out, groups=groups,\n",
    "                            shift_scale=shift_scale)\n",
    "        self.block2 = Block(dim_out, dim_out, groups=groups,\n",
    "                            shift_scale=shift_scale)\n",
    "        # self.res_conv = nn.Conv1d(dim, dim_out, 1) if dim != dim_out else nn.Identity()\n",
    "        self.lin_layer = nn.Linear(\n",
    "            dim, dim_out) if dim != dim_out else nn.Identity()\n",
    "\n",
    "    def forward(self, x, time_emb=None):\n",
    "\n",
    "        scale_shift = None\n",
    "        if exists(self.mlp) and exists(time_emb):\n",
    "\n",
    "            time_emb = self.mlp(time_emb)\n",
    "            scale_shift = time_emb\n",
    "\n",
    "        h = self.block1(x, t=scale_shift)\n",
    "\n",
    "        h = self.block2(h)\n",
    "\n",
    "        return h + self.lin_layer(x)\n",
    "\n",
    "\n",
    "class UnetMLP_simple(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        dim,\n",
    "        init_dim=128,\n",
    "        dim_mults=(1, 1),\n",
    "        resnet_block_groups=8,\n",
    "        time_dim=128,\n",
    "        nb_var=1,\n",
    "    ):\n",
    "        super().__init__()\n",
    "\n",
    "        # determine dimensions\n",
    "        self.nb_var = nb_var\n",
    "        init_dim = default(init_dim, dim)\n",
    "        if init_dim == None:\n",
    "            init_dim = dim * dim_mults[0]\n",
    "\n",
    "        dim_in = dim\n",
    "        dims = [init_dim, *map(lambda m: init_dim * m, dim_mults)]\n",
    "        in_out = list(zip(dims[:-1], dims[1:]))\n",
    "\n",
    "        block_klass = partial(ResnetBlock, groups=resnet_block_groups)\n",
    "\n",
    "        self.init_lin = nn.Linear(dim * 2, init_dim)\n",
    "\n",
    "        self.time_mlp = nn.Sequential(\n",
    "            nn.Linear(nb_var, time_dim),\n",
    "            nn.GELU(),\n",
    "            nn.Linear(time_dim, time_dim)\n",
    "        )\n",
    "\n",
    "        # layers\n",
    "\n",
    "        self.downs = nn.ModuleList([])\n",
    "        self.ups = nn.ModuleList([])\n",
    "\n",
    "        num_resolutions = len(in_out)\n",
    "\n",
    "        for ind, (dim_in, dim_out) in enumerate(in_out):\n",
    "            is_last = ind >= (num_resolutions - 1)\n",
    "\n",
    "            module = nn.ModuleList([block_klass(dim_in, dim_in, time_emb_dim=time_dim),\n",
    "                                    #        block_klass(dim_in, dim_in, time_emb_dim = time_dim)\n",
    "                                    ])\n",
    "\n",
    "            # module.append( Downsample(dim_in, dim_out) if not is_last else nn.Linear(dim_in, dim_out))\n",
    "            self.downs.append(module)\n",
    "\n",
    "        mid_dim = dims[-1]\n",
    "        joint_dim = mid_dim\n",
    "       # joint_dim = 24\n",
    "        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)\n",
    "\n",
    "        # self.mid_block2 = block_klass(joint_dim, mid_dim, time_emb_dim = time_dim)\n",
    "\n",
    "        for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):\n",
    "            is_last = ind == (len(in_out) - 1)\n",
    "            module = nn.ModuleList([block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),\n",
    "                                    #       block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim)\n",
    "                                    ])\n",
    "            self.ups.append(module)\n",
    "\n",
    "\n",
    "        self.out_dim = dim_in\n",
    "\n",
    "        self.final_res_block = block_klass(\n",
    "            init_dim * 2, init_dim, time_emb_dim=time_dim)\n",
    "\n",
    "        self.proj = nn.Linear(init_dim, dim)\n",
    "\n",
    "        self.proj.weight.data.fill_(0.0)\n",
    "        self.proj.bias.data.fill_(0.0)\n",
    "\n",
    "        self.final_lin = nn.Sequential(\n",
    "            nn.GroupNorm(resnet_block_groups, init_dim),\n",
    "            nn.SiLU(),\n",
    "            self.proj\n",
    "        )\n",
    "\n",
    "    def forward(self, x, t=None, std=None):\n",
    "        t = t.reshape(t.size(0), self.nb_var)\n",
    "\n",
    "        x = self.init_lin(x.float())\n",
    "        \n",
    "        r = x.clone()\n",
    "\n",
    "        t = self.time_mlp(t).squeeze()\n",
    "\n",
    "        h = []\n",
    "\n",
    "        for blocks in self.downs:\n",
    "\n",
    "            block1 = blocks[0]\n",
    "\n",
    "            x = block1(x, t)\n",
    "\n",
    "            h.append(x)\n",
    "            \n",
    "        for blocks in self.ups:\n",
    "\n",
    "            block1 = blocks[0]\n",
    "            x = torch.cat((x, h.pop()), dim=1)\n",
    "            x = block1(x, t)\n",
    "        \n",
    "\n",
    "        x = torch.cat((x, r), dim=1)\n",
    "\n",
    "        x = self.final_res_block(x, t)\n",
    "\n",
    "        if std != None:\n",
    "            return self.final_lin(x) / std\n",
    "        else:\n",
    "            return self.final_lin(x)\n",
    "\n",
    "\n",
    "var_list_0 = {\"x\" + str(i): dim for i in range(2)}\n",
    "\n",
    "var_list = list(var_list_0.keys())\n",
    "sizes = list(var_list_0.values())\n",
    "hidden_dim = 128\n",
    "\n",
    "score = UnetMLP_simple(dim=np.sum(sizes), init_dim=128, dim_mults=[],\n",
    "                            time_dim=hidden_dim, nb_var=len(var_list_0.keys()))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "device = 'cuda'\n",
    "\n",
    "class CondMLPTimeMINDE(nn.Module):\n",
    "    \n",
    "    def __init__(self, minde_mlp):\n",
    "        super().__init__()\n",
    "        self.net = minde_mlp\n",
    "        \n",
    "    def forward(self, x, x_0, time):\n",
    "        \n",
    "        return self.net(torch.cat([x, x_0], dim=-1), time)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "from torch import nn\n",
    "\n",
    "class BridgeMathcing(nn.Module):\n",
    "    def __init__(self, unet, eps, predict_type='vector_field', loss_weight=False):\n",
    "        super().__init__()\n",
    "\n",
    "        self.loss_weight = loss_weight\n",
    "\n",
    "        assert predict_type in ['vector_field', 'x_1', 'noise']\n",
    "        \n",
    "        self.predict_type = predict_type\n",
    "        self.vector_net = unet\n",
    "        \n",
    "        self.eps = eps\n",
    "        \n",
    "    def forward(self, x_0):\n",
    "        # solve forward ODE via Euler or torchdiffeq solver\n",
    "        x_t = x_0\n",
    "        \n",
    "        t_range = tqdm(torch.arange(0, 1, step=self.euler_dt))\n",
    "        \n",
    "        for t in t_range:\n",
    "            eps_noise = torch.randn_like(x_t, device=x_0.device)\n",
    "            x_t = x_t + self.vector_net(x_t, x_0, t.squeeze()) * self.euler_dt + torch.sqrt(self.euler_dt * self.eps) * eps_noise\n",
    "        \n",
    "        return x_t\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def sample(self, x_0, nfe=100, pbar=True):\n",
    "\n",
    "        euler_dt = 1. / nfe\n",
    "        \n",
    "        x_t = x_0\n",
    "        \n",
    "        if pbar:\n",
    "            t_range = tqdm(torch.arange(0, 1, step=euler_dt).to(x_0.device))\n",
    "        else:\n",
    "            t_range = torch.arange(0, 1, step=euler_dt).to(x_0.device)\n",
    "        \n",
    "        for t in t_range:\n",
    "            eps_noise = torch.randn_like(x_t, device=x_0.device)\n",
    "            \n",
    "            x_t = x_t + self.vector_net(x_t, x_0, t.squeeze()) * euler_dt + math.sqrt(euler_dt * self.eps) * eps_noise\n",
    "            \n",
    "        return x_t\n",
    "\n",
    "    def sample_x_t(self, x_0, x_1, t):\n",
    "\n",
    "        coef_0, coef_1 = 1 - t, t\n",
    "\n",
    "        std_t = torch.sqrt(t * (1 - t) * self.eps)\n",
    "\n",
    "        z = torch.randn_like(x_0, device=x_0.device)\n",
    "        \n",
    "        x_t = coef_1.reshape([-1, 1]) * x_1 + coef_0.reshape([-1, 1]) * x_0 + z * std_t.reshape([-1, 1])\n",
    "        return x_t, z\n",
    "    \n",
    "    def step(self, x_0, x_1, t):\n",
    "        t = t.reshape([-1, 1])\n",
    "        x_t, z = self.sample_x_t(x_0, x_1, t)\n",
    "        x_t_hat = self.vector_net(x_t, x_0, t.squeeze())\n",
    "        return self.loss(x_t_hat, x_1, x_0, x_t, t, z).mean()\n",
    "    \n",
    "    def loss(self, x_t_hat, x_1, x_0, x_t, t, z):\n",
    "\n",
    "        if self.predict_type == 'x_1':\n",
    "\n",
    "            vector_field = (x_t_hat - (x_1 - x_t) / (1 - t)).reshape([x_1.shape[0], -1])\n",
    "\n",
    "            if self.loss_weight:\n",
    "\n",
    "                with torch.no_grad():\n",
    "                    \n",
    "                    z_coef = torch.abs((vector_field**2 - (((x_1 - x_t) / (1 - t)).reshape([x_1.shape[0], -1]) * vector_field)).mean(-1))\n",
    "\n",
    "                    if self.loss_weight == 'mean':\n",
    "                        norm_const = z_coef.mean()\n",
    "                    elif self.loss_weight == 'sum':\n",
    "                        norm_const = z_coef.sum()\n",
    "                    else:\n",
    "                        raise RuntimeError('Unknown loss weight')\n",
    "                    \n",
    "                return torch.norm((x_t_hat - x_1).reshape([x_1.shape[0], -1]), dim=-1) * z_coef / norm_const\n",
    "            \n",
    "            return torch.norm((x_t_hat - x_1).reshape([x_1.shape[0], -1]), dim=-1)\n",
    "        \n",
    "        if self.predict_type == 'vector_field':\n",
    "\n",
    "            \n",
    "\n",
    "            # print(x_t_hat.shape, x_1.shape, vector_field.shape, ((x_1 - x_t) / (1 - t)).shape)\n",
    "            \n",
    "            if self.loss_weight:\n",
    "                \n",
    "                vector_field = (x_t_hat - (x_1 - x_t) / (1 - t)).reshape([x_1.shape[0], -1])\n",
    "\n",
    "                with torch.no_grad():\n",
    "                    \n",
    "                    z_coef = torch.abs((vector_field**2 - (((x_1 - x_t) / (1 - t)).reshape([x_1.shape[0], -1]) * vector_field)).mean(-1))\n",
    "\n",
    "                    if self.loss_weight == 'mean':\n",
    "                        norm_const = z_coef.mean()\n",
    "                    elif self.loss_weight == 'sum':\n",
    "                        norm_const = z_coef.sum()\n",
    "                    else:\n",
    "                        raise RuntimeError('Unknown loss weight')\n",
    "                    \n",
    "                return torch.norm((x_t_hat - ((x_1 - x_t) / (1 - t))).reshape([x_1.shape[0], -1]), dim=-1) * z_coef / norm_const\n",
    "\n",
    "            # print(x_t_hat.shape, vector_field.shape, ((x_1 - x_t) / (1 - t)).shape)\n",
    "            \n",
    "            return torch.norm((x_t_hat - ((x_1 - x_t) / (1 - t))).reshape([x_1.shape[0], -1]), dim=-1)\n",
    "\n",
    "        if self.predict_type == 'noise':\n",
    "            \n",
    "            return torch.norm((x_t_hat - z).reshape([x_1.shape[0], -1]), dim=-1)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "@torch.no_grad()\n",
    "def estimate_kl(drift_1, drift_2, eps, test_loader, t_eps=1e-3, posterior='partial', predict_type='vector_field', n_repeat=10):\n",
    "    \n",
    "    # inference\n",
    "    \n",
    "    def BB_sample(x_1, x_0, t):\n",
    "        mean = x_1 * t + x_0 * (1 - t)\n",
    "        std = torch.sqrt(eps * (1 - t) * t)\n",
    "\n",
    "        return mean + std * torch.randn_like(x_1)\n",
    "    \n",
    "    kl_value = 0\n",
    "    \n",
    "    fn = lambda x: x\n",
    "\n",
    "    n_elems = 0\n",
    "\n",
    "    for k in range(n_repeat):\n",
    "        \n",
    "        for item in iter(test_loader):\n",
    "            \n",
    "            samples_x = item['x'].to(device)\n",
    "            samples_y = item['y'].to(device)\n",
    "            \n",
    "            n_elems += samples_x.shape[0]\n",
    "    \n",
    "            x_batch = samples_x\n",
    "            y_batch = samples_y\n",
    "            \n",
    "            x_batch, y_batch = x_batch.to(device), y_batch.to(device)\n",
    "            \n",
    "            # print(x_batch.shape, y_batch.shape)\n",
    "            \n",
    "            t = (torch.rand(x_batch.shape[0]).to(device)) * (1 - t_eps)\n",
    "            \n",
    "            t = t.reshape([-1, 1])\n",
    "                    \n",
    "            x_t = BB_sample(y_batch, x_batch, t)\n",
    "    \n",
    "            if predict_type == 'vector_field':\n",
    "                \n",
    "                v_1, v_2 = drift_1(x_t, x_batch, t.squeeze()), drift_2(x_t, x_batch, t.squeeze())\n",
    "                kl_value += ((fn( v_1 ) - fn( v_2 ) )**2).sum([-1]).sum()\n",
    "                \n",
    "            elif predict_type == 'x_1':\n",
    "                \n",
    "                v_1, v_2 = drift_1(x_t, x_batch, t.squeeze()), drift_2(x_t, x_batch, t.squeeze())\n",
    "                \n",
    "                kl_value += ( ( (fn( v_1 ) - fn( v_2 ) ) / (1 - t) )**2).sum([-1]).sum()\n",
    "                \n",
    "                # kl_value += ( ( (fn( v_1 ) - fn( v_2 ) ) / (1 - t) )**2).sum([-1, -2]).sum()\n",
    "    \n",
    "            elif predict_type == 'noise':\n",
    "                \n",
    "                z_predict_1, z_predict_2 = drift_1(x_t, x_batch, t.squeeze()), drift_2(x_t, x_batch, t.squeeze())\n",
    "    \n",
    "                x_1_predict_1 = (x_t - (1 - t) * x_batch - torch.sqrt(t * (1 - t) * eps) * z_predict_1) / t\n",
    "                \n",
    "                x_1_predict_2 = (x_t - (1 - t) * x_batch - torch.sqrt(t * (1 - t) * eps) * z_predict_2) / t\n",
    "                \n",
    "                v_1, v_2 = (x_1_predict_1), (x_1_predict_2)\n",
    "                \n",
    "                kl_value += ( ( (fn( v_1 ) - fn( v_2 ) ) / (1 - t) )**2).sum([-1]).sum()\n",
    "            \n",
    "            \n",
    "\n",
    "        # print( 1 / (2 * eps) * kl_value / n_elems )\n",
    "        \n",
    "    return 1 / (2 * eps) * kl_value / n_elems\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import torch.nn as nn\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "class MLP(nn.Module):\n",
    "    def __init__(self, input_dim, layer_widths=[100,100,2], activate_final = False, activation_fn=F.tanh):\n",
    "        super(MLP, self).__init__()\n",
    "        layers = []\n",
    "        prev_width = input_dim\n",
    "        for layer_width in layer_widths:\n",
    "            layers.append(torch.nn.Linear(prev_width, layer_width))\n",
    "            prev_width = layer_width\n",
    "        self.input_dim = input_dim\n",
    "        self.layer_widths = layer_widths\n",
    "        self.layers = nn.ModuleList(layers)\n",
    "        self.activate_final = activate_final\n",
    "        self.activation_fn = activation_fn\n",
    "        \n",
    "    def forward(self, x):\n",
    "        for i, layer in enumerate(self.layers[:-1]):\n",
    "            x = self.activation_fn(layer(x))\n",
    "        x = self.layers[-1](x)\n",
    "        if self.activate_final:\n",
    "            x = self.activation_fn(x)\n",
    "        return x\n",
    "\n",
    "import math\n",
    "class SinusoidalPosEmb(nn.Module):\n",
    "\n",
    "    def __init__(self, dim):\n",
    "        super().__init__()\n",
    "        self.dim = dim\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.squeeze()\n",
    "        half_dim = self.dim // 2\n",
    "        emb = math.log(10000) / (half_dim - 1)\n",
    "        emb = torch.exp(torch.arange(half_dim, device=x.device) * -emb)\n",
    "        emb = x.unsqueeze(-1) * emb.unsqueeze(0)\n",
    "\n",
    "        return torch.cat((emb.sin(), emb.cos()), dim=-1)\n",
    "    \n",
    "class MLPTime(nn.Module):\n",
    "    \n",
    "    def __init__(self, in_channels, out_channels, hidden_channels, time_embed_dim, num_hidden_blocks):\n",
    "        super().__init__()\n",
    "        self.to_time_embed = SinusoidalPosEmb(time_embed_dim)\n",
    "        self.in_layer = nn.Sequential(\n",
    "            nn.Linear(in_channels + time_embed_dim, hidden_channels),\n",
    "            nn.SiLU(),\n",
    "        )\n",
    "        \n",
    "        self.hidden_blocks = nn.ModuleList([\n",
    "            nn.Sequential(\n",
    "                nn.Linear(hidden_channels, hidden_channels),\n",
    "                nn.SiLU(),\n",
    "                nn.Linear(hidden_channels, hidden_channels),\n",
    "                nn.SiLU(),\n",
    "            ) for _ in range(num_hidden_blocks)\n",
    "        ])\n",
    "\n",
    "        self.out_layer = nn.Linear(hidden_channels, out_channels)\n",
    "        \n",
    "    def forward(self, x, time):\n",
    "        time_embed = self.to_time_embed(time)\n",
    "        x = self.in_layer(torch.cat([x, time_embed], dim=-1))\n",
    "        for hidden_block in self.hidden_blocks:\n",
    "            x = hidden_block(x) + x\n",
    "        x = self.out_layer(x)\n",
    "\n",
    "        return x\n",
    "\n",
    "\n",
    "class CondMLPTime(nn.Module):\n",
    "    \n",
    "    def __init__(self, mlp_time):\n",
    "        super().__init__()\n",
    "        self.net = mlp_time\n",
    "        \n",
    "    def forward(self, x, x_0, time):\n",
    "        \n",
    "        return self.net(torch.cat([x, x_0], dim=-1), time)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "\n",
    "class CondMLPTimeMINDE(nn.Module):\n",
    "    \n",
    "    def __init__(self, minde_mlp, plan=True):\n",
    "        super().__init__()\n",
    "        self.net = minde_mlp\n",
    "        self.plan = plan\n",
    "        \n",
    "    def forward(self, x, x_0, time):\n",
    "        # print(torch.ones(x.shape[0]))\n",
    "        if self.plan:\n",
    "            \n",
    "            return self.net(torch.cat([x, x_0], dim=-1), torch.cat([time.unsqueeze(1), torch.ones(x.shape[0], 1).to(device)], dim=1) )\n",
    "        else:\n",
    "            \n",
    "            return self.net(torch.cat([x, x_0], dim=-1), torch.cat([time.unsqueeze(1), -torch.ones(x.shape[0], 1).to(device)], dim=1) )\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "device = 'cuda'\n",
    "\n",
    "# drift_net_1 = CondMLPTime(MLPTime(dim * 2, dim, n_filters, 32, 2)).to(device)\n",
    "\n",
    "# drift_net_2 = CondMLPTime(MLPTime(dim * 2, dim, n_filters, 32, 2)).to(device)\n",
    "\n",
    "var_list_0 = {\"x\" + str(i): dim for i in range(2)}\n",
    "\n",
    "var_list = list(var_list_0.keys())\n",
    "sizes = list(var_list_0.values())\n",
    "\n",
    "net_1 = UnetMLP_simple(dim=dim, init_dim=n_filters, dim_mults=[],\n",
    "                            time_dim=time_embed, nb_var=2)\n",
    "\n",
    "drift_net_1 = CondMLPTimeMINDE(net_1, plan=True).to(device)\n",
    "\n",
    "# for i, item in enumerate(train_l):\n",
    "    \n",
    "#     x_samples = item['x'].to(device)\n",
    "#     y_samples = item['y'].to(device)\n",
    "\n",
    "    \n",
    "#     t = sample_fn(x_samples.shape[0])\n",
    "\n",
    "# drift_net_1(x_samples, y_samples, t).shape\n",
    "\n",
    "\n",
    "# net_2 = UnetMLP_simple(dim=dim, init_dim=512, dim_mults=[],\n",
    "#                             time_dim=hidden_dim, nb_var=1)\n",
    "\n",
    "drift_net_2 = CondMLPTimeMINDE(net_1, plan=False).to(device)\n",
    "\n",
    "\n",
    "from torch_ema import ExponentialMovingAverage\n",
    "\n",
    "ema_g_bm_1 = ExponentialMovingAverage(drift_net_1.parameters(), decay=ema_decay)\n",
    "ema_g_bm_2 = ema_g_bm_1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# device = 'cuda'\n",
    "\n",
    "# drift_net_1 = CondMLPTime(MLPTime(dim * 2, dim, n_filters, 32, 2)).to(device)\n",
    "\n",
    "# drift_net_2 = CondMLPTime(MLPTime(dim * 2, dim, n_filters, 32, 2)).to(device)\n",
    "\n",
    "# var_list_0 = {\"x\" + str(i): dim for i in range(2)}\n",
    "\n",
    "# var_list = list(var_list_0.keys())\n",
    "# sizes = list(var_list_0.values())\n",
    "# hidden_dim = 128\n",
    "\n",
    "# net_1 = UnetMLP_simple(dim=dim, init_dim=512, dim_mults=[],\n",
    "#                             time_dim=hidden_dim, nb_var=1)\n",
    "\n",
    "# drift_net_1 = CondMLPTimeMINDE(net_1).to(device)\n",
    "\n",
    "# # for i, item in enumerate(train_l):\n",
    "    \n",
    "# #     x_samples = item['x'].to(device)\n",
    "# #     y_samples = item['y'].to(device)\n",
    "\n",
    "    \n",
    "# #     t = sample_fn(x_samples.shape[0])\n",
    "\n",
    "# # drift_net_1(x_samples, y_samples, t).shape\n",
    "\n",
    "\n",
    "# net_2 = UnetMLP_simple(dim=dim, init_dim=512, dim_mults=[],\n",
    "#                             time_dim=hidden_dim, nb_var=1)\n",
    "\n",
    "# drift_net_2 = CondMLPTimeMINDE(net_2).to(device)\n",
    "\n",
    "\n",
    "# from torch_ema import ExponentialMovingAverage\n",
    "\n",
    "# ema_g_bm_1 = ExponentialMovingAverage(drift_net_1.parameters(), decay=ema_decay)\n",
    "# ema_g_bm_2 = ExponentialMovingAverage(drift_net_2.parameters(), decay=ema_decay)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "bm_1 = BridgeMathcing(drift_net_1, eps=eps, predict_type=predict_type, loss_weight=loss_weight)\n",
    "\n",
    "opt = torch.optim.Adam(net_1.parameters(), lr=lr)\n",
    "\n",
    "bm_2 = BridgeMathcing(drift_net_2, eps=eps, predict_type=predict_type, loss_weight=loss_weight) \n",
    "\n",
    "# opt = torch.optim.Adam(drift_net_2.parameters(), lr=lr)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import numpy as np\n",
    "\n",
    "sum([np.prod(p.size()) for p in drift_net_1.parameters()])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "t_alpha = 1\n",
    "t_beta = 1\n",
    "\n",
    "t_eps = 1e-3\n",
    "\n",
    "if t_alpha is None or t_beta is None:\n",
    "\n",
    "    sample_fn = lambda batch_size: (torch.rand(batch_size).to(device)) * (1 - t_eps)\n",
    "\n",
    "else:\n",
    "    \n",
    "    print(f'Beta alpha {t_alpha} beta {t_beta}')\n",
    "    dist = torch.distributions.beta.Beta(t_alpha, t_beta)\n",
    "\n",
    "    sample_fn = lambda batch_size: dist.sample([batch_size]).to(device) * (1 - t_eps)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "highest_mi = 0\n",
    "since_best_mi = 0\n",
    "\n",
    "def early_stopping(mi):\n",
    "\n",
    "    global highest_mi\n",
    "    global since_best_mi\n",
    "    \n",
    "    if mi > highest_mi:\n",
    "        since_best_mi = 0\n",
    "        highest_mi = mi\n",
    "        print('Update best MI')\n",
    "    else:\n",
    "        since_best_mi += 1\n",
    "\n",
    "    print(f'since_best_mi {since_best_mi}')\n",
    "\n",
    "    if since_best_mi > 1:\n",
    "        return True\n",
    "    else:\n",
    "        return False\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "\n",
    "x_start = []\n",
    "\n",
    "for item in iter(train_l):\n",
    "    \n",
    "    x_start.append(item['x'].to(device))\n",
    "\n",
    "x_start = torch.cat(x_start, dim=0)\n",
    "\n",
    "for j in range(n_epochs):\n",
    "\n",
    "    for i, item in enumerate(train_l):\n",
    "        \n",
    "        x_samples = item['x'].to(device)\n",
    "        y_samples = item['y'].to(device)\n",
    "        \n",
    "        t = sample_fn(x_samples.shape[0])\n",
    "        \n",
    "        # t = (torch.rand(x_samples.shape[0]).to(device)) * (1 - t_eps)\n",
    "    \n",
    "        loss_1 = bm_1.step(x_samples, y_samples, t)\n",
    "        \n",
    "        opt.zero_grad()\n",
    "    \n",
    "        loss_1.backward()\n",
    "        \n",
    "        # torch.nn.utils.clip_grad_norm_(drift_net_1.parameters(), 1.)\n",
    "    \n",
    "        opt.step()\n",
    "        ema_g_bm_1.update()\n",
    "        \n",
    "        y_samples_permuted = y_samples[torch.randperm(y_samples.shape[0])]\n",
    "    \n",
    "        t = sample_fn(x_samples.shape[0])\n",
    "        \n",
    "        # t = (torch.rand(x_samples.shape[0]).to(device)) * (1 - t_eps)\n",
    "        \n",
    "        loss_2 = bm_2.step(x_samples, y_samples_permuted, t)\n",
    "    \n",
    "        opt.zero_grad()\n",
    "    \n",
    "        loss_2.backward()\n",
    "        \n",
    "        # torch.nn.utils.clip_grad_norm_(drift_net_2.parameters(), 1.)\n",
    "    \n",
    "        opt.step()\n",
    "        ema_g_bm_2.update()\n",
    "        \n",
    "        # if i % 100 == 0:\n",
    "    \n",
    "        #     print(loss_1, loss_2)\n",
    "        \n",
    "        wandb.log({'Loss plan': loss_1, 'Loss ind': loss_2})\n",
    "    \n",
    "    with ema_g_bm_1.average_parameters():\n",
    "        with ema_g_bm_2.average_parameters():\n",
    "            mutual_entropy_est_ema = estimate_kl(drift_net_1, drift_net_2, eps, test_l, t_eps=t_eps, posterior='partial', predict_type=predict_type, n_repeat=10)\n",
    "\n",
    "    if early_stopping(mutual_entropy_est_ema):\n",
    "        print(f'!!EARLY STOPPING!!')\n",
    "        break\n",
    "        \n",
    "    \n",
    "    # with ema_g_bm_1.average_parameters():\n",
    "    #     with ema_g_bm_2.average_parameters():\n",
    "    #         mutual_entropy_est_ema_train = estimate_kl(drift_net_1, drift_net_2, eps, train_l, t_eps=t_eps, posterior='partial', predict_type=predict_type, n_repeat=1)\n",
    "    \n",
    "    mutual_entropy_est = estimate_kl(drift_net_1, drift_net_2, eps, test_l, t_eps=t_eps, posterior='partial', predict_type=predict_type, n_repeat=10)\n",
    "    \n",
    "    print(f'Epoch {j} MI: {mutual_entropy_est} EMA MI {mutual_entropy_est_ema} EMA MI train {mutual_entropy_est_ema_train}')\n",
    "\n",
    "    wandb.log({'MI': mutual_entropy_est, 'MI_est': mutual_entropy_est_ema, 'MI EMA train': mutual_entropy_est_ema_train})\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "with ema_g_bm_1.average_parameters():\n",
    "    with ema_g_bm_2.average_parameters():\n",
    "        mutual_entropy_est_ema = estimate_kl(drift_net_1, drift_net_2, eps, test_l, t_eps=1e-3,\n",
    "                                             posterior='partial', predict_type=predict_type)\n",
    "        \n",
    "\n",
    "\n",
    "with open(log_path, 'w+') as f:\n",
    "    f.write(f'MI={mutual_entropy_est_ema}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "@torch.no_grad()\n",
    "def estimate_kl_time(drift_1, drift_2, eps, loader, t_eps=1e-3, step=0.001, predict_type=predict_type):\n",
    "    \n",
    "    # inference \n",
    "    \n",
    "    def BB_sample(x_1, x_0, t):\n",
    "        mean = x_1 * t + x_0 * (1 - t)\n",
    "        std = torch.sqrt(eps * (1 - t) * t)\n",
    "\n",
    "        return mean + std * torch.randn_like(x_1)\n",
    "\n",
    "    kl_value = 0\n",
    "    \n",
    "    fn = lambda x: x\n",
    "    \n",
    "    t_range = torch.arange(1e-8, 1 - t_eps, step=step).to('cuda')\n",
    "\n",
    "    data = []\n",
    "    \n",
    "    item = next(iter(loader))\n",
    "    \n",
    "    x_samples = item['x'].to(device)\n",
    "    y_samples = item['y'].to(device)\n",
    "\n",
    "    x_batch, y_batch = x_samples.to(device), y_samples.to(device)\n",
    "    \n",
    "    # x_batch, y_batch = batch[0].to(device), batch[1].to(device)\n",
    "    \n",
    "    for t in t_range:\n",
    "        \n",
    "        t = t.reshape([-1, 1])\n",
    "        \n",
    "        # x_t = BB_sample(y_gen_batch, x_batch, t)\n",
    "        \n",
    "        x_t = BB_sample(y_batch, x_batch, t)\n",
    "        \n",
    "        if predict_type == 'vector_field':\n",
    "            \n",
    "            v_1, v_2 = drift_1(x_t, x_batch, t), drift_2(x_t, x_batch, t)\n",
    "            \n",
    "        elif predict_type == 'x_1':\n",
    "            \n",
    "            v_1, v_2 = (drift_1(x_t, x_batch, t) - x_t) / (1 - t), (drift_2(x_t, x_batch, t) - x_t) / (1 - t)\n",
    "\n",
    "        elif predict_type == 'noise':\n",
    "\n",
    "            z_predict_1, z_predict_2 = drift_1(x_t, x_batch, t.squeeze()), drift_2(x_t, x_batch, t.squeeze())\n",
    "\n",
    "            x_1_predict_1 = (x_t - (1 - t) * x_batch - torch.sqrt(t * (1 - t) * eps) * z_predict_1) / t\n",
    "\n",
    "            \n",
    "            x_1_predict_2 = (x_t - (1 - t) * x_batch - torch.sqrt(t * (1 - t) * eps) * z_predict_2) / t\n",
    "            \n",
    "            v_1, v_2 = (x_1_predict_1 - x_t) / (1 - t), (x_1_predict_2 - x_t) / (1 - t)\n",
    "\n",
    "        print( ((fn( v_1 ) - fn( v_2 ) )**2).shape )\n",
    "        \n",
    "        kl_value += ((fn( v_1 ) - fn( v_2 ) )**2).sum([-1]).mean()\n",
    "\n",
    "        print(f'Time: {t} ',  1 / (2 * eps) *  ((fn( v_1 ) - fn( v_2 ) )**2).sum([-1]).mean())\n",
    "    \n",
    "        data.append((t.item(),  1 / (2 * eps) *  ((fn( v_1 ) - fn( v_2 ) )**2).sum([-1]).mean()))\n",
    "        \n",
    "    return data\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "with ema_g_bm_1.average_parameters():\n",
    "    with ema_g_bm_2.average_parameters():\n",
    "        estimate_kl_time(drift_net_1, drift_net_2, eps, test_l, t_eps=1e-3, step=0.001, predict_type=predict_type)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# x_start = []\n",
    "\n",
    "# for item in iter(test_l):\n",
    "    \n",
    "#     x_start.append(item['x'].to(device))\n",
    "\n",
    "# x_start = torch.cat(x_start, dim=0)\n",
    "# x_start.shape\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# em_steps_list = [5, 10, 25, 50, 100]\n",
    "\n",
    "# for n_em_steps in em_steps_list:\n",
    "\n",
    "#     mi_estimate = estimate_kl(drift_net_1, drift_net_2, x_start, eps, n_em_steps=n_em_steps, t_eps=1e-4)\n",
    "\n",
    "#     print(f'EM steps {n_em_steps} MI {mi_estimate}')\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
