{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "8d27225d",
   "metadata": {},
   "source": [
    "## DDPM Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7f18ccb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Args: Namespace(data_root='/data/schoudh8/ecmmd/data/mnist/data', save_dir='/data/schoudh8/ecmmd/data/mnist/chckpoints', epochs=10, batch_size=128, lr=0.0002, timesteps=1000, base_ch=64, time_emb_dim=128, sigma_c=0.9, cf_drop_prob=0.2, guidance_scale=2.0, seed=42, log_interval=100, save_interval=1000, cpu=False)\n"
     ]
    }
   ],
   "source": [
    "import math\n",
    "import os\n",
    "import random\n",
    "import argparse\n",
    "from tqdm import tqdm\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from torchvision import datasets, transforms, utils\n",
    "\n",
    "import time\n",
    "from torchmetrics.image.fid import FrechetInceptionDistance\n",
    "from torchmetrics.image.inception import InceptionScore\n",
    "from torchmetrics.functional.image.ssim import structural_similarity_index_measure as ssim\n",
    "\n",
    "# ---------------------------\n",
    "# Utilities & noise schedule\n",
    "# ---------------------------\n",
    "\n",
    "def set_seed(seed):\n",
    "    random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "\n",
    "def linear_noise_schedule(timesteps, beta_start=1e-4, beta_end=0.02):\n",
    "    \"\"\"\n",
    "    Linear beta schedule from beta_start to beta_end (inclusive)\n",
    "    Returns betas tensor of shape (timesteps,)\n",
    "    \"\"\"\n",
    "    return torch.linspace(beta_start, beta_end, timesteps)\n",
    "\n",
    "def make_ddpm_schedule(timesteps, device):\n",
    "    betas = linear_noise_schedule(timesteps).to(device)  # shape (T,)\n",
    "    alphas = 1.0 - betas\n",
    "    alphas_cumprod = torch.cumprod(alphas, dim=0)\n",
    "    alphas_cumprod_prev = torch.cat([torch.tensor([1.], device=device), alphas_cumprod[:-1]])\n",
    "    sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)\n",
    "    sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)\n",
    "    return {\n",
    "        'betas': betas,\n",
    "        'alphas': alphas,\n",
    "        'alphas_cumprod': alphas_cumprod,\n",
    "        'alphas_cumprod_prev': alphas_cumprod_prev,\n",
    "        'sqrt_alphas_cumprod': sqrt_alphas_cumprod,\n",
    "        'sqrt_one_minus_alphas_cumprod': sqrt_one_minus_alphas_cumprod,\n",
    "        'timesteps': timesteps\n",
    "    }\n",
    "\n",
    "# ---------------------------\n",
    "# Dataset that produces (x_clean, c_noisy)\n",
    "# ---------------------------\n",
    "\n",
    "class MNISTNoisyConditionDataset(Dataset):\n",
    "    \"\"\"\n",
    "    For each MNIST image x (normalized to [-1,1]):\n",
    "      - Create conditioning c by adding Gaussian noise with std sigma_c (and clip to [-1,1])\n",
    "      - Return x (clean), c (conditioning noisy measurement)\n",
    "    During training we'll sample a time t and produce x_t on the fly: x_t = sqrt(alpha_cumprod[t]) * x + sqrt(1 - alpha_cumprod[t]) * eps\n",
    "    \"\"\"\n",
    "    def __init__(self, root='./data', train=True, download=True, sigma_c=0.6, ETA_DIM = 7, transform=None):\n",
    "        self.ETA_DIM = ETA_DIM\n",
    "        self.mnist = datasets.MNIST(root=root, train=train, download=download)\n",
    "        self.sigma_c = sigma_c\n",
    "        self.transform = transform or transforms.Compose([\n",
    "            transforms.ToTensor(),  # [0,1]\n",
    "            transforms.Normalize((0.5,), (0.5,))  # -> [-1,1]\n",
    "        ])\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.mnist)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        img, _ = self.mnist[idx]\n",
    "        x = self.transform(img)  # 1x28x28, in [-1,1]\n",
    "        # conditioning c: add gaussian noise (std sigma_c, scaled relative to image range)\n",
    "        noise = torch.randn_like(x) * self.sigma_c\n",
    "        c = x + noise\n",
    "        c = c.clamp(-1., 1.)\n",
    "        eta = torch.randn(len(x), self.ETA_DIM, self.ETA_DIM)\n",
    "        return {\n",
    "            'x': x,     # clean\n",
    "            'c': c,     # conditioning noisy measurement\n",
    "            'eta': eta  # additional noise for ecmmd\n",
    "        }\n",
    "\n",
    "# ---------------------------\n",
    "# Model: small UNet-like backbone\n",
    "# ---------------------------\n",
    "\n",
    "# ---- Sinusoidal embedding ----\n",
    "class SinusoidalPosEmb(nn.Module):\n",
    "    def __init__(self, dim):\n",
    "        super().__init__()\n",
    "        self.dim = dim\n",
    "\n",
    "    def forward(self, t):\n",
    "        device = t.device\n",
    "        half = self.dim // 2\n",
    "        freqs = torch.exp(-math.log(10000) * torch.arange(0, half, device=device) / half)\n",
    "        args = t[:, None] * freqs[None, :]  # (B, half)\n",
    "        emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)\n",
    "        if self.dim % 2:\n",
    "            emb = F.pad(emb, (0,1))\n",
    "        return emb\n",
    "\n",
    "\n",
    "# ---- Basic conv block ----\n",
    "def conv_block(in_ch, out_ch, kernel_size=3, padding=1):\n",
    "    return nn.Sequential(\n",
    "        nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, padding=padding),\n",
    "        nn.GroupNorm(8, out_ch),\n",
    "        nn.SiLU()\n",
    "    )\n",
    "\n",
    "\n",
    "# ---- Simplified UNet (~20M params) ----\n",
    "class SmallUNet(nn.Module):\n",
    "    def __init__(self, in_channels=2, base_ch=48, time_emb_dim=96):\n",
    "        \"\"\"\n",
    "        Smaller UNet ~20M params.\n",
    "        \"\"\"\n",
    "        super().__init__()\n",
    "        self.time_mlp = nn.Sequential(\n",
    "            SinusoidalPosEmb(time_emb_dim),\n",
    "            nn.Linear(time_emb_dim, time_emb_dim * 2),\n",
    "            nn.SiLU(),\n",
    "            nn.Linear(time_emb_dim * 2, time_emb_dim)\n",
    "        )\n",
    "\n",
    "        # encoder\n",
    "        self.enc1 = nn.Sequential(conv_block(in_channels, base_ch), conv_block(base_ch, base_ch))\n",
    "        self.down1 = nn.Conv2d(base_ch, base_ch*2, kernel_size=4, stride=2, padding=1)  # 28->14\n",
    "        self.enc2 = nn.Sequential(conv_block(base_ch*2, base_ch*2), conv_block(base_ch*2, base_ch*2))\n",
    "        self.down2 = nn.Conv2d(base_ch*2, base_ch*4, kernel_size=4, stride=2, padding=1)  # 14->7\n",
    "        self.enc3 = nn.Sequential(conv_block(base_ch*4, base_ch*4), conv_block(base_ch*4, base_ch*4))\n",
    "\n",
    "        # bottleneck\n",
    "        self.mid = nn.Sequential(conv_block(base_ch*4, base_ch*6), conv_block(base_ch*6, base_ch*4))\n",
    "\n",
    "        # decoder\n",
    "        self.up2 = nn.ConvTranspose2d(base_ch*4, base_ch*2, kernel_size=4, stride=2, padding=1)  # 7->14\n",
    "        self.dec2 = nn.Sequential(conv_block(base_ch*4, base_ch*2), conv_block(base_ch*2, base_ch*2))\n",
    "        self.up1 = nn.ConvTranspose2d(base_ch*2, base_ch, kernel_size=4, stride=2, padding=1)  # 14->28\n",
    "        self.dec1 = nn.Sequential(conv_block(base_ch*2, base_ch), conv_block(base_ch, base_ch))\n",
    "\n",
    "        self.final = nn.Sequential(\n",
    "            conv_block(base_ch, base_ch),\n",
    "            nn.Conv2d(base_ch, 1, kernel_size=1)  # predict noise\n",
    "        )\n",
    "\n",
    "        # time embedding projections\n",
    "        self.time_proj1 = nn.Linear(time_emb_dim, base_ch)\n",
    "        self.time_proj2 = nn.Linear(time_emb_dim, base_ch*2)\n",
    "        self.time_proj3 = nn.Linear(time_emb_dim, base_ch*4)\n",
    "\n",
    "    def forward(self, x, t):\n",
    "        t_emb = self.time_mlp(t.float())\n",
    "\n",
    "        # encode\n",
    "        h1 = self.enc1(x)\n",
    "        h1 = h1 + self.time_proj1(t_emb)[:, :, None, None]\n",
    "\n",
    "        h2 = self.down1(h1)\n",
    "        h2 = self.enc2(h2)\n",
    "        h2 = h2 + self.time_proj2(t_emb)[:, :, None, None]\n",
    "\n",
    "        h3 = self.down2(h2)\n",
    "        h3 = self.enc3(h3)\n",
    "        h3 = h3 + self.time_proj3(t_emb)[:, :, None, None]\n",
    "\n",
    "        h = self.mid(h3)\n",
    "\n",
    "        # decode\n",
    "        h = self.up2(h)\n",
    "        h = torch.cat([h, h2], dim=1)\n",
    "        h = self.dec2(h)\n",
    "        h = self.up1(h)\n",
    "        h = torch.cat([h, h1], dim=1)\n",
    "        h = self.dec1(h)\n",
    "\n",
    "        return self.final(h)\n",
    "\n",
    "# ---------------------------\n",
    "# Helpers for forward noising and sampling (DDPM)\n",
    "# ---------------------------\n",
    "\n",
    "def q_sample(x_start, t, noise, schedule):\n",
    "    \"\"\"\n",
    "    x_start: (B,1,H,W) in [-1,1]\n",
    "    t: (B,) long, indices [0..T-1]\n",
    "    noise: (B,1,H,W) gaussian noise\n",
    "    schedule: dict from make_ddpm_schedule\n",
    "    returns x_t\n",
    "    \"\"\"\n",
    "    # sqrt_alphas_cumprod[t] * x_start + sqrt(1 - alphas_cumprod[t]) * noise\n",
    "    sqrt_ac = schedule['sqrt_alphas_cumprod'][t].view(-1, 1, 1, 1)\n",
    "    sqrt_om = schedule['sqrt_one_minus_alphas_cumprod'][t].view(-1, 1, 1, 1)\n",
    "    return sqrt_ac * x_start + sqrt_om * noise\n",
    "\n",
    "@torch.no_grad()\n",
    "def p_sample(model, x_t, c, t_idx, schedule, cond_scale=1.0, device='cpu'):\n",
    "    \"\"\"\n",
    "    One step of DDPM denoising with classifier-free guidance.\n",
    "    - model: takes input concatenated [x_t, c] -> predicts noise\n",
    "    - x_t: (B,1,H,W)\n",
    "    - c: (B,1,H,W) conditioning measurement\n",
    "    - t_idx: scalar timestep index (int) or tensor of shape (B,)\n",
    "    - cond_scale: guidance scale (>=0). 0 => unconditional\n",
    "    \"\"\"\n",
    "    betas = schedule['betas']\n",
    "    alphas = schedule['alphas']\n",
    "    alphas_cumprod = schedule['alphas_cumprod']\n",
    "    sqrt_one_minus_alphas_cumprod = schedule['sqrt_one_minus_alphas_cumprod']\n",
    "    sqrt_recip_alphas = (1.0 / torch.sqrt(alphas)).to(device)\n",
    "\n",
    "    B = x_t.shape[0]\n",
    "    t = torch.full((B,), t_idx, dtype=torch.long, device=device)\n",
    "    \n",
    "    # compute conditional prediction (always compute eps_cond because common case)\n",
    "    inp_cond = torch.cat([x_t, c], dim=1)\n",
    "    eps_cond = model(inp_cond, t)\n",
    "\n",
    "    # pick epsilon according to guidance scale:\n",
    "    # - if cond_scale == 1.0: use eps_cond (pure conditional)\n",
    "    # - if cond_scale == 0.0: use eps_uncond (pure unconditional)\n",
    "    # - otherwise: compute both and combine\n",
    "    eps = None\n",
    "    if abs(cond_scale - 1.0) < 1e-8:\n",
    "        eps = eps_cond\n",
    "    else:\n",
    "        # we need the unconditional prediction\n",
    "        inp_uncond = torch.cat([x_t, torch.zeros_like(c)], dim=1)  # zeros on same device/dtype as c\n",
    "        eps_uncond = model(inp_uncond, t)\n",
    "        if abs(cond_scale) < 1e-8:\n",
    "            eps = eps_uncond\n",
    "        else:\n",
    "            eps = eps_uncond + cond_scale * (eps_cond - eps_uncond)\n",
    "    \n",
    "    # compute x_{t-1} mean\n",
    "    beta_t = betas[t_idx]\n",
    "    alpha_t = alphas[t_idx]\n",
    "    alpha_cumprod_t = alphas_cumprod[t_idx]\n",
    "    alpha_cumprod_prev = schedule['alphas_cumprod_prev'][t_idx]\n",
    "    # following DDPM paper eqn (sampling from q)\n",
    "    coef1 = 1.0 / torch.sqrt(alpha_t)\n",
    "    coef2 = (beta_t / torch.sqrt(1.0 - alpha_cumprod_t))\n",
    "    mean = coef1 * (x_t - coef2 * eps)\n",
    "\n",
    "    if t_idx == 0:\n",
    "        return mean  # final step, return mean\n",
    "    else:\n",
    "        noise = torch.randn_like(x_t).to(device)\n",
    "        posterior_var = beta_t * (1.0 - alpha_cumprod_prev) / (1.0 - alpha_cumprod_t)\n",
    "        return mean + torch.sqrt(posterior_var) * noise\n",
    "\n",
    "@torch.no_grad()\n",
    "def p_sample_loop(model, shape, c, schedule, cond_scale=1.0, device='cpu', progress=False):\n",
    "    \"\"\"\n",
    "    Iteratively sample x_{T-1} ... x_0\n",
    "    shape: (B,1,H,W)\n",
    "    c: conditioning noisy measurement (B,1,H,W) to use for guidance\n",
    "    cond_scale: guidance scale\n",
    "    \"\"\"\n",
    "    B = shape[0]\n",
    "    x_t = torch.randn(shape, device=device)\n",
    "    T = schedule['timesteps']\n",
    "    iterator = range(T-1, -1, -1)\n",
    "    if progress:\n",
    "        iterator = tqdm(iterator, desc=\"sampling\")\n",
    "    for t in iterator:\n",
    "        x_t = p_sample(model, x_t, c, t, schedule, cond_scale=cond_scale, device=device)\n",
    "    return x_t\n",
    "\n",
    "# ---------------------------\n",
    "# Training loop\n",
    "# ---------------------------\n",
    "\n",
    "def train(args):\n",
    "    device = torch.device('cuda' if torch.cuda.is_available() and not args.cpu else 'cpu')\n",
    "    set_seed(args.seed)\n",
    "\n",
    "    # create dataset\n",
    "    transform = transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.5,), (0.5,))  # [-1,1]\n",
    "    ])\n",
    "    train_ds = MNISTNoisyConditionDataset(root=args.data_root, train=True, download=True, sigma_c=args.sigma_c, transform=transform)\n",
    "    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=True)\n",
    "\n",
    "    # model & schedule\n",
    "    schedule = make_ddpm_schedule(args.timesteps, device)\n",
    "    model = SmallUNet(in_channels=2, base_ch=args.base_ch, time_emb_dim=args.time_emb_dim).to(device)\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)\n",
    "\n",
    "    global_step = 0\n",
    "    save_dir = args.save_dir\n",
    "    os.makedirs(save_dir, exist_ok=True)\n",
    "\n",
    "    # training\n",
    "    model.train()\n",
    "    for epoch in range(args.epochs):\n",
    "        pbar = tqdm(train_loader, desc=f\"Epoch {epoch+1}/{args.epochs}\")\n",
    "        for batch in pbar:\n",
    "            x = batch['x'].to(device)  # clean\n",
    "            c = batch['c'].to(device)  # conditioning noisy measurement (kept as [-1,1])\n",
    "\n",
    "            B = x.shape[0]\n",
    "            # sample random t for each sample\n",
    "            t = torch.randint(0, args.timesteps, (B,), device=device).long()\n",
    "            # sample noise and create x_t\n",
    "            noise = torch.randn_like(x).to(device)\n",
    "            x_t = q_sample(x, t, noise, schedule)\n",
    "\n",
    "            # classifier-free guidance: with prob p_drop, drop conditioning by replacing with zeros\n",
    "            if args.cf_drop_prob > 0.0:\n",
    "                mask = (torch.rand(B, device=device) >= args.cf_drop_prob).float().view(B, 1, 1, 1)\n",
    "                c_input = c * mask  # when mask=0 => unconditional (zeroed c)\n",
    "            else:\n",
    "                c_input = c\n",
    "\n",
    "            # model input concatenated (x_t, c_input)\n",
    "            inp = torch.cat([x_t, c_input], dim=1)\n",
    "            pred_noise = model(inp, t)\n",
    "\n",
    "            loss = F.mse_loss(pred_noise, noise)\n",
    "\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            global_step += 1\n",
    "            if global_step % args.log_interval == 0:\n",
    "                pbar.set_postfix({'loss': loss.item(), 'step': global_step})\n",
    "\n",
    "            # save checkpoints occasionally\n",
    "            if global_step % args.save_interval == 0:\n",
    "                ckpt = {\n",
    "                    'model_state': model.state_dict(),\n",
    "                    'optimizer_state': optimizer.state_dict(),\n",
    "                    'global_step': global_step,\n",
    "                    'args': vars(args)\n",
    "                }\n",
    "                torch.save(ckpt, os.path.join(save_dir, f'ckpt_step_{global_step}.pt'))\n",
    "\n",
    "        # end epoch: save snapshot and sample some images\n",
    "        torch.save({'model_state': model.state_dict()}, os.path.join(save_dir, f'model_epoch_{epoch+1}.pt'))\n",
    "        # sample a small batch for visual inspection\n",
    "        sample_and_save_images(model, schedule, train_ds, device, epoch, save_dir, args)\n",
    "\n",
    "    # final save\n",
    "    torch.save({'model_state': model.state_dict()}, os.path.join(save_dir, f'model_final.pt'))\n",
    "    print(\"Training complete. Model saved to\", save_dir)\n",
    "\n",
    "# ---------------------------\n",
    "# Sampling helpers & visualization\n",
    "# ---------------------------\n",
    "\n",
    "@torch.no_grad()\n",
    "def sample_and_save_images(model, schedule, dataset, device, epoch, save_dir, args):\n",
    "    model.eval()\n",
    "    # pick some examples from dataset (first N)\n",
    "    N = min(8, len(dataset))\n",
    "    batch = [dataset[i] for i in range(N)]\n",
    "    x_clean = torch.stack([b['x'] for b in batch], dim=0).to(device)\n",
    "    c = torch.stack([b['c'] for b in batch], dim=0).to(device)\n",
    "\n",
    "    # sample with different guidance scales\n",
    "    out_images = []\n",
    "    guidance_scales = [0.0, 1.0, args.guidance_scale]  # unconditional, normal, strong guidance\n",
    "    for gs in guidance_scales:\n",
    "        samples = p_sample_loop(model, (N, 1, 28, 28), c, schedule, cond_scale=gs, device=device, progress=False)\n",
    "        out_images.append(samples.clamp(-1., 1.))\n",
    "    # concat images for grid:\n",
    "    # rows: for each guidance scale; columns: examples (clean, conditioning, sampled)\n",
    "    grid_rows = []\n",
    "    for i in range(N):\n",
    "        row_imgs = []\n",
    "        # clean\n",
    "        row_imgs.append(x_clean[i:i+1])\n",
    "        # conditioning\n",
    "        row_imgs.append(c[i:i+1])\n",
    "        for samples in out_images:\n",
    "            row_imgs.append(samples[i:i+1])\n",
    "        row = torch.cat(row_imgs, dim=3)  # concat horizontally\n",
    "        grid_rows.append(row)\n",
    "    grid = torch.cat(grid_rows, dim=2)  # concat vertically\n",
    "    # grid is (1, H, W_total)\n",
    "    out_path = os.path.join(save_dir, f'sample_epoch_{epoch+1}.png')\n",
    "    utils.save_image((grid + 1) * 0.5, out_path)  # to [0,1]\n",
    "    print(f\"Saved sample grid to {out_path}\")\n",
    "    model.train()\n",
    "\n",
    "# ---------------------------\n",
    "# CLI / hyperparameters\n",
    "# ---------------------------\n",
    "\n",
    "def get_args():\n",
    "    parser = argparse.ArgumentParser()\n",
    "    parser.add_argument('--data-root', type=str, default='./data')\n",
    "    parser.add_argument('--save-dir', type=str, default='./chckpoints')\n",
    "    parser.add_argument('--epochs', type=int, default=10)\n",
    "    parser.add_argument('--batch-size', type=int, default=128)\n",
    "    parser.add_argument('--lr', type=float, default=2e-4)\n",
    "    parser.add_argument('--timesteps', type=int, default=1000)\n",
    "    parser.add_argument('--base-ch', type=int, default=64)\n",
    "    parser.add_argument('--time-emb-dim', type=int, default=128)\n",
    "    parser.add_argument('--sigma-c', type=float, default=0.9, help='std of gaussian noise used for conditioning measurement c')\n",
    "    parser.add_argument('--cf-drop-prob', type=float, default=0.2, help='prob of dropping the conditioning during training (classifier-free guidance)')\n",
    "    parser.add_argument('--guidance-scale', type=float, default=2.0, help='guidance scale used at sampling (>=0. 0 -> unconditional)')\n",
    "    parser.add_argument('--seed', type=int, default=42)\n",
    "    parser.add_argument('--log-interval', type=int, default=100)\n",
    "    parser.add_argument('--save-interval', type=int, default=1000)\n",
    "    parser.add_argument('--cpu', action='store_true', help='force CPU even if GPU available')\n",
    "    return parser.parse_known_args()[0]\n",
    "\n",
    "args = get_args()\n",
    "print(\"Args:\", args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "909f38e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "## Comment out the next line only for training the model\n",
    "# train(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0ce2562",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_1774544/1977089565.py:101: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  model.load_state_dict(torch.load(ckpt_path, map_location=device)['model_state'])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Model Evaluation Function\n",
    "@torch.no_grad()\n",
    "def evaluate_model(model, dataloader, schedule, device, args, num_batches=10, model_type = \"ddpm\"):\n",
    "    \"\"\"\n",
    "    Evaluate diffusion model performance on MNIST.\n",
    "    Returns dict of metrics: MSE, PSNR, SSIM, FID, IS, GenTime\n",
    "    \"\"\"\n",
    "    model.eval()\n",
    "    mse_vals, psnr_vals, ssim_vals, gen_times = [], [], [], []\n",
    "\n",
    "    # Torchmetrics FID/IS modules\n",
    "    fid = FrechetInceptionDistance(feature=64).to(device)\n",
    "    inception = InceptionScore().to(device)\n",
    "\n",
    "    for i, batch in enumerate(dataloader):\n",
    "        if i >= num_batches:\n",
    "            break\n",
    "\n",
    "        x_clean = batch['x'].to(device)\n",
    "        c = batch['c'].to(device)\n",
    "        eta = batch['eta'].to(device)\n",
    "\n",
    "        # --- Timing the generation ---\n",
    "        start_time = time.time()\n",
    "        if model_type == \"ddpm\":\n",
    "            samples = p_sample_loop(\n",
    "                model, x_clean.shape, c, schedule,\n",
    "                cond_scale=args.guidance_scale, device=device\n",
    "            ).clamp(-1., 1.)\n",
    "\n",
    "            \n",
    "        elif model_type == \"ecmmd\":\n",
    "            B = len(batch['x'])\n",
    "            samples = torch.zeros_like(x_clean).to(device)\n",
    "            for b in range(B):\n",
    "                # print(c.shape, eta.shape)\n",
    "                samples[b] = model(c[b].unsqueeze(0), eta[b]).squeeze(0).clamp(-1., 1.)\n",
    "            \n",
    "        end_time = time.time()\n",
    "        gen_times.append(end_time - start_time)\n",
    "\n",
    "        # --- Pairwise metrics ---\n",
    "        mse_val = F.mse_loss(samples, x_clean, reduction=\"mean\").item()\n",
    "        psnr_val = -10 * math.log10(mse_val + 1e-8)\n",
    "        ssim_val = ssim(samples, x_clean, data_range=2.0).item()  # since range is [-1,1]\n",
    "\n",
    "        mse_vals.append(mse_val)\n",
    "        psnr_vals.append(psnr_val)\n",
    "        ssim_vals.append(ssim_val)\n",
    "\n",
    "        # --- Distribution-level metrics ---\n",
    "        # rescale to [0,255] uint8 for FID/IS and convert to 3 channels\n",
    "        real_uint8 = ((x_clean + 1) * 127.5).clamp(0,255).byte()\n",
    "        fake_uint8 = ((samples + 1) * 127.5).clamp(0,255).byte()\n",
    "\n",
    "        # Convert 1 channel to 3 channels by repeating\n",
    "        real_uint8_3ch = real_uint8.repeat(1, 3, 1, 1)\n",
    "        fake_uint8_3ch = fake_uint8.repeat(1, 3, 1, 1)\n",
    "\n",
    "        fid.update(real_uint8_3ch, real=True)\n",
    "        fid.update(fake_uint8_3ch, real=False)\n",
    "\n",
    "        inception.update(fake_uint8_3ch)\n",
    "\n",
    "    # aggregate\n",
    "    metrics = {\n",
    "        \"MSE\": sum(mse_vals) / len(mse_vals),\n",
    "        \"PSNR\": sum(psnr_vals) / len(psnr_vals),\n",
    "        \"SSIM\": sum(ssim_vals) / len(ssim_vals),\n",
    "        \"FID\": fid.compute().item(),\n",
    "        \"Inception Score (mean)\": inception.compute()[0].item(),\n",
    "        \"GenTime (s/batch)\": sum(gen_times) / len(gen_times),  # avg seconds per batch\n",
    "        \"GenTime (s/img)\": (sum(gen_times) / len(gen_times)) / x_clean.shape[0]  # avg per image\n",
    "    }\n",
    "    return metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "37f695d9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_520765/1955888141.py:16: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  model.load_state_dict(torch.load(ckpt_path, map_location=device)['model_state'])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "\n",
    "transform = transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.5,), (0.5,))  # [-1,1]\n",
    "    ])\n",
    "\n",
    "val_ds = MNISTNoisyConditionDataset(root=args.data_root, train=False, download=True, sigma_c=args.sigma_c, transform=transform)\n",
    "val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=2, pin_memory=True)\n",
    "\n",
    "# Prepare model and schedule for sampling\n",
    "device = torch.device('cuda' if torch.cuda.is_available() and not args.cpu else 'cpu')\n",
    "schedule = make_ddpm_schedule(args.timesteps, device)\n",
    "model = SmallUNet(in_channels=2, base_ch=args.base_ch, time_emb_dim=args.time_emb_dim).to(device)\n",
    "ckpt_path = os.path.join(args.save_dir, 'model_final.pt')\n",
    "model.load_state_dict(torch.load(ckpt_path, map_location=device)['model_state'])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b40ab033",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "metrics = evaluate_model(model, val_loader, schedule, device, args, num_batches=50)\n",
    "print(metrics)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "429baad8",
   "metadata": {},
   "source": [
    "## CGMMD Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "eb886a68",
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCHSIZE      = 200\n",
    "ETA_DIM        = 7\n",
    "\n",
    "NUM_EPOCH      = 200\n",
    "LEARNING_RATE  = 1e-4\n",
    "WEIGHT_DECAY   = 1e-5\n",
    "NEIGHBORS      = 50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "b13936f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import *\n",
    "from nn import *\n",
    "import torch.optim as optim\n",
    "\n",
    "train_ds = MNISTNoisyConditionDataset(root=args.data_root, train=True, download=True, sigma_c=args.sigma_c, transform=transform)\n",
    "train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=True)\n",
    "\n",
    "# Model & optimizer\n",
    "model = DenoiseCNN().to(device)\n",
    "optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38db0746",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 1/10: 100%|██████████| 469/469 [00:05<00:00, 83.54it/s, loss=0.00573]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaIAAACaCAYAAADrVUwbAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAArrElEQVR4nO2deXxN19rHf4ckkiYE0TRipm0MQaghpWKqakhRVNU1RYmrRUuV0tb0aiMhbuuq0tbQ25r1xa0XNQ8lWmk1KlSrSgkJN5oihmtY7x+9OdfzO8c5GW30+X4++eO3z9p7r7X2ylln799+nmUzxhgoiqIoikUUsboCiqIoyp8bnYgURVEUS9GJSFEURbEUnYgURVEUS9GJSFEURbEUnYgURVEUS9GJSFEURbEUnYgURVEUS9GJSFEURbGUfE9E06dPh81mQ2hoaJ6PcfLkSYwfPx7fffddfquTI1q0aIEWLVrkqFx+2qX8wfz582Gz2eDt7Y1jx445fJ6ffs7ptSwojh49CpvNhqlTp962c94NZF/j7D9vb28EBQWhZcuWiI2NxenTp62uoqBv376oXLnybT+vzWbD+PHjXZb5M46xfE9Ec+fOBQCkpKTgq6++ytMxTp48iQkTJty2iUixhitXruCNN94o0GPOnDkTM2fOLNBjKnln3rx5SExMxIYNG/Dee+8hLCwMcXFxqFGjBjZu3Gh19ey8+eabWLFihdXVUP5DviaipKQkJCcno3379gCAOXPmFEillHuTJ598EgsXLkRycnKBHbNmzZqoWbNmgR1PyR+hoaEIDw9Hs2bN0KVLF/ztb3/Dvn374Ovri86dOyM9Pd3qKgIAqlWrhnr16lldDeU/5Gsiyp54Jk+ejCZNmmDx4sW4ePGiQ7nU1FTExMSgQoUK8PLyQnBwMLp27Yr09HRs3boVDRs2BABER0fbb+2zb19v9ejF2a31hAkT0LhxY5QuXRolSpRA/fr1MWfOHBRkXlebzYbBgwdj3rx5CAkJgY+PDxo0aIDdu3fDGIMpU6agSpUq8PPzQ6tWrXD48GGx/4YNG9CxY0eUL18e3t7eePDBBzFw4ED861//cjjXqlWrUKdOHRQrVgxVq1bFu+++i/Hjx8Nms4lyxhjMnDkTYWFh8PHxQalSpdC1a1ccOXKkwNpdEIwcORIBAQEYNWqU27KXL1/G6NGjUaVKFXh5eaFcuXJ48cUXkZmZKco5Gx/vv/8+6tatCz8/PxQvXhzVq1fHmDFjAPzx2MPDwwOxsbEO59y+fTtsNhuWLVuWq3ZlP5bavHkzBgwYgICAAJQoUQK9e/dGVlYW0tLS0K1bN5QsWRJly5bFiBEjcPXqVXGMnI7dK1eu4JVXXkFQUBDuu+8+RERE4JtvvkHlypXRt29fUTYtLQ0DBw5E+fLl4eXlhSpVqmDChAm4du1artqXXypWrIiEhAScP38es2fPFp8lJSWhQ4cOKF26NLy9vVGvXj0sXbpUlMnu3y1btmDQoEEoU6YMAgIC0LlzZ5w8eVKUvXHjBuLj41G9enUUK1YMgYGB6N27N06cOCHKOfv+WLZsGRo3bgx/f3/cd999qFq1Kvr16yfKnDt3DiNGjBDj8uWXX0ZWVpZDueyx4OfnhyeffBI//vhjXrpP9MG9OsY88tQrAC5duoRFixahYcOGCA0NRb9+/dC/f38sW7YMffr0sZdLTU1Fw4YNcfXqVYwZMwZ16tRBRkYGvvjiC/z222+oX78+5s2bh+joaLzxxhv2u6vy5cvnuk5Hjx7FwIEDUbFiRQDA7t27MWTIEKSmpmLs2LF5baoDq1evxt69ezF58mTYbDaMGjUK7du3R58+fXDkyBHMmDEDv//+O4YPH44uXbrgu+++s08eP//8Mx599FH0798f/v7+OHr0KKZNm4bHHnsM33//PTw9PQEA69atQ+fOnREREYElS5bg2rVrmDp1qtNflAMHDsT8+fMxdOhQxMXF4ezZs5g4cSKaNGmC5ORkPPDAAwXW9vxQvHhxvPHGG3jppZewefNmtGrVymk5Yww6deqETZs2YfTo0WjWrBn27duHcePGITExEYmJiShWrJjTfRcvXowXXngBQ4YMwdSpU1GkSBEcPnwYBw4cAABUrlwZHTp0wKxZszBy5EgULVrUvu+MGTMQHByMp59+Ok/t69+/Pzp37ozFixdj7969GDNmDK5du4ZDhw6hc+fOiImJwcaNGxEXF4fg4GAMHz7cvm9Ox250dDSWLFmCkSNHolWrVjhw4ACefvppnDt3TtQlLS0NjRo1QpEiRTB27FhUq1YNiYmJmDRpEo4ePYp58+blqY15pV27dihatCi2b99u37ZlyxY8+eSTaNy4MWbNmgV/f38sXrwYzz77LC5evOjwpde/f3+0b98eCxcuxPHjx/Hqq6+iZ8+e2Lx5s73MoEGD8MEHH2Dw4MGIiorC0aNH8eabb2Lr1q349ttvUaZMGaf1S0xMxLPPPotnn30W48ePt/uZNx/74sWLaN68OU6cOGH/LktJScHYsWPx/fffY+PGjbDZbPbxu2vXLowdOxYNGzbEzp07ERkZme9+vGfHmMkj//jHPwwAM2vWLGOMMefPnzd+fn6mWbNmoly/fv2Mp6enOXDgwC2PtWfPHgPAzJs3z+Gz5s2bm+bNmzts79Onj6lUqdItj3n9+nVz9epVM3HiRBMQEGBu3Ljh9pjOzl2rVi2xDYAJCgoyFy5csG9buXKlAWDCwsLEed555x0DwOzbt8/p8W/cuGGuXr1qjh07ZgCYVatW2T9r2LChqVChgrly5Yp92/nz501AQIC5+bIlJiYaACYhIUEc+/jx48bHx8eMHDnSbTsLm3nz5hkAZs+ePebKlSumatWqpkGDBva+4n5et26dAWDi4+PFcZYsWWIAmA8++MC+ja/l4MGDTcmSJV3WZ8uWLQaAWbFihX1bamqq8fDwMBMmTHC57y+//GIAmClTpji0b8iQIaJsp06dDAAzbdo0sT0sLMzUr1//lue41dhNSUkxAMyoUaNE+UWLFhkApk+fPvZtAwcONH5+fubYsWOi7NSpUw0Ak5KS4rKdueXma3wrHnjgAVOjRg27rl69uqlXr565evWqKBcVFWXKli1rrl+/Lo79wgsviHLx8fEGgDl16pQxxpiDBw86LffVV18ZAGbMmDH2bfz9kd0vmZmZt6x/bGysKVKkiEMbly9fbgCYNWvWGGOMWbt2rQFg3n33XVHurbfeMgDMuHHjbnkOY/6cYyzPj+bmzJkDHx8fdO/eHQDg5+eHZ555Bjt27MBPP/1kL7d27Vq0bNkSNWrUyOupcszmzZvx+OOPw9/fH0WLFoWnpyfGjh2LjIyMAn1rp2XLlvD19bXr7LZFRkaKx2bZ229+U+z06dP461//igoVKsDDwwOenp6oVKkSAODgwYMAgKysLCQlJaFTp07w8vKy7+vn54ennnpK1GX16tWw2Wzo2bMnrl27Zv8LCgpC3bp1sXXr1gJrd0Hg5eWFSZMmISkpyeERTDbZv0L5F/EzzzwDX19fbNq06ZbHb9SoETIzM/Hcc89h1apVTh95tmjRAnXr1sV7771n3zZr1izYbDbExMTkoVV/EBUVJXT29c++y795O789mJOxu23bNgBAt27dxL5du3aFh4d8uLF69Wq0bNkSwcHBYlxk/yrPPtbtxNz0COjw4cP44Ycf8Je//AUARB3btWuHU6dO4dChQ2L/Dh06CF2nTh0A//3/2rJlCwDHcdOoUSPUqFHD5bjJtge6deuGpUuXIjU11aHM6tWrERoairCwMFHftm3bwmaz2f/XsuuR3bZsevToccvz55R7dYzlaSI6fPgwtm/fjvbt28MYg8zMTGRmZqJr164A/vsmHQCcOXMmT4/ZcsvXX3+NJ554AgDw4YcfYufOndizZw9ef/11AH88SiwoSpcuLXT2ZHGr7ZcvXwbwx/PrJ554Av/7v/+LkSNHYtOmTfj666+xe/duUcfffvsNxhinj9R4W3p6ur2sp6en+Nu9e7fTL2Kr6d69O+rXr4/XX3/d4Tk2AGRkZMDDwwP333+/2G6z2RAUFISMjIxbHrtXr16YO3cujh07hi5duiAwMBCNGzfGhg0bRLmhQ4di06ZNOHToEK5evYoPP/wQXbt2RVBQUJ7blZtxkT0mgJyP3ex28xjw8PBAQECA2Jaeno7PP//cYUzUqlULAG77uMjKykJGRgaCg4Pt9QOAESNGONTxhRdecFpHbmP241nun7JlyzqcPzg42OW4iYiIwMqVK3Ht2jX07t0b5cuXR2hoKBYtWmQvk56ejn379jnUt3jx4jDG2OubPX65vvkZW9ncq2MsTx7R3LlzYYzB8uXLsXz5cofPP/74Y0yaNAlFixbF/fff72AU5gZvb2/8/vvvDtu5kYsXL4anpydWr14Nb29v+/aVK1fm+dwFzf79+5GcnIz58+cLH41faChVqhRsNptTPygtLU3oMmXKwGazYceOHU59k1t5KVZis9kQFxeHNm3a4IMPPnD4PCAgANeuXcOZM2fEZGSMQVpamv3X662Ijo5GdHQ0srKysH37dowbNw5RUVH48ccf7XefPXr0wKhRo/Dee+8hPDwcaWlpePHFFwu2oTkkp2M3+4sgPT0d5cqVs2+/du2aw5dsmTJlUKdOHbz11ltOz5k9Idwu/u///g/Xr1+3v1iS7dWMHj0anTt3drpPSEhIrs6R3T+nTp1y+PF78uTJW/pD2XTs2BEdO3bElStXsHv3bsTGxqJHjx6oXLkyHn30UZQpUwY+Pj7ih/bNZB8/e/xmZGSIL2/+372d3OljLNcT0fXr1/Hxxx+jWrVq+Oijjxw+X716NRISErB27VpERUUhMjISn3zyCQ4dOnTLgcW/bG6mcuXKWLZsGa5cuWIvl5GRgV27dqFEiRL2cjabDR4eHsJ8vnTpEj755JPcNrHQyH5sx5MDv0nk6+uLBg0aYOXKlZg6dar9V8+FCxewevVqUTYqKgqTJ09Gamqqw+30nczjjz+ONm3aYOLEiahQoYL4rHXr1oiPj8enn36KYcOG2bd/9tlnyMrKQuvWrXN0Dl9fX0RGRuLf//43OnXqhJSUFPtE5O3tjZiYGMyYMQO7du1CWFgYmjZtWnANzAU5HbsREREAgCVLlqB+/fr27cuXL3d4SykqKgpr1qxBtWrVUKpUqUKsvXt+/fVXjBgxAv7+/hg4cCCAPyaZhx56CMnJyXj77bcL5DzZL798+umn4sfKnj17cPDgQfuvf3cUK1YMzZs3R8mSJfHFF19g7969ePTRRxEVFYW3334bAQEBqFKlyi33b9myJeLj47FgwQIMHTrUvn3hwoV5bFn+udPHWK4norVr1+LkyZOIi4tz+lp1aGgoZsyYgTlz5iAqKgoTJ07E2rVrERERgTFjxqB27drIzMzEunXrMHz4cFSvXh3VqlWDj48PFixYgBo1asDPzw/BwcEIDg5Gr169MHv2bPTs2RMDBgxARkYG4uPjxSQE/PGMdNq0aejRowdiYmKQkZGBqVOn3lF3BNltfe2112CMQenSpfH55587PDYCgIkTJ6J9+/Zo27YtXnrpJVy/fh1TpkyBn58fzp49ay/XtGlTxMTEIDo6GklJSYiIiICvry9OnTqFL7/8ErVr18agQYNuZzNzTFxcHB555BGcPn3afjsPAG3atEHbtm0xatQonDt3Dk2bNrW/NVevXj306tXrlsccMGAAfHx80LRpU5QtWxZpaWmIjY2Fv7+/w53UCy+8gPj4eHzzzTdOf1TdLnI6dmvVqoXnnnsOCQkJKFq0KFq1aoWUlBQkJCTA398fRYr890n7xIkTsWHDBjRp0gRDhw5FSEgILl++jKNHj2LNmjWYNWtWoTwy379/v90rOH36NHbs2IF58+ahaNGiWLFihbjDnT17NiIjI9G2bVv07dsX5cqVw9mzZ3Hw4EF8++23uX6NPiQkBDExMfj73/+OIkWKIDIy0v7WXIUKFcSPGmbs2LE4ceIEWrdujfLlyyMzMxPvvvsuPD090bx5cwDAyy+/jM8++wwREREYNmwY6tSpgxs3buDXX3/F+vXr8corr6Bx48Z44oknEBERgZEjRyIrKwsNGjTAzp07Lf1RfMePsRy/1vAfOnXqZLy8vMzp06dvWaZ79+7Gw8PDpKWlGWP+eIOrX79+JigoyHh6eprg4GDTrVs3k56ebt9n0aJFpnr16sbT09PhzZKPP/7Y1KhRw3h7e5uaNWuaJUuWOH1rbu7cuSYkJMQUK1bMVK1a1cTGxpo5c+YYAOaXX36xl8vvW3Mvvvii2ObsLRdj/vt21rJly+zbDhw4YNq0aWOKFy9uSpUqZZ555hnz66+/On2bZsWKFaZ27drGy8vLVKxY0UyePNkMHTrUlCpVyqGuc+fONY0bNza+vr7Gx8fHVKtWzfTu3dskJSW5bWdh4+qNqh49ehgADv186dIlM2rUKFOpUiXj6elpypYtawYNGmR+++03UY6v5ccff2xatmxpHnjgAePl5WUfa7d6c7FFixamdOnS5uLFizlqi6s3mrh948aNMwDMmTNnxPY+ffoYX19fsS2nY/fy5ctm+PDhJjAw0Hh7e5vw8HCTmJho/P39zbBhw8Qxz5w5Y4YOHWqqVKliPD09TenSpc0jjzxiXn/9dfHWZ0GQ3QfZf15eXiYwMNA0b97cvP3227f8vkhOTjbdunUzgYGBxtPT0wQFBZlWrVrZ38a9+djcv9n/X1u2bLFvu379uomLizMPP/yw8fT0NGXKlDE9e/Y0x48fF/vy98fq1atNZGSkKVeunL3u7dq1Mzt27BD7XbhwwbzxxhsmJCTEeHl5GX9/f1O7dm0zbNgw+/edMcZkZmaafv36mZIlS5r77rvPtGnTxvzwww/5fmvuXh1jNmMKMNpTKVSuXr2KsLAwlCtXDuvXr7e6Onc9p0+fRqVKlTBkyBDEx8dbXZ08s2vXLjRt2hQLFiwokDezFIUp7DGmE9EdzPPPP482bdrYHzHNmjUL27Ztw/r16/H4449bXb27lhMnTuDIkSOYMmUKNm/ejB9//FEYs3cyGzZsQGJiIh555BH4+PggOTkZkydPhr+/P/bt2yeMaEXJC1aMsTxnVlAKn/Pnz2PEiBE4c+YMPD09Ub9+faxZs0YnoXzy0UcfYeLEiahcuTIWLFhw10xCAFCiRAmsX78e77zzDs6fP48yZcogMjISsbGxOgkpBYIVY0zviBRFURRL0YXxFEVRFEvRiUhRFEWxFJ2IFEVRFEvRiUhRFEWxlBy/NceLsSl3J4X5bkpgYKDQnPGcM4GHh4cL7ePj4/L4nIJn586dDmX+8Y9/CM1Lk9+cCNJZHRs1aiT0119/LfRjjz0m9Jdffil09npS2ThL6noznN7o+PHjLssDcIh5GjlypNt9boZT3XCusMIcI9w/N0fqA3DI8Pzvf/9baK7bzSlrnOkrV6441IHPcePGDaF5HF64cEFozkbAqW+4jVwH/pzPz33CY4i/i7k9zo7B/Xj9+nWXx+Q2uts/v+gdkaIoimIpOhEpiqIolqITkaIoimIpmllBKTDOnDkj9ODBg4W+eUVUAKhbt26uju/ME2J69+6dq2PGxcUJ/eGHHwp94MABoWvWrOnyeO48Ia4fe1oMe1oA7Auc5RX2pW4n7Ke484T4c+4P9jbYu2DPyNk52E9hT4d9Kf6c9+c2ch15f/78vvvuc3k8d/6OszJcR3f9xG1mX6ug0TsiRVEUxVJ0IlIURVEsRSciRVEUxVLUI1IKDfaEmI0bNwodExMj9JgxY4T+4YcfhHa1Ums2oaGhQvft21foGTNmCM2eBHtC2cu2Z8N+w9SpU4VOTU0V+m9/+5vQ3bt3F7patWpCs18AAKtWrRJ627ZtQjdr1kxoztb+17/+VehRo0Y5nKOw4OzNufVr2Ktw58c48+w4Tsid58NwjA3HATFcJ19fX6EvXbokNPs37Nfw+Z3F9LirE49jPgb3W2HnxtY7IkVRFMVSdCJSFEVRLEUnIkVRFMVScrwwnuaauzcozGe9PEbc5W1zx8MPP+zy84SEBIdtTz31lNAcM3PixAmhuT/ceUCVK1cWev78+UInJycLzc/qOW5p6NChQrMvlhdee+01oSdPnix0bGys0KNHjxa6MMeIs7xoN8Nehbu8cBz/wt6Gs/Oxr8S+Fccq8TF4jPD+Z8+eFbps2bIuP/fz8xM6IyNDaH9/f6F///13odkzAhx9L24D94G7cc+fO8vhlx/0jkhRFEWxFJ2IFEVRFEvRiUhRFEWxFPWI/mTcTo8ov6SlpQnNz8o7d+7ssM/BgweFzsrKcnkO7o9//etfLsuz/8K54g4dOiQ0e1RRUVFC//Of/xT63LlzQjdo0MChDklJSS7r6I46deoIvW/fPqELc4ywp8MxO/y5u9xz7Bmxl+Esjqh48eIuz8lxRhwLxnFa7BFx7BfXkf9PeIyeP39eaB4TFy9eFNpZPkJuN3tCfEy+Du5y0+l6RIqiKMo9hU5EiqIoiqXoRKQoiqJYyj2Va46fx3NeqsDAwAI9X0hIiNDly5cX+uWXXxaac4Tx/gAwYMAAoTm/2t3E2LFjhZ44cWKu9g8KChI6PDxc6N27dzvs06JFC6G3bt0qdOvWrYXmuB93sKdz8uRJodl/eP3114V+6623hP7oo4+E5mfz0dHRbuvUsGFDoffs2SM0exJdu3YVmj2iwsRdDjRuP/8Ps2fEn/P+zvyu3377TehKlSoJXapUKaE5no3bUKVKFaHvv/9+ocuUKSM0+zUcF3Tq1Cmh2QP69NNPhXb2PcKxSHxO7hceI9xGZ7FKBYneESmKoiiWohORoiiKYik6ESmKoiiWcld5RPwcs2PHjkK/8847QpcoUUJofvZ7u3n++efdlmGP4272iNgT4pxbFy5cEJrjJxiOqQkLC3MoM2vWLKHdxcRwrrfp06e7LM9w/ryVK1cKPWnSJKH5WTuPCY5L+uSTTxzOyeswsUfkLp6LPRL+PylM3OU0c7e+kDuvgsuXLFnSoQzHo3H73a0JxXFYwcHBQrMnxG12l1vup59+Evr06dNC8xpW7CkBjt8bpUuXdihzM5w7jj2igs4tx+gdkaIoimIpOhEpiqIolqITkaIoimIpd5VHxM9eV6xYYVFNCoYjR444bHMWG3O3wnnB+Nn7l19+6fLzn3/+WegZM2a4PecXX3yRmyo6eEL8PN5d7Nn7778v9Pjx412Wd+dT/v3vfxeaPSPA0Qfp1KmTy2MynBuN844VJuwJuVsLiOOE2Ktgv4V9x4CAAIc6HD16VGiOLeNrzmtQcdzRQw89JDR7TtwG1gz3AcemcR45brMzfH19hXa3JhIf012d84veESmKoiiWohORoiiKYik6ESmKoiiWohORoiiKYil31csKbLBt3LhR6Mcff9zl/mzAsSk4d+5cod0laORF0NgQXLduncv9OZANADIzM13uczfhLtEi06hRI6E5aWqfPn2E5kBCAGjbtq3QHHzI1/T7778X+tKlS0JzgsmePXsKfebMGaHZyOagW36ZgV/A4ZcnatasCXdwEK073n77baG3bduWq/3zA7/Awi8f8MsLbMzz/lz+gQceEJoXVwSAJ554QugmTZoIzQGsNWrUEJpfRuCkqMePHxea/w84oJjbyN8jvFgjJ0nlRK+AYxtTUlKE5pcTeKE7DsItbPSOSFEURbEUnYgURVEUS9GJSFEURbGUu8oj4ueYWVlZudqfn88nJCTkt0qKC6pXry70zp07hW7atKnQixYtcqkZXpQOcEy0mpiYKPSaNWtcHnPz5s1C9+vXT2hOODpz5kyh3SUc5SBm7oO//OUvLvcHHINce/fuLXStWrWE3r9/v9B8XerWrev2nAUF+yXs+TD8OQfE8ncCe07s7wBA0aJFheZEqrzAJge4sqfDY4YDUN0lDGWP5+LFi0LzmHLnkwFAenq60D4+PkKnpqa6PGZOFhgsSPSOSFEURbEUnYgURVEUS9GJSFEURbGUu8ojUu4u2MN56aWXhOa4rwkTJgidlJQkNMcAhYeHO5zzwQcfFPrNN98UesiQIUI3btxY6D179ghdrlw5ofnZ+pNPPulQB1ewX8MLlq1fv15ojgcB3HtCufWtOnTo4PLzgoS9CHexZuwBuUsay/4Pa2fn5Ng99ow4FonjejjuiJOq8vnYc2JPieMd+XzcR85iD7nO/L/DdWZfiq8TX4eCRu+IFEVRFEvRiUhRFEWxFJ2IFEVRFEuxmRy+IO7uObMVcI6tjh07uiw/YsQIof+McUSFGQ/w1FNPCc1xWw0aNBCaFxxjz+jEiRNCc+4+zusGAMeOHRPaXXu3bNki9NatW4XmuCRevC8oKEjoihUrCu0uFophz4if5QOO/cIeA8fBHDhwIFd1KMwxwvEpnNOMY274e4f7w138S0hIiEMdeNzxuAwNDRWa8wlyjA5rzkXHcT48ZjhnIp+P/R72Ttm3BIC9e/cKzdeUYzC5nzlWihfG43iu/KJ3RIqiKIql6ESkKIqiWIpORIqiKIqlaBzRTTz22GNCP/3000I3b95caI45GTVqlNDnzp0rwNrd+fD6UAcPHnRZnp/Vf/TRR0K7i31wtp4Teyx9+/YVev78+UK3bNlS6IULF96qugCA559/Xmhe/4Zz03EbY2NjXR6fY07YDwKAqlWrCs3569x5QtxHzvqxsGAPh+NTeC0e9jK4f3j9KPZfnPldVapUEZr78/z580Lz+kHsl/A15jWk2PPhNagY9mu4PtxnztZN43NwLBJ7a9xPnJuO+72g0TsiRVEUxVJ0IlIURVEsRSciRVEUxVL+VB7R4MGDhe7UqZPQ7BG545FHHhGan29zTrB7nTFjxgg9duxYl+U5ZodhP4HzxO3YscNhH35ePn36dKHZx9u+fbvQ7777rtDTpk0TmmN2eH2jl19+Wejo6GihV6xYIfSvv/4qdE58RfaE2Lts0qSJ0OxLLV26VGj2JAozjoj9DfZb2BNy5xlx7jmO4fHz83NbJ+5zjvvhc3BeNj4nx0bVr19faHc+F/s5XL/PP/9c6IcffhgM78PjlvPTcU4+vi7OfKiCRO+IFEVRFEvRiUhRFEWxFJ2IFEVRFEu5qzwifve9ZMmSudqf3/dnnV84puTPxvfffy80xxHVqFHD5f7Vq1cX+ocffhD68uXLQn/33Xdu6/T+++8LzZ5QYmKi22O4ol27di71p59+KjT3SUxMjNCHDh0S2plnxO3etWuX0Jz7jOF1otiHKUw4PoWvKXs67J+wb+guB6az9Ys4LojXH+LvFb4G7KdwXJK/v7/QvP4Q9z+3kfuE87q1bdtWaM6vCDh6QFwnvuZcB/aEdD0iRVEU5Z5GJyJFURTFUnQiUhRFUSzlrlqPiPMnnT59Wmh3dTx79qzQS5YsEXrOnDlCc9fUrVtX6Llz5wrNOaE4ZoX9CisozBiRevXqCc1rmnTo0EFojgP68ssvXR6f4yc47xvgOEbYE/rll1+EZn+F84TxGkm5/T/g/n7wwQeF/vbbb4V+7bXXhHbmYXGcytdff52rOjHcxg0bNuTreK5gP4ZhP4Q9JfYy2ANi/4fzRQKO3jDnmuMxxOOY44zq1KkjNMf18Jjh/IEcS8beaHJystCcG/Do0aNguB8uXLjg8px8XXhdKHfrRuUXvSNSFEVRLEUnIkVRFMVSdCJSFEVRLOWuiiPiHExxcXFCc6449gO4fEpKSq7Oz+vhMJzPqXv37kLfCR5RYcLxLZ988onQnHuP/ZPJkycLzX4J51hjjxBwjKEZNmyY0K1btxa6UqVKQrN/ULt2baHDw8OF5pgQfvbOHD58WOgePXoI3bBhQ6GdjRmOteE6cS6zRx99VGj2lDp27OiixgUL+yXsNfD/GPsz5cuXF5qvN38HnDp1yqEOHCfEcTicS47XOGKPiNcb4txzHIPDPtfvv/8uNMcRcZ9xHJMz35f7ldvEbeA68ucF7QkxekekKIqiWIpORIqiKIql6ESkKIqiWMpd5RExo0ePtroKLnnooYesroKl9OrVy+XnHB/B6zsdOHBAaI7x6dKli8MxV65cKfQ777wj9PHjx4X+7LPPhOZYsf/5n/8RmuNSOP7ilVdecaiTK8aPHy90SEiI0Pfff7/DPmXLlhW6X79+QnP+uqSkJKE59oY9icKEvQZ3696wTk1NFbpcuXJCcxwSa8DRo2FPjfuD68y+E+/PcUj8OftaHPPD9WPfbMGCBS7PBzi2mz0f9ts5h19hrz/E6B2RoiiKYik6ESmKoiiWohORoiiKYil3tUek3Fk899xzQnNc1fz584VmP4ZhT2jTpk1Cc0wQADRt2tTlMdkTYjhGY+nSpUJPnz5daM4F99VXX7k8Pufb43yHjLNYKY4rYU/IHexJ8DpShYm73HG85hjHt3AcEV8v9jrYT3N2To6x4f7lHJK87hh7OOwBsV/DnhNrjgvi9ZA4N6CzXHO8D8djuVsXiutQmDkqAb0jUhRFUSxGJyJFURTFUnQiUhRFUSzlrlqPyGpmzpwp9KBBg1yWX7NmjdDt27cv8DrllsJ81stjhL0Ijl8ZPHiw0DNmzCjwOnFcT0JCgtAtW7YUesuWLUJzf7Vr107otWvXujz/7NmzhW7Tpo3Q7O9wrjqOmwGArKwsocPCwoSuWLGi0G+99ZbQPI45n93tHCO59YzYE+I4Is5HyLnnAKBChQpCc+45jt3iOvv7+wvN6xkFBwcLzR4TxxXxNWdfkHPhsSeUmZkJhmP0AgICXJ6D28ieEvc7tyG/6B2RoiiKYik6ESmKoiiWohORoiiKYikaR+QCfj791FNP5Wr/5cuXF2R17jrYE2J/JLeeEOdhO3TokEMZ9jf42feECROEHjdunNCcq86dN8p+gbPn9TezbNkyofnZO8ekNGvWzOEY69atE5o9iJ9//lnoRo0aCd23b1+h33zzzVtXuIBx5zVwTjT2Knx9fYVmv6xBgwZCcy5AwPGacb47zt3GHhLHEbEn5C6Waf/+/UJfuHBBaPbJ2BPi9Ys4bxzgOI54H64jw96drkekKIqi3NPoRKQoiqJYik5EiqIoiqXcVR4RP1+OiIgQmp/98vN+Ly8voTnegvOUjRkzRmjOc8X89NNPQrvLa3avU716daHZ0+G1fVasWOHyeOwRtWrVyqGMO09n0qRJQs+dO1do9k8iIyOF7tSpk9AcB8RxSBynxHndvvjiC5f1dQbn8GNPiH0S9hxWrVqV7zrkFc4dx9eLPSP2b86ePSs0/09yzjT2KQHg4MGDQjdp0kRojiPiOvExuTz7KVxn9oQ4F92JEyeEZp+L8xkGBgaCcbeuE5+T+5m9ucJG74gURVEUS9GJSFEURbEUnYgURVEUS9GJSFEURbGUu+plhcWLFwvdtWtXodkk5EXLOMiLA/1yCxt6I0eOFJoXp/qzwS8fREdHC/3www8LzYGEHKjHC5jxiwGA+4SdaWlpQgcFBbk856uvvio0J7Llz6dMmSI0B5uy0VyiRAmhuQ+cLR7I/cpJPJOSkoRu0aKF0Fu3bhW6du3aQhdm0lN3AascaMmf84sCbLrzGOH+B4Dw8HCXdeJjckArvzTFwaLukpryS038AgsnJM3IyBCaX07g8zuD68QvQHCb+SUSftmhoNE7IkVRFMVSdCJSFEVRLEUnIkVRFMVS7liPaPTo0Q7bunTp4nIffu7Jz8ZzCweiLVy4UGgOWOVn7392Zs2aJTQnOeXAS/aM2K85c+aM0BxMCjg+L9+2bZvQHTp0EJrH2eHDhx2OeTO8MB6Psd27dwvtzMe6GfZnDhw4IDR7HoDjYnDHjx93eQ4+JrN06VKXnxckHHDqbmE89mPYM2Lvgv0VZ3B/caJaDnznYE/2fNh/YR+SA1h/+eUXoTnAlj0kDpjlPuQ+AhzHDfczH4PbzH67uySp+UXviBRFURRL0YlIURRFsRSdiBRFURRLsZkcBg24SyZZ0Dz44IMO24YPHy70oEGD8nUOd0lKZ8+eLTQvUHU3UpgxIu7GCC8gdvLkyVwdv1atWkJPmzbNoczAgQOFDg0NFfrzzz93WZ6vOXtI/Px//vz5QvOz9pSUFKEfeughhzrnFk68ysl9+Rr36tVL6H/+859Cc7xbYY4R9hrY32A/hr0KLs/HY81JVgHHheo4Nqly5cpCs0fD/gvHIbHnxF4zL+bHvhZ/7m5RO2cxPtxP7ha242O6+1/muKP8ondEiqIoiqXoRKQoiqJYik5EiqIoiqXcsR6RUjhY6RExCQkJQr/yyisuy48bN05o9kYAoGrVqkJz3A/7jPw83l1OLT7+kSNHhI6Pjxea8w+647HHHhN6//79DmXYg3AHX3OOpWLf7HaOEfYy3H3OXoeHhwyFzMkCb+wbBQQECO3r6yu0n5+f0Oz78QKQHFfEfgrHIXG+Qd6f28ieHn+eE9ydg68Dj4mCXjhP74gURVEUS9GJSFEURbEUnYgURVEUS1GP6E/GneQR8bN3zqF26NChfNepfv36QnMdv/nmm1wdj5+lDxgwQGiOM+K1gziuyZkHdDPsSQGOvhTXYfr06UL3799f6AULFggdFhYm9N69e13WKT9wzI2zOJ+bcedlsKfHnzsb7+48Fc5ZyTE27FO5y3/HbShevLjQnF+Pz88+pru1hAD3ueP4OriL1+LrVNDrE+kdkaIoimIpOhEpiqIolqITkaIoimIp6hH9ybidHtHgwYOF5vWImIiICKG3b98udGBgoNA5WXuGad68udC8XlF4eLjQvL4QM2zYMKE5bqlNmza5qt+pU6eE7t27t0MZzrnH62Clp6cLzW1ij+LVV18Vmn2sgoTHCPsd7FWw18F+i7vcdPw54OjpsF/Ca/Wwp8Rr+/Dx3K25xJ4Ot9Gd58Q6J2109/3N18Hdmkd8HfKL3hEpiqIolqITkaIoimIpOhEpiqIolpJjj0hRFEVRCgO9I1IURVEsRSciRVEUxVJ0IlIURVEsRSciRVEUxVJ0IlIURVEsRSciRVEUxVJ0IlIURVEsRSciRVEUxVJ0IlIURVEs5f8BWiO/d+qY9PMAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 500x300 with 3 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 2/10: 100%|██████████| 469/469 [00:05<00:00, 83.54it/s, loss=0.00249]\n",
      "Epoch 3/10: 100%|██████████| 469/469 [00:05<00:00, 81.05it/s, loss=0.00091]\n",
      "Epoch 4/10: 100%|██████████| 469/469 [00:05<00:00, 79.47it/s, loss=0.0013]\n",
      "Epoch 5/10: 100%|██████████| 469/469 [00:05<00:00, 82.87it/s, loss=0.000433]\n",
      "Epoch 6/10: 100%|██████████| 469/469 [00:05<00:00, 79.40it/s, loss=0.000617]\n",
      "Epoch 7/10: 100%|██████████| 469/469 [00:05<00:00, 81.29it/s, loss=0.000307]\n",
      "Epoch 8/10: 100%|██████████| 469/469 [00:05<00:00, 79.90it/s, loss=0.000169]\n",
      "Epoch 9/10: 100%|██████████| 469/469 [00:05<00:00, 81.03it/s, loss=0.000163]\n",
      "Epoch 10/10: 100%|██████████| 469/469 [00:05<00:00, 81.73it/s, loss=0.000104]\n",
      "Epoch 11/10: 100%|██████████| 469/469 [00:05<00:00, 81.54it/s, loss=4.93e-5]\n",
      "Epoch 12/10: 100%|██████████| 469/469 [00:05<00:00, 79.22it/s, loss=3.5e-5]\n",
      "Epoch 13/10: 100%|██████████| 469/469 [00:06<00:00, 76.94it/s, loss=5.25e-5]\n",
      "Epoch 14/10: 100%|██████████| 469/469 [00:05<00:00, 79.09it/s, loss=2.35e-5]\n",
      "Epoch 15/10: 100%|██████████| 469/469 [00:05<00:00, 81.05it/s, loss=1.88e-5]\n",
      "Epoch 16/10: 100%|██████████| 469/469 [00:05<00:00, 78.17it/s, loss=2.08e-5]\n",
      "Epoch 17/10: 100%|██████████| 469/469 [00:05<00:00, 81.84it/s, loss=1.22e-5]\n",
      "Epoch 18/10: 100%|██████████| 469/469 [00:06<00:00, 77.21it/s, loss=1.06e-5]\n",
      "Epoch 19/10: 100%|██████████| 469/469 [00:05<00:00, 80.53it/s, loss=2.05e-5]\n",
      "Epoch 20/10: 100%|██████████| 469/469 [00:05<00:00, 79.60it/s, loss=9.76e-6]\n",
      "Epoch 21/10: 100%|██████████| 469/469 [00:05<00:00, 79.08it/s, loss=6.44e-6]\n",
      "Epoch 22/10: 100%|██████████| 469/469 [00:05<00:00, 83.60it/s, loss=1.13e-5]\n",
      "Epoch 23/10: 100%|██████████| 469/469 [00:05<00:00, 82.74it/s, loss=3.98e-6]\n",
      "Epoch 24/10: 100%|██████████| 469/469 [00:05<00:00, 82.24it/s, loss=6.75e-6]\n",
      "Epoch 25/10: 100%|██████████| 469/469 [00:05<00:00, 79.83it/s, loss=1.91e-5]\n",
      "Epoch 26/10: 100%|██████████| 469/469 [00:05<00:00, 82.64it/s, loss=2.83e-6]\n",
      "Epoch 27/10: 100%|██████████| 469/469 [00:05<00:00, 83.07it/s, loss=4.02e-6]\n",
      "Epoch 28/10: 100%|██████████| 469/469 [00:05<00:00, 82.33it/s, loss=9.55e-6]\n",
      "Epoch 29/10: 100%|██████████| 469/469 [00:05<00:00, 81.18it/s, loss=3.53e-6]\n",
      "Epoch 30/10: 100%|██████████| 469/469 [00:05<00:00, 78.64it/s, loss=5.43e-6]\n",
      "Epoch 31/10: 100%|██████████| 469/469 [00:05<00:00, 82.40it/s, loss=2.15e-6]\n",
      "Epoch 32/10: 100%|██████████| 469/469 [00:05<00:00, 81.86it/s, loss=2.88e-6]\n",
      "Epoch 33/10: 100%|██████████| 469/469 [00:05<00:00, 80.09it/s, loss=1.76e-6]\n",
      "Epoch 34/10: 100%|██████████| 469/469 [00:05<00:00, 80.81it/s, loss=3.89e-6]\n",
      "Epoch 35/10: 100%|██████████| 469/469 [00:05<00:00, 81.99it/s, loss=9.65e-7]\n",
      "Epoch 36/10: 100%|██████████| 469/469 [00:05<00:00, 79.87it/s, loss=5.66e-6]\n",
      "Epoch 37/10: 100%|██████████| 469/469 [00:05<00:00, 78.80it/s, loss=2.81e-6]\n",
      "Epoch 38/10: 100%|██████████| 469/469 [00:05<00:00, 79.04it/s, loss=6.77e-7]\n",
      "Epoch 39/10: 100%|██████████| 469/469 [00:05<00:00, 79.68it/s, loss=1.59e-6]\n",
      "Epoch 40/10: 100%|██████████| 469/469 [00:05<00:00, 81.17it/s, loss=2.59e-9]\n",
      "Epoch 41/10: 100%|██████████| 469/469 [00:05<00:00, 81.67it/s, loss=1.97e-7]\n",
      "Epoch 42/10: 100%|██████████| 469/469 [00:05<00:00, 82.63it/s, loss=7.24e-7]\n",
      "Epoch 43/10: 100%|██████████| 469/469 [00:06<00:00, 72.98it/s, loss=1.66e-6]\n",
      "Epoch 44/10: 100%|██████████| 469/469 [00:06<00:00, 73.20it/s, loss=3.9e-6]\n",
      "Epoch 45/10: 100%|██████████| 469/469 [00:05<00:00, 82.66it/s, loss=6.79e-8]\n",
      "Epoch 46/10: 100%|██████████| 469/469 [00:05<00:00, 78.75it/s, loss=1.7e-6]\n",
      "Epoch 47/10: 100%|██████████| 469/469 [00:05<00:00, 81.31it/s, loss=4.12e-7]\n",
      "Epoch 48/10: 100%|██████████| 469/469 [00:05<00:00, 79.33it/s, loss=1.24e-6]\n",
      "Epoch 49/10: 100%|██████████| 469/469 [00:05<00:00, 80.87it/s, loss=1.28e-6]\n",
      "Epoch 50/10: 100%|██████████| 469/469 [00:05<00:00, 81.47it/s, loss=3.39e-6]\n",
      "Epoch 51/10: 100%|██████████| 469/469 [00:05<00:00, 83.25it/s, loss=1.09e-6]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaIAAACaCAYAAADrVUwbAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAnk0lEQVR4nO3deXxMV/8H8M+QRCIhJClZLEH7EGvs2yOxpYRYaqe2KFJ7LZVSgpSSiBZF7aH24sFPLLWXVuyE2lVjSSQ0hIjShPP7o0/m8f1OzGR1k/i+Xy9/fCZ37pw7ueZk7veec3RKKQUhhBBCI/m0boAQQoh3m3REQgghNCUdkRBCCE1JRySEEEJT0hEJIYTQlHREQgghNCUdkRBCCE1JRySEEEJT0hEJIYTQVKY7orlz50Kn06Fy5coZ3kd0dDQmT56Mc+fOZbY5adK4cWM0btw4Tdtl5rjEP1asWAGdTgdLS0vcunXL4OeZeZ/T+rvMKpGRkdDpdAgJCXlrr5kbpPyOU/5ZWlrC0dERTZo0wfTp03H//n2tm0j07dsXrq6ub/11dTodJk+ebHSbd/Ecy3RHtHz5cgDAxYsXcfz48QztIzo6GlOmTHlrHZHQxosXLzBhwoQs3eeCBQuwYMGCLN2nyLjQ0FCEh4dj7969mD9/Ptzd3REUFAQ3Nzfs27dP6+bpTZw4EVu2bNG6GeK/MtURnTp1ChEREWjdujUAYNmyZVnSKJE3tWzZEmvXrkVERESW7bNixYqoWLFilu1PZE7lypVRr149NGrUCB07dsS3336L8+fPw9raGh06dEBsbKzWTQQAlCtXDtWrV9e6GeK/MtURpXQ8M2bMQIMGDbB+/Xo8e/bMYLuoqCgMHDgQJUuWhIWFBZydndGpUyfExsbi0KFDqF27NgDA19dX/9U+5evrmy69pPbVesqUKahbty7s7OxQuHBh1KhRA8uWLUNWzuuq0+kwdOhQhIaGonz58rCyskKtWrVw7NgxKKUwc+ZMlClTBjY2NmjatClu3LhBnr937160a9cOJUqUgKWlJd5//334+fnhzz//NHitbdu2oWrVqihQoADKli2LOXPmYPLkydDpdGQ7pRQWLFgAd3d3WFlZoWjRoujUqRNu3ryZZcedFcaOHQt7e3v4+/ub3Pb58+cYN24cypQpAwsLC7i4uGDIkCGIj48n26V2fnz//feoVq0abGxsUKhQIVSoUAHjx48H8M9lDzMzM0yfPt3gNQ8fPgydToeNGzem67hSLksdOHAAAwYMgL29PQoXLozevXsjMTERMTEx6NKlC4oUKQInJyeMGTMGSUlJZB9pPXdfvHiB0aNHw9HREQULFoSHhwdOnz4NV1dX9O3bl2wbExMDPz8/lChRAhYWFihTpgymTJmC5OTkdB1fZpUqVQqzZs1CQkICFi1aRH526tQptG3bFnZ2drC0tET16tXx448/km1S3t+DBw9i0KBBcHBwgL29PTp06IDo6Giy7atXrxAcHIwKFSqgQIECKFasGHr37o27d++S7VL7/Ni4cSPq1q0LW1tbFCxYEGXLlkW/fv3INk+ePMGYMWPIefnZZ58hMTHRYLuUc8HGxgYtW7bEtWvXMvL2kfcgr55jZhl6VwD89ddfWLduHWrXro3KlSujX79+6N+/PzZu3Ig+ffrot4uKikLt2rWRlJSE8ePHo2rVqoiLi8NPP/2ER48eoUaNGggNDYWvry8mTJig/3ZVokSJdLcpMjISfn5+KFWqFADg2LFjGDZsGKKiohAQEJDRQzUQFhaGs2fPYsaMGdDpdPD390fr1q3Rp08f3Lx5E/PmzcPjx48xatQodOzYEefOndN3Hr///jvq16+P/v37w9bWFpGRkfjmm2/w73//GxcuXIC5uTkAYPfu3ejQoQM8PDywYcMGJCcnIyQkJNW/KP38/LBixQoMHz4cQUFBePjwIQIDA9GgQQNERESgePHiWXbsmVGoUCFMmDABI0aMwIEDB9C0adNUt1NKoX379ti/fz/GjRuHRo0a4fz585g0aRLCw8MRHh6OAgUKpPrc9evXY/DgwRg2bBhCQkKQL18+3LhxA5cuXQIAuLq6om3btli4cCHGjh2L/Pnz6587b948ODs746OPPsrQ8fXv3x8dOnTA+vXrcfbsWYwfPx7Jycm4evUqOnTogIEDB2Lfvn0ICgqCs7MzRo0apX9uWs9dX19fbNiwAWPHjkXTpk1x6dIlfPTRR3jy5AlpS0xMDOrUqYN8+fIhICAA5cqVQ3h4OKZOnYrIyEiEhoZm6BgzqlWrVsifPz8OHz6sf+zgwYNo2bIl6tati4ULF8LW1hbr169H165d8ezZM4MPvf79+6N169ZYu3Yt7ty5g88//xw9e/bEgQMH9NsMGjQIixcvxtChQ+Hj44PIyEhMnDgRhw4dwpkzZ+Dg4JBq+8LDw9G1a1d07doVkydP1tczX9/3s2fP4Onpibt37+o/yy5evIiAgABcuHAB+/btg06n05+/R48eRUBAAGrXro1ff/0V3t7emX4f8+w5pjLohx9+UADUwoULlVJKJSQkKBsbG9WoUSOyXb9+/ZS5ubm6dOnSG/d18uRJBUCFhoYa/MzT01N5enoaPN6nTx9VunTpN+7z5cuXKikpSQUGBip7e3v16tUrk/tM7bUrVapEHgOgHB0d1dOnT/WPbd26VQFQ7u7u5HVmz56tAKjz58+nuv9Xr16ppKQkdevWLQVAbdu2Tf+z2rVrq5IlS6oXL17oH0tISFD29vbq9V9beHi4AqBmzZpF9n3nzh1lZWWlxo4da/I4s1toaKgCoE6ePKlevHihypYtq2rVqqV/r/j7vHv3bgVABQcHk/1s2LBBAVCLFy/WP8Z/l0OHDlVFihQx2p6DBw8qAGrLli36x6KiopSZmZmaMmWK0ef+8ccfCoCaOXOmwfENGzaMbNu+fXsFQH3zzTfkcXd3d1WjRo03vsabzt2LFy8qAMrf359sv27dOgVA9enTR/+Yn5+fsrGxUbdu3SLbhoSEKADq4sWLRo8zvV7/Hb9J8eLFlZubmz5XqFBBVa9eXSUlJZHtfHx8lJOTk3r58iXZ9+DBg8l2wcHBCoC6d++eUkqpy5cvp7rd8ePHFQA1fvx4/WP88yPlfYmPj39j+6dPn67y5ctncIybNm1SANTOnTuVUkrt2rVLAVBz5swh202bNk0BUJMmTXrjayj1bp5jGb40t2zZMlhZWaFbt24AABsbG3Tu3BlHjhzB9evX9dvt2rULTZo0gZubW0ZfKs0OHDiA5s2bw9bWFvnz54e5uTkCAgIQFxeXpXftNGnSBNbW1vqccmze3t7kslnK46/fKXb//n18+umnKFmyJMzMzGBubo7SpUsDAC5fvgwASExMxKlTp9C+fXtYWFjon2tjY4M2bdqQtoSFhUGn06Fnz55ITk7W/3N0dES1atVw6NChLDvurGBhYYGpU6fi1KlTBpdgUqT8Fcr/Iu7cuTOsra2xf//+N+6/Tp06iI+PR/fu3bFt27ZUL3k2btwY1apVw/z58/WPLVy4EDqdDgMHDszAUf3Dx8eH5JTff8q3/Ncf53cPpuXc/fnnnwEAXbp0Ic/t1KkTzMzoxY2wsDA0adIEzs7O5LxI+as8ZV9vk3rtEtCNGzdw5coVfPzxxwBA2tiqVSvcu3cPV69eJc9v27YtyVWrVgXwv/9fBw8eBGB43tSpUwdubm5Gz5uU8kCXLl3w448/IioqymCbsLAwVK5cGe7u7qS9LVq0gE6n0/9fS2lHyrGl6NGjxxtfP63y6jmWoY7oxo0bOHz4MFq3bg2lFOLj4xEfH49OnToB+N+ddADw4MGDDF1mS68TJ07gww8/BAAsWbIEv/76K06ePIkvv/wSwD+XErOKnZ0dySmdxZsef/78OYB/rl9/+OGH+M9//oOxY8di//79OHHiBI4dO0ba+OjRIyilUr2kxh+LjY3Vb2tubk7+HTt2LNUPYq1169YNNWrUwJdffmlwHRsA4uLiYGZmhvfee488rtPp4OjoiLi4uDfuu1evXli+fDlu3bqFjh07olixYqhbty727t1Lths+fDj279+Pq1evIikpCUuWLEGnTp3g6OiY4eNKz3mRck4AaT93U46bnwNmZmawt7cnj8XGxmL79u0G50SlSpUA4K2fF4mJiYiLi4Ozs7O+fQAwZswYgzYOHjw41TbyY0y5PMvfHycnJ4PXd3Z2NnreeHh4YOvWrUhOTkbv3r1RokQJVK5cGevWrdNvExsbi/Pnzxu0t1ChQlBK6dubcv7y9mbm3EqRV8+xDNWIli9fDqUUNm3ahE2bNhn8fOXKlZg6dSry58+P9957z6BQmB6WlpZ4/PixweP8INevXw9zc3OEhYXB0tJS//jWrVsz/NpZ7bfffkNERARWrFhB6mj8hoaiRYtCp9OlWg+KiYkh2cHBATqdDkeOHEm1bvKmWoqWdDodgoKC4OXlhcWLFxv83N7eHsnJyXjw4AHpjJRSiImJ0f/1+ia+vr7w9fVFYmIiDh8+jEmTJsHHxwfXrl3Tf/vs0aMH/P39MX/+fNSrVw8xMTEYMmRI1h5oGqX13E35IIiNjYWLi4v+8eTkZIMPWQcHB1StWhXTpk1L9TVTOoS3ZceOHXj58qX+xpKUWs24cePQoUOHVJ9Tvnz5dL1Gyvtz7949gz9+o6Oj31gfStGuXTu0a9cOL168wLFjxzB9+nT06NEDrq6uqF+/PhwcHGBlZUX+0H5dyv5Tzt+4uDjy4c3/775NOf0cS3dH9PLlS6xcuRLlypXD0qVLDX4eFhaGWbNmYdeuXfDx8YG3tzdWrVqFq1evvvHE4n/ZvM7V1RUbN27Eixcv9NvFxcXh6NGjKFy4sH47nU4HMzMzUnz+66+/sGrVqvQeYrZJuWzHOwd+J5G1tTVq1aqFrVu3IiQkRP9Xz9OnTxEWFka29fHxwYwZMxAVFWXwdTona968Oby8vBAYGIiSJUuSnzVr1gzBwcFYvXo1Ro4cqX988+bNSExMRLNmzdL0GtbW1vD29sbff/+N9u3b4+LFi/qOyNLSEgMHDsS8efNw9OhRuLu7o2HDhll3gOmQ1nPXw8MDALBhwwbUqFFD//imTZsM7lLy8fHBzp07Ua5cORQtWjQbW2/a7du3MWbMGNja2sLPzw/AP53MBx98gIiICHz99ddZ8jopN7+sXr2a/LFy8uRJXL58Wf/XvykFChSAp6cnihQpgp9++glnz55F/fr14ePjg6+//hr29vYoU6bMG5/fpEkTBAcHY82aNRg+fLj+8bVr12bwyDIvp59j6e6Idu3ahejoaAQFBaV6W3XlypUxb948LFu2DD4+PggMDMSuXbvg4eGB8ePHo0qVKoiPj8fu3bsxatQoVKhQAeXKlYOVlRXWrFkDNzc32NjYwNnZGc7OzujVqxcWLVqEnj17YsCAAYiLi0NwcDDphIB/rpF+88036NGjBwYOHIi4uDiEhITkqG8EKcf6xRdfQCkFOzs7bN++3eCyEQAEBgaidevWaNGiBUaMGIGXL19i5syZsLGxwcOHD/XbNWzYEAMHDoSvry9OnToFDw8PWFtb4969e/jll19QpUoVDBo06G0eZpoFBQWhZs2auH//vv7rPAB4eXmhRYsW8Pf3x5MnT9CwYUP9XXPVq1dHr1693rjPAQMGwMrKCg0bNoSTkxNiYmIwffp02NraGnyTGjx4MIKDg3H69OlU/6h6W9J67laqVAndu3fHrFmzkD9/fjRt2hQXL17ErFmzYGtri3z5/nelPTAwEHv37kWDBg0wfPhwlC9fHs+fP0dkZCR27tyJhQsXZssl899++01fK7h//z6OHDmC0NBQ5M+fH1u2bCHfcBctWgRvb2+0aNECffv2hYuLCx4+fIjLly/jzJkz6b6Nvnz58hg4cCC+++475MuXD97e3vq75kqWLEn+qOECAgJw9+5dNGvWDCVKlEB8fDzmzJkDc3NzeHp6AgA+++wzbN68GR4eHhg5ciSqVq2KV69e4fbt29izZw9Gjx6NunXr4sMPP4SHhwfGjh2LxMRE1KpVC7/++qumfxTn+HMszbc1/Ff79u2VhYWFun///hu36datmzIzM1MxMTFKqX/u4OrXr59ydHRU5ubmytnZWXXp0kXFxsbqn7Nu3TpVoUIFZW5ubnBnycqVK5Wbm5uytLRUFStWVBs2bEj1rrnly5er8uXLqwIFCqiyZcuq6dOnq2XLlikA6o8//tBvl9m75oYMGUIeS+0uF6X+d3fWxo0b9Y9dunRJeXl5qUKFCqmiRYuqzp07q9u3b6d6N82WLVtUlSpVlIWFhSpVqpSaMWOGGj58uCpatKhBW5cvX67q1q2rrK2tlZWVlSpXrpzq3bu3OnXqlMnjzG7G7qjq0aOHAmDwPv/111/K399flS5dWpmbmysnJyc1aNAg9ejRI7Id/12uXLlSNWnSRBUvXlxZWFjoz7U33bnYuHFjZWdnp549e5amYzF2RxM/vkmTJikA6sGDB+TxPn36KGtra/JYWs/d58+fq1GjRqlixYopS0tLVa9ePRUeHq5sbW3VyJEjyT4fPHighg8frsqUKaPMzc2VnZ2dqlmzpvryyy/JXZ9ZIeU9SPlnYWGhihUrpjw9PdXXX3/9xs+LiIgI1aVLF1WsWDFlbm6uHB0dVdOmTfV3476+b/7+pvz/OnjwoP6xly9fqqCgIPWvf/1LmZubKwcHB9WzZ091584d8lz++REWFqa8vb2Vi4uLvu2tWrVSR44cIc97+vSpmjBhgipfvryysLBQtra2qkqVKmrkyJH6zzullIqPj1f9+vVTRYoUUQULFlReXl7qypUrmb5rLq+eYzqlsnC0p8hWSUlJcHd3h4uLC/bs2aN1c3K9+/fvo3Tp0hg2bBiCg4O1bk6GHT16FA0bNsSaNWuy5M4sIbjsPsekI8rBPvnkE3h5eekvMS1cuBA///wz9uzZg+bNm2vdvFzr7t27uHnzJmbOnIkDBw7g2rVrpDCbk+3duxfh4eGoWbMmrKysEBERgRkzZsDW1hbnz58nhWghMkKLcyzDMyuI7JeQkIAxY8bgwYMHMDc3R40aNbBz507phDJp6dKlCAwMhKurK9asWZNrOiEAKFy4MPbs2YPZs2cjISEBDg4O8Pb2xvTp06UTEllCi3NMvhEJIYTQlCyMJ4QQQlPSEQkhhNCUdERCCCE0JR2REEIITaX5rjm+GJvInbLz3pTsPkdSpodJwadGSotx48aRnNoCea/7/PPPSeYzOJ85cyZdz9+1axfJVapUIfn1STbTik8O++DBA6Pbv//++yTzuQ5z8zmSG5h6D/j7n9pCmFrL6jbINyIhhBCako5ICCGEpqQjEkIIoak0D2iVa7t5Q06+/l+hQgWS+Sjuc+fOmdwHnyWBr7TJ1z+qVq0ayXXr1iWZL13CVw2tU6cOyXfu3CH53r17JH/11VckT5w4EcaktqZLdHS00edwo0ePJnn27Nkkv3z5kuScfI7kBaZqPvznfGXU1BaTfNukRiSEECJPkY5ICCGEpqQjEkIIoSmpEb1jsvP6/4oVK0j29fXN1P6uXbtGspubG8njx483eA6vwfDloVu0aEHypEmTSH59pVgA+L//+z+SeR2LrwtVqFAhkhMSEgzamN34OKG7d++SPGbMGJKnTp1KstSIspapmpC1tTXJf/31l9GfP3361Oj+3gapEQkhhMhTpCMSQgihKemIhBBCaEpWaDWCjzEZPnw4yTVr1iSZj+/47rvvSObzjOU1J06cMPrznj17kjxz5kySnZycSL5w4QLJfLzLqFGjDF4jNjaW5HLlypHs4eFhtI0+Pj4km6oJ8XNg/vz5JPNxSYcPHzbankuXLpFcvXp1gzZWrVqVZD626fnz5yTb2NiQzGtCImvxsV98fkNPT0+Sw8PDSeb/j1atWkVyTpx7LrPkG5EQQghNSUckhBBCU9IRCSGE0NQ7PY7I0dGR5MmTJ5Pcp08fki0sLNK1/7///ptkKyurdD0/O+TkMSKm1grKl4/+3fTq1atMvR5gOI8XrwlVrlyZ5PXr15PM61x8jA7Xpk0bksPCwkjmdTG+XhEANGvWjGS+RhLHx3O1bNmS5K5du5Kck8+RrMDPIz7+LSQkhOTLly+TzM/Tb7/9luTPPvuM5H79+pFsZ2dHMq/h3b59m+SyZcuSnBNqQjKOSAghRJ4iHZEQQghNSUckhBBCU+9UjcjV1ZXkHTt2kMzrA/yYL168SHJAQADJ+fPnJ3nDhg0kN2/enOSDBw8ab3A2yM7ry/xa+NatW0l++PAhyfz95eO2rl+/TvKzZ89ITq02wuefu3LlCsmNGjUi+f79+yTzMRz8/VqyZAnJAwYMIJnXiJ48eULyjBkzSOb1goz8PytRogTJfG45UwoXLkzy48eP092GtMoNnyOmxukULVqU5Fq1apHMx/2cPXuWZP45wGtQfGxZfHy88QZrQGpEQggh8hTpiIQQQmhKOiIhhBCaytNzzfExIvz6PK8JvXjxguSVK1eSzMcZ8XnN+OvVr1+fZFNzseV227ZtI7lGjRok82vv/v7+JEdGRpLMx3FNmDCBZD6eBjCcx4v/ziZOnEgyX9uF1w25evXqkVymTBmS+TFw/ByIiIgguW3btiTzcUaA4fgpXlvj52VSUhLJe/fuJdnLy8tIi989puoffn5+JI8ePZpkXqPjv8MePXqQzGtA2VHH5WMmY2Jisvw1MkO+EQkhhNCUdERCCCE0JR2REEIITeXpcURffPEFydOmTTO6PV+nZdKkSVneJq1l5zgiPp4lKirK6Pam5pbLiAULFpA8ePBgknn9pGnTpiTzsWJ8bjk+L1t6/1+cOnWKZD4Gha9HxMdOAcDcuXNJ/uGHH0iuU6cOyQULFiT50KFDRtuY1+eaM4W3cdGiRSR37tyZZP7+njlzhmS+xhUfT5cd7zcf08jripl9TRlHJIQQIk+RjkgIIYSmpCMSQgihKemIhBBCaCpPDWj19vYmmd9swIuQvFgeFBSUPQ17R1haWqZre75Am6mbFfiNBHwBMsDw5gS+T178/+2334y+Jh/4x9vM8YX0+ABbfnMCHyD7/vvvk9y/f3+D1+A3J/DF9NauXUvyvHnzjLTYcGD3u44PxO7evTvJ/IYSPvEt/3lW3yiQFi9fvsz218hK8o1ICCGEpqQjEkIIoSnpiIQQQmgqT9WI+GSCFhYWJPMaw4oVK7K7Se8UvgjcrFmzSM6Xj/7dM2LEiHTtv1KlSuluE68D8oF+I0eOJJlPXDts2DCSP/nkE5J5zYjvj9eg+DnHF08sVqwYyXzhvNRUqVLF6HNcXFxI5osF8vyu4/WVGzdukPz8+XOS+XmdnJxMMj/nhCH5RiSEEEJT0hEJIYTQlHREQgghNJWra0RVq1YlmS8qxm3atIlkPgGmyBz+/n///fck82vtpvDxFnxhQ17/AYCBAweSvHjxYpL5Ynu9e/cmmY8BcXZ2JrlLly5GWgx07NiR5D179pAcEBBA8u3bt0nm45B4ewCgSZMmJB88eJBkfsydOnUy0mJhiru7O8l83NadO3eMPp/XnExN/Po2xhnlNPKNSAghhKakIxJCCKEp6YiEEEJoKlfXiPj1+EKFCpGclJREMr+W3qBBA6P7a9iwIcn82m1CQgLJmzdvJnnNmjUkP336FHnZzZs3SeY1oa1bt5K8a9cukvkCZHwOr6VLl5psw/bt20nmC+XxsWT8nDh69CjJ//73v0kuUKAAyaNGjSKZ15SaN29OMp+7bsqUKSTzMSi8Dppam/l5yccNcXw8Fn9P3jV87BhfvJD75ZdfSOZzyfFxXI8ePTL6ehz/ffLt+TimvEC+EQkhhNCUdERCCCE0JR2REEIITelUGm9aN3XvuxZ27NhBMl+PiF+b/fPPP0nma79w/JjTe39/dHQ0yY0aNSI5MjIyXfvLCm9zjEKJEiVI5jU3XlPi1+YHDRpEMh+XlBZ8PZ8hQ4aQzMeIbNu2jeR27dql+zVfN3nyZKP5wIEDJPMxKqmtuWRK0aJFSX748CHJWo5jyYmfI7zmwueG4zUaPi6Ib88/Z3gNib8HvLbNn89fn9dKJ06ciLctq88R+UYkhBBCU9IRCSGE0JR0REIIITSVq2pEfH0hfj9/7dq1STZ1aH///TfJS5YsIfn48eNG98fXgeHznBUpUoTkqKgokt3c3AzalN1jjd7m9X8+TojX8Dhes+vatSvJ06ZNI7l9+/YG++BjlQoWLEjy3r17jeYOHTqQzH/H/v7+JPM64OPHj0nm45r4GB++hhY/h/h6RwAwYcIEkn19fQ22SY8jR46QzMdOZaWc+DmSmJhIsqlxPvxzg2eOjw2zsrIimdeYeA2Kt4fXEVP7ffFjympSIxJCCJGnSEckhBBCU9IRCSGE0FSummvu888/J7lmzZok8+uWpq5j8nFI/H78+Pj4dLXvu+++I/n8+fMk8/rAsGHDDPYxffr0dL1mTsbHAe3cuZPkQ4cOkczn6Priiy9I7tWrF8mrVq0yeE0+dmn27Nkk8/WH0rtGUnBwsNGf8/kLOV4XfPHiBcm8PpDaOTxnzhyS+THwWhofpzJ37lyS+Xx82Vkjygl4TcZU3YqPA+I1n+HDh5PM64x//PEHyXxORV6HXLFihdH28nOIjxsDsr9GlNXkG5EQQghNSUckhBBCU9IRCSGE0FSuGkcUHh5Ocp06dUieP38+yU5OTiS3adOGZHNzc5IXLlxIMp+XLL0+/fRTkvm1+GPHjhk8x1SNIbNy8jxioaGhJJsaH5Pa6/Hr+Xfv3iWZ1+X4uCM+91vTpk2N/jwwMJBkXvfiSpUqRfLt27dJ5nPj8bWDAMM1kfjcZAEBAUbbwFlbW5OcnWPZcsLnCK+ZvffeeyTzmg2v2+3evdvoz3mdkp+DHK8BlStXjuQrV66QzNdZ27Bhg8E+eT01q8k4IiGEEHmKdERCCCE0JR2REEIITeWqcUSm8Lnh1qxZQ3LPnj1JrlWrFsmnT5/O0vbwMSKcqTmqcrsKFSqQzK9187nlzp49S7KXlxfJfF44PpcdYFiDGDlyJMkeHh5Gt+fzF1asWJFk/jvr06cPyXyNKT7uiK+pxMeU8LnlYmNjwTk6Oho8ZgyvtV2+fJnk1GqVeRk/j+rWrUvy/v37Se7evTvJfBwRd+fOnXS1h+/v6tWrJPNzlI8R4u0DgEuXLpHMz0Ne19KafCMSQgihKemIhBBCaEo6IiGEEJrKUzUiU1avXm00ZxYf38HnoOJCQkKy9PW1dv36dZJv3bpF8s8//0wynzNt7dq1JPPxMa6uriSntlbPs2fPSOZrv3CjRo0imY8b4tfafXx8SDZVL+D1hoMHD5LMxyVxfIxJRvA5E3lNyNPTM9OvkZvwdcv4ObNv3z6S+Xn4tvFzgI+VS63WzOe348/JaeQbkRBCCE1JRySEEEJT0hEJIYTQVK6qEfH76U3lt61v374kV6tWjWS+PhEfz5DbrV+/nmS+vpMptra2JPM5u7p27UpyUFCQwT5M1YT4ayQkJJBcvXp1kvn6UOPGjTO6f75m1syZM41uz+fX69evH8kxMTEGz+HzfPHxVK1atSKZz6dnYWFBMq/d5XW8plKwYEGS+Rgb/rmSnfM1psbBwYFkPjdgap97fP65t93m9JJvREIIITQlHZEQQghNSUckhBBCU7mqRnTy5EmS+XgAPpfc4cOHSeZrv6SXjY0NyXzcS4sWLUjm12X5OjGm5qLLbcaOHUsyrxFduHCBZL7uy+PHj0nmc9WlVhMyhb8GH7t04sQJkvl6RXPnzk3X69nZ2ZHMj7ly5cpGn8/XM+LrFwHAgAEDSObrcHH8POTbjxgxwujzc7PU6id8HTK+vs+0adNI5nO/8XFGfBxSZhUuXJjkH374wej2N2/ezNLX14J8IxJCCKEp6YiEEEJoSjoiIYQQmtKpNN5grvUYHcBwnM6yZctI5m28f/8+yXytme3bt5NcpkwZo6//8ccfk1y2bFmS+fgDXhOaMWOG0f2/Ddk5noCve7NixQqj2/M5tPj799FHH5G8ZcsWkp2dnQ32ycfI8Dm3+DghvgYSv94eGBhIMj8mfs7xNZb4/Htct27dSObjP9Ji8eLFJD969Ihkf3//dO0vO8+RnPA5Eh0dTbKTk1Om9sfnceO1Tj5Oib+/fI5KU+/R4MGDSd68ebPBNvyzL6tl9Tki34iEEEJoSjoiIYQQmpKOSAghhKZyVY2Iz7F06tQpkvl6NbxekF6m5pji67wMHTqU5HXr1mXq9bODltf/TY3fyAr8+D744AOS+VgnPiaH4/O0xcXFkcxrQo0bNyb54cOHJPP1i4oVK0YyHzt1+fJlgzbxcSv8/0XJkiVJdnFxIZmvR8Tl9RqRmRkdPnn8+HGSa9So8TabY+Dp06ck9+jRg2Q+t6CpNbGyg9SIhBBC5CnSEQkhhNCUdERCCCE0JR2REEIITeWqmxU4PqCRL0TXvXt3ktu2bUtyoUKFjO6fD3wLCQkheenSpSQnJiYa3V9O8DYL0XzALy/084F/nKenJ8l8ATe+yF1q++Rt4oMPTQ1o7dixo9E2cnzS1Hnz5pFcvHhxkv38/EjmA2jTonXr1iTv2LEjXc/nN/Vk52S8OeFzJF8++vc3/z/BfwdNmzYluX79+iSbOia+uOKQIUNI5jeT8JsVciK5WUEIIUSeIh2REEIITUlHJIQQQlO5ukYk0i87a0S9evUiefXq1SRXrFiRZH6tnU9iyxdCtLS0JJkvegcYHl9mz1teR0xISCC5TZs2JPOJdPmg61q1apHM62YNGzYk+cqVKybbyAdyR0ZGkswnfuWT+/IFI/mA2KyUGz9HeC3SVG0zq5kaWK8FqREJIYTIU6QjEkIIoSnpiIQQQmjKzPQmQqQNrwlxfIGwixcvktysWTOSa9euTXJ4eDjJXl5eBq9hqgZRvnx5kq9evWp0e14T4kyN+eA1IT5mx97enmR+7Z3XyQCgTp06JNvZ2ZHMJ/U0teBjqVKljLbhXfe2a0Lcu/D7kG9EQgghNCUdkRBCCE1JRySEEEJTUiMSWYbPxXf69GmSx48fTzKfC5Av2MbHFfFxR1999ZVBG4oWLUryjz/+SLKpmpApvMZ04MABknmNytfXl+TQ0FCj+588eTLJfP5EAGjRogXJfGG33r17k3zhwgWSea3u3LlzRtskRHaTb0RCCCE0JR2REEIITUlHJIQQQlMy19w75m2uR5RefG2eRYsWkXz9+nWSP/jgA4N9VKpUiWQ+P11ERATJRYoUIfnPP/802kZegypdujTJvN7CxxHxuedMWb9+vcFj3bp1I5nX2urVq0fyiBEjSObrePF1t3LyOSJyBplrTgghRJ4iHZEQQghNSUckhBBCU1Ijesdk5/V/BwcHkvlaOzY2NiS/evWK5GfPnpHs4uJCclRUlMk2eHt7k/z777+THBsbS3KHDh1INjXO5+bNmySXLVuWZF4jcnd3N7q/jOD/F9u3b08yH1fEa0AVKlQg+caNGyQnJSVlsoVvJp8jeYPUiIQQQuQp0hEJIYTQlHREQgghNJXmGpEQQgiRHeQbkRBCCE1JRySEEEJT0hEJIYTQlHREQgghNCUdkRBCCE1JRySEEEJT0hEJIYTQlHREQgghNCUdkRBCCE39P8vrKUzI6ftbAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 500x300 with 3 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 52/10: 100%|██████████| 469/469 [00:05<00:00, 81.99it/s, loss=4.04e-8]\n",
      "Epoch 53/10: 100%|██████████| 469/469 [00:05<00:00, 81.78it/s, loss=4.82e-7]\n",
      "Epoch 54/10: 100%|██████████| 469/469 [00:05<00:00, 79.46it/s, loss=8.09e-7]\n",
      "Epoch 55/10: 100%|██████████| 469/469 [00:06<00:00, 77.16it/s, loss=6.34e-11]\n",
      "Epoch 56/10: 100%|██████████| 469/469 [00:05<00:00, 80.08it/s, loss=7.52e-7]\n",
      "Epoch 57/10: 100%|██████████| 469/469 [00:05<00:00, 81.62it/s, loss=3.01e-6]\n",
      "Epoch 58/10: 100%|██████████| 469/469 [00:05<00:00, 81.20it/s, loss=9.84e-8]\n",
      "Epoch 59/10: 100%|██████████| 469/469 [00:05<00:00, 80.57it/s, loss=6.66e-8]\n",
      "Epoch 60/10: 100%|██████████| 469/469 [00:05<00:00, 81.68it/s, loss=1.84e-7]\n",
      "Epoch 61/10: 100%|██████████| 469/469 [00:05<00:00, 81.12it/s, loss=1.8e-6]\n",
      "Epoch 62/10: 100%|██████████| 469/469 [00:05<00:00, 80.90it/s, loss=1.7e-7]\n",
      "Epoch 63/10: 100%|██████████| 469/469 [00:05<00:00, 80.83it/s, loss=3.81e-6]\n",
      "Epoch 64/10: 100%|██████████| 469/469 [00:05<00:00, 81.15it/s, loss=2.81e-6]\n",
      "Epoch 65/10: 100%|██████████| 469/469 [00:05<00:00, 79.97it/s, loss=3.82e-8]\n",
      "Epoch 66/10: 100%|██████████| 469/469 [00:05<00:00, 83.88it/s, loss=8.48e-9]\n",
      "Epoch 67/10: 100%|██████████| 469/469 [00:05<00:00, 82.09it/s, loss=1.8e-8]\n",
      "Epoch 68/10: 100%|██████████| 469/469 [00:05<00:00, 81.96it/s, loss=1.89e-7]\n",
      "Epoch 69/10: 100%|██████████| 469/469 [00:05<00:00, 81.79it/s, loss=1.4e-7]\n",
      "Epoch 70/10: 100%|██████████| 469/469 [00:05<00:00, 82.19it/s, loss=7.22e-7]\n",
      "Epoch 71/10: 100%|██████████| 469/469 [00:05<00:00, 78.23it/s, loss=4.57e-7]\n",
      "Epoch 72/10: 100%|██████████| 469/469 [00:05<00:00, 80.88it/s, loss=9.79e-10]\n",
      "Epoch 73/10: 100%|██████████| 469/469 [00:05<00:00, 80.72it/s, loss=7.69e-7]\n",
      "Epoch 74/10: 100%|██████████| 469/469 [00:05<00:00, 80.09it/s, loss=2.31e-7]\n",
      "Epoch 75/10: 100%|██████████| 469/469 [00:06<00:00, 77.33it/s, loss=6.08e-7]\n",
      "Epoch 76/10: 100%|██████████| 469/469 [00:05<00:00, 80.44it/s, loss=2.81e-8]\n",
      "Epoch 77/10: 100%|██████████| 469/469 [00:05<00:00, 80.04it/s, loss=1.06e-10]\n",
      "Epoch 78/10: 100%|██████████| 469/469 [00:05<00:00, 84.65it/s, loss=7.97e-7]\n",
      "Epoch 79/10: 100%|██████████| 469/469 [00:05<00:00, 84.55it/s, loss=1.37e-7]\n",
      "Epoch 80/10: 100%|██████████| 469/469 [00:05<00:00, 80.82it/s, loss=1.38e-7]\n",
      "Epoch 81/10: 100%|██████████| 469/469 [00:05<00:00, 80.08it/s, loss=1.68e-7]\n",
      "Epoch 82/10: 100%|██████████| 469/469 [00:05<00:00, 83.04it/s, loss=1.48e-6]\n",
      "Epoch 83/10: 100%|██████████| 469/469 [00:05<00:00, 82.26it/s, loss=3.03e-7]\n",
      "Epoch 84/10: 100%|██████████| 469/469 [00:05<00:00, 81.32it/s, loss=4.42e-7]\n",
      "Epoch 85/10: 100%|██████████| 469/469 [00:05<00:00, 79.46it/s, loss=6.59e-7]\n",
      "Epoch 86/10: 100%|██████████| 469/469 [00:05<00:00, 78.99it/s, loss=3.5e-8]\n",
      "Epoch 87/10: 100%|██████████| 469/469 [00:05<00:00, 80.07it/s, loss=1.05e-7]\n",
      "Epoch 88/10: 100%|██████████| 469/469 [00:05<00:00, 81.23it/s, loss=3.55e-7]\n",
      "Epoch 89/10: 100%|██████████| 469/469 [00:05<00:00, 80.50it/s, loss=3.08e-7]\n",
      "Epoch 90/10: 100%|██████████| 469/469 [00:05<00:00, 79.49it/s, loss=6.58e-7]\n",
      "Epoch 91/10: 100%|██████████| 469/469 [00:05<00:00, 82.10it/s, loss=1.98e-8]\n",
      "Epoch 92/10: 100%|██████████| 469/469 [00:05<00:00, 83.53it/s, loss=5.21e-8]\n",
      "Epoch 93/10: 100%|██████████| 469/469 [00:05<00:00, 78.87it/s, loss=8.27e-8]\n",
      "Epoch 94/10: 100%|██████████| 469/469 [00:05<00:00, 79.34it/s, loss=1.14e-7]\n",
      "Epoch 95/10: 100%|██████████| 469/469 [00:05<00:00, 79.26it/s, loss=1.43e-6]\n",
      "Epoch 96/10: 100%|██████████| 469/469 [00:06<00:00, 76.51it/s, loss=2.21e-7]\n",
      "Epoch 97/10: 100%|██████████| 469/469 [00:05<00:00, 78.39it/s, loss=7.14e-9]\n",
      "Epoch 98/10: 100%|██████████| 469/469 [00:05<00:00, 81.12it/s, loss=3.72e-6]\n",
      "Epoch 99/10: 100%|██████████| 469/469 [00:05<00:00, 80.34it/s, loss=5.28e-7]\n",
      "Epoch 100/10: 100%|██████████| 469/469 [00:05<00:00, 78.23it/s, loss=4.76e-8]\n",
      "Epoch 101/10: 100%|██████████| 469/469 [00:05<00:00, 78.26it/s, loss=1.41e-7]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaIAAACaCAYAAADrVUwbAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAnv0lEQVR4nO3deVxO6fsH8M9DpRShUNk1M2XPvpfskX3flx8ZW2NJfe00RsoyGyYGMRgaZjAa62CsWTJkxjqGDFFMZAlN5f79Md+er+t60tPqkOv9evXH5+k85znndDp3nevc961TSikIIYQQGsmn9QYIIYR4t0lDJIQQQlPSEAkhhNCUNERCCCE0JQ2REEIITUlDJIQQQlPSEAkhhNCUNERCCCE0JQ2REEIITWW7Ifriiy+g0+lQtWrVLK/j9u3bmDVrFs6ePZvdzcmQZs2aoVmzZhlaLjv7Jf61evVq6HQ6mJub48aNGwbfz85xzujPMqdERUVBp9NhwYIFr+0z3wapP+PUL3Nzc9jZ2cHd3R0BAQG4e/eu1ptIDB48GOXLl3/tn6vT6TBr1qx0l3kXz7FsN0SrVq0CAJw/fx4nTpzI0jpu376N2bNnv7aGSGgjMTER06ZNy9F1Ll26FEuXLs3RdYqsCwkJQXh4OPbu3YslS5bAxcUFgYGBqFSpEn7++WetN09v+vTp2LJli9abIf4rWw1RREQEIiMj0b59ewDAypUrc2SjRN7Utm1bfPvtt4iMjMyxdVauXBmVK1fOsfWJ7KlatSoaNGiApk2bolu3bvj0009x7tw5WFpaomvXroiNjdV6EwEAjo6OqFmzptabIf4rWw1RasMzb948NGrUCBs3bsTTp08NlouOjoaXlxfKlCkDMzMzODg4oHv37oiNjcUvv/yCunXrAgCGDBmi/9c+9d/XV916Setf69mzZ6N+/fooVqwYChcujFq1amHlypXIyXFddTodxowZg5CQEDg5OcHCwgJ16tTB8ePHoZTC/PnzUaFCBVhZWaF58+a4evUqef/evXvRqVMnlC5dGubm5njvvfcwYsQI/P333waftW3bNlSvXh0FChRAxYoV8fnnn2PWrFnQ6XRkOaUUli5dChcXF1hYWKBo0aLo3r07rl27lmP7nRN8fX1hY2MDPz8/o8s+f/4ckydPRoUKFWBmZoZSpUph9OjRiI+PJ8uldX589dVXqFGjBqysrFCoUCE4OztjypQpAP697WFiYoKAgACDzzx06BB0Oh02bdqUqf1KvS21f/9+DB8+HDY2NihcuDAGDhyIhIQExMTEoGfPnihSpAjs7e3h4+ODpKQkso6MnruJiYmYOHEi7OzsULBgQbi6uuL06dMoX748Bg8eTJaNiYnBiBEjULp0aZiZmaFChQqYPXs2kpOTM7V/2VW2bFksXLgQjx8/xrJly8j3IiIi0LFjRxQrVgzm5uaoWbMmvvvuO7JM6vE9cOAARo4cCVtbW9jY2KBr1664ffs2WfbFixcICgqCs7MzChQogBIlSmDgwIG4desWWS6t68emTZtQv359WFtbo2DBgqhYsSKGDh1Klnn06BF8fHzIeTlu3DgkJCQYLJd6LlhZWaFt27a4cuVKVg4fOQZ59RwzydJRAfDs2TNs2LABdevWRdWqVTF06FAMGzYMmzZtwqBBg/TLRUdHo27dukhKSsKUKVNQvXp1xMXFYffu3Xjw4AFq1aqFkJAQDBkyBNOmTdP/d1W6dOlMb1NUVBRGjBiBsmXLAgCOHz+OsWPHIjo6GjNmzMjqrhoICwvDmTNnMG/ePOh0Ovj5+aF9+/YYNGgQrl27hsWLF+Phw4eYMGECunXrhrNnz+objz///BMNGzbEsGHDYG1tjaioKCxatAhNmjTBb7/9BlNTUwDArl270LVrV7i6uiI0NBTJyclYsGBBmn9RjhgxAqtXr4a3tzcCAwNx//59+Pv7o1GjRoiMjETJkiVzbN+zo1ChQpg2bRo++ugj7N+/H82bN09zOaUUOnfujH379mHy5Mlo2rQpzp07h5kzZyI8PBzh4eEoUKBAmu/duHEjRo0ahbFjx2LBggXIly8frl69igsXLgAAypcvj44dOyI4OBi+vr7Inz+//r2LFy+Gg4MDunTpkqX9GzZsGLp27YqNGzfizJkzmDJlCpKTk3H58mV07doVXl5e+PnnnxEYGAgHBwdMmDBB/96MnrtDhgxBaGgofH190bx5c1y4cAFdunTBo0ePyLbExMSgXr16yJcvH2bMmAFHR0eEh4djzpw5iIqKQkhISJb2MavatWuH/Pnz49ChQ/rXDhw4gLZt26J+/foIDg6GtbU1Nm7ciF69euHp06cGF71hw4ahffv2+Pbbb3Hz5k1MmjQJ/fv3x/79+/XLjBw5EsuXL8eYMWPg6emJqKgoTJ8+Hb/88gt+/fVX2Nraprl94eHh6NWrF3r16oVZs2bp65kvr/vp06dwc3PDrVu39Ney8+fPY8aMGfjtt9/w888/Q6fT6c/fY8eOYcaMGahbty6OHj0KDw+PbB/HPHuOqSz65ptvFAAVHByslFLq8ePHysrKSjVt2pQsN3ToUGVqaqouXLjwynWdOnVKAVAhISEG33Nzc1Nubm4Grw8aNEiVK1fuletMSUlRSUlJyt/fX9nY2KgXL14YXWdan12lShXyGgBlZ2ennjx5on9t69atCoBycXEhn/PZZ58pAOrcuXNprv/FixcqKSlJ3bhxQwFQ27Zt03+vbt26qkyZMioxMVH/2uPHj5WNjY16+ccWHh6uAKiFCxeSdd+8eVNZWFgoX19fo/uZ20JCQhQAderUKZWYmKgqVqyo6tSpoz9W/Djv2rVLAVBBQUFkPaGhoQqAWr58uf41/rMcM2aMKlKkSLrbc+DAAQVAbdmyRf9adHS0MjExUbNnz073vdevX1cA1Pz58w32b+zYsWTZzp07KwBq0aJF5HUXFxdVq1atV37Gq87d8+fPKwDKz8+PLL9hwwYFQA0aNEj/2ogRI5SVlZW6ceMGWXbBggUKgDp//ny6+5lZL/+MX6VkyZKqUqVK+uzs7Kxq1qypkpKSyHKenp7K3t5epaSkkHWPGjWKLBcUFKQAqDt37iillLp48WKay504cUIBUFOmTNG/xq8fqcclPj7+ldsfEBCg8uXLZ7CPmzdvVgDUjh07lFJK7dy5UwFQn3/+OVnuk08+UQDUzJkzX/kZSr2b51iWb82tXLkSFhYW6N27NwDAysoKPXr0wOHDh/HHH3/ol9u5cyfc3d1RqVKlrH5Uhu3fvx8tW7aEtbU18ufPD1NTU8yYMQNxcXE5+tSOu7s7LC0t9Tl13zw8PMhts9TXX35S7O7du/jwww9RpkwZmJiYwNTUFOXKlQMAXLx4EQCQkJCAiIgIdO7cGWZmZvr3WllZoUOHDmRbwsLCoNPp0L9/fyQnJ+u/7OzsUKNGDfzyyy85tt85wczMDHPmzEFERITBLZhUqX+F8r+Ie/ToAUtLS+zbt++V669Xrx7i4+PRp08fbNu2Lc1bns2aNUONGjWwZMkS/WvBwcHQ6XTw8vLKwl79y9PTk+TUn3/qf/kvv86fHszIuXvw4EEAQM+ePcl7u3fvDhMTenMjLCwM7u7ucHBwIOdF6l/lqet6ndRLt4CuXr2KS5cuoV+/fgBAtrFdu3a4c+cOLl++TN7fsWNHkqtXrw7gf79fBw4cAGB43tSrVw+VKlVK97xJLQ/07NkT3333HaKjow2WCQsLQ9WqVeHi4kK2t02bNtDpdPrftdTtSN23VH379n3l52dUXj3HstQQXb16FYcOHUL79u2hlEJ8fDzi4+PRvXt3AP97kg4A7t27l6XbbJl18uRJtG7dGgDw9ddf4+jRozh16hSmTp0K4N9biTmlWLFiJKc2Fq96/fnz5wD+vX/dunVr/PDDD/D19cW+fftw8uRJHD9+nGzjgwcPoJRK85Yafy02Nla/rKmpKfk6fvx4mhdirfXu3Ru1atXC1KlTDe5jA0BcXBxMTExQvHhx8rpOp4OdnR3i4uJeue4BAwZg1apVuHHjBrp164YSJUqgfv362Lt3L1nO29sb+/btw+XLl5GUlISvv/4a3bt3h52dXZb3KzPnReo5AWT83E3db34OmJiYwMbGhrwWGxuL7du3G5wTVapUAYDXfl4kJCQgLi4ODg4O+u0DAB8fH4NtHDVqVJrbyPcx9fYsPz729vYGn+/g4JDueePq6oqtW7ciOTkZAwcOROnSpVG1alVs2LBBv0xsbCzOnTtnsL2FChWCUkq/vannL9/e7JxbqfLqOZalGtGqVauglMLmzZuxefNmg++vWbMGc+bMQf78+VG8eHGDQmFmmJub4+HDhwav853cuHEjTE1NERYWBnNzc/3rW7duzfJn57Tff/8dkZGRWL16Namj8QcaihYtCp1Ol2Y9KCYmhmRbW1vodDocPnw4zbrJq2opWtLpdAgMDESrVq2wfPlyg+/b2NggOTkZ9+7dI42RUgoxMTH6v15fZciQIRgyZAgSEhJw6NAhzJw5E56enrhy5Yr+v8++ffvCz88PS5YsQYMGDRATE4PRo0fn7I5mUEbP3dQLQWxsLEqVKqV/PTk52eAia2tri+rVq+OTTz5J8zNTG4TX5aeffkJKSor+wZLUWs3kyZPRtWvXNN/j5OSUqc9IPT537twx+OP39u3br6wPperUqRM6deqExMREHD9+HAEBAejbty/Kly+Phg0bwtbWFhYWFuQP7Zelrj/1/I2LiyMXb/67+zq96edYphuilJQUrFmzBo6OjlixYoXB98PCwrBw4ULs3LkTnp6e8PDwwNq1a3H58uVXnlj8L5uXlS9fHps2bUJiYqJ+ubi4OBw7dgyFCxfWL6fT6WBiYkKKz8+ePcPatWszu4u5JvW2HW8c+JNElpaWqFOnDrZu3YoFCxbo/+p58uQJwsLCyLKenp6YN28eoqOjDf6dfpO1bNkSrVq1gr+/P8qUKUO+16JFCwQFBWHdunUYP368/vXvv/8eCQkJaNGiRYY+w9LSEh4eHvjnn3/QuXNnnD9/Xt8QmZubw8vLC4sXL8axY8fg4uKCxo0b59wOZkJGz11XV1cAQGhoKGrVqqV/ffPmzQZPKXl6emLHjh1wdHRE0aJFc3Hrjfvrr7/g4+MDa2trjBgxAsC/jcz777+PyMhIzJ07N0c+J/Xhl3Xr1pE/Vk6dOoWLFy/q//o3pkCBAnBzc0ORIkWwe/dunDlzBg0bNoSnpyfmzp0LGxsbVKhQ4ZXvd3d3R1BQENavXw9vb2/9699++20W9yz73vRzLNMN0c6dO3H79m0EBgam+Vh11apVsXjxYqxcuRKenp7w9/fHzp074erqiilTpqBatWqIj4/Hrl27MGHCBDg7O8PR0REWFhZYv349KlWqBCsrKzg4OMDBwQEDBgzAsmXL0L9/fwwfPhxxcXEICgoijRDw7z3SRYsWoW/fvvDy8kJcXBwWLFjwRv1HkLqv//nPf6CUQrFixbB9+3aD20YA4O/vj/bt26NNmzb46KOPkJKSgvnz58PKygr379/XL9e4cWN4eXlhyJAhiIiIgKurKywtLXHnzh0cOXIE1apVw8iRI1/nbmZYYGAgateujbt37+r/nQeAVq1aoU2bNvDz88OjR4/QuHFj/VNzNWvWxIABA165zuHDh8PCwgKNGzeGvb09YmJiEBAQAGtra4P/pEaNGoWgoCCcPn06zT+qXpeMnrtVqlRBnz59sHDhQuTPnx/NmzfH+fPnsXDhQlhbWyNfvv/daff398fevXvRqFEjeHt7w8nJCc+fP0dUVBR27NiB4ODgXLll/vvvv+trBXfv3sXhw4cREhKC/PnzY8uWLeQ/3GXLlsHDwwNt2rTB4MGDUapUKdy/fx8XL17Er7/+munH6J2cnODl5YUvv/wS+fLlg4eHh/6puTJlypA/argZM2bg1q1baNGiBUqXLo34+Hh8/vnnMDU1hZubGwBg3Lhx+P777+Hq6orx48ejevXqePHiBf766y/s2bMHEydORP369dG6dWu4urrC19cXCQkJqFOnDo4eParpH8Vv/DmW4cca/qtz587KzMxM3b1795XL9O7dW5mYmKiYmBil1L9PcA0dOlTZ2dkpU1NT5eDgoHr27KliY2P179mwYYNydnZWpqamBk+WrFmzRlWqVEmZm5urypUrq9DQ0DSfmlu1apVycnJSBQoUUBUrVlQBAQFq5cqVCoC6fv26frnsPjU3evRo8lpaT7ko9b+nszZt2qR/7cKFC6pVq1aqUKFCqmjRoqpHjx7qr7/+SvNpmi1btqhq1aopMzMzVbZsWTVv3jzl7e2tihYtarCtq1atUvXr11eWlpbKwsJCOTo6qoEDB6qIiAij+5nb0nuiqm/fvgqAwXF+9uyZ8vPzU+XKlVOmpqbK3t5ejRw5Uj148IAsx3+Wa9asUe7u7qpkyZLKzMxMf6696snFZs2aqWLFiqmnT59maF/Se6KJ79/MmTMVAHXv3j3y+qBBg5SlpSV5LaPn7vPnz9WECRNUiRIllLm5uWrQoIEKDw9X1tbWavz48WSd9+7dU97e3qpChQrK1NRUFStWTNWuXVtNnTqVPPWZE1KPQeqXmZmZKlGihHJzc1Nz58595fUiMjJS9ezZU5UoUUKZmpoqOzs71bx5c/3TuC+vmx/f1N+vAwcO6F9LSUlRgYGB6oMPPlCmpqbK1tZW9e/fX928eZO8l18/wsLClIeHhypVqpR+29u1a6cOHz5M3vfkyRM1bdo05eTkpMzMzJS1tbWqVq2aGj9+vP56p5RS8fHxaujQoapIkSKqYMGCqlWrVurSpUvZfmour55jOqVysLenyFVJSUlwcXFBqVKlsGfPHq0356139+5dlCtXDmPHjkVQUJDWm5Nlx44dQ+PGjbF+/foceTJLCC63zzFpiN5g//d//4dWrVrpbzEFBwfj4MGD2LNnD1q2bKn15r21bt26hWvXrmH+/PnYv38/rly5Qgqzb7K9e/ciPDwctWvXhoWFBSIjIzFv3jxYW1vj3LlzpBAtRFZocY5leWQFkfseP34MHx8f3Lt3D6ampqhVqxZ27NghjVA2rVixAv7+/ihfvjzWr1//1jRCAFC4cGHs2bMHn332GR4/fgxbW1t4eHggICBAGiGRI7Q4x+Q/IiGEEJqSifGEEEJoShoiIYQQmpKGSAghhKakIRJCCKGpDD81xydjE2+n3Hw2hZ8j9evXJ5lPJZ+YmEgyf3rN2dmZ5NRBdVONGzfO6Db16tWL5NDQ0HSXL1iwIMlpTfT4Mn48jf2e8FEuLl26RHLqyM2p0pobKbNTXNepU4fkiIiIdJd/neeIeDvl9Dki/xEJIYTQlDREQgghNCUNkRBCCE1luEOr3NvNG3Lz/r+LiwvJkZGRufZZAPRTOrysSJEiJPOh918e5Tsr+PFLncQt1csz6gKGU5vwMQKjoqIyvQ1t2rQheffu3Zlex8v4zKWp0ynkBrmO5A1SIxJCCJGnSEMkhBBCU9IQCSGE0JTUiN4xuVkj2rZtG8mdO3cmOSQkhGReXxk9ejTJCxYsSPfzPDw8DF67cOECyTdu3Eh3HVxwcDDJH374YbrLG+tHxKcfL1SoEMm7du0iefLkySQHBAQYfGbv3r1Jvn37NsmHDh0i2cnJieSHDx+SHBsbS/KLFy8MPjOnyHUkb5AakRBCiDxFGiIhhBCakoZICCGEpqRG9I7JzRrR0qVLSeY1H0dHR5KnT59Osr+/P8nXrl3L9DYUL16c5Hv37pHcoEEDko8fP57u+jp16kQyr4MZU7lyZZJ5DcvY9vKx6QBg586dJGelL9LLeC1u4sSJ2Vpfel73dSQjn2fsd4L3DeM1NAsLC5ILFChAclJSEsm8Rvc2khqREEKIPEUaIiGEEJqShkgIIYSmpEb0jnmdc8107NiRZD6G2aRJk0jm99JzwpgxY0hOSEggmfdtsrOzIzkmJobkJk2akGxvb0/ypk2bSJ49ezbJfI4mXgfjfat4vyMAePz4scFrLxs/fjzJZ8+eJZnXMHhfpjd5PiJTU1OSy5QpQ3KFChVI7tevn8E6vL29SV65ciXJKSkpJN+9e5fktWvXksz7cd25c8fgM/MaqREJIYTIU6QhEkIIoSlpiIQQQmjKROsNeBVLS0uD19K6X54ePhdMXnh+/03Gfz783vq4ceNI7tChA8nbt28nmdcybG1tSU5rHDjeN+nIkSMk16lTx+A9L2vatCnJFy9eJLlkyZIklypVimRen5k5cybJCxcuJNlYn50uXboYvPbNN9+k+54ff/yR5Pz585N85coVkvfu3Zvu+t4k/LrQt29fkuvWrUsy7zcGGI6BWLRoUZJ5jYj3G4qLiyN5yZIlJOfLR/++53Uxvn4h/xEJIYTQmDREQgghNCUNkRBCCE3lWj8i/rw/H3OrW7duJJcrVy7d5QGgVq1aJBvb9OvXr5N87Ngxkvfv30/ymjVr0l1fXpCbfUT4mFzZ7RfEt5WPiebj42Pwnvnz55PM+yoZExgYSLKfnx/J/PfA2dmZ5EuXLpHs5eVF8vLly0nmdTPez6hPnz4G28jrSrzu5ObmRvLBgwdJrlevHsknT54k+U3qR2RtbU0yr8nx+gyvI6Z1HeHbcP/+fZJtbGxIDg0NJbl9+/Yk8/OczyHFz9vcPL6vYm5uTvLz58+ztT7pRySEECJPkYZICCGEpqQhEkIIoalcqxEFBQWRPGHChEy9PyPbkN37lPzeLh+XjI9BlRdoef+f31v/+OOPSS5btizJ/F79xo0bSe7du7fBZ/C+Yvw9I0aMIHnFihUkV6lSheSGDRuSbGwfeY3nxIkT6S7/OvA604YNG0jm/YySk5NzbVsyex3hfdN4Hxw+XxSvfzVr1sxgnbxvGa9nf/TRRyQnJiaSPGzYMJL5tY7PDxUdHU0y7z/3OvA+eXyfMktqREIIIfIUaYiEEEJoShoiIYQQmpKGSAghhKZy7WEF3jm0UaNGJPMJrDKCF9j4JGR8krIWLVqkuz6+T5GRkSTzDrR5QW4+rMA7d/LOhbxw/Mknn5AcERFBcu3atUnmReDy5csbbMPIkSNJ5g8rbN26lWQ+yRw/x3jHa44/zHD8+HGS3d3dST5w4ADJfADTgQMHkrx69WqDzxw8eDDJxh5GMKZ48eIk88Fqc1JmryMmJnRcZl505w8aFCtWjGT+AAwAnD59muSuXbuSzM8ZJycnknlHeP5QDf8d49chfg68jeRhBSGEEHmKNERCCCE0JQ2REEIITeVajYjjgxfye70ZwTf13r176a6T16W2bdtGMp9kS2pE2VOjRg2Sz507l631BQcHk8wnwktrokRe88msHj16kPzdd9+R/Nlnn5HMJ8LLLj4RXloTRK5bty7ddfCBPvlAoXwAzKNHj5LMJ37LSZm9jhjrxM6/zzOf1A4wPG94XfCff/4huWrVqiSHhISQzK8T/Pjx69DVq1cNtultIzUiIYQQeYo0REIIITQlDZEQQghNmRhfJGfwwShzQ6VKlUjm/QH4/fZ8+Wg7vGjRotzZsHfEH3/8kaPr4zUhbtCgQQavLV68ON1t8vX1Jfk///kPybGxsel+ZrVq1UjmNR0+ICbfh5s3b6a7fisrK5Jbt25tsAyvEbVt25ZkXpvj6+Q1oadPn6a7TVoyVovg389I7YJPRsgH3+UD1w4fPpxkXhP6888/SX7vvfdIvnbtmtFtetfJf0RCCCE0JQ2REEIITUlDJIQQQlOvrR9RbuBjkfFxxPjYcxzfJx8fH5I//fTTrG/cG+p1Tozn4eFB8s6dO0kuXLgwybz/Cx+rLiwsjGQLCwuDbXj27BnJfNy0VatWkdyqVSuSjfUd4+PbNWjQgOTQ0FCSjfV7mTFjBsm8b9xXX32V7vakhdc0+GfeuXOH5O3bt5Os5eSJrwMfv473q+J1woIFC5LMx0SsU6cOyc+fPyeZ1xH37dtHMp+g820g/YiEEELkKdIQCSGE0JQ0REIIITT1VteI+JxH/fr1y9T7+T7xccomTZpEMu+XlN1xzbSQm/f/+bw2f//9N8n857N+/XqSW7ZsSTLvD8PrPbzGBACPHj0i2Vg/k4ULF5Lcu3dvkvkYiXxOJH5/n9ct+Vw0vH7A+/SsWLGCZN4vCTCsS/E5kKpXr06yq6sryRcvXiQ5Ojo63e/npNy+jvD6T0pKitFt4GPP8Vozn0eLM1YH5GNc8uWLFClCMv894PvA+z+mNZ5ebpMakRBCiDxFGiIhhBCakoZICCGEpt7qGhHv8+Ht7U1yxYoVSeb3YqtUqUKysUPBxy07cuQIyXyumvPnz6e7Pi28zj4iHTp0IJkfv0uXLpHcvn17ksPDw0k+ffo0yRUqVDDYBl4/4Z+5a9cuknkfED8/P5IDAwNJjo+PJzkoKIjks2fPksznVOL7xGtSGcFrl/Pnz8/0Ol62fPlyknk/pJz0Jl5HOD5W3LFjx0h+8uQJyXwOKxsbG5J5v6Pr16+TfP/+fZJXr15NMh8/0c7OjmQtrjNSIxJCCJGnSEMkhBBCU9IQCSGE0NRbXSMyho9Flj9/fpK7d+9OMu/DkVl8nDA+1w3v96QFLfsRLV26lORRo0Zlav1ffvklyXv37jVY5scffyT54MGDJLu5uZHMzwneZ+OHH34gedmyZSSvXbuWZH4Mrl69SvKWLVtI5vMjHThwgOSRI0eC4zUMPgbfggULSOZjKDo6OpLM59PJ62PNZRbvS8bPGV7j4d8vU6YMyfx482PCjz8/J3nfs3bt2hlsMx9zMadJjUgIIUSeIg2REEIITUlDJIQQQlN5ukaUWSVKlCB52rRpJA8YMIBkPtaZsTGgNm/eTPLUqVMNtoHXFHJabt7/r1atGsl8HLbLly9nan03btwgmR+/iRMnGl3H5MmTSeY1GD5OmzH8/vy6detI5nUwjv9833//fZKNjVuWEbz2yY8b72tVqVIlkvmcTTnpXbiO8H00MzMjmf+e8DmseD+hAgUKkMyvM7xvG2D4M+VzLBljrG4lNSIhhBB5ijREQgghNCUNkRBCCE1JjSgT+Nh1Y8eOJZmPdWfs0P70008Gr/F+I7dv387MJhr1OvuIdOnShWTeh2fcuHE5vg2NGzcmuU2bNiTPmDEjU+vjY8mtXLmSZD4/0cCBA0nmY8k1b948U5/P+wwBQM+ePUkOCAggmfdb4XMg8Vpdnz59SObj8eUkuY4YzmvGx5LbvXs3ybxfEj+G//zzj8FnzJ07l2Q+ZmJ25zCSGpEQQog8RRoiIYQQmpKGSAghhKakRpSDOnXqRDK/92tvb290HbzPR1bmq0nP6+xHxOshX3zxBcn8eA0bNoxkPq6bsXHjAMNx17KL1194vyRjTp48SfLChQtJnjNnDslp1YS4a9eukczHkuPj2XGDBw8mmY9VxufXyUlyHTHeR4fXhFxdXUnmYyzy5QHD60jfvn1J5n38MktqREIIIfIUaYiEEEJoShoiIYQQmpIaUS5q0aIFybx/QEaYmJjk1OYAeLPmmrG1tSWZ3+vm42Pxfly8VgIYH++P4zU43seD4/MN8XlfbGxsSOZ1wg4dOpDM61yrV68mOa3x+XjfqMwe9yZNmpB85MgRkt+kcyQ38N+pZs2akRwZGUkyn2cst/FjxGuAQ4YMIdnZ2dlgHR9//DHJfN6s7I5pKTUiIYQQeYo0REIIITQlDZEQQghNvdU1Ig8PD5ItLS1J5s/S5zb++Tt27CCZj4OWEW9zjYjvb7169Uj+9NNPc21bUi1ZsoRk3q9nzZo1JPO+TXzunl69epHM61bGbN26leS2bduSzOeeefLkicE6rKysSOY1o6ZNm5I8ffr0TG1jXq8R1a9fn2Q+RxUfQ9LT05NkPoaiMZntN5SSkkIyrwnx/nhp4bVF3qfv8ePH6W6TMVIjEkIIkadIQySEEEJT0hAJIYTQ1FtdI/rmm29I7tevH8kXLlwgmT9bf+jQoUx9Hp8HnteoJk2aRLKxPix8nhjAcN4Qf3//zGyiUbl5/79gwYIk8z42xlhYWJDM+wQlJCRkepsePXpE8v79+0muXLkyybwecOXKFZLnzZtHMu83xMfLM+b06dMk165d2+h7bty4QXL58uVJHjBgAMm8D8n3339PMv+9OHPmjNFtyKo34TrCt4H3R+PHk+NjvfFzhtfk+PiHX331VbrbY25uTjIff5DXjRMTEw22sXPnziQfPXqUZF57lBqREEKId5o0REIIITQlDZEQQghNvdU1In4fdP78+SRXqFAh3fcbe74/szK7Pl7TAoDQ0NBsbYMxuVkj8vb2JvnLL78kmfez4vnu3bs5vk28JuTu7p7jn5Ee3m+In7PcqVOnSE6rDxAfs3DkyJEkb9u2jeSqVauSvGfPHpJr1apFMq9b5aS34TqyYsUKkosUKULy/fv3SeZzVvG6Lu/3xWunvDZqTExMDMlPnz41WKZVq1Yk83Eb+XukRiSEEOKdJg2REEIITUlDJIQQQlPSEAkhhNDUW/2wAsc7F/JBNbt3704yH2Aypx9W+O2330gODw8n2dfX12AdfDDCnPY6B7TkE3ZdunSJ5OrVq5N87ty5bG8DH9CSd07khefg4OB018c7f3br1o3khw8fkmxtbU3yrVu3SC5dujTJycnJJPOJ+Xjn1KzgHVqNrTOvD3pqDD9H+MMEvDPoiRMnSOYd3435559/SDYzMyOZn0ODBw8mmV9XAMPO9LzTa3Z/xvKwghBCiDxFGiIhhBCakoZICCGEpvJUjcgYPsCln58fycYOxcWLF0n+6aefSOaDdN68eZPkBw8eZGg7c1Nu3v/nExHywRs7dOiQ7vt57YLf+3ZwcCCZD2gKAGfPnjW2mekqXrw4ycY62fLljQ1i+vfff5PMB77lxywtfHBZPqBlzZo1ja7jZceOHSO5YcOGmXp/ZryN1xE+yCiv6/EaUqNGjUjetWsXyb/++ivJ/PeGf96+ffvSfT+vKQGGnWb5cTc2ILMxUiMSQgiRp0hDJIQQQlPSEAkhhNDUO1UjEq+3j0jFihVJ5hOQZRYfVHXdunUGy9SoUYPkAwcOkOzj40My7xNirF8R75cUFRVFMh/Iln8+7+e0ZcsWkvkApmlNUsf7nfC+S3/++We6y/OJ1vjgtC4uLgafmVPkOpI3SI1ICCFEniINkRBCCE1JQySEEEJTUiN6x+RmjYiP0cUnGOvRo0e21s/7a/C+Ehnh6OhIMq9b8boWr7dwTZo0Ifn3338nOT4+nmRef8luvycAmDVrFsl37twhecyYMSRXq1aNZH4MjE0omR1yHckbpEYkhBAiT5GGSAghhKakIRJCCKEpqRG9Y15nPyI+Bhbvz8KVLVuW5KFDh5LM533ZsWOH0W3g+HxCfDw7Pl/QyZMn010f9/HHH5M8ffp0kvmcWXFxcemuj/cRAgznQOKM1aH4+He8DpabYyLKdSRvkBqREEKIPEUaIiGEEJqShkgIIYSmpEb0jnmdNaJSpUqRHB0dTfIHH3xA8pUrV7K9DXystu3bt5N869Ytknk9hfeh4WPBZbcf0BdffEEy71tlb29vdB1dunQhmW8jx+fhunDhAsn8mC1dutToNmSVXEfyBqkRCSGEyFOkIRJCCKEpaYiEEEJoKsM1IiGEECI3yH9EQgghNCUNkRBCCE1JQySEEEJT0hAJIYTQlDREQgghNCUNkRBCCE1JQySEEEJT0hAJIYTQlDREQgghNPX/QRVfCoqx1WkAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 500x300 with 3 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 102/10: 100%|██████████| 469/469 [00:05<00:00, 78.85it/s, loss=2.08e-7]\n",
      "Epoch 103/10: 100%|██████████| 469/469 [00:05<00:00, 80.29it/s, loss=1.32e-8]\n",
      "Epoch 104/10: 100%|██████████| 469/469 [00:05<00:00, 81.63it/s, loss=1.74e-7]\n",
      "Epoch 105/10: 100%|██████████| 469/469 [00:05<00:00, 80.89it/s, loss=3.43e-9]\n",
      "Epoch 106/10: 100%|██████████| 469/469 [00:05<00:00, 81.20it/s, loss=1.13e-6]\n",
      "Epoch 107/10: 100%|██████████| 469/469 [00:05<00:00, 81.48it/s, loss=1.3e-6]\n",
      "Epoch 108/10: 100%|██████████| 469/469 [00:06<00:00, 76.70it/s, loss=3.8e-8]\n",
      "Epoch 109/10: 100%|██████████| 469/469 [00:05<00:00, 80.49it/s, loss=5.6e-7]\n",
      "Epoch 110/10: 100%|██████████| 469/469 [00:05<00:00, 81.52it/s, loss=6.07e-7]\n",
      "Epoch 111/10: 100%|██████████| 469/469 [00:05<00:00, 83.92it/s, loss=6.05e-7]\n",
      "Epoch 112/10: 100%|██████████| 469/469 [00:05<00:00, 79.60it/s, loss=6.54e-7]\n",
      "Epoch 113/10: 100%|██████████| 469/469 [00:05<00:00, 82.00it/s, loss=5.77e-8]\n",
      "Epoch 114/10: 100%|██████████| 469/469 [00:05<00:00, 83.23it/s, loss=1.34e-7]\n",
      "Epoch 115/10: 100%|██████████| 469/469 [00:05<00:00, 79.72it/s, loss=3.14e-7]\n",
      "Epoch 116/10: 100%|██████████| 469/469 [00:05<00:00, 84.08it/s, loss=1.61e-8]\n",
      "Epoch 117/10: 100%|██████████| 469/469 [00:05<00:00, 79.15it/s, loss=3.91e-8]\n",
      "Epoch 118/10: 100%|██████████| 469/469 [00:05<00:00, 82.06it/s, loss=3.05e-7]\n",
      "Epoch 119/10: 100%|██████████| 469/469 [00:05<00:00, 82.68it/s, loss=1.19e-6]\n",
      "Epoch 120/10: 100%|██████████| 469/469 [00:05<00:00, 80.77it/s, loss=3.47e-7]\n",
      "Epoch 121/10: 100%|██████████| 469/469 [00:05<00:00, 79.70it/s, loss=1.14e-8]\n",
      "Epoch 122/10: 100%|██████████| 469/469 [00:05<00:00, 81.58it/s, loss=1.68e-7]\n",
      "Epoch 123/10: 100%|██████████| 469/469 [00:05<00:00, 80.07it/s, loss=4.03e-7]\n",
      "Epoch 124/10: 100%|██████████| 469/469 [00:05<00:00, 83.25it/s, loss=3.96e-9]\n",
      "Epoch 125/10: 100%|██████████| 469/469 [00:05<00:00, 79.80it/s, loss=2.21e-9]\n",
      "Epoch 126/10: 100%|██████████| 469/469 [00:05<00:00, 83.95it/s, loss=3.45e-7]\n",
      "Epoch 127/10: 100%|██████████| 469/469 [00:05<00:00, 78.44it/s, loss=1.64e-7]\n",
      "Epoch 128/10: 100%|██████████| 469/469 [00:05<00:00, 78.51it/s, loss=2.68e-7]\n",
      "Epoch 129/10: 100%|██████████| 469/469 [00:05<00:00, 78.19it/s, loss=2.45e-7]\n",
      "Epoch 130/10: 100%|██████████| 469/469 [00:06<00:00, 77.90it/s, loss=5.44e-7]\n",
      "Epoch 131/10: 100%|██████████| 469/469 [00:06<00:00, 77.28it/s, loss=3.2e-8]\n",
      "Epoch 132/10: 100%|██████████| 469/469 [00:05<00:00, 81.97it/s, loss=5.05e-10]\n",
      "Epoch 133/10: 100%|██████████| 469/469 [00:05<00:00, 79.15it/s, loss=7.37e-7]\n",
      "Epoch 134/10: 100%|██████████| 469/469 [00:05<00:00, 79.89it/s, loss=2.2e-7]\n",
      "Epoch 135/10: 100%|██████████| 469/469 [00:05<00:00, 82.13it/s, loss=2.45e-7]\n",
      "Epoch 136/10: 100%|██████████| 469/469 [00:05<00:00, 82.66it/s, loss=1.91e-7]\n",
      "Epoch 137/10: 100%|██████████| 469/469 [00:05<00:00, 80.14it/s, loss=6e-9]\n",
      "Epoch 138/10: 100%|██████████| 469/469 [00:05<00:00, 79.85it/s, loss=5.59e-9]\n",
      "Epoch 139/10: 100%|██████████| 469/469 [00:05<00:00, 80.80it/s, loss=7.78e-8]\n",
      "Epoch 140/10: 100%|██████████| 469/469 [00:05<00:00, 80.12it/s, loss=1.68e-6]\n",
      "Epoch 141/10: 100%|██████████| 469/469 [00:05<00:00, 83.77it/s, loss=5.37e-7]\n",
      "Epoch 142/10: 100%|██████████| 469/469 [00:05<00:00, 80.17it/s, loss=1.84e-8]\n",
      "Epoch 143/10: 100%|██████████| 469/469 [00:05<00:00, 80.73it/s, loss=5.15e-9]\n",
      "Epoch 144/10: 100%|██████████| 469/469 [00:05<00:00, 78.35it/s, loss=8.18e-8]\n",
      "Epoch 145/10: 100%|██████████| 469/469 [00:05<00:00, 79.39it/s, loss=2.57e-8]\n",
      "Epoch 146/10: 100%|██████████| 469/469 [00:05<00:00, 79.88it/s, loss=7.05e-8]\n",
      "Epoch 147/10: 100%|██████████| 469/469 [00:05<00:00, 80.47it/s, loss=4.25e-8]\n",
      "Epoch 148/10: 100%|██████████| 469/469 [00:05<00:00, 82.42it/s, loss=4.73e-8]\n",
      "Epoch 149/10: 100%|██████████| 469/469 [00:05<00:00, 81.64it/s, loss=6.3e-11]\n",
      "Epoch 150/10: 100%|██████████| 469/469 [00:05<00:00, 81.11it/s, loss=1.44e-7]\n",
      "Epoch 151/10: 100%|██████████| 469/469 [00:05<00:00, 83.62it/s, loss=1.18e-6]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaIAAACaCAYAAADrVUwbAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAmMElEQVR4nO3deXiMV/sH8O+QiURCSIIIaom3sQQRO7UEQUjte22xRG1Rqrx28tKIrbQoWltLrf3hkqL2UhKVInaqagtJNMSSEBLn90ffzJv7nmQmqyfi/lyXP76TZ545M/OYM/PczzlHp5RSEEIIITSST+sGCCGEeLdJRySEEEJT0hEJIYTQlHREQgghNCUdkRBCCE1JRySEEEJT0hEJIYTQlHREQgghNCUdkRBCCE1luSP68ssvodPp4Obmlul93Lt3DzNmzMDZs2ez2px0adasGZo1a5au7bLyvMQ/1q5dC51OBysrK9y6dcvo71l5ndP7XmaXmzdvQqfTYf78+W/sMd8Gye9x8j8rKys4OTnB09MTgYGBiI6O1rqJxIABA1CuXLk3/rg6nQ4zZswwuc27eIxluSNavXo1AODixYs4efJkpvZx7949zJw58411REIbCQkJmDJlSrbuc9myZVi2bFm27lNk3po1axASEoL9+/dj6dKlcHd3R1BQECpXrowDBw5o3TyDqVOnYvv27Vo3Q/xXljqisLAwhIeHo127dgCAVatWZUujRN7Upk0b/PDDDwgPD8+2fVapUgVVqlTJtv2JrHFzc0P9+vXRuHFjdOnSBV988QXOnTsHGxsbdO7cGVFRUVo3EQDg4uKCmjVrat0M8V9Z6oiSO545c+agYcOG2LRpE+Lj4422i4iIgJ+fH8qUKQNLS0s4Ozuja9euiIqKwpEjR1CnTh0AgK+vr+GnffLP17ROvaT203rmzJmoV68e7O3tUbhwYXh4eGDVqlXIznlddTodRo4ciTVr1sDV1RXW1taoXbs2QkNDoZTCvHnzUL58edja2qJ58+a4fv06uf/+/fvRoUMHlC5dGlZWVqhYsSKGDh2Kv//+2+ixdu7cierVq6NAgQKoUKECFi9ejBkzZkCn05HtlFJYtmwZ3N3dYW1tjaJFi6Jr1664ceNGtj3v7DB+/Hg4ODhgwoQJZrd98eIFJk6ciPLly8PS0hKlSpXCiBEjEBsbS7ZL7fj4+uuvUaNGDdja2qJQoUKoVKkSJk2aBOCf0x4WFhYIDAw0esyjR49Cp9Nh69atGXpeyaelDh06hCFDhsDBwQGFCxdGv379EBcXh8jISHTv3h1FihRByZIlMW7cOLx69YrsI73HbkJCAj799FM4OTmhYMGCaNKkCX7//XeUK1cOAwYMINtGRkZi6NChKF26NCwtLVG+fHnMnDkTiYmJGXp+WfXee+9hwYIFePr0KVasWEH+FhYWhvbt28Pe3h5WVlaoWbMmtmzZQrZJfn0PHz6MYcOGwdHREQ4ODujcuTPu3btHtn39+jXmzp2LSpUqoUCBAihevDj69euHu3fvku1S+/zYunUr6tWrBzs7OxQsWBAVKlTAwIEDyTZPnjzBuHHjyHH5ySefIC4uzmi75GPB1tYWbdq0wbVr1zLz8pHXIK8eYxaZelUAPH/+HBs3bkSdOnXg5uaGgQMHYvDgwdi6dSv69+9v2C4iIgJ16tTBq1evMGnSJFSvXh0xMTH4+eef8ejRI3h4eGDNmjXw9fXFlClTDL+uSpcuneE23bx5E0OHDsV7770HAAgNDcWoUaMQERGBadOmZfapGgkODsaZM2cwZ84c6HQ6TJgwAe3atUP//v1x48YNLFmyBI8fP8bYsWPRpUsXnD171tB5/Pnnn2jQoAEGDx4MOzs73Lx5EwsXLsQHH3yA8+fPQ6/XAwD27t2Lzp07o0mTJti8eTMSExMxf/78VL9RDh06FGvXroW/vz+CgoLw8OFDBAQEoGHDhggPD0eJEiWy7blnRaFChTBlyhSMHj0ahw4dQvPmzVPdTimFjh074uDBg5g4cSIaN26Mc+fOYfr06QgJCUFISAgKFCiQ6n03bdqE4cOHY9SoUZg/fz7y5cuH69ev49KlSwCAcuXKoX379li+fDnGjx+P/PnzG+67ZMkSODs7o1OnTpl6foMHD0bnzp2xadMmnDlzBpMmTUJiYiKuXr2Kzp07w8/PDwcOHEBQUBCcnZ0xduxYw33Te+z6+vpi8+bNGD9+PJo3b45Lly6hU6dOePLkCWlLZGQk6tati3z58mHatGlwcXFBSEgIZs2ahZs3b2LNmjWZeo6Z1bZtW+TPnx9Hjx413Hb48GG0adMG9erVw/Lly2FnZ4dNmzahR48eiI+PN/rQGzx4MNq1a4cffvgBd+7cwWeffYY+ffrg0KFDhm2GDRuGlStXYuTIkfDx8cHNmzcxdepUHDlyBKdPn4ajo2Oq7QsJCUGPHj3Qo0cPzJgxw1DPTLnv+Ph4NG3aFHfv3jV8ll28eBHTpk3D+fPnceDAAeh0OsPxe+LECUybNg116tTB8ePH4e3tneXXMc8eYyqTvvvuOwVALV++XCml1NOnT5Wtra1q3Lgx2W7gwIFKr9erS5cupbmvU6dOKQBqzZo1Rn9r2rSpatq0qdHt/fv3V2XLlk1zn0lJSerVq1cqICBAOTg4qNevX5vdZ2qPXbVqVXIbAOXk5KSePXtmuG3Hjh0KgHJ3dyePs2jRIgVAnTt3LtX9v379Wr169UrdunVLAVA7d+40/K1OnTqqTJkyKiEhwXDb06dPlYODg0r5toWEhCgAasGCBWTfd+7cUdbW1mr8+PFmn2dOW7NmjQKgTp06pRISElSFChVU7dq1Da8Vf5337t2rAKi5c+eS/WzevFkBUCtXrjTcxt/LkSNHqiJFiphsz+HDhxUAtX37dsNtERERysLCQs2cOdPkff/66y8FQM2bN8/o+Y0aNYps27FjRwVALVy4kNzu7u6uPDw80nyMtI7dixcvKgBqwoQJZPuNGzcqAKp///6G24YOHapsbW3VrVu3yLbz589XANTFixdNPs+MSvkep6VEiRKqcuXKhlypUiVVs2ZN9erVK7Kdj4+PKlmypEpKSiL7Hj58ONlu7ty5CoC6f/++Ukqpy5cvp7rdyZMnFQA1adIkw2388yP5dYmNjU2z/YGBgSpfvnxGz3Hbtm0KgNq9e7dSSqk9e/YoAGrx4sVku9mzZysAavr06Wk+hlLv5jGW6VNzq1atgrW1NXr27AkAsLW1Rbdu3XDs2DH88ccfhu327NkDT09PVK5cObMPlW6HDh1Cy5YtYWdnh/z580Ov12PatGmIiYnJ1qt2PD09YWNjY8jJz83b25ucNku+PeWVYtHR0fj4449RpkwZWFhYQK/Xo2zZsgCAy5cvAwDi4uIQFhaGjh07wtLS0nBfW1tbfPjhh6QtwcHB0Ol06NOnDxITEw3/nJycUKNGDRw5ciTbnnd2sLS0xKxZsxAWFmZ0CiZZ8rdQ/o24W7dusLGxwcGDB9Pcf926dREbG4tevXph586dqZ7ybNasGWrUqIGlS5cablu+fDl0Oh38/Pwy8az+4ePjQ3Ly+5/8Kz/l7fzqwfQcu7/88gsAoHv37uS+Xbt2hYUFPbkRHBwMT09PODs7k+Mi+Vt58r7eJJXiFND169dx5coVfPTRRwBA2ti2bVvcv38fV69eJfdv3749ydWrVwfwv/9fhw8fBmB83NStWxeVK1c2edwklwe6d++OLVu2ICIiwmib4OBguLm5wd3dnbS3devW0Ol0hv9rye1Ifm7Jevfunebjp1dePcYy1RFdv34dR48eRbt27aCUQmxsLGJjY9G1a1cA/7uSDgAePHiQqdNsGfXbb7+hVatWAIBvvvkGx48fx6lTpzB58mQA/5xKzC729vYkJ3cWad3+4sULAP+cv27VqhX+7//+D+PHj8fBgwfx22+/ITQ0lLTx0aNHUEqlekqN3xYVFWXYVq/Xk3+hoaGpfhBrrWfPnvDw8MDkyZONzmMDQExMDCwsLFCsWDFyu06ng5OTE2JiYtLcd9++fbF69WrcunULXbp0QfHixVGvXj3s37+fbOfv74+DBw/i6tWrePXqFb755ht07doVTk5OmX5eGTkuko8JIP3HbvLz5seAhYUFHBwcyG1RUVHYtWuX0TFRtWpVAHjjx0VcXBxiYmLg7OxsaB8AjBs3zqiNw4cPT7WN/Dkmn57lr0/JkiWNHt/Z2dnkcdOkSRPs2LEDiYmJ6NevH0qXLg03Nzds3LjRsE1UVBTOnTtn1N5ChQpBKWVob/Lxy9ublWMrWV49xjJVI1q9ejWUUti2bRu2bdtm9Pd169Zh1qxZyJ8/P4oVK2ZUKMwIKysrPH782Oh2/iQ3bdoEvV6P4OBgWFlZGW7fsWNHph87u124cAHh4eFYu3YtqaPxCxqKFi0KnU6Xaj0oMjKSZEdHR+h0Ohw7dizVuklatRQt6XQ6BAUFwcvLCytXrjT6u4ODAxITE/HgwQPSGSmlEBkZafj2mhZfX1/4+voiLi4OR48exfTp0+Hj44Nr164Zfn327t0bEyZMwNKlS1G/fn1ERkZixIgR2ftE0ym9x27yB0FUVBRKlSpluD0xMdHoQ9bR0RHVq1fH7NmzU33M5A7hTfnpp5+QlJRkuLAkuVYzceJEdO7cOdX7uLq6Zugxkl+f+/fvG335vXfvXpr1oWQdOnRAhw4dkJCQgNDQUAQGBqJ3794oV64cGjRoAEdHR1hbW5Mv2ikl7z/5+I2JiSEf3vz/7puU24+xDHdESUlJWLduHVxcXPDtt98a/T04OBgLFizAnj174OPjA29vb3z//fe4evVqmgcW/2aTUrly5bB161YkJCQYtouJicGJEydQuHBhw3Y6nQ4WFhak+Pz8+XN8//33GX2KOSb5tB3vHPiVRDY2NqhduzZ27NiB+fPnG771PHv2DMHBwWRbHx8fzJkzBxEREUY/p3Ozli1bwsvLCwEBAShTpgz5W4sWLTB37lysX78eY8aMMdz+448/Ii4uDi1atEjXY9jY2MDb2xsvX75Ex44dcfHiRUNHZGVlBT8/PyxZsgQnTpyAu7s7GjVqlH1PMAPSe+w2adIEALB582Z4eHgYbt+2bZvRVUo+Pj7YvXs3XFxcULRo0RxsvXm3b9/GuHHjYGdnh6FDhwL4p5P517/+hfDwcHz++efZ8jjJF7+sX7+efFk5deoULl++bPj2b06BAgXQtGlTFClSBD///DPOnDmDBg0awMfHB59//jkcHBxQvnz5NO/v6emJuXPnYsOGDfD39zfc/sMPP2TymWVdbj/GMtwR7dmzB/fu3UNQUFCql1W7ublhyZIlWLVqFXx8fBAQEIA9e/agSZMmmDRpEqpVq4bY2Fjs3bsXY8eORaVKleDi4gJra2ts2LABlStXhq2tLZydneHs7Iy+fftixYoV6NOnD4YMGYKYmBjMnTuXdELAP+dIFy5ciN69e8PPzw8xMTGYP39+rvpFkPxc//3vf0MpBXt7e+zatcvotBEABAQEoF27dmjdujVGjx6NpKQkzJs3D7a2tnj48KFhu0aNGsHPzw++vr4ICwtDkyZNYGNjg/v37+PXX39FtWrVMGzYsDf5NNMtKCgItWrVQnR0tOHnPAB4eXmhdevWmDBhAp48eYJGjRoZrpqrWbMm+vbtm+Y+hwwZAmtrazRq1AglS5ZEZGQkAgMDYWdnZ/RLavjw4Zg7dy5+//33VL9UvSnpPXarVq2KXr16YcGCBcifPz+aN2+OixcvYsGCBbCzs0O+fP870x4QEID9+/ejYcOG8Pf3h6urK168eIGbN29i9+7dWL58eY6cMr9w4YKhVhAdHY1jx45hzZo1yJ8/P7Zv305+4a5YsQLe3t5o3bo1BgwYgFKlSuHhw4e4fPkyTp8+neHL6F1dXeHn54evvvoK+fLlg7e3t+GquTJlypAvNdy0adNw9+5dtGjRAqVLl0ZsbCwWL14MvV6Ppk2bAgA++eQT/Pjjj2jSpAnGjBmD6tWr4/Xr17h9+zb27duHTz/9FPXq1UOrVq3QpEkTjB8/HnFxcahduzaOHz+u6ZfiXH+Mpfuyhv/q2LGjsrS0VNHR0Wlu07NnT2VhYaEiIyOVUv9cwTVw4EDl5OSk9Hq9cnZ2Vt27d1dRUVGG+2zcuFFVqlRJ6fV6oytL1q1bpypXrqysrKxUlSpV1ObNm1O9am716tXK1dVVFShQQFWoUEEFBgaqVatWKQDqr7/+MmyX1avmRowYQW5L7SoXpf53ddbWrVsNt126dEl5eXmpQoUKqaJFi6pu3bqp27dvp3o1zfbt21W1atWUpaWleu+999ScOXOUv7+/Klq0qFFbV69ererVq6dsbGyUtbW1cnFxUf369VNhYWFmn2dOM3VFVe/evRUAo9f5+fPnasKECaps2bJKr9erkiVLqmHDhqlHjx6R7fh7uW7dOuXp6alKlCihLC0tDcdaWlcuNmvWTNnb26v4+Ph0PRdTVzTx5zd9+nQFQD148IDc3r9/f2VjY0NuS++x++LFCzV27FhVvHhxZWVlperXr69CQkKUnZ2dGjNmDNnngwcPlL+/vypfvrzS6/XK3t5e1apVS02ePJlc9Zkdkl+D5H+WlpaqePHiqmnTpurzzz9P8/MiPDxcde/eXRUvXlzp9Xrl5OSkmjdvbrgaN+W++eub/P/r8OHDhtuSkpJUUFCQev/995Ver1eOjo6qT58+6s6dO+S+/PMjODhYeXt7q1KlShna3rZtW3Xs2DFyv2fPnqkpU6YoV1dXZWlpqezs7FS1atXUmDFjDJ93SikVGxurBg4cqIoUKaIKFiyovLy81JUrV7J81VxePcZ0SmXjaE+Ro169egV3d3eUKlUK+/bt07o5b73o6GiULVsWo0aNwty5c7VuTqadOHECjRo1woYNG7LlyiwhuJw+xqQjysUGDRoELy8vwymm5cuX45dffsG+ffvQsmVLrZv31rp79y5u3LiBefPm4dChQ7h27RopzOZm+/fvR0hICGrVqgVra2uEh4djzpw5sLOzw7lz50ghWojM0OIYy/TMCiLnPX36FOPGjcODBw+g1+vh4eGB3bt3SyeURd9++y0CAgJQrlw5bNiw4a3phACgcOHC2LdvHxYtWoSnT5/C0dER3t7eCAwMlE5IZAstjjH5RSSEEEJTsjCeEEIITUlHJIQQQlPSEQkhhNCUdERCCCE0le6r5vhibOLtlJPXpvCZq/mUJilnuwZAZlMAgL/++ovkUaNGkfwmVvccPHgwyRmdcYGvZcSXo65VqxbJyWvDpLV9dnB3dyf57NmzJPOZN3Jy6XX5HMkbsvtzRH4RCSGE0JR0REIIITQlHZEQQghNpXtAq5zbzRtyskaU0WOEz6LNZyeuWLEiyXzdpszgs7annBYfMF4Dh2+fcmVewHgVypQzowPGS34kJCSkv7Ewru8AxjUec4oXL04yX6Dt3LlzJOemY0TkTlIjEkIIkadIRySEEEJT0hEJIYTQlMy+Ld6YAQMGkKzX60lOufoj8M/S6CnVrVvX5P4AYPz48SS/ePGCZL4c+J49e0iOj48nma8wOXz4cJL5uJ+jR4+SnLz0clrM1ckyWg9KTXR0NMk1atQwmYV40+QXkRBCCE1JRySEEEJT0hEJIYTQlIwjesfkpjEiHTt2JJmP4QkKCiKZj4fhtQ8A6NevH8m8RnT37l2ST5w4QXKVKlVIvnr1KsnW1tYk8zrXo0ePSG7Tpg3JfFwRr0EdO3aMZD7GJzPq169PcmhoKMl8Pr1BgwZl+THTIp8jeYOMIxJCCJGnSEckhBBCU9IRCSGE0JTUiN4xb7JGNHr0aJIXLVpkcns+Jmfv3r0k8/rKN998Y7YNfFxQTEwMyV999RXJM2bMMNpnRvD98TWVJk+eTPLs2bNJ5s9p8eLFRo9x4cIFk23g47Fev35NMq9DffDBByb3l53kcyRvkBqREEKIPEU6IiGEEJqSjkgIIYSmZK45kW18fHxI7tSpE8nNmzcnma+1w+dxmzhxIsn+/v4ZbpObmxvJp0+fJtlcTejUqVMk79y5k+RZs2aRzGtC77//PslxcXEke3p6kszHFaVWD+L7vHbtGsm8JlSqVCmSGzdubLTPlHKyjihEauQXkRBCCE1JRySEEEJT0hEJIYTQlIwjesfk5Pn/2rVrk/z777+b3J7PgXbp0iWSV65cSfL06dNJ5vPAAUCtWrVIvnHjBsl8Lris6tOnD8nr16/P1v2n5uuvvybZw8OD5Hr16mVp/7lpPsKckD9/fpJ5TU1qZObJOCIhhBB5inREQgghNCUdkRBCCE3lmhpRmTJlSB4yZIjRNvfv3yeZrx3DtW7dmuTnz5+TXK1aNZL5S7FmzRqT9+f1hmnTpplsT26Qk+e///jjD5K/+OILknltw97enuSHDx+a3P+VK1dIfvnypdE2y5YtI7lhw4YkBwQEkMzH4PA6VEJCAsmDBw8meenSpST36tWL5JCQEJLHjBlDMp8/r2/fviTztYIA4Pz58yTz49jX15fk4OBgkmvWrEnyvn37SH6bakR8f/z15McYYPwe8DWsGjVqRPLNmzdJ7tmzJ8m85sTnTKxQoQLJZ8+eNWrT20ZqREIIIfIU6YiEEEJoSjoiIYQQmpKOSAghhKY0u1jhP//5D8n84gRHR0ezbchqwSyr++PFdz7hZW70JgvR5l7fjz76iORChQqRvHz5cpKtrKxI5hePvAn8QoFVq1aRzBf/27JlC8k9evTI9jbxgcF8oC+/qIZPghoREUFybrpYgW/PF/3juXz58iSndtHTuHHjSE5MTCSZD3C1tLRMX2PT8ODBA5L5Zxu/oCY0NDRLjwcYv24WFnR+61evXmVp/3KxghBCiDxFOiIhhBCako5ICCGEpjRbGC89NSHxdmvfvj3J27ZtI1mv15PMzzvzGlx8fDzJ69atM3rMhQsXknzu3DmS+XHH6wXOzs4kv3jxguSCBQuSXLJkSaM2pNS9e3eSM1oj4vUfAPD29ib56dOnJJurMbRq1YpkPlA8N+P1m7Jly5LMFw1MrY7IjzNeP8nu+kexYsVM/n337t0k8wHK9+7dIzk97ePb8FpabpO7WyeEECLPk45ICCGEpqQjEkIIoSnNakR80TM+0WBqzI1L4TUHfn0+xyfp5Ody7ezszLZJpC0mJoZk/v5w5t5fBwcHkoOCgoz2ER4ebjLXqFHDZBsmTZpEcoMGDUjm4y8mTJhA8q1bt0ieOXMmyfyY5K8RrzkdOXLEqI28VsbrYD/99BPJ7dq1I5mPi3FycjJ6jNyCHwN8gtG4uDiS+cS4fBJUADh58iTJfEFHXk/J6Qmf+fg5XrPLjppVahME5ybyi0gIIYSmpCMSQgihKemIhBBCaCrXLIyXHfh4AD6mgIuMjCS5cOHCJIeFhZFcpEgRkmWuOcrW1pZkPu/XhQsXsvXxmjVrZnQbX8TM1dWV5J9//pnkY8eOkczHEbm4uJDM50h8/PgxySNGjCC5RYsWJN+4ccOozSmVKFGC5HLlyhltw2sc5vTv35/ktWvXkszrYHwxv+yU1c8Rc3VEXt+xtrY22gevK3F8TkM+lozfn4/j4u+hOc+ePSOZ14xyI5lrTgghRJ4iHZEQQghNSUckhBBCU5qNI8oJfF2RS5cuZej+Dx8+JDkpKYlkfn6a14jedXycT1ZrQtu3byd5/fr1JP/6669G9xkwYADJfKzR/PnzSd66dSvJX375pck2VahQgeTevXuT3LlzZ5L5uCBzNSJ+7j0hIcHk9gAwbNgwkvmaSXxOvu+++47k1GptuZW52gSfi85cPSg1vCbE8WNs7NixJGe0RrRs2bIMbZ8XyS8iIYQQmpKOSAghhKakIxJCCKGpPFUjym78fDTPbm5uJGe0JpXXlC5dmmS+FsyDBw9I3rhxI8m9evUi+eDBgyT/+OOPZtuQ2vxzKfH1h8zh6wl99NFHJL/33nsk87nq6tSpY3L/vN7w/fffk8zHsqX2mMWLFyc5tdpZSo8ePSKZ172EaXwsGq9L8XW3+OcGrzXz9aWmTp1Kcm6fJy47yC8iIYQQmpKOSAghhKakIxJCCKEpqRGlULVqVZILFixocns+Toafq+fjO3777TeS9+zZk9Em5mp8/AWvCXFDhgwx+fclS5ZkuU09evQgma+DZW7NKb7+UP369Unm84xNnDiR5H379pHM6zuVKlUi2dz6OwDg5eVFMl/ziONrffE5E0XG8LnhmjZtavLvfA5Gjs8nyGutERERJKdnbNnbRn4RCSGE0JR0REIIITQlHZEQQghN5ekaET8X7ujoSPL169dJ5ufezdWI+BiQatWqmdxfeHg4yXmtRlSsWDGSp0yZQnKrVq1I5nP7dejQgWQ+3oLXe1KbE2znzp0kb968meSYmBij+6TExwnx/fn4+JDMa0h8fSNu8eLFJPO56PjcdXyMCgDs37+fZF5DqFGjBsm8puDu7m4yC9P4e8LHu2V0XTIbGxuSjx8/TvLff/9NMn9/UztG3jbyi0gIIYSmpCMSQgihKemIhBBCaEqn0rn4eFbXmtcCX2eFrz0TGRlJctmyZUk2VyPi+GvEX9pu3bqRzNfbeROye635lFxdXUm+du2aye27dOlCMp9Ljo+xuXLlitk28HE6t2/fNrk9n5tu4MCBJMfHx5vcvzm7du0imY8L4vMVLlq0iOTU1qrh41LM1b2qVKlCsrk5EXPyGHkbP0fy5aPf13ntma/3xOuI5sTGxpIcEBBA8pkzZ0g+deqUyfYBxuPbslt2HyPyi0gIIYSmpCMSQgihKemIhBBCaCpPjyPi7O3tSebjirJ6PT4/V8v3Z26dmLfdihUrSPb09DS5vbn1hczVhHgNCQASExNN3ofj5/MLFy5MMj9Gpk+fTjKf542PExo9erTJx+f1HT42KrV5xczNNcbnPHzX18nKKl4P4ePfateunaX985ofrxmdPHmSZH6MxsXFGe3T3GdRbiO/iIQQQmhKOiIhhBCako5ICCGEpt6pcUQHDhwg2dy4H3P+/PNPkitWrEjyt99+S/LIkSNJ5vOUvQm5eYwIn3OLn/vmtY/hw4cb7WPGjBkk8zEf/Pw7x8//Fy1a1OT2WT2GzO0vNXzs04QJE0jm62AVKlSI5KioKJL5OJjcfIzkBrwWOmjQIJJTW0MqpVu3bpHcv39/ko8ePUqyXq8nmddBtaj/yDgiIYQQeYp0REIIITQlHZEQQghN5ekakaWlJcl2dnYk87nOlixZYnJ/Q4YMIZnPXWdtbU0yr0doURPi3uT5f16jO3LkiMn787n++Llwfu6czwsHAPfv3yeZz3dXokQJknm9hNeE/P39SeY1KI6/BnxuOj5/YXrmz+P4e8gf8/Tp0yR7eHiQbG7uOakRmcZrNi9fvszQ/e/evUsyX1/oyZMnJJsbG8fbAxh/1mR3LVNqREIIIfIU6YiEEEJoSjoiIYQQmsrTc83xc7cPHjwgmdd0OH6e9caNGyQ/e/bMZH7X9OvXj+RixYqRbK5G9P7775O8f/9+kl1cXDLfuP/iNSE+1os/Jj9fz+cL/OCDD0iuX78+yaGhoSS3aNGCZD6Gp169eqk1m/D19SW5Xbt2JPOaECdzz2UNrzXzz4nUajYpXb9+nWQ+di2j0lN7zsm6X3aQX0RCCCE0JR2REEIITUlHJIQQQlN5ukZkToUKFUz+ndcDfvnll5xszlvvu+++I3n27Nkk85oRP5f+008/kdy2bVuS+VyB6eHs7EzyvXv3SObzhPF6y507d0iuU6eOycfjNSGOry0THR1tcns+vx4AlCxZkuS1a9eSzNdA4msknTlzhuSaNWuabIOgatWqRbK5ud54DWfq1KnZ3qa3nfwiEkIIoSnpiIQQQmhKOiIhhBCako5ICCGEpvL0pKfm8KfOi46HDx8muWXLljneppyWkwPbLl++TPLjx49J5pM38rbwSWRDQkJIDgsLI9nNzc2oDRcuXCCZv4cTJ04kuUyZMiSfOHGC5IsXL5LMB8TyQbj8/wmf5DQ+Pp7knj17kswL2/zxAeOJUvlxyQfl8oXxhg4dSvKCBQtIlklPM+bRo0ck88UYX7x4QbKtrS3JSUlJOdKunCSTngohhMhTpCMSQgihKemIhBBCaOqdHtD6xx9/kMwHuFatWvVNNuetxxdc4xo1akTy8ePHTW5fuHBhk383NxgUABYuXEgyr4fwuhOvU/EJLvlgUF4j4nhNiOML+fFBwalNaFmxYkWS+UBfXof597//TTJfmE1kjJOTE8n58pn+Ps//ntsnINWC/CISQgihKemIhBBCaEo6IiGEEJp6p2tE/Pp/ztHRkeQPP/yQ5F27dmV7m95m/Ny3l5cXyXzRN14j6tWrF8kHDx4kmS9Sl54aEX+P+KSmt2/fNruPlJo1a0ZyVsfF8DEmAwcOJHnEiBFG97l69SrJrq6uJh+DT8RqboFCQfEaDx+3xSfv/fvvv0nmx6C5SVLfRfKLSAghhKakIxJCCKEp6YiEEEJo6p2uEZm7np/XD3iNQlC9e/cmmY9vKVCggMn7d+/eneSNGzdmuU18brnz58+TvHTp0iztn48rat68Ocnm6pB8oTyO18lSw+tWvAbEa0R79+41u0/xP/y45bVPPociX+CRjw3jNSepGckvIiGEEBqTjkgIIYSmpCMSQgihqXe6RjRo0CCSAwICSP7kk09IvnPnTk436a3GazoNGzYk+bPPPiOZn0v/+uuvSebrG/Ht+ZgbALC3tyd5/fr1JPv5+RndJ6VVq1aRzOteERERJPOxUJUrVybZXI3InE6dOhnddujQIZI9PT1J5jWh2NhYkj/++GOS165dS7LMhUY9f/6cZH6M8PGGn376Kcm8tiw1IWPyi0gIIYSmpCMSQgihKemIhBBCaEqn0nlCOC+uNf8uysnz//wY4evmXL9+PUP7Gzx4MMlbtmwhOTeO6woMDCSZr290+vRpkosUKUIyr+c4OzsbPUbjxo1JtrKyIvnly5ckmxuPVb16dZLDw8NNbp8VeeFzJH/+/CTzmk9G/49ZWNBSfVJSksntc0MNL7vbIL+IhBBCaEo6IiGEEJqSjkgIIYSm0l0jEkIIIXKC/CISQgihKemIhBBCaEo6IiGEEJqSjkgIIYSmpCMSQgihKemIhBBCaEo6IiGEEJqSjkgIIYSmpCMSQgihqf8HKTGNOugZOxwAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 500x300 with 3 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 152/10: 100%|██████████| 469/469 [00:05<00:00, 78.57it/s, loss=4.96e-8]\n",
      "Epoch 153/10: 100%|██████████| 469/469 [00:05<00:00, 78.74it/s, loss=9.28e-8]\n",
      "Epoch 154/10: 100%|██████████| 469/469 [00:06<00:00, 75.28it/s, loss=2.28e-7]\n",
      "Epoch 155/10: 100%|██████████| 469/469 [00:05<00:00, 80.98it/s, loss=5.24e-8]\n",
      "Epoch 156/10: 100%|██████████| 469/469 [00:05<00:00, 79.80it/s, loss=1.85e-7]\n",
      "Epoch 157/10: 100%|██████████| 469/469 [00:05<00:00, 78.86it/s, loss=8.45e-9]\n",
      "Epoch 158/10: 100%|██████████| 469/469 [00:05<00:00, 80.23it/s, loss=4.17e-7]\n",
      "Epoch 159/10: 100%|██████████| 469/469 [00:05<00:00, 79.78it/s, loss=5.8e-8]\n",
      "Epoch 160/10: 100%|██████████| 469/469 [00:05<00:00, 78.24it/s, loss=1.45e-7]\n",
      "Epoch 161/10: 100%|██████████| 469/469 [00:05<00:00, 80.74it/s, loss=2.52e-8]\n",
      "Epoch 162/10: 100%|██████████| 469/469 [00:05<00:00, 80.76it/s, loss=2.77e-10]\n",
      "Epoch 163/10: 100%|██████████| 469/469 [00:05<00:00, 80.39it/s, loss=1.37e-7]\n",
      "Epoch 164/10: 100%|██████████| 469/469 [00:05<00:00, 83.63it/s, loss=3.18e-10]\n",
      "Epoch 165/10: 100%|██████████| 469/469 [00:05<00:00, 80.68it/s, loss=1.12e-6]\n",
      "Epoch 166/10: 100%|██████████| 469/469 [00:05<00:00, 83.20it/s, loss=2.41e-7]\n",
      "Epoch 167/10: 100%|██████████| 469/469 [00:05<00:00, 79.30it/s, loss=3.45e-7]\n",
      "Epoch 168/10: 100%|██████████| 469/469 [00:05<00:00, 80.40it/s, loss=2.83e-7]\n",
      "Epoch 169/10: 100%|██████████| 469/469 [00:05<00:00, 80.38it/s, loss=4.8e-14]\n",
      "Epoch 170/10: 100%|██████████| 469/469 [00:05<00:00, 78.61it/s, loss=5.57e-7]\n",
      "Epoch 171/10: 100%|██████████| 469/469 [00:06<00:00, 77.64it/s, loss=6.37e-8]\n",
      "Epoch 172/10: 100%|██████████| 469/469 [00:06<00:00, 76.77it/s, loss=1.11e-7]\n",
      "Epoch 173/10: 100%|██████████| 469/469 [00:05<00:00, 81.67it/s, loss=2.81e-8]\n",
      "Epoch 174/10: 100%|██████████| 469/469 [00:05<00:00, 81.11it/s, loss=1.58e-7]\n",
      "Epoch 175/10: 100%|██████████| 469/469 [00:05<00:00, 81.36it/s, loss=1.01e-7]\n",
      "Epoch 176/10: 100%|██████████| 469/469 [00:05<00:00, 82.70it/s, loss=1.29e-8]\n",
      "Epoch 177/10: 100%|██████████| 469/469 [00:05<00:00, 79.49it/s, loss=2.81e-8]\n",
      "Epoch 178/10: 100%|██████████| 469/469 [00:05<00:00, 78.62it/s, loss=1.19e-8]\n",
      "Epoch 179/10: 100%|██████████| 469/469 [00:05<00:00, 81.38it/s, loss=6.31e-7]\n",
      "Epoch 180/10: 100%|██████████| 469/469 [00:05<00:00, 82.24it/s, loss=1.16e-10]\n",
      "Epoch 181/10: 100%|██████████| 469/469 [00:05<00:00, 79.32it/s, loss=1.17e-7]\n",
      "Epoch 182/10: 100%|██████████| 469/469 [00:05<00:00, 80.56it/s, loss=3.49e-8]\n",
      "Epoch 183/10: 100%|██████████| 469/469 [00:05<00:00, 79.66it/s, loss=2.2e-8]\n",
      "Epoch 184/10: 100%|██████████| 469/469 [00:05<00:00, 80.78it/s, loss=2.83e-7]\n",
      "Epoch 185/10: 100%|██████████| 469/469 [00:05<00:00, 80.65it/s, loss=9.45e-9]\n",
      "Epoch 186/10: 100%|██████████| 469/469 [00:05<00:00, 81.52it/s, loss=1.56e-7]\n",
      "Epoch 187/10: 100%|██████████| 469/469 [00:05<00:00, 81.44it/s, loss=9.47e-8]\n",
      "Epoch 188/10: 100%|██████████| 469/469 [00:05<00:00, 83.53it/s, loss=1.11e-7]\n",
      "Epoch 189/10: 100%|██████████| 469/469 [00:05<00:00, 81.55it/s, loss=3.18e-8]\n",
      "Epoch 190/10: 100%|██████████| 469/469 [00:06<00:00, 74.95it/s, loss=3.35e-8]\n",
      "Epoch 191/10: 100%|██████████| 469/469 [00:05<00:00, 78.54it/s, loss=2.74e-8]\n",
      "Epoch 192/10: 100%|██████████| 469/469 [00:05<00:00, 80.46it/s, loss=5.31e-10]\n",
      "Epoch 193/10: 100%|██████████| 469/469 [00:06<00:00, 78.07it/s, loss=1.86e-7]\n",
      "Epoch 194/10: 100%|██████████| 469/469 [00:05<00:00, 83.03it/s, loss=4.09e-7]\n",
      "Epoch 195/10: 100%|██████████| 469/469 [00:05<00:00, 78.87it/s, loss=4.68e-10]\n",
      "Epoch 196/10: 100%|██████████| 469/469 [00:05<00:00, 80.12it/s, loss=5.65e-8]\n",
      "Epoch 197/10: 100%|██████████| 469/469 [00:06<00:00, 78.04it/s, loss=3.49e-7]\n",
      "Epoch 198/10: 100%|██████████| 469/469 [00:05<00:00, 80.59it/s, loss=7.1e-8]\n",
      "Epoch 199/10: 100%|██████████| 469/469 [00:05<00:00, 80.28it/s, loss=4.23e-7]\n",
      "Epoch 200/10: 100%|██████████| 469/469 [00:05<00:00, 80.35it/s, loss=1.1e-8]\n",
      "100%|██████████| 200/200 [19:27<00:00,  5.84s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training complete. Model saved to /data/schoudh8/ecmmd/data/mnist/chckpoints\n"
     ]
    }
   ],
   "source": [
    "## Comment out the next line only for training the model\n",
    "# for epoch in tqdm(range(NUM_EPOCH)):\n",
    "#     model.train()\n",
    "#     pbar = tqdm(train_loader, desc=f\"Epoch {epoch+1}/{args.epochs}\")\n",
    "#     for batch in pbar:\n",
    "#         x = batch['x'].to(device)  # clean\n",
    "#         c = batch['c'].to(device)  # conditioning noisy measurement (kept as [-1,1])\n",
    "#         eta = batch['eta'].squeeze(1).to(device)  # additional noise for ecmmd\n",
    "        \n",
    "#         # print(\"eta:\", eta[:10])\n",
    "        \n",
    "#         optimizer.zero_grad()\n",
    "#         denoised_train_images = model(c, eta)\n",
    "#         loss = (ECMMD(\n",
    "#                     denoised_train_images.reshape(denoised_train_images.shape[0], -1),\n",
    "#                     x.reshape(x.shape[0], -1),\n",
    "#                     c.reshape(c.shape[0], -1),\n",
    "#                     kernel=gaussian_kernel,\n",
    "#                     neighbors=NEIGHBORS\n",
    "#                 ) ** 2)\n",
    "#         loss.backward()\n",
    "#         optimizer.step()\n",
    "#         pbar.set_postfix({'loss': loss.item()})\n",
    "        \n",
    "#     if epoch % 50 == 0:\n",
    "#         plot_idx = 0\n",
    "#         with torch.no_grad():\n",
    "#             test_noisy = c[plot_idx].unsqueeze(0).to(device)\n",
    "#             test_eta_batch = eta[plot_idx].unsqueeze(0).to(device)\n",
    "#             temp_img = model(test_noisy, test_eta_batch).cpu()\n",
    "            \n",
    "#             plt.figure(figsize=(5, 3))\n",
    "#             plt.subplot(1, 3, 1)\n",
    "#             plt.imshow(x[plot_idx].cpu().squeeze(), cmap='gray')\n",
    "#             plt.title('Actual Image')\n",
    "#             plt.axis('off')\n",
    "            \n",
    "#             plt.subplot(1, 3, 2)\n",
    "#             plt.imshow(c[plot_idx].cpu().squeeze(), cmap='gray')\n",
    "#             plt.title('Noisy Image')\n",
    "#             plt.axis('off')\n",
    "            \n",
    "#             plt.subplot(1, 3, 3)\n",
    "#             plt.imshow(temp_img.cpu().squeeze(), cmap='gray')\n",
    "#             plt.title('Denoised Image')\n",
    "#             plt.axis('off')\n",
    "            \n",
    "#             plt.show()\n",
    "\n",
    "# save_dir = \"/data/schoudh8/ecmmd/data/mnist/chckpoints\"\n",
    "# torch.save({'model_state': model.state_dict()}, os.path.join(save_dir, f'model_final_ECMMD.pt'))\n",
    "# print(\"Training complete. Model saved to\", save_dir)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "0daa3396",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_1774544/208168450.py:3: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  model.load_state_dict(torch.load(ckpt_path, map_location=device)['model_state'])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = DenoiseCNN().to(device)\n",
    "ckpt_path = os.path.join(args.save_dir, 'model_final_ECMMD.pt')\n",
    "model.load_state_dict(torch.load(ckpt_path, map_location=device)['model_state'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "6c62c66e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/schoudh8/miniconda3/envs/PLoGenv/lib/python3.12/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: Metric `InceptionScore` will save all extracted features in buffer. For large datasets this may lead to large memory footprint.\n",
      "  warnings.warn(*args, **kwargs)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'MSE': 0.128294016122818, 'PSNR': 8.92179866553973, 'SSIM': 0.7181134343147277, 'FID': 0.008006079122424126, 'Inception Score (mean)': 2.411038875579834, 'GenTime (s/batch)': 0.07215308666229248, 'GenTime (s/img)': 0.00056369598954916}\n"
     ]
    }
   ],
   "source": [
    "model.eval()\n",
    "metrics = evaluate_model(model, val_loader, schedule, device, args, num_batches=50, model_type=\"ecmmd\")\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c1a0bee5",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "PLoGenv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
