{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "960c5136",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "from tqdm import tqdm\n",
    "from sklearn.decomposition import PCA\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "import wandb\n",
    "\n",
    "from eot_msci_utils import pca_plot\n",
    "\n",
    "import copy\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "175dcdb1",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "\n",
    "DIM = 2\n",
    "\n",
    "SEED = 42\n",
    "BATCH_SIZE = 128\n",
    "EPSILON = 0.1\n",
    "SAMPLING_BATCH_SIZE = 128\n",
    "DAY_START = 3\n",
    "DAY_END = 7\n",
    "DAY_EVAL = 4\n",
    "DEVICE = \"cuda:0\"\n",
    "device = DEVICE\n",
    "dim = DIM\n",
    "eps = EPSILON\n",
    "EVAL_EVERY = 10000\n",
    "SERIES_ID = 1\n",
    "\n",
    "MAX_STEPS = 10000\n",
    "CONTINUE = -1\n",
    "\n",
    "# Either \"ipf\" for IPF, \"imf\" for IMF, \"imf_mbot\" for mini-batch OT, \"id_permuted\" for Inf p_0 -> p_0\n",
    "starting_prior = 'id'\n",
    "\n",
    "imf_iters = 20\n",
    "\n",
    "lr = 1e-4\n",
    "\n",
    "batch_size = 128\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4f3cda1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "plan = 'ref'\n",
    "\n",
    "if starting_prior == 'ipf':\n",
    "    plan = 'ref'\n",
    "\n",
    "if starting_prior == 'imf':\n",
    "    plan = 'ind'\n",
    "\n",
    "if starting_prior == 'imf_mbot':\n",
    "    plan = 'mb_ot'\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "432dc40d",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n",
    "if plan == 'mb_ot':\n",
    "    from optimal_transport import OTPlanSampler\n",
    "    \n",
    "    ot_plan_sampler = OTPlanSampler('exact')\n",
    "    \n",
    "    raise RuntimeError('Mini batch OT is not implemented becuase of plan entity issues')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6b83208c",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "torch.manual_seed(SEED); np.random.seed(SEED)\n",
    "EPS = EPSILON\n",
    "EPSILON_END = EPSILON"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "44b7535a",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "source": [
    "## Data Loading"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca710850",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import random\n",
    "from scipy.linalg import sqrtm\n",
    "from sklearn import datasets\n",
    "\n",
    "class Sampler:\n",
    "    def __init__(\n",
    "        self, device='cuda',\n",
    "    ):\n",
    "        self.device = device\n",
    "    \n",
    "    def sample(self, size=5):\n",
    "        pass\n",
    "    \n",
    "class SwissRollSampler(Sampler):\n",
    "    def __init__(\n",
    "        self, dim=2, device='cuda'\n",
    "    ):\n",
    "        super(SwissRollSampler, self).__init__(device=device)\n",
    "        assert dim == 2\n",
    "        self.dim = 2\n",
    "        \n",
    "    def sample(self, batch_size=10):\n",
    "        batch = datasets.make_swiss_roll(\n",
    "            n_samples=batch_size,\n",
    "            noise=0.8\n",
    "        )[0].astype('float32')[:, [0, 2]] / 7.5\n",
    "        return torch.tensor(batch, device=self.device)\n",
    "    \n",
    "    \n",
    "class StandardNormalSampler(Sampler):\n",
    "    def __init__(self, dim=1, device='cuda'):\n",
    "        super(StandardNormalSampler, self).__init__(device=device)\n",
    "        self.dim = dim\n",
    "        \n",
    "    def sample(self, batch_size=10):\n",
    "        return torch.randn(batch_size, self.dim, device=self.device)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "75541de1-bbee-463d-94ee-c2ac2232a9b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "X_sampler = StandardNormalSampler(dim=2, device=\"cpu\")\n",
    "Y_sampler = SwissRollSampler(dim=2, device=\"cpu\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e269c7a0",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "008f4f3b",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "# DSBM\n",
    "class DSBM(nn.Module):\n",
    "    def __init__(self, net_fwd=None, net_bwd=None, num_steps=1000, sig=0, safe_t_eps=1e-3, first_coupling=\"ref\"):\n",
    "        super().__init__()\n",
    "        self.net_fwd = net_fwd\n",
    "        self.net_bwd = net_bwd\n",
    "        self.net_dict = {\"f\": self.net_fwd, \"b\": self.net_bwd}\n",
    "        # self.optimizer_dict = {\"f\": torch.optim.Adam(self.net_fwd.parameters(), lr=lr), \"b\": torch.optim.Adam(self.net_bwd.parameters(), lr=lr)}\n",
    "        self.N = num_steps\n",
    "        self.sig = sig\n",
    "        self.eps = safe_t_eps\n",
    "        self.first_coupling = first_coupling\n",
    "    \n",
    "    @torch.no_grad()\n",
    "    def get_train_tuple(self, x_pairs=None, fb='', **kwargs):\n",
    "        z0, z1 = x_pairs[:, 0], x_pairs[:, 1]\n",
    "        t = torch.rand((z1.shape[0], 1), device=device) * (1-2*self.eps) + self.eps\n",
    "        z_t = t * z1 + (1.-t) * z0\n",
    "        z = torch.randn_like(z_t)\n",
    "        z_t = z_t + self.sig * torch.sqrt(t*(1.-t)) * z\n",
    "        if fb == 'f':\n",
    "            # z1 - z_t / (1-t)\n",
    "            target = z1 - z0 \n",
    "            target = target - self.sig * torch.sqrt(t/(1.-t)) * z\n",
    "        else:\n",
    "            # z0 - z_t / t\n",
    "            target = - (z1 - z0)\n",
    "            target = target - self.sig * torch.sqrt((1.-t)/t) * z\n",
    "        return z_t, t, target\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def generate_new_dataset(self, x_pairs, prev_model=None, fb='', first_it=False):\n",
    "        assert fb in ['f', 'b']\n",
    "        if prev_model is None:\n",
    "            assert first_it\n",
    "            if self.first_coupling == 'ind':\n",
    "                assert fb == 'b'\n",
    "            zstart = x_pairs[:, 0]\n",
    "            if self.first_coupling == \"ref\":\n",
    "                # First coupling is x_0, x_0 perturbed\n",
    "                zend = zstart + torch.randn_like(zstart) * self.sig\n",
    "            elif self.first_coupling == \"ind\":\n",
    "                zend = x_pairs[:, 1].clone()\n",
    "                zend = zend[torch.randperm(len(zend))]\n",
    "            else:\n",
    "                raise NotImplementedError\n",
    "            z0, z1 = zstart, zend\n",
    "        else:\n",
    "            assert not first_it\n",
    "            if prev_model.fb == 'f':\n",
    "                zstart = x_pairs[:, 0]\n",
    "            else:\n",
    "                zstart = x_pairs[:, 1]\n",
    "            zend = prev_model.sample_sde(zstart=zstart, fb=prev_model.fb)[-1]\n",
    "            if prev_model.fb == 'f':\n",
    "                z0, z1 = zstart, zend\n",
    "            else:\n",
    "                z0, z1 = zend, zstart\n",
    "        return z0, z1\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def sample_sde(self, zstart=None, N=None, fb='', first_it=False):\n",
    "        assert fb in ['f', 'b']\n",
    "        ### NOTE: Use Euler method to sample from the learned flow\n",
    "        if N is None:\n",
    "            N = self.N   \n",
    "        dt = 1./N\n",
    "        traj = [] # to store the trajectory\n",
    "        z = zstart.detach().clone()\n",
    "        batchsize = z.shape[0]\n",
    "        \n",
    "        traj.append(z.detach().clone())\n",
    "        ts = np.arange(N) / N\n",
    "        if fb == 'b':\n",
    "            ts = 1 - ts\n",
    "        for i in range(N):\n",
    "            t = torch.ones((batchsize,1), device=device) * ts[i]\n",
    "            pred = self.net_dict[fb](z, t)\n",
    "            z = z.detach().clone() + pred * dt\n",
    "            z = z + self.sig * torch.randn_like(z) * np.sqrt(dt)\n",
    "            traj.append(z.detach().clone())\n",
    "\n",
    "        return traj\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "4ad19e49",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n",
    "class MLP(nn.Module):\n",
    "    def __init__(self, input_dim, layer_widths=[100,100,2], activate_final = False, activation_fn=F.tanh):\n",
    "        super(MLP, self).__init__()\n",
    "        layers = []\n",
    "        prev_width = input_dim\n",
    "        for layer_width in layer_widths:\n",
    "            layers.append(torch.nn.Linear(prev_width, layer_width))\n",
    "            prev_width = layer_width\n",
    "        self.input_dim = input_dim\n",
    "        self.layer_widths = layer_widths\n",
    "        self.layers = nn.ModuleList(layers)\n",
    "        self.activate_final = activate_final\n",
    "        self.activation_fn = activation_fn\n",
    "        \n",
    "    def forward(self, x):\n",
    "        for i, layer in enumerate(self.layers[:-1]):\n",
    "            x = self.activation_fn(layer(x))\n",
    "        x = self.layers[-1](x)\n",
    "        if self.activate_final:\n",
    "            x = self.activation_fn(x)\n",
    "        return x\n",
    "\n",
    "import math\n",
    "class SinusoidalPosEmb(nn.Module):\n",
    "\n",
    "    def __init__(self, dim):\n",
    "        super().__init__()\n",
    "        self.dim = dim\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.squeeze()\n",
    "        half_dim = self.dim // 2\n",
    "        emb = math.log(10000) / (half_dim - 1)\n",
    "        emb = torch.exp(torch.arange(half_dim, device=x.device) * -emb)\n",
    "        emb = x.unsqueeze(-1) * emb.unsqueeze(0)\n",
    "\n",
    "        return torch.cat((emb.sin(), emb.cos()), dim=-1)\n",
    "    \n",
    "class MLPTime(nn.Module):\n",
    "    \n",
    "    def __init__(self, in_channels, out_channels, hidden_channels, time_embed_dim, num_hidden_blocks):\n",
    "        super().__init__()\n",
    "        self.to_time_embed = SinusoidalPosEmb(time_embed_dim)\n",
    "        self.in_layer = nn.Sequential(\n",
    "            nn.Linear(in_channels + time_embed_dim, hidden_channels),\n",
    "            nn.SiLU(),\n",
    "        )\n",
    "        \n",
    "        self.hidden_blocks = nn.ModuleList([\n",
    "            nn.Sequential(\n",
    "                nn.Linear(hidden_channels, hidden_channels),\n",
    "                nn.SiLU(),\n",
    "                nn.Linear(hidden_channels, hidden_channels),\n",
    "                nn.SiLU(),\n",
    "            ) for _ in range(num_hidden_blocks)\n",
    "        ])\n",
    "\n",
    "        self.out_layer = nn.Linear(hidden_channels, out_channels)\n",
    "        \n",
    "    def forward(self, x, time):\n",
    "        time_embed = self.to_time_embed(time)\n",
    "        x = self.in_layer(torch.cat([x, time_embed], dim=-1))\n",
    "        for hidden_block in self.hidden_blocks:\n",
    "            x = hidden_block(x) + x\n",
    "        x = self.out_layer(x)\n",
    "\n",
    "        return x\n",
    "\n",
    "class ScoreNetwork(nn.Module):\n",
    "    def __init__(self, input_dim, layer_widths=[100,100,2], activate_final = False, activation_fn=F.tanh):\n",
    "        super().__init__()\n",
    "        self.net = MLPTime(input_dim - 1, input_dim - 1, 128, 12, 2)\n",
    "      \n",
    "    def forward(self, t, x_input, *args, **kwargs):        \n",
    "        if len(t.shape) == 0:\n",
    "            t = t.reshape([-1, 1]).repeat([x_input.shape[0], 1]).to(x_input.device)\n",
    "        ret_val = self.net(t.squeeze(), x_input)\n",
    "        \n",
    "        return ret_val\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8967cf64",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "source": [
    "## Train DSBM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "609f4471",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "\n",
    "model_list = []\n",
    "\n",
    "def train_imf_iter(dsbm_ipf, x_pairs, batch_size, inner_iters, prev_model=None, fb='', first_it=False, lr=1e-4):\n",
    "    assert fb in ['f', 'b']\n",
    "    dsbm_ipf.fb = fb\n",
    "    optimizer = torch.optim.Adam(dsbm_ipf.net_dict[fb].parameters(), lr=lr)\n",
    "    # optimizer = dsbm_ipf.optimizer_dict[fb]\n",
    "    loss_curve = []\n",
    "    \n",
    "    dl = iter(DataLoader(TensorDataset(*dsbm_ipf.generate_new_dataset(x_pairs, prev_model=prev_model, fb=fb, first_it=first_it)), \n",
    "                        batch_size=batch_size, shuffle=True, pin_memory=False, drop_last=True))\n",
    "    \n",
    "    for i in tqdm(range(inner_iters)):\n",
    "        try:\n",
    "            z0, z1 = next(dl)\n",
    "        except StopIteration:\n",
    "            dl = iter(DataLoader(TensorDataset(*dsbm_ipf.generate_new_dataset(x_pairs, prev_model=prev_model, fb=fb, first_it=first_it)), \n",
    "                                batch_size=batch_size, shuffle=True, pin_memory=False, drop_last=True))\n",
    "            z0, z1 = next(dl)\n",
    "\n",
    "        if first_it:\n",
    "\n",
    "            if starting_prior == 'id':\n",
    "    \n",
    "                if fb == 'f':\n",
    "                    z1 = z0.clone().detach()\n",
    "                else:\n",
    "                    z0 = z1.clone().detach()\n",
    "            \n",
    "            elif starting_prior == 'id_permuted':\n",
    "                \n",
    "                if fb == 'f':\n",
    "                    z1 = z0.clone().detach()\n",
    "                    idx = torch.randperm(z1.size(0))\n",
    "                    z1 = z1[idx]\n",
    "                else:\n",
    "                    z0 = z1.clone().detach()\n",
    "                    idx = torch.randperm(z0.size(0))\n",
    "                    z0 = z0[idx]\n",
    "                    \n",
    "                # z0, z1 = torch.randn_like(z0), torch.randn_like(z1)\n",
    "        \n",
    "        z_pairs = torch.stack([z0, z1], dim=1)\n",
    "        z_t, t, target = dsbm_ipf.get_train_tuple(z_pairs, fb=fb, first_it=first_it)\n",
    "        optimizer.zero_grad()\n",
    "        pred = dsbm_ipf.net_dict[fb](z_t, t)\n",
    "        loss = (target - pred).view(pred.shape[0], -1).abs().pow(2).sum(dim=1)\n",
    "        loss = loss.mean()\n",
    "        loss.backward()\n",
    "        \n",
    "        if wandb.run and (i % 100) == 0:\n",
    "            wandb.log({'Loss': loss.item()})\n",
    "        \n",
    "        if torch.isnan(loss).any():\n",
    "            raise ValueError(\"Loss is nan\")\n",
    "            break\n",
    "         \n",
    "        optimizer.step()\n",
    "        loss_curve.append(np.log(loss.item())) ## to store the loss curve\n",
    "\n",
    "    \n",
    "    if fb == 'f':\n",
    "        pca_plot(z0.reshape([z0.shape[0], -1]), z1.reshape([z1.shape[0], -1]), z1.reshape([z1.shape[0], -1]), 128, save_name='Train_pairs.png', is_wandb=wandb.run)\n",
    "    else:\n",
    "        pca_plot(z1.reshape([z1.shape[0], -1]), z0.reshape([z1.shape[0], -1]), z0.reshape([z1.shape[0], -1]), 128, save_name='Train_pairs.png', is_wandb=wandb.run)\n",
    "    \n",
    "    return dsbm_ipf, loss_curve\n",
    "\n",
    "\n",
    "def fit_imf(dsbm_model, x_pairs, dim, eps, outer_iters=100, inner_iters=10000, batch_size=128, lr=1e-4):\n",
    "    for it in range(1, outer_iters):\n",
    "        for fb in ['b', 'f']:\n",
    "            print(f\"Iteration {it}/{outer_iters} {fb}\")\n",
    "            first_it = (it == 1 and fb == 'b')\n",
    "            \n",
    "            if first_it:\n",
    "                prev_model = None\n",
    "            else:\n",
    "                prev_model = model_list[-1][\"model\"].to(device).eval()\n",
    "                \n",
    "            model, loss_curve = train_imf_iter(dsbm_model, x_pairs, batch_size, inner_iters, prev_model=prev_model, fb=fb, first_it=first_it, lr=lr)\n",
    "            model_list.append({'fb': fb, 'model': copy.deepcopy(model).to('cpu').eval()})\n",
    "            \n",
    "            # compute BW UVP\n",
    "            \n",
    "            x_0_samples = x_pairs[:, 0]\n",
    "            x_1_samples = x_pairs[:, 1]\n",
    "            \n",
    "            x_1_pred = dsbm_model.sample_sde(zstart=x_0_samples, N=100, fb='f', first_it=False)[-1]\n",
    "\n",
    "            x_0_samples, x_1_samples, x_1_pred = x_0_samples.cpu(), x_1_samples.cpu(), x_1_pred.cpu()\n",
    "            \n",
    "            pca_plot(x_0_samples.reshape([x_0_samples.shape[0], -1]), x_1_samples.reshape([x_0_samples.shape[0], -1]), x_1_pred.reshape([x_0_samples.shape[0], -1]), 128, save_name='SDE_gen.png', is_wandb=wandb.run)\n",
    "            \n",
    "                        \n",
    "            # model evaluation\n",
    "\n",
    "            fig, ax = plt.subplots(1, 1, figsize=(4., 4.), dpi=200)\n",
    "\n",
    "            titles = [None]\n",
    "            \n",
    "            ax.grid(zorder=-20)\n",
    "            ax.get_xaxis().set_ticklabels([])\n",
    "            ax.get_yaxis().set_ticklabels([])\n",
    "            \n",
    "            \n",
    "            x_samples = X_sampler.sample(512)\n",
    "            \n",
    "            tr_samples = torch.tensor([[0.0, 0.0], [1.75, -1.75], [-1.5, 1.5], [2, 2]])\n",
    "            \n",
    "            tr_samples = tr_samples[None].repeat(3, 1, 1).reshape(12, 2)\n",
    "            \n",
    "            # Sampling\n",
    "            # y_pred = model.sample(x_samples.to(device)).cpu()\n",
    "            \n",
    "            # your sampling function\n",
    "            \n",
    "            y_pred = dsbm_model.sample_sde(zstart=x_samples.to(device), N=100, fb='f', first_it=False)[-1].cpu()\n",
    "            \n",
    "            ax.scatter(y_pred[:, 0], y_pred[:, 1], \n",
    "                       c=\"salmon\", s=64, edgecolors=\"black\", label = \"Fitted distribution\", zorder=1)\n",
    "            \n",
    "            # trajectory = sample_traj(model, tr_samples.to(device)).detach().cpu()\n",
    "            trajectory = torch.stack(dsbm_model.sample_sde(zstart=tr_samples.to(device), N=100, fb='f', first_it=False), dim=1).cpu()\n",
    "            \n",
    "            ax.scatter(tr_samples[:, 0], tr_samples[:, 1], \n",
    "               c=\"lime\", s=128, edgecolors=\"black\", label = r\"Trajectory start ($x \\sim p_0$)\", zorder=3)\n",
    "            \n",
    "            ax.scatter(trajectory[:, -1, 0], trajectory[:, -1, 1], \n",
    "               c=\"yellow\", s=64, edgecolors=\"black\", label = r\"Trajectory end (fitted)\", zorder=3)\n",
    "            \n",
    "            for i in range(12):\n",
    "                ax.plot(trajectory[i, ::1, 0], trajectory[i, ::1, 1], \"black\", markeredgecolor=\"black\",\n",
    "                     linewidth=1.5, zorder=2)\n",
    "                if i == 0:\n",
    "                    ax.plot(trajectory[i, ::1, 0], trajectory[i, ::1, 1], \"grey\", markeredgecolor=\"black\",\n",
    "                             linewidth=0.5, zorder=2, label=r\"Trajectory of $S_{\\theta}$\")\n",
    "                else:\n",
    "                    ax.plot(trajectory[i, ::1, 0], trajectory[i, ::1, 1], \"grey\", markeredgecolor=\"black\",\n",
    "                             linewidth=0.5, zorder=2)\n",
    "            \n",
    "            ax.set_xlim([-2.5, 2.5])\n",
    "            ax.set_ylim([-2.5, 2.5])\n",
    "            \n",
    "            ax.legend(loc=\"lower left\")\n",
    "            \n",
    "            fig.tight_layout(pad=0.1)\n",
    "            \n",
    "            swiss_roll_pic_path = 'Swiss_roll_wandb_log.png'\n",
    "        \n",
    "            plt.savefig(swiss_roll_pic_path)\n",
    "            \n",
    "            wandb.log({swiss_roll_pic_path : wandb.Image(swiss_roll_pic_path)})\n",
    "            \n",
    "                \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ac353e4",
   "metadata": {
    "editable": true,
    "scrolled": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n",
    "model_list = []\n",
    "\n",
    "dataset_size = 100000\n",
    "x0 = X_sampler.sample(dataset_size)\n",
    "x1 = Y_sampler.sample(dataset_size)\n",
    "\n",
    "net_fwd = ScoreNetwork(dim + 1, layer_widths=[128, 128, dim], activation_fn=F.tanh)\n",
    "net_bwd = ScoreNetwork(dim + 1, layer_widths=[128, 128, dim], activation_fn=F.tanh)\n",
    "\n",
    "dsbm_model = DSBM(net_fwd=net_fwd, net_bwd=net_bwd, num_steps=100, sig=math.sqrt(eps),\n",
    "                  safe_t_eps=1e-3, first_coupling=plan)\n",
    "\n",
    "dsbm_model.to(device)\n",
    "\n",
    "wandb_config = {'eps': eps, 'dim': dim, 'plan': plan, 'starting_prior': starting_prior, 'imf_iters': imf_iters}\n",
    "\n",
    "wandb.init(project='DSBM_exps', name=f\"DSBM_SwissRoll_{dim}_eps_{eps}\", config=wandb_config)\n",
    "\n",
    "x_pairs = torch.stack([x0, x1], dim=1).to(device)\n",
    "fit_imf(dsbm_model, x_pairs, dim=dim, eps=eps, outer_iters=imf_iters, inner_iters=20000, lr=lr, batch_size=batch_size)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "092ec9dc-e54d-4789-84f2-74d146adb6fe",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
