{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46b47cf6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.sparse import csc_matrix, eye\n",
    "from scipy.sparse.linalg import expm_multiply\n",
    "import time\n",
    "import random\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "seed = 42\n",
    "random.seed(seed)\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.manual_seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "861b0d18",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build Oracle\n",
    "class OracleScoreModel:\n",
    "    def __init__(self, p0, T, device, steps=1001):\n",
    "        print(\"Oracle Initialization\")\n",
    "        self.grid_size = p0.shape[0]\n",
    "        self.S = self.grid_size * self.grid_size\n",
    "        self.T = T\n",
    "        self.times_cpu = torch.linspace(0, T, steps)\n",
    "        self.device = device\n",
    "        p0_flat = p0.flatten()\n",
    "        Q = self._create_rate_matrix()\n",
    "        print(\"Computing True Distribution p_t...\")\n",
    "        self.p_t_flat = self._solve_forward_process(Q, p0_flat)\n",
    "        print(\"Done\")\n",
    "\n",
    "    def _create_rate_matrix(self):\n",
    "        S = self.S; size = self.grid_size; row_indices, col_indices, data = [], [], []\n",
    "        for r in range(size):\n",
    "            for c in range(size):\n",
    "                idx = r * size + c\n",
    "                neighbors = []\n",
    "                if r > 0: neighbors.append((r - 1) * size + c)\n",
    "                if r < size - 1: neighbors.append((r + 1) * size + c)\n",
    "                if c > 0: neighbors.append(r * size + (c - 1))\n",
    "                if c < size - 1: neighbors.append(r * size + (c + 1))\n",
    "                for neighbor_idx in neighbors:\n",
    "                    row_indices.append(neighbor_idx); col_indices.append(idx); data.append(1.0)\n",
    "                row_indices.append(idx); col_indices.append(idx); data.append(-len(neighbors))\n",
    "        return csc_matrix((data, (row_indices, col_indices)), shape=(S, S))\n",
    "\n",
    "    def _solve_forward_process(self, Q, p0_flat):\n",
    "        p_t_list = []\n",
    "        p_current = p0_flat.cpu().numpy()\n",
    "        dt = (self.times_cpu[1] - self.times_cpu[0]).item()\n",
    "        exp_dtQ = expm_multiply(Q * dt, eye(self.S))\n",
    "        for _ in self.times_cpu:\n",
    "            p_t_list.append(torch.from_numpy(p_current))\n",
    "            p_current = exp_dtQ.dot(p_current)\n",
    "        return torch.stack(p_t_list, dim=0).float().to(self.device)\n",
    "    \n",
    "    def get_intensities(self, t, current_states_flat):\n",
    "        time_indices = torch.searchsorted(self.times_cpu, t.cpu()).to(self.device)\n",
    "        p_t = self.p_t_flat[time_indices]\n",
    "        p_t_x = p_t.gather(1, current_states_flat.unsqueeze(1)).squeeze(1)\n",
    "        r, c = current_states_flat // self.grid_size, current_states_flat % self.grid_size\n",
    "        neighbors, masks = [], []; moves = [(-1, 0), (1, 0), (0, -1), (0, 1)]\n",
    "        for dr, dc in moves:\n",
    "            nr, nc = r + dr, c + dc\n",
    "            mask = (nr >= 0) & (nr < self.grid_size) & (nc >= 0) & (nc < self.grid_size)\n",
    "            neighbor_idx = torch.where(mask, (nr * self.grid_size + nc), current_states_flat)\n",
    "            neighbors.append(neighbor_idx); masks.append(mask)\n",
    "        neighbors = torch.stack(neighbors, dim=1); masks = torch.stack(masks, dim=1)\n",
    "        p_t_y = p_t.gather(1, neighbors)\n",
    "        intensities = p_t_y / (p_t_x.unsqueeze(1) + 1e-9)\n",
    "        intensities[~masks] = 0\n",
    "        return intensities\n",
    "\n",
    "    def get_scores(self, t, current_states_flat, proposed_states_flat):\n",
    "        if isinstance(t, torch.Tensor): t_cpu = t.cpu()\n",
    "        else: t_cpu = torch.tensor(t)\n",
    "        time_idx = torch.searchsorted(self.times_cpu, t_cpu).to(self.device)\n",
    "        p_t = self.p_t_flat[time_idx]\n",
    "        p_t_current = p_t.gather(0, current_states_flat)\n",
    "        p_t_proposed = p_t.gather(0, proposed_states_flat)\n",
    "        return p_t_proposed / (p_t_current + 1e-9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c636e69",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Helper Functions\n",
    "def create_checkerboard(size):\n",
    "    board = torch.zeros((size, size))\n",
    "    board[::2, ::2] = 1; board[1::2, 1::2] = 1\n",
    "    return board / board.sum()\n",
    "\n",
    "def calculate_kl_divergence(p_target, samples, grid_size):\n",
    "    q_hist, _, _ = np.histogram2d(samples[:, 1], samples[:, 0], bins=grid_size, range=[[-0.5, grid_size-0.5], [-0.5, grid_size-0.5]])\n",
    "    q_generated = q_hist / q_hist.sum()\n",
    "    p_target_tensor = torch.from_numpy(p_target); q_generated_tensor = torch.from_numpy(q_generated)\n",
    "    epsilon = 1e-12\n",
    "    p = p_target_tensor.flatten() + epsilon; q = q_generated_tensor.flatten() + epsilon\n",
    "    return torch.sum(p * (torch.log(p) - torch.log(q))).item()\n",
    "\n",
    "def plot_distributions_comparison(p0, samples_parallel, samples_seq, grid_size):\n",
    "    fig, axes = plt.subplots(1, 3, figsize=(18, 6))\n",
    "    axes[0].imshow(p0, cmap='viridis', origin='lower'); axes[0].set_title(\"Target Distribution\")\n",
    "    hist_p, _, _ = np.histogram2d(samples_parallel[:, 1], samples_parallel[:, 0], bins=grid_size, range=[[-0.5, grid_size-0.5], [-0.5, grid_size-0.5]])\n",
    "    axes[1].imshow(hist_p, cmap='viridis', origin='lower'); axes[1].set_title(\"Picard+Corr\")\n",
    "    hist_s, _, _ = np.histogram2d(samples_seq[:, 1], samples_seq[:, 0], bins=grid_size, range=[[-0.5, grid_size-0.5], [-0.5, grid_size-0.5]])\n",
    "    axes[2].imshow(hist_s, cmap='viridis', origin='lower'); axes[2].set_title(\"Seq+Corr\")\n",
    "    for ax in axes: ax.set_xticks([]), ax.set_yticks([])\n",
    "    plt.tight_layout(); plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96444f2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Sequential\n",
    "class SequentialSamplerWithCorrector:\n",
    "    def __init__(self, model, grid_size, N, M, K_c, device):\n",
    "        self.model = model\n",
    "        self.grid_size = grid_size\n",
    "        self.N = N # num of block\n",
    "        self.M = M # num of small block\n",
    "        self.K_c = K_c # num of corrector step\n",
    "        self.T = model.T #total time interval\n",
    "        self.h = self.T / N  #big step size\n",
    "        self.dt = self.h / M # small stepsize\n",
    "        self.device = device\n",
    "        self.jump_vectors = torch.tensor([[-1, 0], [1, 0], [0, -1], [0, 1]], dtype=torch.float32, device=self.device)\n",
    "\n",
    "    def _run_corrector(self, y_pred, t):#Metropolis-Hastings\n",
    "        y_current = y_pred\n",
    "        for _ in range(self.K_c):\n",
    "            B = y_current.shape[0]\n",
    "            directions = torch.randint(0, 4, (B,), device=self.device)\n",
    "\n",
    "            #Propose a state from Q\n",
    "            y_proposed = (y_current + self.jump_vectors[directions]).clamp(0, self.grid_size - 1)\n",
    "            y_current_flat = (y_current[:, 0] * self.grid_size + y_current[:, 1]).long()\n",
    "            y_proposed_flat = (y_proposed[:, 0] * self.grid_size + y_proposed[:, 1]).long()\n",
    "            changed_mask = (y_current_flat != y_proposed_flat)\n",
    "            acceptance_ratio = torch.ones(B, device=self.device)\n",
    "\n",
    "            if torch.any(changed_mask):\n",
    "                # Compute Score as acceptance ratio (Only valid when Q is symmetric)\n",
    "                scores = self.model.get_scores(t, y_current_flat[changed_mask], y_proposed_flat[changed_mask])\n",
    "                acceptance_ratio[changed_mask] = scores\n",
    "\n",
    "            # Compute final acceptance rate\n",
    "            A = torch.min(torch.ones_like(acceptance_ratio), acceptance_ratio)\n",
    "            \n",
    "            # Accept only when smaller than A\n",
    "            mask = torch.rand_like(A) < A \n",
    "            y_current = torch.where(mask.unsqueeze(1), y_proposed, y_current)\n",
    "        return y_current\n",
    "\n",
    "    def sample(self, batch_size=64):\n",
    "        y_t_n = torch.randint(0, self.grid_size, (batch_size, 2), device=self.device).float()\n",
    "        \n",
    "        for n in range(self.N - 1, -1, -1):\n",
    "            t_start_block = n * self.h\n",
    "            t_end_block = t_start_block + self.h\n",
    "            print(f\"\\n--- Seq sampling: Processing Block {self.N-n}/{self.N} ---\")\n",
    "            \n",
    "            # M-step Seq tau-leaping\n",
    "            y_pred = y_t_n\n",
    "            time_grid = torch.linspace(t_end_block, t_start_block + self.dt, self.M, device=self.device)\n",
    "            for i, t in enumerate(time_grid):\n",
    "                if (i+1) % 10 == 0:\n",
    "                    print(f\"\\r  Predictor Step {i+1}/{self.M}\", end=\"\")\n",
    "                # Derive current intensity μ (Provided by Orcale)\n",
    "                current_states_flat = (y_pred[:, 0] * self.grid_size + y_pred[:, 1]).long()\n",
    "                intensities = self.model.get_intensities(t.expand(batch_size), current_states_flat)\n",
    "                # Compute λ for Poisson\n",
    "                rates = intensities * self.dt\n",
    "                # Poisson Sampling: Compute num of jump\n",
    "                num_jumps = torch.poisson(rates)\n",
    "                # Compute total jump vector\n",
    "                delta_y = torch.sum(self.jump_vectors.unsqueeze(0) * num_jumps.unsqueeze(-1), dim=1)\n",
    "                # Update states\n",
    "                y_pred = (y_pred + delta_y).clamp(0, self.grid_size - 1)\n",
    "            \n",
    "            print(\"\\n Running Corrector\")\n",
    "            y_t_n = self._run_corrector(y_pred, t_start_block)\n",
    "            \n",
    "        print(\"\\n Seq Done\")\n",
    "        return y_t_n.cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f0779b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Parallel\n",
    "class ParallelSamplerWithCorrector:\n",
    "    def __init__(self, model, grid_size, N, M, K_p, K_c, device): #K_p: Picard Depth\n",
    "        self.model, self.grid_size, self.N, self.M, self.K_p, self.K_c, self.device = model, grid_size, N, M, K_p, K_c, device\n",
    "        self.T, self.h, self.epsilon = model.T, model.T / N, model.T / N / M\n",
    "        self.jump_vectors = torch.tensor([[-1, 0], [1, 0], [0, -1], [0, 1]], dtype=torch.float32, device=self.device)\n",
    "        \n",
    "        # Pre-allocate reusable tensors for better memory efficiency\n",
    "        self._ones_for_min = None  # Will be initialized based on batch size\n",
    "        \n",
    "        # Pre-compute time steps that don't change between iterations\n",
    "        self._time_offsets = torch.linspace(self.h, self.epsilon, M, device=self.device)\n",
    "\n",
    "    def _run_predictor(self, y_t_n, t_start_block):\n",
    "        B = y_t_n.shape[0]\n",
    "        \n",
    "        # Pre-allocate trajectory tensor once\n",
    "        y_k = y_t_n.unsqueeze(1).expand(-1, self.M + 1, -1).contiguous()\n",
    "        \n",
    "        # Pre-compute constants outside the loop\n",
    "        time_steps_base = t_start_block + self._time_offsets\n",
    "        time_steps_flat = time_steps_base.unsqueeze(0).expand(B, -1).reshape(-1)\n",
    "        \n",
    "        for _ in range(self.K_p): #Picard Iteration\n",
    "            # Use views instead of creating new tensors where possible\n",
    "            start_states = y_k[:, :-1, :]\n",
    "            \n",
    "            # Vectorized state flattening using torch operations\n",
    "            start_states_flat = (start_states[:, :, 0] * self.grid_size + start_states[:, :, 1]).long().view(-1)\n",
    "            \n",
    "            # Derive intensity of the entire block parallelly\n",
    "            intensities = self.model.get_intensities(time_steps_flat, start_states_flat)\n",
    "            \n",
    "            # Compute Poisson λ and sample parallelly - use in-place operation\n",
    "            rates = intensities.view(B, self.M, 4)\n",
    "            rates.mul_(self.epsilon)\n",
    "            num_jumps = torch.poisson(rates)\n",
    "            \n",
    "            # Compute jump vector and total jump for all small steps parallelly\n",
    "            # Use einsum for efficient tensor contraction\n",
    "            delta_y = torch.einsum('bmj,jd->bmd', num_jumps, self.jump_vectors)\n",
    "            \n",
    "            # Use in-place cumsum for memory efficiency\n",
    "            cumulative_jumps = delta_y.cumsum_(dim=1)\n",
    "            \n",
    "            # Update the entire trajectory in the block parallelly\n",
    "            # Reuse y_k tensor with in-place operations\n",
    "            y_k[:, 0, :] = y_t_n\n",
    "            y_k[:, 1:, :] = y_t_n.unsqueeze(1) + cumulative_jumps\n",
    "            \n",
    "            # In-place clamping\n",
    "            y_k.clamp_(0, self.grid_size - 1)\n",
    "            \n",
    "        return y_k[:, -1, :].clone()\n",
    "    \n",
    "    def _run_corrector(self, y_pred, t): #Same to Seq version\n",
    "        y_current = y_pred\n",
    "        B = y_current.shape[0]\n",
    "        \n",
    "        # Pre-allocate tensors that will be reused\n",
    "        if self._ones_for_min is None or self._ones_for_min.shape[0] != B:\n",
    "            self._ones_for_min = torch.ones(B, device=self.device)\n",
    "        \n",
    "        for _ in range(self.K_c):\n",
    "            # Propose a new state(random pick from 4 neighbour)\n",
    "            directions = torch.randint(0, 4, (B,), device=self.device)\n",
    "            \n",
    "            # Vectorized proposal computation with broadcasting\n",
    "            y_proposed = y_current + self.jump_vectors[directions]\n",
    "            y_proposed.clamp_(0, self.grid_size - 1)\n",
    "            \n",
    "            # Vectorized flattening using fused multiply-add\n",
    "            y_current_flat = (y_current[:, 0].long() * self.grid_size + y_current[:, 1].long())\n",
    "            y_proposed_flat = (y_proposed[:, 0].long() * self.grid_size + y_proposed[:, 1].long())\n",
    "            \n",
    "            # Compute acceptance rate A = min(1, p(y')/p(y)=score) only when state changed\n",
    "            changed_mask = (y_current_flat != y_proposed_flat)\n",
    "            \n",
    "            if torch.any(changed_mask):\n",
    "                # Initialize acceptance_ratio in-place\n",
    "                acceptance_ratio = self._ones_for_min.clone()\n",
    "                \n",
    "                # Compute scores only for changed states\n",
    "                scores = self.model.get_scores(t, y_current_flat[changed_mask], y_proposed_flat[changed_mask])\n",
    "                acceptance_ratio[changed_mask] = scores\n",
    "                \n",
    "                # Use torch.minimum for element-wise min (more efficient than torch.min)\n",
    "                A = torch.minimum(self._ones_for_min, acceptance_ratio)\n",
    "            else:\n",
    "                A = self._ones_for_min\n",
    "            \n",
    "            # Accept propose with probability A\n",
    "            mask = torch.rand(B, device=self.device) < A\n",
    "            \n",
    "            # Use masked_scatter for efficient conditional update\n",
    "            mask_expanded = mask.unsqueeze(1).expand_as(y_current)\n",
    "            y_current = torch.where(mask_expanded, y_proposed, y_current)\n",
    "            \n",
    "        return y_current\n",
    "    \n",
    "    def sample(self, batch_size=64):\n",
    "        # Use torch.rand for uniform distribution (slightly faster than randint for floats)\n",
    "        y_t_n = torch.rand(batch_size, 2, device=self.device) * self.grid_size\n",
    "        y_t_n = y_t_n.floor()\n",
    "        \n",
    "        # Pre-allocate progress string buffer to reduce string operations\n",
    "        for n in range(self.N - 1, -1, -1):\n",
    "            t_start_block = n * self.h\n",
    "            print(f\"\\r Para Sampling... Processing Block {self.N-n}/{self.N}\", end=\"\", flush=True)\n",
    "            \n",
    "            y_pred = self._run_predictor(y_t_n, t_start_block)\n",
    "            \n",
    "            y_t_n = self._run_corrector(y_pred, t_start_block)\n",
    "        \n",
    "        print(\"\\n Para Done\")\n",
    "        return y_t_n.cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "868e1309",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Parameters\n",
    "GRID_SIZE = 8\n",
    "TOTAL_TIME = 2.0\n",
    "BATCH_SIZE = 8192\n",
    "\n",
    "# Shared\n",
    "N_BLOCKS =  40    \n",
    "STEPS_PER_BLOCK = 50 \n",
    "K_CORRECTOR = 5    \n",
    "\n",
    "# For Parallel\n",
    "K_PREDICTOR = 5     # Picard depth\n",
    "\n",
    "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(f\"Device: {DEVICE}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8fadd023",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Experiment\n",
    "# Create target distribution and oracle\n",
    "p0_numpy = create_checkerboard(GRID_SIZE).numpy()\n",
    "p0_tensor = torch.from_numpy(p0_numpy).to(DEVICE)\n",
    "oracle_model = OracleScoreModel(p0_tensor, T=TOTAL_TIME, device=DEVICE)\n",
    "\n",
    "#Run Para Sampler\n",
    "print(\"=\"*50)\n",
    "print(\"Running Para Sampling w/ corr...\")\n",
    "parallel_sampler = ParallelSamplerWithCorrector(\n",
    "    oracle_model, GRID_SIZE, N_BLOCKS, STEPS_PER_BLOCK, \n",
    "    K_PREDICTOR, K_CORRECTOR, DEVICE\n",
    ")\n",
    "start_p = time.time()\n",
    "samples_p = parallel_sampler.sample(BATCH_SIZE)\n",
    "end_p = time.time()\n",
    "time_p = end_p - start_p\n",
    "kl_p = calculate_kl_divergence(p0_numpy, samples_p, GRID_SIZE)\n",
    "\n",
    "#Run Seq Sampler\n",
    "print(\"\\n\" + \"=\"*50)\n",
    "print(\"Running Seq Sampling w/ corr...\")\n",
    "sequential_sampler = SequentialSamplerWithCorrector(\n",
    "    oracle_model, GRID_SIZE, N_BLOCKS, STEPS_PER_BLOCK, K_CORRECTOR, DEVICE\n",
    ")\n",
    "start_s = time.time()\n",
    "samples_s = sequential_sampler.sample(BATCH_SIZE)\n",
    "end_s = time.time()\n",
    "time_s = end_s - start_s\n",
    "kl_s = calculate_kl_divergence(p0_numpy, samples_s, GRID_SIZE)\n",
    "\n",
    "seq_stages_p = N_BLOCKS * (K_PREDICTOR + K_CORRECTOR)\n",
    "seq_stages_s = N_BLOCKS * (STEPS_PER_BLOCK + K_CORRECTOR)\n",
    "\n",
    "\n",
    "\n",
    "#Result\n",
    "print(\"\\n\" + \"=\"*50)\n",
    "print(\"Result\")\n",
    "print(\"-\"*50)\n",
    "print(f\"| {'Score':<18} | {'Picard + corr':<25} | {'Seq + corr':<20} |\")\n",
    "print(f\"| {'-'*18} | {'-'*25} | {'-'*20} |\")\n",
    "print(f\"| {'Runtime':<18} | {time_p:<25.2f} | {time_s:<20.2f} |\")\n",
    "print(f\"| {'KL Divergence':<18} | {kl_p:<25.4f} | {kl_s:<20.4f} |\")\n",
    "# Compute total Seq steps\n",
    "print(f\"| {'Sequantial Step':<18} | {seq_stages_p:<25} | {seq_stages_s:<20} |\")\n",
    "print(\"-\"*50)\n",
    "\n",
    "# Visualization\n",
    "plot_distributions_comparison(p0_numpy, samples_p, samples_s, GRID_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd0fc317",
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_RUNS = 20  \n",
    "\n",
    "# Parameters\n",
    "GRID_SIZE = 8\n",
    "TOTAL_TIME = 2.0\n",
    "BATCH_SIZE = 8192\n",
    "N_BLOCKS = 40      \n",
    "STEPS_PER_BLOCK = 50  \n",
    "K_CORRECTOR = 5   \n",
    "K_PREDICTOR = 12     # Picard depth\n",
    "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(f\"Device: {DEVICE}\\n\")\n",
    "\n",
    "# Create target distribution and oracle\n",
    "p0_numpy = create_checkerboard(GRID_SIZE).numpy()\n",
    "p0_tensor = torch.from_numpy(p0_numpy).to(DEVICE)\n",
    "\n",
    "oracle_model = OracleScoreModel(p0_tensor, T=TOTAL_TIME, device=DEVICE)\n",
    "\n",
    "times_p_list = []\n",
    "kls_p_list = []\n",
    "times_s_list = []\n",
    "kls_s_list = []\n",
    "\n",
    "print(\"\\n\" + \"=\"*50)\n",
    "print(f\"Running Parallel Sampler w/ Corrector for {NUM_RUNS} runs...\")\n",
    "for i in range(NUM_RUNS):\n",
    "    print(f\"\\n--- Parallel Run {i+1}/{NUM_RUNS} ---\")\n",
    "    parallel_sampler = ParallelSamplerWithCorrector(\n",
    "    oracle_model, GRID_SIZE, N_BLOCKS, STEPS_PER_BLOCK, \n",
    "    K_PREDICTOR, K_CORRECTOR, DEVICE)\n",
    "    start_p = time.time()\n",
    "    samples_p = parallel_sampler.sample(BATCH_SIZE)\n",
    "    end_p = time.time()\n",
    "    \n",
    "    time_p = end_p - start_p\n",
    "    kl_p = calculate_kl_divergence(p0_numpy, samples_p, GRID_SIZE)\n",
    "    \n",
    "    times_p_list.append(time_p)\n",
    "    kls_p_list.append(kl_p)\n",
    "    print(f\"Run {i+1} finished. Time: {time_p:.2f}s, KL: {kl_p:.4f}\")\n",
    "\n",
    "print(\"\\n\" + \"=\"*50)\n",
    "print(f\"Running Sequential Sampler w/ Corrector for {NUM_RUNS} runs...\")\n",
    "for i in range(NUM_RUNS):\n",
    "    print(f\"\\n--- Sequential Run {i+1}/{NUM_RUNS} ---\")\n",
    "    sequential_sampler = SequentialSamplerWithCorrector(\n",
    "        oracle_model, GRID_SIZE, N_BLOCKS, STEPS_PER_BLOCK, K_CORRECTOR, DEVICE # K_PREDICTOR is not used by seq.\n",
    "    )\n",
    "    start_s = time.time()\n",
    "    samples_s = sequential_sampler.sample(BATCH_SIZE)\n",
    "    end_s = time.time()\n",
    "    \n",
    "    time_s = end_s - start_s\n",
    "    kl_s = calculate_kl_divergence(p0_numpy, samples_s, GRID_SIZE)\n",
    "    \n",
    "    times_s_list.append(time_s)\n",
    "    kls_s_list.append(kl_s)\n",
    "    print(f\"Run {i+1} finished. Time: {time_s:.2f}s, KL: {kl_s:.4f}\")\n",
    "\n",
    "avg_time_p = np.mean(times_p_list)\n",
    "std_time_p = np.std(times_p_list)\n",
    "avg_kl_p = np.mean(kls_p_list)\n",
    "std_kl_p = np.std(kls_p_list)\n",
    "\n",
    "avg_time_s = np.mean(times_s_list)\n",
    "std_time_s = np.std(times_s_list)\n",
    "avg_kl_s = np.mean(kls_s_list)\n",
    "std_kl_s = np.std(kls_s_list)\n",
    "\n",
    "seq_stages_p = N_BLOCKS * (K_PREDICTOR + K_CORRECTOR)\n",
    "seq_stages_s = N_BLOCKS * (STEPS_PER_BLOCK + K_CORRECTOR)\n",
    "\n",
    "result_p_time = f\"{avg_time_p:.2f} ± {std_time_p:.2f}\"\n",
    "result_p_kl = f\"{avg_kl_p:.4f} ± {std_kl_p:.4f}\"\n",
    "result_s_time = f\"{avg_time_s:.2f} ± {std_time_s:.2f}\"\n",
    "result_s_kl = f\"{avg_kl_s:.4f} ± {std_kl_s:.4f}\"\n",
    "\n",
    "print(\"\\n\" + \"=\"*50)\n",
    "print(f\"Final Averaged Results ({NUM_RUNS} runs)\")\n",
    "print(\"-\"*68)\n",
    "print(f\"| {'Score':<18} | {'Picard + Corrector':<28} | {'Sequential + Corrector':<24} |\")\n",
    "print(f\"| {'-'*18} | {'-'*28} | {'-'*24} |\")\n",
    "print(f\"| {'Runtime (s)':<18} | {result_p_time:<28} | {result_s_time:<24} |\")\n",
    "print(f\"| {'KL Divergence':<18} | {result_p_kl:<28} | {result_s_kl:<24} |\")\n",
    "print(f\"| {'Sequential Stages':<18} | {seq_stages_p:<28} | {seq_stages_s:<24} |\")\n",
    "print(\"-\"*68)\n",
    "\n",
    "print(\"\\nVisualizing the last run's results...\")\n",
    "plot_distributions_comparison(p0_numpy, samples_p, samples_s, GRID_SIZE)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
