{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cac679cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.sparse import csc_matrix, eye\n",
    "from scipy.sparse.linalg import expm_multiply\n",
    "import time\n",
    "import random\n",
    "\n",
    "%matplotlib inline\n",
    "seed = 42\n",
    "random.seed(seed)\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.manual_seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "223458bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build Oracle\n",
    "class OracleScoreModel:\n",
    "    def __init__(self, p0, T, device, steps=1001):\n",
    "        print(\"Oracle Initialization\")\n",
    "        self.grid_size = p0.shape[0]\n",
    "        self.S = self.grid_size * self.grid_size\n",
    "        self.T = T\n",
    "        self.times_cpu = torch.linspace(0, T, steps)\n",
    "        self.device = device\n",
    "        p0_flat = p0.flatten()\n",
    "        Q = self._create_rate_matrix()\n",
    "        print(\"Computing True Distribution p_t...\")\n",
    "        self.p_t_flat = self._solve_forward_process(Q, p0_flat)\n",
    "        print(\"Done\")\n",
    "\n",
    "    def _create_rate_matrix(self):\n",
    "        S = self.S; size = self.grid_size; row_indices, col_indices, data = [], [], []\n",
    "        for r in range(size):\n",
    "            for c in range(size):\n",
    "                idx = r * size + c\n",
    "                neighbors = []\n",
    "                if r > 0: neighbors.append((r - 1) * size + c)\n",
    "                if r < size - 1: neighbors.append((r + 1) * size + c)\n",
    "                if c > 0: neighbors.append(r * size + (c - 1))\n",
    "                if c < size - 1: neighbors.append(r * size + (c + 1))\n",
    "                for neighbor_idx in neighbors:\n",
    "                    row_indices.append(neighbor_idx); col_indices.append(idx); data.append(1.0)\n",
    "                row_indices.append(idx); col_indices.append(idx); data.append(-len(neighbors))\n",
    "        return csc_matrix((data, (row_indices, col_indices)), shape=(S, S))\n",
    "\n",
    "    def _solve_forward_process(self, Q, p0_flat):\n",
    "        p_t_list = []\n",
    "        p_current = p0_flat.cpu().numpy()\n",
    "        dt = (self.times_cpu[1] - self.times_cpu[0]).item()\n",
    "        exp_dtQ = expm_multiply(Q * dt, eye(self.S))\n",
    "        for _ in self.times_cpu:\n",
    "            p_t_list.append(torch.from_numpy(p_current))\n",
    "            p_current = exp_dtQ.dot(p_current)\n",
    "        return torch.stack(p_t_list, dim=0).float().to(self.device)\n",
    "    \n",
    "    def get_intensities(self, t, current_states_flat):\n",
    "        time_indices = torch.searchsorted(self.times_cpu, t.cpu()).to(self.device)\n",
    "        p_t = self.p_t_flat[time_indices]\n",
    "        p_t_x = p_t.gather(1, current_states_flat.unsqueeze(1)).squeeze(1)\n",
    "        r, c = current_states_flat // self.grid_size, current_states_flat % self.grid_size\n",
    "        neighbors, masks = [], []; moves = [(-1, 0), (1, 0), (0, -1), (0, 1)]\n",
    "        for dr, dc in moves:\n",
    "            nr, nc = r + dr, c + dc\n",
    "            mask = (nr >= 0) & (nr < self.grid_size) & (nc >= 0) & (nc < self.grid_size)\n",
    "            neighbor_idx = torch.where(mask, (nr * self.grid_size + nc), current_states_flat)\n",
    "            neighbors.append(neighbor_idx); masks.append(mask)\n",
    "        neighbors = torch.stack(neighbors, dim=1); masks = torch.stack(masks, dim=1)\n",
    "        p_t_y = p_t.gather(1, neighbors)\n",
    "        intensities = p_t_y / (p_t_x.unsqueeze(1) + 1e-9)\n",
    "        intensities[~masks] = 0\n",
    "        return intensities\n",
    "\n",
    "    def get_scores(self, t, current_states_flat, proposed_states_flat):\n",
    "        if isinstance(t, torch.Tensor): t_cpu = t.cpu()\n",
    "        else: t_cpu = torch.tensor(t)\n",
    "        time_idx = torch.searchsorted(self.times_cpu, t_cpu).to(self.device)\n",
    "        p_t = self.p_t_flat[time_idx]\n",
    "        p_t_current = p_t.gather(0, current_states_flat)\n",
    "        p_t_proposed = p_t.gather(0, proposed_states_flat)\n",
    "        return p_t_proposed / (p_t_current + 1e-9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3fc15829",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Helper Functions\n",
    "def create_ring_distribution(size, r_inner, r_outer):\n",
    "    center = size / 2\n",
    "    x = np.arange(size)\n",
    "    y = np.arange(size)\n",
    "    xx, yy = np.meshgrid(x, y)\n",
    "    dist = np.sqrt((xx - center)**2 + (yy - center)**2)\n",
    "    \n",
    "    ring = np.zeros((size, size))\n",
    "    ring[(dist >= r_inner) & (dist < r_outer)] = 1.0\n",
    "    \n",
    "    if ring.sum() == 0:\n",
    "        raise ValueError(\"Parameter Setting Error\")\n",
    "        \n",
    "    return torch.from_numpy(ring / ring.sum()).float()\n",
    "\n",
    "def calculate_kl_divergence(p_target, samples, grid_size):\n",
    "    q_hist, _, _ = np.histogram2d(\n",
    "        samples[:, 1], samples[:, 0], \n",
    "        bins=grid_size, range=[[-0.5, grid_size-0.5], [-0.5, grid_size-0.5]]\n",
    "    )\n",
    "    q_generated = q_hist / q_hist.sum()\n",
    "    p_target_tensor = torch.from_numpy(p_target)\n",
    "    q_generated_tensor = torch.from_numpy(q_generated)\n",
    "    epsilon = 1e-12\n",
    "    p = p_target_tensor.flatten() + epsilon\n",
    "    q = q_generated_tensor.flatten() + epsilon\n",
    "    return torch.sum(p * (torch.log(p) - torch.log(q))).item()\n",
    "\n",
    "def plot_distributions_comparison(p0, samples_parallel, samples_seq, grid_size):\n",
    "    fig, axes = plt.subplots(1, 3, figsize=(18, 6))\n",
    "    \n",
    "    axes[0].imshow(p0, cmap='viridis', origin='lower')\n",
    "    axes[0].set_title(\"Target\")\n",
    "    \n",
    "    hist_p, _, _ = np.histogram2d(samples_parallel[:, 1], samples_parallel[:, 0], bins=grid_size, range=[[-0.5, grid_size-0.5], [-0.5, grid_size-0.5]])\n",
    "    axes[1].imshow(hist_p, cmap='viridis', origin='lower')\n",
    "    axes[1].set_title(\"Parallel+corr\")\n",
    "    \n",
    "    hist_s, _, _ = np.histogram2d(samples_seq[:, 1], samples_seq[:, 0], bins=grid_size, range=[[-0.5, grid_size-0.5], [-0.5, grid_size-0.5]])\n",
    "    axes[2].imshow(hist_s, cmap='viridis', origin='lower')\n",
    "    axes[2].set_title(\"Seq+corr\")\n",
    "    \n",
    "    for ax in axes:\n",
    "        ax.set_xticks([]), ax.set_yticks([])\n",
    "        \n",
    "    plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1db15a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Sequential\n",
    "class SequentialSamplerWithCorrector:\n",
    "    def __init__(self, model, grid_size, N, M, K_c, device):\n",
    "        self.model, self.grid_size, self.N, self.M, self.K_c, self.device = model, grid_size, N, M, K_c, device\n",
    "        self.T, self.h, self.dt = model.T, model.T / N, model.T / N / M\n",
    "        self.jump_vectors = torch.tensor([[-1, 0], [1, 0], [0, -1], [0, 1]], dtype=torch.float32, device=self.device)\n",
    "    def _run_corrector(self, y_pred, t):\n",
    "        y_current = y_pred\n",
    "        for _ in range(self.K_c):\n",
    "            B = y_current.shape[0]\n",
    "            directions = torch.randint(0, 4, (B,), device=self.device)\n",
    "            y_proposed = (y_current + self.jump_vectors[directions]).clamp(0, self.grid_size - 1)\n",
    "            y_current_flat = (y_current[:, 0] * self.grid_size + y_current[:, 1]).long()\n",
    "            y_proposed_flat = (y_proposed[:, 0] * self.grid_size + y_proposed[:, 1]).long()\n",
    "            changed_mask = (y_current_flat != y_proposed_flat)\n",
    "            acceptance_ratio = torch.ones(B, device=self.device)\n",
    "            if torch.any(changed_mask):\n",
    "                scores = self.model.get_scores(t, y_current_flat[changed_mask], y_proposed_flat[changed_mask])\n",
    "                acceptance_ratio[changed_mask] = scores\n",
    "            A = torch.min(torch.ones_like(acceptance_ratio), acceptance_ratio)\n",
    "            mask = torch.rand_like(A) < A\n",
    "            y_current = torch.where(mask.unsqueeze(1), y_proposed, y_current)\n",
    "        return y_current\n",
    "    def sample(self, batch_size=64):\n",
    "        y_t_n = torch.randint(0, self.grid_size, (batch_size, 2), device=self.device).float()\n",
    "        for n in range(self.N - 1, -1, -1):\n",
    "            t_start_block = n * self.h\n",
    "            t_end_block = t_start_block + self.h\n",
    "            print(f\"\\n--- Seq Sampling: Processing Block {self.N-n}/{self.N} ---\")\n",
    "            y_pred = y_t_n\n",
    "            time_grid = torch.linspace(t_end_block, t_start_block + self.dt, self.M, device=self.device)\n",
    "            for i, t in enumerate(time_grid):\n",
    "                if (i+1) % 10 == 0: \n",
    "                    print(f\"\\r  Predictor step {i+1}/{self.M}\", end=\"\")\n",
    "                current_states_flat = (y_pred[:, 0] * self.grid_size + y_pred[:, 1]).long()\n",
    "                intensities = self.model.get_intensities(t.expand(batch_size), current_states_flat)\n",
    "                rates = intensities * self.dt\n",
    "                num_jumps = torch.poisson(rates)\n",
    "                delta_y = torch.sum(self.jump_vectors.unsqueeze(0) * num_jumps.unsqueeze(-1), dim=1)\n",
    "                y_pred = (y_pred + delta_y).clamp(0, self.grid_size - 1)\n",
    "            print(\"\\n  Running Corrector \")\n",
    "            y_t_n = self._run_corrector(y_pred, t_start_block)\n",
    "        print(\"\\n Seq Done\")\n",
    "        return y_t_n.cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd95b16d",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ParallelSamplerWithCorrector:\n",
    "    def __init__(self, model, grid_size, N, M, K_p, K_c, device): #K_p: Picard Depth\n",
    "        self.model, self.grid_size, self.N, self.M, self.K_p, self.K_c, self.device = model, grid_size, N, M, K_p, K_c, device\n",
    "        self.T, self.h, self.epsilon = model.T, model.T / N, model.T / N / M\n",
    "        self.jump_vectors = torch.tensor([[-1, 0], [1, 0], [0, -1], [0, 1]], dtype=torch.float32, device=self.device)\n",
    "        \n",
    "        # Pre-allocate reusable tensors for better memory efficiency\n",
    "        self._ones_for_min = None  # Will be initialized based on batch size\n",
    "        \n",
    "        # Pre-compute time steps that don't change between iterations\n",
    "        self._time_offsets = torch.linspace(self.h, self.epsilon, M, device=self.device)\n",
    "\n",
    "    def _run_predictor(self, y_t_n, t_start_block):\n",
    "        B = y_t_n.shape[0]\n",
    "        \n",
    "        # Pre-allocate trajectory tensor once\n",
    "        y_k = y_t_n.unsqueeze(1).expand(-1, self.M + 1, -1).contiguous()\n",
    "        \n",
    "        # Pre-compute constants outside the loop\n",
    "        time_steps_base = t_start_block + self._time_offsets\n",
    "        time_steps_flat = time_steps_base.unsqueeze(0).expand(B, -1).reshape(-1)\n",
    "        \n",
    "        for _ in range(self.K_p): #Picard Iteration\n",
    "            # Use views instead of creating new tensors where possible\n",
    "            start_states = y_k[:, :-1, :]\n",
    "            \n",
    "            # Vectorized state flattening using torch operations\n",
    "            start_states_flat = (start_states[:, :, 0] * self.grid_size + start_states[:, :, 1]).long().view(-1)\n",
    "            \n",
    "            # Derive intensity of the entire block parallelly\n",
    "            intensities = self.model.get_intensities(time_steps_flat, start_states_flat)\n",
    "            \n",
    "            # Compute Poisson λ and sample parallelly - use in-place operation\n",
    "            rates = intensities.view(B, self.M, 4)\n",
    "            rates.mul_(self.epsilon)\n",
    "            num_jumps = torch.poisson(rates)\n",
    "            \n",
    "            # Compute jump vector and total jump for all small steps parallelly\n",
    "            # Use einsum for efficient tensor contraction\n",
    "            delta_y = torch.einsum('bmj,jd->bmd', num_jumps, self.jump_vectors)\n",
    "            \n",
    "            # Use in-place cumsum for memory efficiency\n",
    "            cumulative_jumps = delta_y.cumsum_(dim=1)\n",
    "            \n",
    "            # Update the entire trajectory in the block parallelly\n",
    "            # Reuse y_k tensor with in-place operations\n",
    "            y_k[:, 0, :] = y_t_n\n",
    "            y_k[:, 1:, :] = y_t_n.unsqueeze(1) + cumulative_jumps\n",
    "            \n",
    "            # In-place clamping\n",
    "            y_k.clamp_(0, self.grid_size - 1)\n",
    "            \n",
    "        return y_k[:, -1, :].clone()\n",
    "    \n",
    "    def _run_corrector(self, y_pred, t): #Same to Seq version\n",
    "        y_current = y_pred\n",
    "        B = y_current.shape[0]\n",
    "        \n",
    "        # Pre-allocate tensors that will be reused\n",
    "        if self._ones_for_min is None or self._ones_for_min.shape[0] != B:\n",
    "            self._ones_for_min = torch.ones(B, device=self.device)\n",
    "        \n",
    "        for _ in range(self.K_c):\n",
    "            # Propose a new state(random pick from 4 neighbour)\n",
    "            directions = torch.randint(0, 4, (B,), device=self.device)\n",
    "            \n",
    "            # Vectorized proposal computation with broadcasting\n",
    "            y_proposed = y_current + self.jump_vectors[directions]\n",
    "            y_proposed.clamp_(0, self.grid_size - 1)\n",
    "            \n",
    "            # Vectorized flattening using fused multiply-add\n",
    "            y_current_flat = (y_current[:, 0].long() * self.grid_size + y_current[:, 1].long())\n",
    "            y_proposed_flat = (y_proposed[:, 0].long() * self.grid_size + y_proposed[:, 1].long())\n",
    "            \n",
    "            # Compute acceptance rate A = min(1, p(y')/p(y)=score) only when state changed\n",
    "            changed_mask = (y_current_flat != y_proposed_flat)\n",
    "            \n",
    "            if torch.any(changed_mask):\n",
    "                # Initialize acceptance_ratio in-place\n",
    "                acceptance_ratio = self._ones_for_min.clone()\n",
    "                \n",
    "                # Compute scores only for changed states\n",
    "                scores = self.model.get_scores(t, y_current_flat[changed_mask], y_proposed_flat[changed_mask])\n",
    "                acceptance_ratio[changed_mask] = scores\n",
    "                \n",
    "                # Use torch.minimum for element-wise min (more efficient than torch.min)\n",
    "                A = torch.minimum(self._ones_for_min, acceptance_ratio)\n",
    "            else:\n",
    "                A = self._ones_for_min\n",
    "            \n",
    "            # Accept propose with probability A\n",
    "            mask = torch.rand(B, device=self.device) < A\n",
    "            \n",
    "            # Use masked_scatter for efficient conditional update\n",
    "            mask_expanded = mask.unsqueeze(1).expand_as(y_current)\n",
    "            y_current = torch.where(mask_expanded, y_proposed, y_current)\n",
    "            \n",
    "        return y_current\n",
    "    \n",
    "    def sample(self, batch_size=64):\n",
    "        # Use torch.rand for uniform distribution (slightly faster than randint for floats)\n",
    "        y_t_n = torch.rand(batch_size, 2, device=self.device) * self.grid_size\n",
    "        y_t_n = y_t_n.floor()\n",
    "        \n",
    "        # Pre-allocate progress string buffer to reduce string operations\n",
    "        for n in range(self.N - 1, -1, -1):\n",
    "            t_start_block = n * self.h\n",
    "            print(f\"\\r Para Sampling... Processing Block {self.N-n}/{self.N}\", end=\"\", flush=True)\n",
    "            \n",
    "            y_pred = self._run_predictor(y_t_n, t_start_block)\n",
    "            \n",
    "            y_t_n = self._run_corrector(y_pred, t_start_block)\n",
    "        \n",
    "        print(\"\\n Para Done\")\n",
    "        return y_t_n.cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97dd8dc1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Parameters\n",
    "GRID_SIZE = 32\n",
    "TOTAL_TIME = 3.0\n",
    "BATCH_SIZE = 4096\n",
    "\n",
    "# Shared\n",
    "N_BLOCKS = 60      \n",
    "STEPS_PER_BLOCK = 50 \n",
    "K_CORRECTOR = 0   \n",
    "\n",
    "# For Parallel\n",
    "K_PREDICTOR = 0   # Picard depth\n",
    "\n",
    "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(f\"Device: {DEVICE}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5769a470",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Experiment\n",
    "# Create target distribution and oracle\n",
    "p0_numpy = create_ring_distribution(GRID_SIZE, r_inner=GRID_SIZE*0.3, r_outer=GRID_SIZE*0.45).numpy()\n",
    "p0_tensor = torch.from_numpy(p0_numpy).to(DEVICE)\n",
    "oracle_model = OracleScoreModel(p0_tensor, T=TOTAL_TIME, device=DEVICE)\n",
    "\n",
    "# Run Para Sampler\n",
    "print(\"=\"*50)\n",
    "print(\"Running Para Sampling w/ corr...\")\n",
    "parallel_sampler = ParallelSamplerWithCorrector(\n",
    "    oracle_model, GRID_SIZE, N_BLOCKS, STEPS_PER_BLOCK, \n",
    "    K_PREDICTOR, K_CORRECTOR, DEVICE\n",
    ")\n",
    "start_p = time.time()\n",
    "samples_p = parallel_sampler.sample(BATCH_SIZE)\n",
    "end_p = time.time()\n",
    "time_p = end_p - start_p\n",
    "kl_p = calculate_kl_divergence(p0_numpy, samples_p, GRID_SIZE)\n",
    "\n",
    "# #Run Seq Sampler\n",
    "print(\"\\n\" + \"=\"*50)\n",
    "print(\"Running Seq Sampling w/ corr...\")\n",
    "sequential_sampler = SequentialSamplerWithCorrector(\n",
    "    oracle_model, GRID_SIZE, N_BLOCKS, STEPS_PER_BLOCK, K_CORRECTOR, DEVICE\n",
    ")\n",
    "start_s = time.time()\n",
    "samples_s = sequential_sampler.sample(BATCH_SIZE)\n",
    "end_s = time.time()\n",
    "time_s = end_s - start_s\n",
    "kl_s = calculate_kl_divergence(p0_numpy, samples_s, GRID_SIZE)\n",
    "\n",
    "#Result\n",
    "print(\"\\n\" + \"=\"*50)\n",
    "print(\"Result (Circle Distribution)\")\n",
    "print(\"-\"*50)\n",
    "print(f\"| {'Score':<18} | {'Parallel + corr':<25} | {'Seq + corr':<20} |\")\n",
    "print(f\"| {'-'*18} | {'-'*25} | {'-'*20} |\")\n",
    "print(f\"| {'Runtime':<18} | {time_p:<25.2f} | {time_s:<20.2f} |\")\n",
    "print(f\"| {'KL-Divergence':<18} | {kl_p:<25.4f} | {kl_s:<20.4f} |\")\n",
    "seq_stages_p = N_BLOCKS * (K_PREDICTOR + K_CORRECTOR)\n",
    "seq_stages_s = N_BLOCKS * (STEPS_PER_BLOCK + K_CORRECTOR)\n",
    "print(f\"| {'Total Seq Step':<18} | {seq_stages_p:<25} | {seq_stages_s:<20} |\")\n",
    "print(\"-\"*50)\n",
    "\n",
    "# 5. Visualization\n",
    "plot_distributions_comparison(p0_numpy, samples_p, samples_s, GRID_SIZE)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
