{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba3f32c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Cell 1: Imports and Setup\n",
    "import torch\n",
    "import numpy as np\n",
    "import random\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.sparse import csc_matrix, eye\n",
    "from scipy.sparse.linalg import expm_multiply\n",
    "from scipy.stats import binom\n",
    "import time\n",
    "%matplotlib inline\n",
    "\n",
    "%matplotlib inline\n",
    "seed = 42\n",
    "random.seed(seed)\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.manual_seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e05f0c0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "class OracleScoreModelHypercube:\n",
    "    def __init__(self, p0_flat, d, T, device, steps=1001):\n",
    "        print(\"Oracle Initialization\")\n",
    "        self.d = d; self.S = 2**d; self.T = T\n",
    "        self.times_cpu = torch.linspace(0, T, steps); self.device = device\n",
    "        assert len(p0_flat) == self.S, \"dim error\"\n",
    "        Q = self._create_rate_matrix()\n",
    "        print(f\"Computing True Distribution p_t (size of state space: {self.S})\")\n",
    "        self.p_t_flat = self._solve_forward_process(Q, p0_flat)\n",
    "        print(\"Done\")\n",
    "\n",
    "    def _create_rate_matrix(self):\n",
    "        S = self.S; row_indices, col_indices, data = [], [], []\n",
    "        for i in range(self.S):\n",
    "            for j in range(self.d):\n",
    "                neighbor_idx = i ^ (1 << j)\n",
    "                row_indices.append(neighbor_idx); col_indices.append(i); data.append(1.0)\n",
    "            row_indices.append(i); col_indices.append(i); data.append(-float(self.d))\n",
    "        return csc_matrix((data, (row_indices, col_indices)), shape=(self.S, self.S))\n",
    "\n",
    "    def _solve_forward_process(self, Q, p0_flat):\n",
    "        p_t_list = []\n",
    "        p_current = p0_flat.cpu().numpy()\n",
    "        dt = (self.times_cpu[1] - self.times_cpu[0]).item()\n",
    "        exp_dtQ = expm_multiply(Q * dt, eye(self.S))\n",
    "        for _ in self.times_cpu:\n",
    "            p_t_list.append(torch.from_numpy(p_current))\n",
    "            p_current = exp_dtQ.dot(p_current)\n",
    "        return torch.stack(p_t_list, dim=0).float().to(self.device)\n",
    "    \n",
    "    def get_intensities(self, t, current_states_idx):\n",
    "        time_indices = torch.searchsorted(self.times_cpu, t.cpu()).to(self.device)\n",
    "        p_t = self.p_t_flat[time_indices]\n",
    "        p_t_x = p_t.gather(1, current_states_idx.unsqueeze(1)).squeeze(1)\n",
    "        neighbors = current_states_idx.unsqueeze(1) ^ (1 << torch.arange(self.d, device=self.device)).unsqueeze(0)\n",
    "        p_t_y = p_t.gather(1, neighbors)\n",
    "        intensities = p_t_y / (p_t_x.unsqueeze(1) + 1e-9)\n",
    "        return intensities\n",
    "\n",
    "    def get_scores(self, t, current_states_idx, proposed_states_idx):\n",
    "        if isinstance(t, torch.Tensor): t_cpu = t.cpu()\n",
    "        else: t_cpu = torch.tensor(t)\n",
    "        time_idx = torch.searchsorted(self.times_cpu, t_cpu).to(self.device)\n",
    "        p_t = self.p_t_flat[time_idx]\n",
    "        p_t_current = p_t.gather(0, current_states_idx)\n",
    "        p_t_proposed = p_t.gather(0, proposed_states_idx)\n",
    "        return p_t_proposed / (p_t_current + 1e-9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89dad534",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_subcube_distribution(d, k):\n",
    "\n",
    "    S = 2**d\n",
    "    p0 = torch.zeros(S)\n",
    "    \n",
    "    for i in range(2**k):\n",
    "        idx = i << (d - k)\n",
    "        p0[idx] = 1.0\n",
    "        \n",
    "    return p0 / p0.sum()\n",
    "\n",
    "def vec_to_int(vectors):\n",
    "    \"\"\" (B, d)---> (B,)\"\"\"\n",
    "    powers = 2**torch.arange(vectors.shape[-1] - 1, -1, -1, device=vectors.device)\n",
    "    return torch.sum(vectors * powers, dim=-1).long()\n",
    "\n",
    "def calculate_kl_divergence(p_target, samples, d):\n",
    "    S = 2**d\n",
    "    sample_indices = vec_to_int(torch.from_numpy(samples)).numpy()\n",
    "    q_hist, _ = np.histogram(sample_indices, bins=np.arange(S + 1) - 0.5)\n",
    "    q_generated = q_hist / q_hist.sum()\n",
    "    p_target_tensor = torch.from_numpy(p_target)\n",
    "    q_generated_tensor = torch.from_numpy(q_generated)\n",
    "    epsilon = 1e-12\n",
    "    p = p_target_tensor.flatten() + epsilon\n",
    "    q = q_generated_tensor.flatten() + epsilon\n",
    "    return torch.sum(p * (torch.log(p) - torch.log(q))).item()\n",
    "\n",
    "def plot_subcube_distributions(samples_parallel, samples_seq, d, k):\n",
    "    fig, axes = plt.subplots(2, 2, figsize=(12, 8), sharey='row')\n",
    "    \n",
    "    hw_active_p = np.sum(samples_parallel[:, :k], axis=1)\n",
    "    hw_active_s = np.sum(samples_seq[:, :k], axis=1)\n",
    "    \n",
    "    x_k = np.arange(k + 1)\n",
    "    pmf_k = binom.pmf(x_k, k, 0.5)\n",
    "\n",
    "    axes[0, 0].hist(hw_active_p, bins=np.arange(k + 2) - 0.5, density=True, color='C1', alpha=0.7, label='Generated')\n",
    "    axes[0, 0].plot(x_k, pmf_k, 'r-o', label='Target')\n",
    "    axes[0, 0].set_title(f\"Parallel\")\n",
    "    axes[0, 0].legend()\n",
    "    \n",
    "    axes[0, 1].hist(hw_active_s, bins=np.arange(k + 2) - 0.5, density=True, color='C2', alpha=0.7, label='Generated')\n",
    "    axes[0, 1].plot(x_k, pmf_k, 'r-o', label='arget')\n",
    "    axes[0, 1].set_title(f\"Seq\")\n",
    "    axes[0, 1].legend()\n",
    "\n",
    "    hw_inactive_p = np.sum(samples_parallel[:, k:], axis=1)\n",
    "    hw_inactive_s = np.sum(samples_seq[:, k:], axis=1)\n",
    "    \n",
    "    axes[1, 0].hist(hw_inactive_p, bins=np.arange(d - k + 2) - 0.5, density=True, color='C1')\n",
    "    axes[1, 0].set_title(f\"Para: irrelavant (last {d-k} digit) hamming weight\")\n",
    "\n",
    "    axes[1, 1].hist(hw_inactive_s, bins=np.arange(d - k + 2) - 0.5, density=True, color='C2')\n",
    "    axes[1, 1].set_title(f\"Seq: irrelavant (last {d-k} digit) hamming weight\")\n",
    "\n",
    "    for ax in axes[1, :]:\n",
    "        ax.set_xlabel(\"Hamming Weight\")\n",
    "    \n",
    "    axes[0,0].set_ylabel(\"Density\")\n",
    "    axes[1,0].set_ylabel(\"Density\")\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58f67ede",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class SequentialSamplerWithCorrector:\n",
    "    def __init__(self, model, d, N, M, K_c, device):\n",
    "        self.model, self.d, self.N, self.M, self.K_c, self.device = model, d, N, M, K_c, device\n",
    "        self.T, self.h, self.dt = model.T, model.T / N, model.T / N / M\n",
    "        self.jump_vectors = torch.eye(d, device=device)\n",
    "    def _run_corrector(self, y_pred_vec, t):\n",
    "        y_current_vec = y_pred_vec\n",
    "        for _ in range(self.K_c):\n",
    "            B = y_current_vec.shape[0]; dim_to_flip = torch.randint(0, self.d, (B,), device=self.device)\n",
    "            flip_vec = torch.nn.functional.one_hot(dim_to_flip, num_classes=self.d).float()\n",
    "            y_proposed_vec = torch.abs(y_current_vec.clone() - flip_vec)\n",
    "            y_current_idx = vec_to_int(y_current_vec); y_proposed_idx = vec_to_int(y_proposed_vec)\n",
    "            changed_mask = (y_current_idx != y_proposed_idx)\n",
    "            acceptance_ratio = torch.ones(B, device=self.device)\n",
    "            if torch.any(changed_mask):\n",
    "                scores = self.model.get_scores(t, y_current_idx[changed_mask], y_proposed_idx[changed_mask])\n",
    "                acceptance_ratio[changed_mask] = scores\n",
    "            A = torch.min(torch.ones_like(acceptance_ratio), acceptance_ratio)\n",
    "            mask = torch.rand_like(A) < A\n",
    "            y_current_vec = torch.where(mask.unsqueeze(1), y_proposed_vec, y_current_vec)\n",
    "        return y_current_vec\n",
    "    def sample(self, batch_size=64):\n",
    "        y_t_n_vec = torch.randint(0, 2, (batch_size, self.d), device=self.device).float()\n",
    "        for n in range(self.N - 1, -1, -1):\n",
    "            t_start_block = n * self.h; t_end_block = t_start_block + self.h\n",
    "            print(f\"\\n--- Seq Sampling: Processing Block {self.N-n}/{self.N} ---\")\n",
    "            y_pred_vec = y_t_n_vec\n",
    "            time_grid = torch.linspace(t_end_block, t_start_block + self.dt, self.M, device=self.device)\n",
    "            for i, t in enumerate(time_grid):\n",
    "                if (i+1) % 10 == 0: print(f\"\\r  Predictor step {i+1}/{self.M}\", end=\"\")\n",
    "                current_states_idx = vec_to_int(y_pred_vec)\n",
    "                intensities = self.model.get_intensities(t.expand(batch_size), current_states_idx)\n",
    "                rates = intensities * self.dt\n",
    "                num_jumps = torch.poisson(rates)\n",
    "                jumps_one_hot = torch.nn.functional.one_hot(torch.arange(self.d, device=self.device), num_classes=self.d).float()\n",
    "                delta_y = torch.sum(jumps_one_hot.unsqueeze(0) * num_jumps.unsqueeze(-1), dim=1)\n",
    "                y_pred_vec = (y_pred_vec + delta_y) % 2\n",
    "            print(\"\\n  Running Corrector\")\n",
    "            y_t_n_vec = self._run_corrector(y_pred_vec, t_start_block)\n",
    "        print(\"\\n Seq Done\")\n",
    "        return y_t_n_vec.cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bd5fd98",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ParallelSamplerWithCorrector:\n",
    "    def __init__(self, model, d, N, M, K_p, K_c, device):\n",
    "        self.model, self.d, self.N, self.M, self.K_p, self.K_c, self.device = model, d, N, M, K_p, K_c, device\n",
    "        self.T, self.h, self.epsilon = model.T, model.T / N, model.T / N / M\n",
    "    def _run_predictor(self, y_t_n_vec, t_start_block):\n",
    "        y_k = y_t_n_vec.unsqueeze(1).repeat(1, self.M + 1, 1)\n",
    "        for _ in range(self.K_p):\n",
    "            start_states_vec = y_k[:, :-1, :]; B, M, D = start_states_vec.shape\n",
    "            start_states_idx = vec_to_int(start_states_vec.reshape(B*M, D))\n",
    "            time_steps = torch.linspace(t_start_block + self.h, t_start_block + self.epsilon, M, device=self.device)\n",
    "            time_steps_flat = time_steps.unsqueeze(0).repeat(B, 1).flatten()\n",
    "            intensities = self.model.get_intensities(time_steps_flat, start_states_idx)\n",
    "            rates = intensities.reshape(B, M, self.d) * self.epsilon\n",
    "            num_jumps = torch.poisson(rates)\n",
    "            jumps_one_hot = torch.nn.functional.one_hot(torch.arange(self.d, device=self.device), num_classes=self.d).float()\n",
    "            delta_y = torch.sum(jumps_one_hot.view(1, 1, self.d, self.d) * num_jumps.unsqueeze(-1), dim=2)\n",
    "            cumulative_jumps = torch.cumsum(delta_y, dim=1)\n",
    "            y_k_plus_1 = torch.zeros_like(y_k); y_k_plus_1[:, 0, :] = y_t_n_vec\n",
    "            y_k_plus_1[:, 1:, :] = y_t_n_vec.unsqueeze(1) + cumulative_jumps\n",
    "            y_k = y_k_plus_1 % 2\n",
    "        return y_k[:, -1, :].clone()\n",
    "    def _run_corrector(self, y_pred_vec, t):\n",
    "        y_current_vec = y_pred_vec\n",
    "        for _ in range(self.K_c):\n",
    "            B = y_current_vec.shape[0]\n",
    "            dim_to_flip = torch.randint(0, self.d, (B,), device=self.device)\n",
    "            flip_vec = torch.nn.functional.one_hot(dim_to_flip, num_classes=self.d).float()\n",
    "            y_proposed_vec = torch.abs(y_current_vec.clone() - flip_vec)\n",
    "            y_current_idx = vec_to_int(y_current_vec); y_proposed_idx = vec_to_int(y_proposed_vec)\n",
    "            changed_mask = (y_current_idx != y_proposed_idx)\n",
    "            acceptance_ratio = torch.ones(B, device=self.device)\n",
    "            if torch.any(changed_mask):\n",
    "                scores = self.model.get_scores(t, y_current_idx[changed_mask], y_proposed_idx[changed_mask])\n",
    "                acceptance_ratio[changed_mask] = scores\n",
    "            A = torch.min(torch.ones_like(acceptance_ratio), acceptance_ratio)\n",
    "            mask = torch.rand_like(A) < A\n",
    "            y_current_vec = torch.where(mask.unsqueeze(1), y_proposed_vec, y_current_vec)\n",
    "        return y_current_vec\n",
    "    def sample(self, batch_size=64):\n",
    "        y_t_n_vec = torch.randint(0, 2, (batch_size, self.d), device=self.device).float()\n",
    "        for n in range(self.N - 1, -1, -1):\n",
    "            t_start_block = n * self.h\n",
    "            print(f\"\\r  Para Sampling... Processing Block {self.N-n}/{self.N}\", end=\"\")\n",
    "            y_pred_vec = self._run_predictor(y_t_n_vec, t_start_block)\n",
    "            y_t_n_vec = self._run_corrector(y_pred_vec, t_start_block)\n",
    "        print(\"\\n Para Done\")\n",
    "        return y_t_n_vec.cpu().numpy()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f382e5fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Parallel\n",
    "class ParallelSamplerWithCorrector:\n",
    "    def __init__(self, model, d, N, M, K_p, K_c, device):\n",
    "        self.model, self.d, self.N, self.M, self.K_p, self.K_c, self.device = model, d, N, M, K_p, K_c, device\n",
    "        self.T, self.h, self.epsilon = model.T, model.T / N, model.T / N / M\n",
    "        \n",
    "        # Pre-compute constants for efficiency\n",
    "        self.d = int(d)\n",
    "        self.N = int(N)\n",
    "        self.M = int(M)\n",
    "        self.K_p = int(K_p)\n",
    "        self.K_c = int(K_c)\n",
    "        \n",
    "        # Pre-allocate reusable tensors and constants\n",
    "        self._jumps_one_hot = torch.nn.functional.one_hot(\n",
    "            torch.arange(self.d, device=self.device, dtype=torch.long), \n",
    "            num_classes=self.d\n",
    "        ).float().view(1, 1, self.d, self.d)\n",
    "        \n",
    "        # Pre-compute time steps template (will be shifted per block)\n",
    "        self._time_steps_template = torch.linspace(\n",
    "            self.epsilon, 0, self.M, device=self.device, dtype=torch.float32\n",
    "        )\n",
    "        \n",
    "    def _run_predictor(self, y_t_n_vec, t_start_block):\n",
    "        B = y_t_n_vec.shape[0]\n",
    "        \n",
    "        # Initialize y_k with broadcasting instead of repeat\n",
    "        y_k = y_t_n_vec.unsqueeze(1).expand(B, self.M + 1, self.d).contiguous()\n",
    "        \n",
    "        # Pre-allocate tensors that will be reused in the loop\n",
    "        y_k_buffer = torch.zeros((B, self.M + 1, self.d), device=self.device, dtype=torch.float32)\n",
    "        \n",
    "        for _ in range(self.K_p):\n",
    "            # Extract start states more efficiently\n",
    "            start_states_vec = y_k[:, :-1, :]  # Shape: (B, M, D)\n",
    "            \n",
    "            # Vectorized conversion to indices - avoid reshape when possible\n",
    "            start_states_flat = start_states_vec.reshape(B * self.M, self.d)\n",
    "            start_states_idx = vec_to_int(start_states_flat)\n",
    "            \n",
    "            # Compute time steps more efficiently using pre-computed template\n",
    "            time_steps = t_start_block + self.h - self._time_steps_template\n",
    "            time_steps_flat = time_steps.unsqueeze(0).expand(B, self.M).reshape(-1)\n",
    "            \n",
    "            # Get intensities in batch\n",
    "            intensities = self.model.get_intensities(time_steps_flat, start_states_idx)\n",
    "            \n",
    "            # Compute rates and jumps\n",
    "            rates = intensities.view(B, self.M, self.d) * self.epsilon\n",
    "            num_jumps = torch.poisson(rates)\n",
    "            \n",
    "            # Compute delta_y using einsum for better performance with large batches\n",
    "            # This replaces the sum over dimension 2\n",
    "            delta_y = torch.einsum('bmdi,bmi->bmd', self._jumps_one_hot.expand(B, -1, -1, -1), num_jumps)\n",
    "            \n",
    "            # Use in-place cumsum for memory efficiency\n",
    "            cumulative_jumps = torch.cumsum(delta_y, dim=1, out=delta_y)\n",
    "            \n",
    "            # Update y_k more efficiently\n",
    "            y_k_buffer[:, 0, :] = y_t_n_vec\n",
    "            y_k_buffer[:, 1:, :] = y_t_n_vec.unsqueeze(1) + cumulative_jumps\n",
    "            \n",
    "            # In-place modulo operation\n",
    "            y_k = y_k_buffer.remainder_(2)\n",
    "            \n",
    "            # Prepare for next iteration if needed\n",
    "            if _ < self.K_p - 1:\n",
    "                y_k = y_k.clone()  # Only clone if we need another iteration\n",
    "        \n",
    "        return y_k[:, -1, :]\n",
    "    \n",
    "    def _run_corrector(self, y_pred_vec, t):\n",
    "        B = y_pred_vec.shape[0]\n",
    "        y_current_vec = y_pred_vec\n",
    "        \n",
    "        # Pre-allocate tensors for reuse\n",
    "        acceptance_ratio = torch.ones(B, device=self.device, dtype=torch.float32)\n",
    "        \n",
    "        for _ in range(self.K_c):\n",
    "            # Generate random dimensions to flip\n",
    "            dim_to_flip = torch.randint(0, self.d, (B,), device=self.device, dtype=torch.long)\n",
    "            \n",
    "            # Create flip vector using scatter instead of one_hot for efficiency\n",
    "            flip_vec = torch.zeros((B, self.d), device=self.device, dtype=torch.float32)\n",
    "            flip_vec.scatter_(1, dim_to_flip.unsqueeze(1), 1.0)\n",
    "            \n",
    "            # Compute proposed states using XOR logic (more efficient than abs(x - y))\n",
    "            y_proposed_vec = torch.where(flip_vec > 0, 1 - y_current_vec, y_current_vec)\n",
    "            \n",
    "            # Convert to indices\n",
    "            y_current_idx = vec_to_int(y_current_vec)\n",
    "            y_proposed_idx = vec_to_int(y_proposed_vec)\n",
    "            \n",
    "            # Find changed indices\n",
    "            changed_mask = (y_current_idx != y_proposed_idx)\n",
    "            \n",
    "            # Reset acceptance ratio\n",
    "            acceptance_ratio.fill_(1.0)\n",
    "            \n",
    "            # Compute scores only for changed states\n",
    "            if torch.any(changed_mask):\n",
    "                scores = self.model.get_scores(t, y_current_idx[changed_mask], y_proposed_idx[changed_mask])\n",
    "                acceptance_ratio[changed_mask] = scores\n",
    "            \n",
    "            # Compute acceptance probabilities\n",
    "            A = torch.minimum(acceptance_ratio, torch.ones_like(acceptance_ratio))\n",
    "            \n",
    "            # Generate random acceptance decisions\n",
    "            mask = torch.rand(B, device=self.device, dtype=torch.float32) < A\n",
    "            \n",
    "            # Update current states based on acceptance\n",
    "            y_current_vec = torch.where(mask.unsqueeze(1), y_proposed_vec, y_current_vec)\n",
    "        \n",
    "        return y_current_vec\n",
    "    \n",
    "    def sample(self, batch_size=64):\n",
    "        batch_size = int(batch_size)\n",
    "        \n",
    "        # Initialize with integer type for better memory efficiency\n",
    "        y_t_n_vec = torch.randint(0, 2, (batch_size, self.d), device=self.device, dtype=torch.float32)\n",
    "        \n",
    "        # Pre-compute all time points if memory allows\n",
    "        for n in range(self.N - 1, -1, -1):\n",
    "            t_start_block = n * self.h\n",
    "            print(f\"\\r  Para Sampling... Processing Block {self.N-n}/{self.N}\", end=\"\")\n",
    "            \n",
    "            # Run predictor\n",
    "            y_pred_vec = self._run_predictor(y_t_n_vec, t_start_block)\n",
    "            \n",
    "            # Run corrector\n",
    "            y_t_n_vec = self._run_corrector(y_pred_vec, t_start_block)\n",
    "        \n",
    "        print(\"\\n Para Done\")\n",
    "        return y_t_n_vec.cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e0cda31",
   "metadata": {},
   "outputs": [],
   "source": [
    "AMBIENT_DIM_D = 12       \n",
    "INTRINSIC_DIM_K = 6     \n",
    "\n",
    "TOTAL_TIME = 5.0\n",
    "BATCH_SIZE = 4096\n",
    "\n",
    "N_BLOCKS = 50      \n",
    "STEPS_PER_BLOCK = 100\n",
    "K_CORRECTOR = 8    \n",
    "\n",
    "K_PREDICTOR = 2     \n",
    "\n",
    "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(f\"Device: {DEVICE}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "451882f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "p0_tensor = create_subcube_distribution(AMBIENT_DIM_D, INTRINSIC_DIM_K).to(DEVICE)\n",
    "p0_numpy = p0_tensor.cpu().numpy()\n",
    "oracle_model = OracleScoreModelHypercube(p0_tensor, d=AMBIENT_DIM_D, T=TOTAL_TIME, device=DEVICE)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "629b8c2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "print(\"=\"*50)\n",
    "parallel_sampler = ParallelSamplerWithCorrector(\n",
    "    oracle_model, AMBIENT_DIM_D, N_BLOCKS, STEPS_PER_BLOCK, \n",
    "    K_PREDICTOR, K_CORRECTOR, DEVICE\n",
    ")\n",
    "start_p = time.time()\n",
    "samples_p = parallel_sampler.sample(BATCH_SIZE)\n",
    "end_p = time.time()\n",
    "time_p = end_p - start_p\n",
    "kl_p = calculate_kl_divergence(p0_numpy, samples_p, AMBIENT_DIM_D)\n",
    "\n",
    "print(\"\\n\" + \"=\"*50)\n",
    "sequential_sampler = SequentialSamplerWithCorrector(\n",
    "    oracle_model, AMBIENT_DIM_D, N_BLOCKS, STEPS_PER_BLOCK, K_CORRECTOR, DEVICE\n",
    ")\n",
    "start_s = time.time()\n",
    "samples_s = sequential_sampler.sample(BATCH_SIZE)\n",
    "end_s = time.time()\n",
    "time_s = end_s - start_s\n",
    "kl_s = calculate_kl_divergence(p0_numpy, samples_s, AMBIENT_DIM_D)\n",
    "\n",
    "\n",
    "print(\"\\n\" + \"=\"*50)\n",
    "print(f\"Result (d={AMBIENT_DIM_D}, k={INTRINSIC_DIM_K} sub-cube)\")\n",
    "print(\"-\"*50)\n",
    "print(f\"| {'Score':<18} | {'Parallel+corr':<25} | {'Seq+corr':<20} |\")\n",
    "print(f\"| {'-'*18} | {'-'*25} | {'-'*20} |\")\n",
    "print(f\"| {'Runtime':<18} | {time_p:<25.2f} | {time_s:<20.2f} |\")\n",
    "print(f\"| {'KL-Divergence':<18} | {kl_p:<25.4f} | {kl_s:<20.4f} |\")\n",
    "seq_stages_p = N_BLOCKS * (K_PREDICTOR + K_CORRECTOR)\n",
    "seq_stages_s = N_BLOCKS * (STEPS_PER_BLOCK + K_CORRECTOR)\n",
    "print(f\"| {'Total Seq Step':<18} | {seq_stages_p:<25} | {seq_stages_s:<20} |\")\n",
    "print(\"-\"*50)\n",
    "\n",
    "plot_subcube_distributions(samples_p, samples_s, AMBIENT_DIM_D, INTRINSIC_DIM_K)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4cffcda9",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "NUM_RUNS = 10  \n",
    "\n",
    "AMBIENT_DIM_D = 12       \n",
    "INTRINSIC_DIM_K = 6     \n",
    "\n",
    "TOTAL_TIME = 5.0\n",
    "BATCH_SIZE = 2048\n",
    "\n",
    "N_BLOCKS = 50      \n",
    "STEPS_PER_BLOCK = 100\n",
    "K_CORRECTOR = 8    \n",
    "\n",
    "K_PREDICTOR = 12 \n",
    "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(f\"Device: {DEVICE}\\n\")\n",
    "\n",
    "# Create target distribution and oracle\n",
    "p0_tensor = create_subcube_distribution(AMBIENT_DIM_D, INTRINSIC_DIM_K).to(DEVICE)\n",
    "p0_numpy = p0_tensor.cpu().numpy()\n",
    "oracle_model = OracleScoreModelHypercube(p0_tensor, d=AMBIENT_DIM_D, T=TOTAL_TIME, device=DEVICE)\n",
    "\n",
    "times_p_list = []\n",
    "kls_p_list = []\n",
    "times_s_list = []\n",
    "kls_s_list = []\n",
    "\n",
    "print(\"\\n\" + \"=\"*50)\n",
    "print(f\"Running Parallel Sampler w/ Corrector for {NUM_RUNS} runs...\")\n",
    "for i in range(NUM_RUNS):\n",
    "    print(f\"\\n--- Parallel Run {i+1}/{NUM_RUNS} ---\")\n",
    "    parallel_sampler = ParallelSamplerWithCorrector(\n",
    "    oracle_model, AMBIENT_DIM_D, N_BLOCKS, STEPS_PER_BLOCK, \n",
    "    K_PREDICTOR, K_CORRECTOR, DEVICE)\n",
    "    start_p = time.time()\n",
    "    samples_p = parallel_sampler.sample(BATCH_SIZE)\n",
    "    end_p = time.time()\n",
    "    \n",
    "    time_p = end_p - start_p\n",
    "    kl_p = calculate_kl_divergence(p0_numpy, samples_p, AMBIENT_DIM_D)\n",
    "    \n",
    "    times_p_list.append(time_p)\n",
    "    kls_p_list.append(kl_p)\n",
    "    print(f\"Run {i+1} finished. Time: {time_p:.2f}s, KL: {kl_p:.4f}\")\n",
    "\n",
    "print(\"\\n\" + \"=\"*50)\n",
    "print(f\"Running Sequential Sampler w/ Corrector for {NUM_RUNS} runs...\")\n",
    "for i in range(NUM_RUNS):\n",
    "    print(f\"\\n--- Sequential Run {i+1}/{NUM_RUNS} ---\")\n",
    "    sequential_sampler = SequentialSamplerWithCorrector(\n",
    "    oracle_model, AMBIENT_DIM_D, N_BLOCKS, STEPS_PER_BLOCK, K_CORRECTOR, DEVICE)\n",
    "    start_s = time.time()\n",
    "    samples_s = sequential_sampler.sample(BATCH_SIZE)\n",
    "    end_s = time.time()\n",
    "    \n",
    "    time_s = end_s - start_s\n",
    "    kl_s = calculate_kl_divergence(p0_numpy, samples_s, AMBIENT_DIM_D)\n",
    "    \n",
    "    times_s_list.append(time_s)\n",
    "    kls_s_list.append(kl_s)\n",
    "    print(f\"Run {i+1} finished. Time: {time_s:.2f}s, KL: {kl_s:.4f}\")\n",
    "\n",
    "avg_time_p = np.mean(times_p_list)\n",
    "std_time_p = np.std(times_p_list)\n",
    "avg_kl_p = np.mean(kls_p_list)\n",
    "std_kl_p = np.std(kls_p_list)\n",
    "\n",
    "avg_time_s = np.mean(times_s_list)\n",
    "std_time_s = np.std(times_s_list)\n",
    "avg_kl_s = np.mean(kls_s_list)\n",
    "std_kl_s = np.std(kls_s_list)\n",
    "--\n",
    "seq_stages_p = N_BLOCKS * (K_PREDICTOR + K_CORRECTOR)\n",
    "seq_stages_s = N_BLOCKS * (STEPS_PER_BLOCK + K_CORRECTOR)\n",
    "\n",
    "result_p_time = f\"{avg_time_p:.2f} ± {std_time_p:.2f}\"\n",
    "result_p_kl = f\"{avg_kl_p:.4f} ± {std_kl_p:.4f}\"\n",
    "result_s_time = f\"{avg_time_s:.2f} ± {std_time_s:.2f}\"\n",
    "result_s_kl = f\"{avg_kl_s:.4f} ± {std_kl_s:.4f}\"\n",
    "\n",
    "print(\"\\n\" + \"=\"*50)\n",
    "print(f\"Final Averaged Results ({NUM_RUNS} runs)\")\n",
    "print(\"-\"*68)\n",
    "print(f\"| {'Score':<18} | {'Picard + Corrector':<28} | {'Sequential + Corrector':<24} |\")\n",
    "print(f\"| {'-'*18} | {'-'*28} | {'-'*24} |\")\n",
    "print(f\"| {'Runtime (s)':<18} | {result_p_time:<28} | {result_s_time:<24} |\")\n",
    "print(f\"| {'KL Divergence':<18} | {result_p_kl:<28} | {result_s_kl:<24} |\")\n",
    "print(f\"| {'Sequential Stages':<18} | {seq_stages_p:<28} | {seq_stages_s:<24} |\")\n",
    "print(\"-\"*68)\n",
    "\n",
    "print(\"\\nVisualizing the last run's results...\")\n",
    "plot_subcube_distributions(samples_p, samples_s, AMBIENT_DIM_D, INTRINSIC_DIM_K)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
