{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f2db60eb",
   "metadata": {},
   "source": [
    "Below is the code for building a dataset and training a semi-supervised U-net for online image segmentation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "id": "ac3dc89d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import Dataset\n",
    "import numpy as np\n",
    "\n",
    "H, W = 50, 50\n",
    "\n",
    "s = 50.0 / 80.0\n",
    "top_row = 22 * s; top_col = 40 * s\n",
    "top_r_y = 12 * s; top_r_x = 18 * s\n",
    "bot_row = 55 * s; bot_col = 40 * s\n",
    "bot_r_y = 12 * s; bot_r_x = 18 * s\n",
    "cen_row = 38 * s; cen_col = 40 * s\n",
    "cen_r_y = 10 * s; cen_r_x = 10 * s\n",
    "\n",
    "\n",
    "def build_base_img_and_seeds(rng):\n",
    "    img = np.ones((H, W)) * 120\n",
    "    yy, xx = np.ogrid[:H, :W]\n",
    "\n",
    "    # base dumbbell\n",
    "    mask_top = (yy - top_row)**2 / top_r_y**2 + (xx - top_col)**2 / top_r_x**2 <= 1\n",
    "    img[mask_top] = 60\n",
    "    mask_bottom = (yy - bot_row)**2 / bot_r_y**2 + (xx - bot_col)**2 / bot_r_x**2 <= 1\n",
    "    img[mask_bottom] = 60\n",
    "    mask_center = (yy - cen_row)**2 / cen_r_y**2 + (xx - cen_col)**2 / cen_r_x**2 <= 1\n",
    "    img[mask_center] = 90\n",
    "\n",
    "    # seeds BEFORE noise (same logic as your offline code)\n",
    "    coords_60 = np.argwhere(img == 60)\n",
    "    coords_90 = np.argwhere(img == 90)\n",
    "    coords_120 = np.argwhere(img == 120)\n",
    "\n",
    "    n_fg = 50\n",
    "    n_bg = 50\n",
    "    n_bg_90_min = 10\n",
    "\n",
    "    bg90_indices = rng.choice(len(coords_90), size=n_bg_90_min, replace=False)\n",
    "    bg_seeds_90 = coords_90[bg90_indices]\n",
    "\n",
    "    coords_bg_rest = np.vstack([coords_90, coords_120])\n",
    "    remaining = n_bg - n_bg_90_min\n",
    "    bg_rest_indices = rng.choice(len(coords_bg_rest), size=remaining, replace=False)\n",
    "    bg_seeds_rest = coords_bg_rest[bg_rest_indices]\n",
    "\n",
    "    fg_indices = rng.choice(len(coords_60), size=n_fg, replace=False)\n",
    "    fg_seeds = [(int(r), int(c)) for r, c in coords_60[fg_indices]]\n",
    "    bg_seeds = [(int(r), int(c)) for r, c in np.vstack([bg_seeds_90, bg_seeds_rest])]\n",
    "\n",
    "    return img, np.array(fg_seeds, dtype=int), np.array(bg_seeds, dtype=int)\n",
    "\n",
    "def transform_image_and_seeds_random(\n",
    "    t, T_total, rng):\n",
    "\n",
    "    tau = t / T_total\n",
    "\n",
    "    angle_deg = rng.uniform(360, 1080) * tau\n",
    "    shift_row = rng.uniform(2, 8) * np.sin(2*np.pi*tau + rng.uniform(0, 2*np.pi))\n",
    "    shift_col = rng.uniform(2, 8) * np.cos(2*np.pi*1.3*tau + rng.uniform(0, 2*np.pi))\n",
    "\n",
    "    theta = np.deg2rad(angle_deg)\n",
    "    cos_t, sin_t = np.cos(theta), np.sin(theta)\n",
    "    cy0, cx0 = cen_row, cen_col  # rotation center\n",
    "\n",
    "    # transformed dumbbell (analytic)\n",
    "    img = np.ones((H, W)) * 120\n",
    "    yy, xx = np.indices((H, W))\n",
    "\n",
    "    ellipses = [\n",
    "        (top_row, top_col, top_r_y, top_r_x, 60),\n",
    "        (bot_row, bot_col, bot_r_y, bot_r_x, 60),\n",
    "        (cen_row, cen_col, cen_r_y, cen_r_x, 90),\n",
    "    ]\n",
    "\n",
    "    for cy, cx, ry, rx, val in ellipses:\n",
    "        dy = cy - cy0\n",
    "        dx = cx - cx0\n",
    "        cy_r = cy0 + cos_t * dy - sin_t * dx + shift_row\n",
    "        cx_r = cx0 + sin_t * dy + cos_t * dx + shift_col\n",
    "        mask = ((yy - cy_r) ** 2) / (ry ** 2) + ((xx - cx_r) ** 2) / (rx ** 2) <= 1\n",
    "        img[mask] = val\n",
    "\n",
    "    # add noise\n",
    "    noise = rng.normal(0, 10, (H, W))\n",
    "    img_noisy = np.clip(img + noise, 0, 255)\n",
    "\n",
    "    # ---- RESAMPLE SEEDS ON THIS FRAME ----\n",
    "    coords_60 = np.argwhere(img == 60)   # foreground (rods)\n",
    "    coords_90 = np.argwhere(img == 90)   # bridge\n",
    "    coords_120 = np.argwhere(img == 120) # background\n",
    "\n",
    "    n_fg = 50\n",
    "    n_bg = 50\n",
    "    n_bg_90_min = 10\n",
    "\n",
    "    # foreground seeds: inside 60-valued regions\n",
    "    fg_indices = rng.choice(len(coords_60), size=n_fg, replace=False)\n",
    "    fg_seeds_t = coords_60[fg_indices]\n",
    "\n",
    "    # background seeds: some on bridge (90), rest on 90/120\n",
    "    bg90_indices = rng.choice(len(coords_90), size=n_bg_90_min, replace=False)\n",
    "    bg_seeds_90 = coords_90[bg90_indices]\n",
    "\n",
    "    coords_bg_rest = np.vstack([coords_90, coords_120])\n",
    "    remaining = n_bg - n_bg_90_min\n",
    "    bg_rest_indices = rng.choice(len(coords_bg_rest), size=remaining, replace=False)\n",
    "    bg_seeds_rest = coords_bg_rest[bg_rest_indices]\n",
    "\n",
    "    bg_seeds_t = np.vstack([bg_seeds_90, bg_seeds_rest])\n",
    "\n",
    "    return img_noisy, fg_seeds_t.astype(int), bg_seeds_t.astype(int)\n",
    "\n",
    "\n",
    "class SeedAwareDataset(Dataset):\n",
    "    def __init__(self, n_samples, rng):\n",
    "        self.n = n_samples\n",
    "        self.rng = rng\n",
    "        self.base_img, self.fg_base, self.bg_base = build_base_img_and_seeds(rng)\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.n\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        t = self.rng.integers(0, self.n)\n",
    "        img, fg, bg = transform_image_and_seeds_random(\n",
    "            t, self.n, self.rng)\n",
    "\n",
    "        I = img.astype(np.float32) / 255.0\n",
    "        S_fg = np.zeros_like(I)\n",
    "        S_bg = np.zeros_like(I)\n",
    "\n",
    "        for r,c in fg: S_fg[r,c] = 1.0\n",
    "        for r,c in bg: S_bg[r,c] = 1.0\n",
    "\n",
    "        x = np.stack([I, S_fg, S_bg], axis=0)\n",
    "\n",
    "        return torch.tensor(x), fg, bg\n",
    "    \n",
    "\n",
    "    \n",
    "def transform_image_and_seeds(t, T_total, rng):\n",
    "    \"\"\"\n",
    "    For frame index t in 0..T_total-1, return:\n",
    "      - img_noisy_t: transformed, noisy image\n",
    "      - fg_seeds_t, bg_seeds_t: NEW seeds sampled from this frame\n",
    "        (always inside the correct regions).\n",
    "    \"\"\"\n",
    "    # motion parameters\n",
    "    tau = t / T_total\n",
    "    angle_deg = 720 * tau              # 2 full rotations over the clip\n",
    "    shift_row = 5 * np.sin(2 * np.pi * tau)\n",
    "    shift_col = 5 * np.cos(2 * np.pi * 1.5 * tau)\n",
    "\n",
    "    theta = np.deg2rad(angle_deg)\n",
    "    cos_t, sin_t = np.cos(theta), np.sin(theta)\n",
    "    cy0, cx0 = cen_row, cen_col  # rotation center\n",
    "\n",
    "    # transformed dumbbell (analytic)\n",
    "    img = np.ones((H, W)) * 120\n",
    "    yy, xx = np.indices((H, W))\n",
    "\n",
    "    ellipses = [\n",
    "        (top_row, top_col, top_r_y, top_r_x, 60),\n",
    "        (bot_row, bot_col, bot_r_y, bot_r_x, 60),\n",
    "        (cen_row, cen_col, cen_r_y, cen_r_x, 90),\n",
    "    ]\n",
    "\n",
    "    for cy, cx, ry, rx, val in ellipses:\n",
    "        dy = cy - cy0\n",
    "        dx = cx - cx0\n",
    "        cy_r = cy0 + cos_t * dy - sin_t * dx + shift_row\n",
    "        cx_r = cx0 + sin_t * dy + cos_t * dx + shift_col\n",
    "        mask = ((yy - cy_r) ** 2) / (ry ** 2) + ((xx - cx_r) ** 2) / (rx ** 2) <= 1\n",
    "        img[mask] = val\n",
    "\n",
    "    # add noise\n",
    "    noise = rng.normal(0, 10, (H, W))\n",
    "    img_noisy = np.clip(img + noise, 0, 255)\n",
    "\n",
    "    # ---- RESAMPLE SEEDS ON THIS FRAME ----\n",
    "    coords_60 = np.argwhere(img == 60)   # foreground (rods)\n",
    "    coords_90 = np.argwhere(img == 90)   # bridge\n",
    "    coords_120 = np.argwhere(img == 120) # background\n",
    "\n",
    "    n_fg = 50\n",
    "    n_bg = 50\n",
    "    n_bg_90_min = 10\n",
    "\n",
    "    # foreground seeds: inside 60-valued regions\n",
    "    fg_indices = rng.choice(len(coords_60), size=n_fg, replace=False)\n",
    "    fg_seeds_t = coords_60[fg_indices]\n",
    "\n",
    "    # background seeds: some on bridge (90), rest on 90/120\n",
    "    bg90_indices = rng.choice(len(coords_90), size=n_bg_90_min, replace=False)\n",
    "    bg_seeds_90 = coords_90[bg90_indices]\n",
    "\n",
    "    coords_bg_rest = np.vstack([coords_90, coords_120])\n",
    "    remaining = n_bg - n_bg_90_min\n",
    "    bg_rest_indices = rng.choice(len(coords_bg_rest), size=remaining, replace=False)\n",
    "    bg_seeds_rest = coords_bg_rest[bg_rest_indices]\n",
    "\n",
    "    bg_seeds_t = np.vstack([bg_seeds_90, bg_seeds_rest])\n",
    "\n",
    "    return img_noisy, fg_seeds_t.astype(int), bg_seeds_t.astype(int)\n",
    "\n",
    "\n",
    "rng = np.random.default_rng()\n",
    "img_base, fg_seeds_base, bg_seeds_base = build_base_img_and_seeds(rng)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "id": "8f5ee938",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "class DoubleConv(nn.Module):\n",
    "    def __init__(self, in_ch, out_ch):\n",
    "        super().__init__()\n",
    "        self.net = nn.Sequential(\n",
    "            nn.Conv2d(in_ch, out_ch, 3, padding=1),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Conv2d(out_ch, out_ch, 3, padding=1),\n",
    "            nn.ReLU(inplace=True),\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.net(x)\n",
    "\n",
    "def center_crop(tensor, target):\n",
    "    _, _, H, W = tensor.shape\n",
    "    _, _, h, w = target.shape\n",
    "    dh = (H - h) // 2\n",
    "    dw = (W - w) // 2\n",
    "    return tensor[:, :, dh:dh + h, dw:dw + w]\n",
    "\n",
    "class SeedAwareUNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "        # note: in_ch changed from 1 -> 3, rest identical\n",
    "        self.enc1 = DoubleConv(3, 32)\n",
    "        self.enc2 = DoubleConv(32, 64)\n",
    "        self.enc3 = DoubleConv(64, 128)\n",
    "\n",
    "        self.pool = nn.MaxPool2d(2)\n",
    "\n",
    "        self.dec2 = DoubleConv(128 + 64, 64)\n",
    "        self.dec1 = DoubleConv(64 + 32, 32)\n",
    "\n",
    "        self.up = nn.Upsample(scale_factor=2, mode=\"bilinear\", align_corners=False)\n",
    "        self.outc = nn.Conv2d(32, 1, kernel_size=1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # Encoder\n",
    "        x1 = self.enc1(x)              # 50×50\n",
    "        x2 = self.enc2(self.pool(x1))  # 25×25\n",
    "        x3 = self.enc3(self.pool(x2))  # 12×12 (after floor)\n",
    "\n",
    "        # Decoder\n",
    "        x = self.up(x3)                # 24×24\n",
    "        x2_crop = center_crop(x2, x)\n",
    "        x = self.dec2(torch.cat([x, x2_crop], dim=1))\n",
    "\n",
    "        x = self.up(x)                 # 48×48\n",
    "        x1_crop = center_crop(x1, x)\n",
    "        x = self.dec1(torch.cat([x, x1_crop], dim=1))\n",
    "\n",
    "        x = self.outc(x)               # 48×48 → 48×48\n",
    "        x = F.interpolate(x, size=(50, 50), mode=\"bilinear\", align_corners=False)\n",
    "        return torch.sigmoid(x)\n",
    "\n",
    "class SeedAwareUNet2(nn.Module):\n",
    "    \"\"\"\n",
    "    3-channel U-Net-like architecture with ~1M parameters.\n",
    "    Same depth as your SeedAwareUNet, but wider.\n",
    "    \"\"\"\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "        # Encoder: widen channels\n",
    "        self.enc1 = DoubleConv(3,   48)   # 3 -> 48\n",
    "        self.enc2 = DoubleConv(48,  96)   # 48 -> 96\n",
    "        self.enc3 = DoubleConv(96, 192)   # 96 -> 192\n",
    "\n",
    "        self.pool = nn.MaxPool2d(2)\n",
    "\n",
    "        # Decoder: mirror encoder widths\n",
    "        self.dec2 = DoubleConv(192 + 96, 96)   # (192+96)=288 -> 96\n",
    "        self.dec1 = DoubleConv(96 + 48, 48)    # (96+48)=144 -> 48\n",
    "\n",
    "        self.up = nn.Upsample(scale_factor=2, mode=\"bilinear\", align_corners=False)\n",
    "        self.outc = nn.Conv2d(48, 1, kernel_size=1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # Encoder\n",
    "        x1 = self.enc1(x)              # (B,48,50,50)\n",
    "        x2 = self.enc2(self.pool(x1))  # (B,96,25,25)\n",
    "        x3 = self.enc3(self.pool(x2))  # (B,192,12,12)\n",
    "\n",
    "        # Decoder\n",
    "        x = self.up(x3)                # (B,192,24,24)\n",
    "        x2_crop = center_crop(x2, x)\n",
    "        x = self.dec2(torch.cat([x, x2_crop], dim=1))  # (B,96,24,24)\n",
    "\n",
    "        x = self.up(x)                 # (B,96,48,48)\n",
    "        x1_crop = center_crop(x1, x)\n",
    "        x = self.dec1(torch.cat([x, x1_crop], dim=1))  # (B,48,48,48)\n",
    "\n",
    "        x = self.outc(x)               # (B,1,48,48)\n",
    "        x = F.interpolate(x, size=(50, 50), mode=\"bilinear\", align_corners=False)\n",
    "        return torch.sigmoid(x)\n",
    "    \n",
    "    \n",
    "def seed_bce_loss(pred_probs, fg, bg, device):\n",
    "    \"\"\"\n",
    "    pred_probs: (B,1,H,W) in [0,1]\n",
    "    fg, bg: lists/tensors of (r,c) for each item in batch\n",
    "    \"\"\"\n",
    "    B, _, H, W = pred_probs.shape\n",
    "    total_loss = 0.0\n",
    "    total_count = 0\n",
    "    for b in range(B):\n",
    "        p = pred_probs[b, 0]  # (H,W)\n",
    "        # foreground seeds\n",
    "        if len(fg[b]) > 0:\n",
    "            rr = fg[b][:, 0].to(device)\n",
    "            cc = fg[b][:, 1].to(device)\n",
    "            total_loss += F.binary_cross_entropy(p[rr, cc], torch.ones_like(rr, dtype=torch.float, device=device))\n",
    "            total_count += 1\n",
    "        # background seeds\n",
    "        if len(bg[b]) > 0:\n",
    "            rr = bg[b][:, 0].to(device)\n",
    "            cc = bg[b][:, 1].to(device)\n",
    "            total_loss += F.binary_cross_entropy(p[rr, cc], torch.zeros_like(rr, dtype=torch.float, device=device))\n",
    "            total_count += 1\n",
    "    if total_count == 0:\n",
    "        return torch.tensor(0.0, device=device)\n",
    "    return total_loss / total_count\n",
    "\n",
    "\n",
    "def tv_smoothness_loss(pred_probs, weight=1.0):\n",
    "    \"\"\"\n",
    "    pred_probs: (B,1,H,W)\n",
    "    \"\"\"\n",
    "    dy = torch.abs(pred_probs[:, :, 1:, :] - pred_probs[:, :, :-1, :])\n",
    "    dx = torch.abs(pred_probs[:, :, :, 1:] - pred_probs[:, :, :, :-1])\n",
    "    return weight * (dy.mean() + dx.mean())\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "id": "b76fc6b4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "UNet params: 1059841\n"
     ]
    }
   ],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "import torch.optim as optim\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "rng = np.random.default_rng(0)\n",
    "train_ds = SeedAwareDataset(3000, rng)  # same as in DNN2\n",
    "train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)\n",
    "\n",
    "model = SeedAwareUNet2().to(device)\n",
    "# model = SeedAwareUNet().to(device)\n",
    "optimizer = optim.Adam(model.parameters())\n",
    "\n",
    "lambda_tv = 0.2  \n",
    "\n",
    "num_epochs = 20\n",
    "print(\"UNet params:\", sum(p.numel() for p in model.parameters()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "id": "d58d5c56",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1: loss 0.1646\n",
      "Epoch 2: loss 0.0211\n",
      "Epoch 3: loss 0.0169\n",
      "Epoch 4: loss 0.0152\n",
      "Epoch 5: loss 0.0146\n",
      "Epoch 6: loss 0.0139\n",
      "Epoch 7: loss 0.0138\n",
      "Epoch 8: loss 0.0136\n",
      "Epoch 9: loss 0.0134\n",
      "Epoch 10: loss 0.0131\n",
      "Epoch 11: loss 0.0130\n",
      "Epoch 12: loss 0.0130\n",
      "Epoch 13: loss 0.0129\n",
      "Epoch 14: loss 0.0132\n",
      "Epoch 15: loss 0.0126\n",
      "Epoch 16: loss 0.0127\n",
      "Epoch 17: loss 0.0125\n",
      "Epoch 18: loss 0.0125\n",
      "Epoch 19: loss 0.0125\n",
      "Epoch 20: loss 0.0124\n"
     ]
    }
   ],
   "source": [
    "for ep in range(num_epochs):\n",
    "    model.train()\n",
    "    running = 0.0\n",
    "    for x, fg, bg in train_loader:\n",
    "        x = x.to(device)           # (B,3,50,50)\n",
    "        # fg, bg are lists of arrays; keep on CPU, we move indices inside loss\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        probs = model(x)           # (B,1,50,50)\n",
    "\n",
    "        loss_seed = seed_bce_loss(probs, fg, bg, device)\n",
    "        loss_tv = tv_smoothness_loss(probs, weight=lambda_tv)\n",
    "        loss = loss_seed + loss_tv  # (+ optional area_prior)\n",
    "\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        running += loss.item()\n",
    "\n",
    "    print(f\"Epoch {ep+1}: loss {running/len(train_loader):.4f}\")\n",
    "\n",
    "# torch.save(model.state_dict(), \"seed_aware_unet1.pt\")\n",
    "torch.save(model.state_dict(), \"seed_aware_unet2.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "id": "2e71ce16",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "\n",
    "fps = 60\n",
    "T_sec = 180\n",
    "n_frames = fps * T_sec\n",
    "\n",
    "nn_seg_video = []\n",
    "nn_video = []\n",
    "with torch.no_grad():\n",
    "    for t in range(n_frames):\n",
    "\n",
    "        # IMPORTANT: get seeds too\n",
    "        img, fg_seeds_t, bg_seeds_t = transform_image_and_seeds(\n",
    "            t, n_frames, rng)\n",
    "\n",
    "        # ---- build 3-channel input ----\n",
    "        I = img.astype(np.float32) / 255.0\n",
    "        S_fg = np.zeros((H, W), dtype=np.float32)\n",
    "        S_bg = np.zeros((H, W), dtype=np.float32)\n",
    "\n",
    "        for r, c in fg_seeds_t:\n",
    "            S_fg[r, c] = 1.0\n",
    "        for r, c in bg_seeds_t:\n",
    "            S_bg[r, c] = 1.0\n",
    "\n",
    "        x = np.stack([I, S_fg, S_bg], axis=0)  # (3,H,W)\n",
    "\n",
    "        x_t = torch.from_numpy(x).unsqueeze(0).to(device)  # (1,3,H,W)\n",
    "\n",
    "        # ---- forward pass ----\n",
    "        logits = model(x_t)\n",
    "        probs = torch.sigmoid(logits)\n",
    "\n",
    "        seg = (probs[0, 0].cpu().numpy() > 0.5).astype(np.float32)\n",
    "\n",
    "        nn_seg_video.append(seg)\n",
    "        nn_video.append((img, seg))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "id": "1bc62cad",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (50, 50) to (64, 64) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to 1 (risking incompatibility).\n"
     ]
    }
   ],
   "source": [
    "import imageio.v2 as imageio\n",
    "def img_seg_to_uint8(img, seg):\n",
    "    out = img * seg\n",
    "    out = np.clip(out, 0, 255)\n",
    "    return out.astype(np.uint8)\n",
    "\n",
    "fps = 60  \n",
    "with imageio.get_writer(\"unet_segmentation_seedaware2.mp4\", fps=fps) as writer:\n",
    "    for img, seg in nn_video:\n",
    "        frame = img_seg_to_uint8(img, seg)\n",
    "        writer.append_data(frame)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "id": "7d1029ce",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mean IoU over video: NN = 0.8474\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "def compute_iou(mask_pred, mask_gt):\n",
    "    \"\"\"\n",
    "    Compute IoU between predicted mask and ground-truth mask.\n",
    "    mask_pred, mask_gt: np.array of 0/1, same shape\n",
    "    \"\"\"\n",
    "    mask_pred = mask_pred.astype(bool)\n",
    "    mask_gt = mask_gt.astype(bool)\n",
    "    \n",
    "    intersection = np.logical_and(mask_pred, mask_gt).sum()\n",
    "    union = np.logical_or(mask_pred, mask_gt).sum()\n",
    "    \n",
    "    if union == 0:\n",
    "        return 1.0  # both empty → perfect match\n",
    "    return intersection / union\n",
    "\n",
    "fps = 60\n",
    "T_sec = 180\n",
    "n_frames = fps * T_sec\n",
    "\n",
    "gt_video = []       # ground truth masks (0/1)\n",
    "gt_video_orig = []  # noisy images\n",
    "\n",
    "rng = np.random.default_rng()\n",
    "\n",
    "# center of rotation\n",
    "cy0, cx0 = cen_row, cen_col\n",
    "\n",
    "yy, xx = np.indices((H, W))\n",
    "\n",
    "for t in range(n_frames):\n",
    "    tau = t / n_frames\n",
    "    angle_deg = 720 * tau\n",
    "    shift_row = 5 * np.sin(2 * np.pi * tau)\n",
    "    shift_col = 5 * np.cos(2 * np.pi * 1.5 * tau)\n",
    "\n",
    "    theta = np.deg2rad(angle_deg)\n",
    "    cos_t, sin_t = np.cos(theta), np.sin(theta)\n",
    "\n",
    "    # full dumbbell image (no noise)\n",
    "    img = np.ones((H, W)) * 120\n",
    "\n",
    "    ellipses = [\n",
    "        (top_row, top_col, top_r_y, top_r_x, 60),\n",
    "        (bot_row, bot_col, bot_r_y, bot_r_x, 60),\n",
    "        (cen_row, cen_col, cen_r_y, cen_r_x, 90),\n",
    "    ]\n",
    "\n",
    "    for cy, cx, ry, rx, val in ellipses:\n",
    "        dy = cy - cy0\n",
    "        dx = cx - cx0\n",
    "        cy_r = cy0 + cos_t * dy - sin_t * dx + shift_row\n",
    "        cx_r = cx0 + sin_t * dy + cos_t * dx + shift_col\n",
    "        mask = ((yy - cy_r) ** 2) / (ry ** 2) + ((xx - cx_r) ** 2) / (rx ** 2) <= 1\n",
    "        img[mask] = val\n",
    "\n",
    "    # GT mask: pixels exactly with value 60\n",
    "    mask_gt = (img == 60).astype(np.float32)\n",
    "\n",
    "    # noisy image for visualization or NN input\n",
    "    img_noisy = np.clip(img + rng.normal(0, 10, (H, W)), 0, 255)\n",
    "\n",
    "    gt_video.append(mask_gt)\n",
    "    gt_video_orig.append(img_noisy)\n",
    "    \n",
    "# Now compute IoU per frame for NN\n",
    "iou_nn = [compute_iou(pred, gt) for pred, gt in zip(nn_seg_video, gt_video)]\n",
    "mean_iou_nn = np.mean(iou_nn)\n",
    "\n",
    "# # Compute IoU per frame for your algorithm\n",
    "# iou_algo = [compute_iou(pred, gt) for pred, gt in zip(algo_seg_video, gt_video)]\n",
    "# mean_iou_algo = np.mean(iou_algo)\n",
    "\n",
    "print(f\"Mean IoU over video: NN = {mean_iou_nn:.4f}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "id": "d800fad5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (50, 50) to (64, 64) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to 1 (risking incompatibility).\n"
     ]
    }
   ],
   "source": [
    "def img_mask_to_uint8(img, mask):\n",
    "    out = img * mask\n",
    "    out = np.clip(out, 0, 255)\n",
    "    return out.astype(np.uint8)\n",
    "\n",
    "with imageio.get_writer(\"unet_seed_gt.mp4\", fps=fps) as writer:\n",
    "    for img, mask in zip(gt_video_orig, gt_video):  \n",
    "        frame = img_mask_to_uint8(img, mask)\n",
    "        writer.append_data(frame)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "id": "905856b2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Video Precision: 0.8537\n",
      "Video Recall:    0.9914\n",
      "Video F1 score:  0.9173\n"
     ]
    }
   ],
   "source": [
    "def compute_prf(mask_pred, mask_gt, eps=1e-8):\n",
    "    \"\"\"\n",
    "    Compute precision, recall, F1 for a single binary mask.\n",
    "    mask_pred, mask_gt: np.ndarray of shape (H,W), values 0/1\n",
    "    \"\"\"\n",
    "    mask_pred = mask_pred.astype(bool)\n",
    "    mask_gt = mask_gt.astype(bool)\n",
    "\n",
    "    tp = np.logical_and(mask_pred, mask_gt).sum()\n",
    "    fp = np.logical_and(mask_pred, np.logical_not(mask_gt)).sum()\n",
    "    fn = np.logical_and(np.logical_not(mask_pred), mask_gt).sum()\n",
    "\n",
    "    precision = tp / (tp + fp + eps)\n",
    "    recall = tp / (tp + fn + eps)\n",
    "    f1 = 2 * precision * recall / (precision + recall + eps)\n",
    "    return precision, recall, f1\n",
    "\n",
    "# Lists to store per-frame metrics\n",
    "precisions, recalls, f1s = [], [], []\n",
    "\n",
    "# Loop over frames\n",
    "for pred, gt in zip(nn_seg_video, gt_video):\n",
    "    p, r, f = compute_prf(pred, gt)\n",
    "    precisions.append(p)\n",
    "    recalls.append(r)\n",
    "    f1s.append(f)\n",
    "\n",
    "# Average over the video\n",
    "mean_precision = np.mean(precisions)\n",
    "mean_recall = np.mean(recalls)\n",
    "mean_f1 = np.mean(f1s)\n",
    "\n",
    "print(f\"Video Precision: {mean_precision:.4f}\")\n",
    "print(f\"Video Recall:    {mean_recall:.4f}\")\n",
    "print(f\"Video F1 score:  {mean_f1:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "id": "c67e964d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAABCIAAAFtCAYAAADS58tNAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAA0cElEQVR4nO3debhedXUo/vXNTBIykZCQAUIGEgEZ4hWBVq/Tba/e52prr9XSW6VXq63XAatob7WttyiWghXrUKdbrf09rZVbf/ys1ae2FVCUQZkhBAwhMyFknkf274/3RI+HvXbOCSf7nCSfz/O8D2Std+293zfn/Z59VvbZq1RVFQAAAABtGDLQBwAAAACcODQiAAAAgNZoRAAAAACt0YgAAAAAWqMRAQAAALRGIwIAAABojUYEAAAA0BqNCAAAAKA1GhHQgtJxXynljT3iZ5VSPlRKGXGU9juslPIHpZSflFL2llJWl1I+XvO8s0sp/15K2VVKWVtK+dNSytA+7OfTpZT/079HD/DsDMTaW0qZV0r5XNd+D5ZSbk6e99pSyjdKKWtKKTtKKXeVUn4jea41GjjmDeY1ueu5ry+l3N21Jq8ppXyllDK95nnW5H6gEZEopfxmKeX7pZSbSinfK6XcW0r521LKSwb62I6GUsoFpZQP9dO2PllKWV5KWd7wnBmllJtLKVu6Hjd3Pe4upSwupfzP/jiWZN+v7jq+kUdrHzV+PSImRsTf9Yi/IiLeVlXVvqO03y9FxDsj4rqI+KWI+IOI2N39CaWUiRHxbxFRRcSrI+JPI+I9EfG/+7CfayPiN0sp8/rhmDmBWXuf1basvc80EGvvORHxyoh4tOuR+f2I2BER746IV0XETRHxd6WUd3R/kjWagWRNflbbsiY/06Bdk0spr4qIv4+IH0ZnrX1/RLwoIr5ZShnS7XnW5P5SVZVHj0dEvC4i1kXE6d1ip0bE/RFx/UAf31F6zZd3vhz6bXsfiojlvXjezRFxc4/Ya6Lz4f7vR+m1vjAi/jUihrb4/v4gIj5SE/8/EfHdo7TP/xwR+yPi7MM8739FxOaIGNct9r6I2NU91ov9/VtEfKyt99Tj+HtYe/tle9ben9/nQKy9Q7r9///t+T53y02uif1dRDzeI2aN9hiQhzW5X7ZnTf75fQ7mNfmrEXFXj9iruv4OntMtZk3up4crIur9ekR8v6qqlYcCVVWtj4iPRsSmATuqE0RVVV+PiC3R6TIeje1/v6qq/1RV1cGjsf2eurqdl0Zn8esefyoi/kdEvKSUUnU9XtqPu/4f0VnUFx/mea+IiH+pqmpbt9hXI+KkiPiPPY75P5dSbu3q2m8rpfy4lPKfutL/GJ3urnWFI2XtHUDW3v5RVdXTvXzehprwPdH5Qa+7Xq3Rh1mfI6zR9J01eQBZk/tHb9fkiBgeEVt7xLZ0/bd0i1mT+8kJ+8IPY19EvKCUMql7sKqqv6+q6k8P/bl0vK/rMrXvlVJ+WEp5V4/Ld0aWzu8BrSyl3FJK+ZtSyidKKXu6Lr16Vynl9q4P3eWllL8rnd8T/XEp5dxSyqWllK+VUpaUUv6xlHJyb/dfSnlVzbbvKKU8WEp5YbftvD06l+xHt0vCLj/C13hzKeWvovNhfDaGRcTPLRy9eL2He6/P7PGevLhm2w905R8opby3lFL68l4mXhYROyPivu77i4j/GhF7IuIjEXFJ1+OHNa952OEeyX5fEBGPllI+VToL4K5SytfLM3/XbWFELOke6Drp2NWVO3QsL46IG6PTPX9NRFwWnW7uga6n/DAipkbEcw/zfkDG2mvtPR7W3mfj0ojo2Tw+7BpdDr8+R1ij6TtrsjX5RFqT/zoiXlhKeUMpZVwp5ayI+HBE3NTjH/Wsyf1loC/JGIyPiHhxdBbfTRHxlxHx8ogYUfO8qyNiWURM6frz5Ih4PCLe3+05H4uIlRExrevPz4lOt215t+fMjs5lP988tJ+IuCEi7o6I3+/68+iIWBURf9jH/R/a9rcjYmRX7NMRsbTHa7k8ai5FexavcX0c+aVo746IvRHx4r4cS2/e6x7vyYt7bPuJiJjV9efTu/78p319L2te4+cj4kc18UPbu7ih9vKu5zQ+ktq9EbE9Im6Nzu/GvS4iVkTEHRFRuj1vf0RcUVO/OiKu7vbnL0TEPzYc67DoLLC/M9CfYY9j8xHW3v54jdben9UNyNrbYzvpZcA1z31ZdH7guLxH/LBrdBxmfe56jjXao0+PsCb3x2u0Jv+sbtCvyRHxm9Fpihza5g8iYkKP51iT++kx4AcwWB8R8bzoLH6Hvhg3dX2wx3Tlx0bnpn/v61F3TUSs7/r/0V31V/V4zt9F/cL7hm6x/9kV6/57ef83Iv7f3u6/x7bf2C32K12x8d1il/f8APfna2x4n2+OzmVPN3d92LdFxPciYkFfjqUvxxE9Ft5u2766x/M+Gp3u5pgedY3vZc1r/EZEfLsm/qronHSObag9JSL+w+EeSe2+6NwI7ZRusRd1He/LusX2R8S7aurXRLff44vOSciOiLgiur651dRsiIg/OtqfT4/j9xHW3n59jQ3v881h7T0qa2+P7fSqEdH1Gp889HXWI3fYNTp6sT53Pc8a7dGnR1iT+/U1NrzPN4c1eUDX5Ih4SXT+Ae+a6DThXhcRD0fnRsJDuz3PmtxPj6NxWeFxoaqquyLitaWUsRHxy9FZmH4/OpfP/FJEnB0RoyLijaWUV3YrHRcRO7suGTszIkZGxNIem388Opdf9rSm2//vrIntiIgzuv7/sPuvqmp7t/jqbv9/6PefJsQzfxequ6PxGuvcW1XViyMiSikvis6lS78SnYWgV8cSEfOfxXEc2nbPO+k+Ep1L6s6OiB91i/f1vRwVnQW8p/MiYllVVTsajm1Tw3YPZ3PX9jd2i90anQbF2RHx792eN6Gmfnz87HfjIiL+KDq/zvXHEfGxUsq/RedfKB7q9py90Xm9cESsvb3bR1h7B/Pa2yelc9n7t6PzL5j/veYpvVmje7M+R1ij6SNrcu/2Edbk42FN/lhEfKOqqvcfCpRS7o3Or2G8OiK+3hW2JvcTjYgapZTJEbGjqqo9XR+Kf4yIfyylfDoi3lZKmdDt6R+rquqvk+2UuniDZ9wspnrmDWR6bjPdf8O2q2Rbmf58jY2qqvpeKeXzEfHBUsoXe/wQnR5LKeX8/th9z80mz+vre7kpIqbVxM+Lbr8nl3hjdEZwHk7d/h+Ozjejuud2/33DJdHtXhAREaWUWRExJrr9DlxVVVsj4u2llHdF5xLiz0bnLscXdyudEG5gxRGy9j6DtffnHStrb6+VUkZH12XoEfFfqqraWfO0w67RvVyfI6zR9IE1+RmsyT/veFuTF0ZnfOdPVVX1SClld0TM7Ra2JvcTN6usd11E/GpN/JHofNCejs7NpPZEpzv4U6WU2aVzc5qIiJ9Ep9PVc0bsmf1wjL3Zf1/89AfTUsqQru7uQL3GP4vOSdkV3WKHO5ZncxyHtr2wR/ys6FyidripE4fzSHIcz7jZTY1/iojn9+JR55sRcV7XicQhL4rOXYG7L/jfjohfLt1u/BSdy9F2R8QtPTdaVdXBqqq+ExH/EhFDD8VLKVOic0lgOqMZDsPaa+2NOPbX3l7purHaDdH5F8xXVJ1pBHV6vUZn63PX/qzR9JU12ZoccYKsydG5j9qi7oFSynOiczXI8m5ha3I/0YjIvbv7D3BdXyyXR8Q3q6ra1tUZvi4iLi+lLOh6zvDo/B7VmoiIqqp2RcSnonP51NSu5yyMTmfsWenN/vtoXdc2JkXERRHx78/yNb7yGXvo/WtbHZ2u5ztKKeO7Yo3H8mze6x7bPr2rdlZ0/r4/lvwLVV/8ICJO7/oa6m5bRLyolPKiUsrFdd30qqo2VlX148M9kv1+PiI2RsQ/lVL+aynlsoj424j4t6qqbu32vM9G55vW10spLy+lvCU6c6//ouoaTVRK+Vzp3IH5v5VSXlJK+eOIeFN0bk50yH+IzonJz93pGPrI2mvtvTyO4bW3lDK6a638bxExIyKmHPpz1xUQh3wmOn9fV0XEpK5jOfTofjVb4xrdy/U5whrNkbEmW5MvjxNjTf5sRLyulPKxrrX2N6Mz+WJ5RHyrx/Osyf2h7sYRJ/ojIl4Ynctn7o7OTWO+1/X//zu63UglOpf/vCc6HcLbovMB+1/x8xMJRkbnC29VdG528lcRcW1E/KQr/9KIuD06X4j3Rqfz/HvR6QxWXfs/MyI+F53FcUtEfKc3+0+2/aqu/6+6cv+x67nDIuL/68r9OCJeeYSv8ZaI+HJEXB+drurNEXFmzXs8I352Y54tXf//i93ys6NzL4MHIuKjvXy9je9113Ne1eM9eUu3bb+3a393RMSDEXHlkbyXNa91RHQaAr/VI/6LEfFQdG56s+4ofS3Pi87iuTM6v9P25YiYWPO8syPiu9Hp5j4RnRPj7jfm+f2IuLPr72pb1/v/6h7b+ER0RhwN+GfY49h8hLXX2nscrL3xsxu51T1md3ve8t48r+u56RodvVifu55njfbo0yOsydbkE2tNLtH5mrs/OufNayLiHyJiTs02rcn98Dj0RcVRUjq/P7e7qqq93WJfiM7df395wA7sODSY3+tSyiciYl5VVf9lII/jaCmlDI3OJW1/UFXV/zPQxwODeT043gzm9/p4X3t7yxrNQBvM68TxZjC/19bkDmtyh1/NOPquiIg/PPSHUsrsiHhNRHxxgI7neHZFDN73+tqIeHEp5ayBPpCj5LXR6Qp/daAPBLpcEYN3PTjeXBGD970+3tfe3rJGM9CuiMG7ThxvrojB+15bkzusyRGuiDjaSikvjs7olpHRubvsyIj4bFVVXxrAwzouDfb3upTy+oh4oqqqZ9wA8lhXSvmN6Pxu4vcG+lggYvCvB8eTwf5eH89rb29Zoxlog32dOJ4M9vfammxNPkQjAgAAAGiNX80AAAAAWqMRAQAAALRmWFOylOL3NgB6oaqqZ8y97m/WZIDeOdprsvUYoHey9dgVEQAAAEBrNCIAAACA1mhEAAAAAK3RiAAAAABaoxEBAAAAtEYjAgAAAGiNRgQAAADQGo0IAAAAoDUaEQAAAEBrNCIAAACA1mhEAAAAAK3RiAAAAABaoxEBAAAAtEYjAgAAAGiNRgQAAADQGo0IAAAAoDUaEQAAAEBrNCIAAACA1mhEAAAAAK3RiAAAAABaoxEBAAAAtEYjAgAAAGiNRgQAAADQGo0IAAAAoDUaEQAAAEBrNCIAAACA1mhEAAAAAK3RiAAAAABaoxEBAAAAtEYjAgAAAGiNRgQAAADQGo0IAAAAoDUaEQAAAEBrNCIAAACA1mhEAAAAAK3RiAAAAABaoxEBAAAAtEYjAgAAAGjNsIE+AACA49mdd96Z5s4888za+K233prWXHLJJWlu2rRpvT8wABggrogAAAAAWqMRAQAAALRGIwIAAABojUYEAAAA0BqNCAAAAKA1paqqPFlKngTgp6qqKkd7H9ZkGLweeOCBNHfdddeluW3bttXGFy5cmNZs2LAhzS1YsKA2vnTp0rRm4sSJae6jH/1omhvMjvaabD0eXO67777a+A033JDWjBw5Ms0tX768Nt70WTl48GBtvOnzOmHChDQ3d+7c2vgVV1yR1sBglK3HrogAAAAAWqMRAQAAALRGIwIAAABojUYEAAAA0BqNCAAAAKA1GhEAAABAa4zvBOgHxnfC8ePVr351mjv77LNr4w8//HBac95556W5bHzn8OHD05qmcYBjx46tje/bty+tGTVqVJrbtGlTbXzIkPzfsp773OfWxt/73vemNf3N+M7BbeXKlbXx97///WnNpEmT0tzevXtr4y94wQvSmgcffDDNrVixoja+aNGitOaJJ56ojTeN6Fy3bl2au+iii2rjmzdvTmu++93vprnLLrusNv4v//Ivac0//MM/pDnoLeM7AQAAgAGnEQEAAAC0RiMCAAAAaI1GBAAAANAajQgAAACgNRoRAAAAQGuM7wToB8Z3wuD0zne+szZ+8ODBtKZp3OayZctq49mYy4iIYcOGpbnVq1fXxhcuXJjWTJ06Nc1lo/127tyZ1owZMybN7dq1qzY+efLktCYbL7pnz560pin3pS99Kc1ljO8ceMuXL09z1113XW189+7daU3TGMzs85yN9Yxo/lzOmTOnNr506dK0JvusNI3H3b9/f5obP358bTxbgyIi5s2bl+ayn/keeuihtOYXf/EX09wrXvGK2njTiFNOTMZ3AgAAAANOIwIAAABojUYEAAAA0BqNCAAAAKA1GhEAAABAa0zNAOgHpmbAwPnQhz6U5p588sna+AMPPJDWvOQlL0lz2XlT0xSOLVu2pLlsSkBTzRlnnJHmRo8eXRtvOr6myQJDhtT/m9Vjjz2W1lxwwQW18abpAQcOHEhzK1asqI1/5jOfSWtGjBhhakYLmn6OeOtb35rm7r///tr4RRddlNacfPLJaS6bJLN27dq0Jvu6isgn56xZsyatyT5HQ4cOTWua3r+RI0fWxps+R01TR7LpOOPGjUtrNm7cmOaGDx9eG3/Tm96U1syfPz/NcfwyNQMAAAAYcBoRAAAAQGs0IgAAAIDWaEQAAAAArdGIAAAAAFqjEQEAAAC0xvhOgH5gfCcMnPe85z1p7umnn66NN42LXLduXZo76aSTauPZaLyIiDlz5qS5DRs21MbXr1+f1kyfPj3NjRgxojbeNA606b0YNmxYbXzTpk1pTTZKMRulGhExYcKENPf444/Xxp/73OemNR/+8IeN72zBNddck+aaRryee+65tfFly5alNUfyNfzCF74wrbn11lvTXDYiM/vajojYv39/bXznzp1pzfjx49Pc4sWLa+Pz5s1La5rWjYkTJ9bGN2/enNZs3bo1zWWf2cmTJ6c1v/3bv10bP+ecc9Iajn3GdwIAAAADTiMCAAAAaI1GBAAAANAajQgAAACgNRoRAAAAQGs0IgAAAIDW1M9kAgA4Rmzbti3NZeMxTz311LTmrLPOSnM7duyojWdjLiPyMXwR+TjAKVOmpDV33HFHmlu4cGFtvGn04ZAh+b9LZeNPm15vKfWTM6dNm5bWZCM6I/IxnXv27ElraMejjz6a5po+YzfddFNtfPbs2WnNaaedlua2b99eG286vqYxmNkI36bxnatWraqNDx06NK1pyl1yySW18aaxqE1rVzZyt+k1Na2t2bHv3r07rRk5cmSa48TjiggAAACgNRoRAAAAQGs0IgAAAIDWaEQAAAAArdGIAAAAAFpjagYAcEx7wQtekOaWLFlSGx8/fnxas2vXrjSX3Xl+zJgxac24cePSXKZpisRznvOcNJdNwNi3b19a87KXvSzN3X777bXxpokVe/furY1n0zQiIqqqSnMTJkzo07HR/173utfVxpu+tpsmtWTTMZo+e01fw2PHju3z9rKJMBH5RIh77703rRkxYkRtvGnax5133pnm5s+fXxsfPnx4WjN58uQ0t3nz5j5v7/zzz09zW7durY0fPHgwrXnXu95VG296Xy+44II0x7HNFREAAABAazQiAAAAgNZoRAAAAACt0YgAAAAAWqMRAQAAALRGIwIAAABojfGdAMAx7c1vfnOa+/KXv1wb3717d1rTNHZw+/btvT2sn5o4cWKaGzVqVG18wYIFaU3TsWej+LL9RETcc889aW7SpEm18QMHDqQ12VjE/fv3pzVNYyCz0X7ZyEb63/Tp02vjTeMiszGuERHLly+vjU+bNi2tyUbTRuTjNpu+rp588sk0N2PGjNp4Nr43Ih/fmR1bRPN4zJUrV9bGmz7LixcvTnPZ6236HI0cOTLNZcexdu3atOYXfuEXauNHMuKYY58rIgAAAIDWaEQAAAAArdGIAAAAAFqjEQEAAAC0RiMCAAAAaI1GBAAAANAa4zsBgOPW5ZdfXhv/7ne/m9bcfvvtae7d7353bTwbExoRsXHjxjQ3fvz42njT2Lxt27alualTp9bGd+3aldbs2bMnzWUjGLMRnRH5qMdTTjklrXnwwQfT3MyZM2vjTe8D7Wga1dg0bvPkk0+ujTeNpt2xY0eau+CCC2rjS5YsSWvmzJmT5lavXl0bb/ocZaNzV6xYkdZk4ywjIh544IHaePYZj2h+/0477bTa+NKlS9OapvGsc+fOrY03fS7XrVtXG2/6u+D45YoIAAAAoDUaEQAAAEBrNCIAAACA1mhEAAAAAK3RiAAAAABaU6qqypOl5MlB4I//+I9r49kdWQ9nwoQJtfEnn3zyiLaX3dV206ZNac3EiRPT3ObNm2vjX/ziF/t2YEC/q6qqHO19DPY1GU4EX/jCF2rj06ZNS2v++Z//Oc1l39vnzZuX1jTduT+bZtG0vaaJFWPHjq2Nb9myJa3J7qbfND1k586daa6U+uW1adrHX//1Xx/VNdl63PGBD3wgzTVNrBg1alRtfMSIEWnNrFmz0lw2saLpnPtIJkzce++9aU02FaZp8kTTRI1JkybVxg8cOJDWTJkyJc099dRTtfGTTjoprWn6GSib+NP0s8xHP/rRNMfxKztHdkUEAAAA0BqNCAAAAKA1GhEAAABAazQiAAAAgNZoRAAAAACt0YgAAAAAWjNsoA/gcH7jN34jzWVjaPbv35/WZON9mra3cuXKtGbRokVpbsiQ+j5PNrIoImLYsPyvJBuvc+WVV6Y1TaOJPvWpT6U5APrfAw880OeabBRiRMS73vWu2njTWMPp06enuWuuuab3B3aCGjp0aJ9rtm/fnubGjRtXG28amzdjxow0N3r06Nr4HXfckdY0vaZshGB23BERBw8erI1nYzgjmkcI7tixoza+YMGCtIZ2rFu3Ls09//nPT3NPPPFEbbxpjGvTOW12HNnnIaJ5VOi+fftq401jerNz+Kav7abtZZ+jpvW96Weg8847rzZ+6623pjVNa0P280z28xT05IoIAAAAoDUaEQAAAEBrNCIAAACA1mhEAAAAAK3RiAAAAABaoxEBAAAAtKZUVZUnS8mTR+A973lPmtu6dWttfPz48WnNqlWrauNjxozp24F1Wb9+fW38oosuSmuaxgxlI4Oy445oHkX1+OOP18Y3b96c1jSN+Mre84kTJ6Y11157bZqDE1lVVflcun7S32syz87dd99dG7/++uvTmvnz5/d5P01jHLPvQU3jIs8999w0t2XLltp402ui4ytf+Uqamzp1apr7/ve/Xxvfu3dvWrN8+fI0l517TJ48Oa1pGsE4YcKE2njTuPFHH320Nr5w4cK0ZtmyZWkuOxf89Kc/ndaMHTv2qK7J1uOObAxnRPM44GykZXYuHnFkI2ObzoOb1slsbG3TOMtsHGjT+M4NGzakuaeffro2vmvXrrTmlltuSXOzZs2qjTe9pqafMbLP85/92Z+lNZyYsnNkV0QAAAAArdGIAAAAAFqjEQEAAAC0RiMCAAAAaI1GBAAAANCaxqkZO3bsSJPZXXL/8A//MN1edrfWiPyOyNndmiOa7ySeWbJkSZobNWpUbXzbtm1pzZw5c9Lc/v37a+NNUz3OOOOMNPfNb36zz9trOr5sakbTMWTv+ac+9am0Bk4EpmYcn6666qo09+CDD9bGm77XNd0h/cILL6yN33nnnWnNnj17auNN3xea7rieTfV40YtelNa84hWvSHN0fOITn0hza9asqY0/9dRTac0pp5yS5kqpX4qaJnY1fb1kkwqaJndkmqaHNHnLW95SG1+0aFFac7TXZOvx4WXT3iLyc9qHHnoorWn6mSCbmnHeeeelNU327dtXGz/11FPTmuxno5NPPjmtaXpN2Xl60/eYG2+8Mc0dOHCgNp5NS2qqiYj427/92zQH3ZmaAQAAAAw4jQgAAACgNRoRAAAAQGs0IgAAAIDWaEQAAAAArdGIAAAAAFozrCm5fv36NHfDDTfUxptGP65bty7N7dy5sza+d+/etCYbrdM0ouqcc85JcyNHjqyNr127Nq2ZOHFimvvRj35UG1+4cGFak43diog466yzauNN7+uIESPSXKbp9c6ePbvP2wMYzF7/+tenuabRmdOmTauNZ+MYIyKGDx+e5rLx0pMmTUprxo4dWxvfvHnzER1DNqrxtttuS2u+9rWvpblf//VfT3MnkiFD8n/3afoensnGBEZETJ8+vTY+dOjQtKZprGw2XrDp6yjTdE73pS99Kc199rOf7fO+GHhnnnlmn2uaRkleffXVaS5bu+bOnZvWNI2gfdOb3pTmgP7higgAAACgNRoRAAAAQGs0IgAAAIDWaEQAAAAArdGIAAAAAFpTqqpKk+94xzvSZDblommqQtOdxLM7KTfdYTnb10knnZTWNN2NN3tNTdtrutP0jh07auMXX3xxWvPggw+muezO6bt3705rmv5+s9eb3SE7ImLWrFm18Q984ANpDZwIqqrKR970k1JK/oHmsH73d3+3Nt40rWj06NFp7tRTT62Nr169Oq1ZsWJFmpsxY0ZtPPteEhGxaNGi2njTlIusJiI/vpe+9KVpzWtf+9o0x+FdccUVtfGpU6emNY8//niay84V7rvvvrTmkksuSXP79++vjV911VVpzWBwtNdk6zFA72TrsSsiAAAAgNZoRAAAAACt0YgAAAAAWqMRAQAAALRGIwIAAABojUYEAAAA0JphR1o4ZcqU2vjKlSvTmtNOOy3Nbd26tTY+bdq0tGbdunW18YkTJ6Y169evT3O7du2qjY8aNSqtacpNnz69Nt408rPp2IcMqe8bbdy4Ma1pGv81b9682ng2Qi4i4vd+7/fSHMBAu/LKK9Pcnj17+hSPiDh48GCay9bXxx57LK0ZOnRomsu+B40bNy6tycY4XnjhhWnN4sWL09yTTz5ZG9+3b19aw7Nz/fXXD/QhxLe//e2BPgQ4JvzN3/xNmnvjG9/Y4pHAsc8VEQAAAEBrNCIAAACA1mhEAAAAAK3RiAAAAABaoxEBAAAAtEYjAgAAAGhN4/jOppFm2QjK1atXpzUnnXRSmhs7dmxtfPfu3WlNZu3atWmuaTzmiBEjauM7d+5Ma5reo+HDh9fG9+7dm9Y0mTt3bm1827Ztac3znve8NJf9HRrRCRyrmkYqb968uTY+evTotGbkyJFp7uGHH+79gXU5++yz09wdd9xRG3/uc5+b1mRjsZtGaTeNEL3gggtq448++mhac+ONN6a5bOz0q171qrQGYLAyohP6jysiAAAAgNZoRAAAAACt0YgAAAAAWqMRAQAAALRGIwIAAABoTePUjClTpqS5jRs31sYvvfTStKbpLt7Znb+zyQ4REdOnT+9zzbJly9LcsGH1b8cpp5zS55qm3Jo1a9Kap59+Os0dPHiwNn7ZZZelNU13dc/e869//etpzWte85o0BzDQtm7dmuYmTZpUG8/W1oh8nYzIJzRl3x8jIvbt25fmsslIjz32WFqTTQKZNWtWWrNq1ao0t2XLltr4ueeem9b867/+a5p7/vOfXxv/2te+ltZMmDAhzf3SL/1SmgMAjh2uiAAAAABaoxEBAAAAtEYjAgAAAGiNRgQAAADQGo0IAAAAoDUaEQAAAEBrSlVVafLgwYNp8oMf/GBtPBv9FRExZsyYNHfyySfXxptGsWUOHDiQ5lavXp3mxo4dWxsfP358WtM0im3kyJG18b1796Y1M2bMSHPZ39WcOXPSmje84Q1pDug/VVWVo72PUkq+YBMREVdddVWa27lzZ21806ZNac369evT3IUXXlgb/9a3vpXWnH322Wlu6NChtfEhQ/J/M8jGSy9fvjytaRoTPXr06Nr4uHHj0pqm9y/7npa91oj8fCAi4pxzzqmNb9u2La3J9vX6178+reHYd7TXZOsxQO9k67ErIgAAAIDWaEQAAAAArdGIAAAAAFqjEQEAAAC0RiMCAAAAaI1GBAAAANCaxvGdK1euTJOf/exna+NTp05Nt7d06dI0l43pzEZqRuQjvppGiY0aNSrNZXbv3p3mtm/fnuaysWVNI06bRo/+2q/9Wm38V37lV9IaOt73vveluZUrV6a5008/vTa+a9eutGbdunW18QkTJqQ1n//859PckiVLauNNYwBpn/Gdg8Pll1+e5rIR0pMnT05r7r///jSXfaYvvfTStOaWW25Jc9l40aZxltn33Kbvg6eeemqay/a1Zs2atKZpJHX23u7fvz+t2bNnT5/31fR9Oluv582bl9Zce+21aY5jg/GdAIOD8Z0AAADAgNOIAAAAAFqjEQEAAAC0RiMCAAAAaI1GBAAAANCaxqkZt912W5rM7tTddKfp4cOHp7nsTtjZXc4jIu65557a+Pnnn5/WPPnkk2lu9uzZtfFly5alNU2TEMaPH18bb7q7dzZpIyJi5syZtfHf+q3fSmuOR5/85CfT3KRJk2rj3//+99OaiRMn9jm3atWqtCb7e2+68/2jjz6a5hYsWFAbv/vuu9Oaj3zkI2lu7ty5aY4jZ2pGe+666640d/XVV6e5bApTNrUpImLWrFlpLvt+cuaZZ6Y1GzZs6PO+HnnkkbTmrLPOqo2fdNJJaU3TlIvs+93FF1+c1jRNFpk+fXptfMWKFX2uicind2QTRyIiSqn/aP793/99WsOxz9QMgMHB1AwAAABgwGlEAAAAAK3RiAAAAABaoxEBAAAAtEYjAgAAAGiNRgQAAADQmsbxnT/5yU/SZDYybPv27en2rrzyyjSXjUEbOXJkWpON9jz11FPTmqYxY9nox6ZRZ6ecckqa27dvX2183Lhxac0555yT5t785jenuePNDTfckOYef/zxNJf93TeNdlu+fHmay/4Ot23bltYMGVLf38s+M4c7hj179vTp2A7nHe94R228aTwfh2d85+CwdOnSNHfNNdfUxrNx1BHNn/Xse8bTTz+d1jR9brORm00jn7P18EjGjkbkr+nAgQNpTTYWNSJfv9avX5/WjBo1Ks1l5xjnnntuWpONJP3KV76S1nDsM76TY8Ff/uVfprmDBw/WxkePHp3WvPWtb33WxwT9zfhOAAAAYMBpRAAAAACt0YgAAAAAWqMRAQAAALRGIwIAAABojUYEAAAA0JrG8Z1tjiZ67LHHauMf/vCH05psTOKECRPSmjVr1qS5bIRbNs7scNvLjq9pHOiiRYvSXDZ28ViWjdObMmVKWjN8+PA098QTT9TGH3300bSmaWzlXXfdVRtvGtu6cePG2vjQoUPTmq1bt6a58ePH18abRtuedtppaS4bBbhp06a05uMf/3iao8P4zsFv7dq1tfHPfe5zac2OHTvSXPa5bVq/mkYJZ9+PJ0+enNZkx7dixYq0ZsGCBWlu9erVtfHp06enNU2j5BYvXlwbbxpV3fR9deLEibXxbN2NiPjQhz5UG28a28qxz/jO48/v/M7vpLmZM2fWxpvGH1966aVpLluHzjjjjLQmOzfMzk0jmkcZZyOdf/jDH6Y18+bNS3M//vGPa+NN32M++clPpjnoLeM7AQAAgAGnEQEAAAC0RiMCAAAAaI1GBAAAANAajQgAAACgNYNmakbm9ttvT3Ol1N8Q+aKLLkprXvOa16S5GTNm1Mab7gg+atSoNJfdoX3btm1pTTZFIiLizDPPTHPHquxu9U2TSrJJFhERTz/9dG286Y7FTRMmsgkn2Z2MIyIeeeSR2njTXeKbJqlkd0dumvaxatWqNJdN22iaDvDOd74zzb3kJS9JcycSUzOOT5dffnmayybaNE2saPq8ZNN99u3bl9Zkdzs//fTT05rNmzenuXvuuac2ftZZZ6U1TVOJsmkbTdOPsu/tERF//ud/XhtftmxZWjN37tw0x/HL1IyB96u/+qtpLjtfa5o61HS+tn379j7tJ6J5nczW8Wzdj4jYsmVLbXz+/PlpTdNEouz8b8+ePWnNsGHD0tyBAwdq43PmzElrsslHERHjxo2rjX/iE59IazgxmZoBAAAADDiNCAAAAKA1GhEAAABAazQiAAAAgNZoRAAAAACt0YgAAAAAWpPPeBkkmkYUtmX9+vVpbsOGDWkuG0/WNEryeBzR2WT37t218TvvvDOtaRq3lI2tzPYTETFx4sQ0l41vahqPlI17PXjwYFrTdHyXXXZZbfy+++5La5rG1WXHkY2dioi48cYb0xwcz7785S+nuWx9+PjHP57W3HzzzWkuG+2cjZaOiNi4cWNtvGmE76xZs9JcNupu69ataU3T8a1evbo2/kd/9EdpzZAh+b+RNI32BNrX9FluGk2+a9eu2njTaN9Ro0aluWwEedMxPPTQQ2kuO44f/OAHaU02QnTnzp1pTbbuR+Svt+mcMRupGZGv7/fee29ac8EFF6S5u+++uzb+mc98Jq1529veluY48bgiAgAAAGiNRgQAAADQGo0IAAAAoDUaEQAAAEBrNCIAAACA1gz6qRmDwamnnjrQh3Dcmj17dm185MiRaU3THYazOwk3TSO566670ty0adNq45s2bepzzcqVK9Oapq+xe+65pzY+ZsyYtGbx4sVp7oUvfGFtvGmqx9ChQ9McnKjOOOOMft1eNpli/Pjx/bqfI5HdHT2ieYrQ2LFja+NN69fUqVN7f2BAK5YvX14b/8hHPpLWVFWV5s4///zaeNO5UjYZIyKfBpRNiohonmKXTfXIJqNFRLz85S+vjWfTNCIinnrqqTTXNAEjs3bt2jSXnZ9m63RE8/lk9h41TXWD7lwRAQAAALRGIwIAAABojUYEAAAA0BqNCAAAAKA1GhEAAABAazQiAAAAgNYY38lR9xd/8RdpLhtLt3Tp0rRm9erVaS4b39Q0bnPv3r1pbuPGjbXxppFU69ev7/N+9uzZ0+djuPjii9OaGTNm9Hl7a9asSWsWLlyY5oD+MRjGdGYWLVo00IcAHGVNox+/9a1v1cabRkxu2bIlzWXneQsWLEhrbrvttjSXjekcMiT/N9emsZrZOOVsBGZE/nqbzv+axqNn52vnnXdeWnPTTTeluQkTJtTGJ02alNY89NBDae7CCy+sjTe9r9CdKyIAAACA1mhEAAAAAK3RiAAAAABaoxEBAAAAtEYjAgAAAGiNRgQAAADQGuM76ZOvf/3rtfGDBw+mNU1jIb/xjW/UxrMxTBHNI+6eeuqp2vjIkSPTmt27d6e5mTNn1sbXrVuX1mSjM5tqmnLZe3HvvfemNU3v+dSpU2vjs2fPTmt27tyZ5gCAY9+GDRvS3KpVq2rj2dj0iHz8ZEQ+9vP+++/v8zFE5Od5Y8eOTWtOPvnkPufuuOOOtOb5z39+bbzpPPOss85Kc3feeWdt/IknnkhrmkZxZueNTSNYTzvttDSXjSXdtGlTWgPduSICAAAAaI1GBAAAANAajQgAAACgNRoRAAAAQGs0IgAAAIDWmJrBM7z73e9Oc3fffXdtfMuWLWlN08SFoUOH1sab7vrbdLfg7M6/TXfwnTx5cprLpoG88pWvTGvuu+++2njTJJDhw4enuWwKR9PdrZsmYGzbtq023nT36NWrV6c5AODYd95556W5N77xjbXxzZs3pzVTpkxJc/v376+Nz58/P62ZOHFimhsxYkSayzSdR2Xnf4sWLUprsgkiTefBTefP2blc02SRs88+O81l55pN53+7du1Kc1u3bq2NV1WV1kB3rogAAAAAWqMRAQAAALRGIwIAAABojUYEAAAA0BqNCAAAAKA1GhEAAABAa4zv5Bn27NmT5rIRlKeddlpas27dujQ3atSo2viQIXmPrJSS5rIRRMuXL09rLrnkkjR322231cbvuuuutOacc86pjT/wwANpTdNozwMHDtTGp02bltaMGTMmzWXvUdMYq7/6q79KcwDA8e1P/uRPauOf+cxn0pphw/IfM7KRlk2jHx9++OE0l517NY3OHDt2bJrLzsuyc7KmfWVjLiOaR2dmdaeffnpak40dbdrX+PHj05qm89Ps3HrmzJlpDXTniggAAACgNRoRAAAAQGs0IgAAAIDWaEQAAAAArdGIAAAAAFqjEQEAAAC0xvhOnqFpVOMHP/jB2njTiM6mcUtTp06tjc+YMSOt2bFjR5o76aSTauOzZ89Oa37wgx+kuWxsUdN4zFtuuaU23jTitGlcaTZWqek9bzq+bDxr09hWAODENWfOnNr4W97ylrSm6Vxu2bJltfGmc5v58+enuW9/+9u18de+9rVpzYoVK9JcNjrzOc95Tlqzd+/e2vjEiRPTmvvvvz/NZaNMm87xms6RJ0+eXBvfsmXLEW3v/PPPr43v27cvrYHuXBEBAAAAtEYjAgAAAGiNRgQAAADQGo0IAAAAoDUaEQAAAEBrSnZH1oiIUkqe5IR000031cZvvPHGtGbNmjVpLrujcjatIiKfjBGR34X5jDPOSGuWLl2a5l7+8pfXxjdu3JjWPP7447Xxptf0xBNPpLls2sb+/fvTmu3bt6e5k08+uTb+1a9+Na3h8KqqKkd7H9ZkgN452muy9bij6Rzv6quvTnPZdIymyRhr165Nc/PmzauNZ+c8Efk5bUQ+JWTmzJlpTTY1o2kSyJQpU9JcKfVfwjt37kxrRo0aleayKSHDhuVDFJte7+7du2vjr371q9OaSy65JM1x/MrWY1dEAAAAAK3RiAAAAABaoxEBAAAAtEYjAgAAAGiNRgQAAADQGo0IAAAAoDXGd9IvHnvssTR3/fXXp7lsFGfTOMszzzwzzWWjhLZu3drnY2jaXjaiKSIftzl+/Pi05r777ktzkydPro2PHDkyrXn729+e5preP46c8Z0Ag4fxne14+OGH09zixYvT3J133lkbbzr/O3jwYJrbs2dPbXzXrl1pzfDhw9Ncdm74vOc9L63JxrdPnTo1rWk6f87Gd2bxiIghQ/J/Y85e76xZs9KapvGs1157bW28aWQqJybjOwEAAIABpxEBAAAAtEYjAgAAAGiNRgQAAADQGo0IAAAAoDUaEQAAAEBrjO/kqFuyZEmamzRpUm18xYoVac0//dM/pblVq1bVxs8444y0Zty4cWnuO9/5Tm18wYIFac2OHTtq4xs2bEhrrrvuujQ3d+7c2njTiCbaZ3wnwOBhfOfAazqXy8Zt3nTTTWnNI488kuaykZFNIz/Xr1+f5rZv396n/UTkY9q3bduW1mTneBH5mM5sTGhExM6dO9Pc/v37a+MzZ85Ma7IRndAXxncCAAAAA04jAgAAAGiNRgQAAADQGo0IAAAAoDUaEQAAAEBrTM2AI9B0x+Ls7shNdyVevnx5mnve857X6+Ni4JiaATB4mJpxbNq8eXOau+GGG9Lcww8/XBtfvXp1WjN58uQ0t3v37tr4mDFj0poRI0b0KR4RMWHChDS3Zs2a2vi8efPSmsWLF6e5K6+8sjY+f/78tAb6g6kZAAAAwIDTiAAAAABaoxEBAAAAtEYjAgAAAGiNRgQAAADQGo0IAAAAoDXGdwL0A+M7AQYP4zuJaB63PmRI/u+xo0ePro3fd999ac1FF11UG9+yZUtac/rpp6e5JUuW1MabRn5mxx0RMW7cuDQHR5PxnQAAAMCA04gAAAAAWqMRAQAAALRGIwIAAABojUYEAAAA0BqNCAAAAKA1xncC9APjOwEGD+M7AQYH4zsBAACAAacRAQAAALRGIwIAAABojUYEAAAA0BqNCAAAAKA1GhEAAABAazQiAAAAgNZoRAAAAACt0YgAAAAAWqMRAQAAALRGIwIAAABojUYEAAAA0BqNCAAAAKA1GhEAAABAazQiAAAAgNZoRAAAAACt0YgAAAAAWqMRAQAAALRGIwIAAABojUYEAAAA0BqNCAAAAKA1GhEAAABAazQiAAAAgNZoRAAAAACt0YgAAAAAWqMRAQAAALRGIwIAAABojUYEAAAA0BqNCAAAAKA1GhEAAABAazQiAAAAgNZoRAAAAACt0YgAAAAAWqMRAQAAALRGIwIAAABojUYEAAAA0BqNCAAAAKA1GhEAAABAazQiAAAAgNZoRAAAAACt0YgAAAAAWqMRAQAAALRGIwIAAABojUYEAAAA0BqNCAAAAKA1paqqgT4GAAAA4AThiggAAACgNRoRAAAAQGs0IgAAAIDWaEQAAAAArdGIAAAAAFqjEQEAAAC05v8HFfOa+hjxBnYAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 1080x360 with 3 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "font = {'family': 'serif', 'color':  'black', 'weight': 'normal', 'size': 15}\n",
    "fig, axs = plt.subplots(1, 3, figsize=(15, 5))\n",
    "\n",
    "axs[0].imshow(img_mask_to_uint8(nn_video[n_frames//3][0],nn_video[n_frames//3][1]), cmap='gray')\n",
    "axs[0].set_title(r'Segmented Region $(t = 60s)$',fontdict=font)\n",
    "axs[1].imshow(nn_video[2*n_frames//3][0]*nn_video[2*n_frames//3][1], cmap='gray')\n",
    "axs[1].set_title(r'Segmented Region $(t = 120s)$',fontdict=font)\n",
    "axs[2].imshow(nn_video[-1][0]*nn_video[-1][1], cmap='gray')\n",
    "axs[2].set_title(r'Segmented Region $(t = 180s)$',fontdict=font)\n",
    "for ax in axs.flat:\n",
    "    ax.set_axis_off()\n",
    "\n",
    "fig.tight_layout()\n",
    "import matplotlib as mpl\n",
    "mpl.rcParams['pdf.fonttype'] = 42 \n",
    "mpl.rcParams['ps.fonttype'] = 42\n",
    "mpl.rcParams['text.usetex'] = False\n",
    "plt.savefig('unet_semi.pdf', dpi=300, transparent=False, bbox_inches='tight')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
