{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6b77b2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import random\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.sparse import csc_matrix, eye\n",
    "from scipy.sparse.linalg import expm_multiply\n",
    "import time\n",
    "\n",
    "%matplotlib inline\n",
    "seed = 42\n",
    "random.seed(seed)\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.manual_seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b04f1e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build Oracle\n",
    "class OracleScoreModelHypercube:\n",
    "    def __init__(self, p0_flat, d, T, device, steps=1001):\n",
    "        self.d = d\n",
    "        self.S = 2**d\n",
    "        self.T = T\n",
    "        self.times_cpu = torch.linspace(0, T, steps)\n",
    "        self.device = device\n",
    "        \n",
    "        assert len(p0_flat) == self.S, \"dim not match\"\n",
    "        \n",
    "        Q = self._create_rate_matrix()\n",
    "        \n",
    "        print(f\"Compute true p_t... (state space size: {self.S})\")\n",
    "        self.p_t_flat = self._solve_forward_process(Q, p0_flat)\n",
    "\n",
    "    def _create_rate_matrix(self):\n",
    "        \"\"\"Q for hypercube\"\"\"\n",
    "        row_indices, col_indices, data = [], [], []\n",
    "        \n",
    "\n",
    "        for i in range(self.S):\n",
    "            for j in range(self.d):\n",
    "                neighbor_idx = i ^ (1 << j) \n",
    "                # Off-diagonal: rate of jumping in is 1.0\n",
    "                row_indices.append(neighbor_idx)\n",
    "                col_indices.append(i)\n",
    "                data.append(1.0)\n",
    "            \n",
    "            # Diagonal: sum of outgoing rates is d\n",
    "            row_indices.append(i)\n",
    "            col_indices.append(i)\n",
    "            data.append(-float(self.d))\n",
    "            \n",
    "        return csc_matrix((data, (row_indices, col_indices)), shape=(self.S, self.S))\n",
    "\n",
    "    def _solve_forward_process(self, Q, p0_flat):\n",
    "        p_t_list = []\n",
    "        p_current = p0_flat.cpu().numpy()\n",
    "        dt = (self.times_cpu[1] - self.times_cpu[0]).item()\n",
    "        exp_dtQ = expm_multiply(Q * dt, eye(self.S))\n",
    "        for _ in self.times_cpu:\n",
    "            p_t_list.append(torch.from_numpy(p_current))\n",
    "            p_current = exp_dtQ.dot(p_current)\n",
    "        return torch.stack(p_t_list, dim=0).float().to(self.device)\n",
    "\n",
    "    def get_intensities(self, t, current_states_idx):\n",
    "        time_indices = torch.searchsorted(self.times_cpu, t.cpu()).to(self.device)\n",
    "        p_t = self.p_t_flat[time_indices]\n",
    "        p_t_x = p_t.gather(1, current_states_idx.unsqueeze(1)).squeeze(1)\n",
    "        neighbors = current_states_idx.unsqueeze(1) ^ (1 << torch.arange(self.d, device=self.device)).unsqueeze(0)\n",
    "        p_t_y = p_t.gather(1, neighbors)\n",
    "        intensities = p_t_y / (p_t_x.unsqueeze(1) + 1e-9)\n",
    "        return intensities\n",
    "\n",
    "    def get_scores(self, t, current_states_idx, proposed_states_idx):\n",
    "        if isinstance(t, torch.Tensor): t_cpu = t.cpu()\n",
    "        else: t_cpu = torch.tensor(t)\n",
    "        time_idx = torch.searchsorted(self.times_cpu, t_cpu).to(self.device)\n",
    "        p_t = self.p_t_flat[time_idx]\n",
    "        p_t_current = p_t.gather(0, current_states_idx)\n",
    "        p_t_proposed = p_t.gather(0, proposed_states_idx)\n",
    "        return p_t_proposed / (p_t_current + 1e-9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96219b24",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Helper Functions\n",
    "def create_bimodal_hypercube_distribution(d):\n",
    "    S = 2**d\n",
    "    p0 = torch.zeros(S)\n",
    "    p0[0] = 0.5  # all-zeros state\n",
    "    p0[-1] = 0.5 # all-ones state\n",
    "    return p0\n",
    "\n",
    "def vec_to_int(vectors):\n",
    "    powers = 2**torch.arange(vectors.shape[-1] - 1, -1, -1, device=vectors.device)\n",
    "    return torch.sum(vectors * powers, dim=1).long()\n",
    "\n",
    "def int_to_vec(indices, d):\n",
    "    mask = 2**torch.arange(d - 1, -1, -1, device=indices.device)\n",
    "    return indices.unsqueeze(-1).bitwise_and(mask).ne(0).float()\n",
    "\n",
    "def calculate_kl_divergence(p_target, samples, d):\n",
    "    S = 2**d\n",
    "    sample_indices = vec_to_int(torch.from_numpy(samples)).numpy()\n",
    "    q_hist, _ = np.histogram(sample_indices, bins=np.arange(S + 1) - 0.5)\n",
    "    q_generated = q_hist / q_hist.sum()\n",
    "    p_target_tensor = torch.from_numpy(p_target)\n",
    "    q_generated_tensor = torch.from_numpy(q_generated)\n",
    "    epsilon = 1e-12\n",
    "    p = p_target_tensor.flatten() + epsilon\n",
    "    q = q_generated_tensor.flatten() + epsilon\n",
    "    return torch.sum(p * (torch.log(p) - torch.log(q))).item()\n",
    "\n",
    "def plot_hamming_distributions(p0, samples_parallel, samples_seq, d):\n",
    "    fig, axes = plt.subplots(1, 3, figsize=(18, 5), sharey=True)\n",
    "    target_weights = [0, d]\n",
    "    axes[0].hist([target_weights, target_weights], bins=np.arange(d + 2) - 0.5, density=True, label=['mode 1', 'mode 2'])\n",
    "    axes[0].set_title(f\"Target (d={d})\")\n",
    "    hamming_p = np.sum(samples_parallel, axis=1)\n",
    "    axes[1].hist(hamming_p, bins=np.arange(d + 2) - 0.5, density=True, color='C1')\n",
    "    axes[1].set_title(\"Para+corr\")\n",
    "    hamming_s = np.sum(samples_seq, axis=1)\n",
    "    axes[2].hist(hamming_s, bins=np.arange(d + 2) - 0.5, density=True, color='C2')\n",
    "    axes[2].set_title(\"Seq+corr\")\n",
    "    for ax in axes:\n",
    "        ax.set_xlabel(\"Hamming Weight (num of 1)\")\n",
    "        ax.set_xlim(-1, d + 1)\n",
    "    axes[0].set_ylabel(\"Density\")\n",
    "    plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06d40fbb",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "#Seqential\n",
    "class SequentialSamplerWithCorrector:\n",
    "    def __init__(self, model, d, N, M, K_c, device):\n",
    "        self.model, self.d, self.N, self.M, self.K_c, self.device = model, d, N, M, K_c, device\n",
    "        self.T, self.h, self.dt = model.T, model.T / N, model.T / N / M\n",
    "        self.jump_vectors = torch.eye(d, device=device)\n",
    "    def _run_corrector(self, y_pred_vec, t): # Corrector is identical\n",
    "        y_current_vec = y_pred_vec\n",
    "        for _ in range(self.K_c):\n",
    "            B = y_current_vec.shape[0]\n",
    "            dim_to_flip = torch.randint(0, self.d, (B,), device=self.device)\n",
    "            flip_vec = torch.nn.functional.one_hot(dim_to_flip, num_classes=self.d).float()\n",
    "            y_proposed_vec = torch.abs(y_current_vec.clone() - flip_vec)\n",
    "            y_current_idx = vec_to_int(y_current_vec); y_proposed_idx = vec_to_int(y_proposed_vec)\n",
    "            changed_mask = (y_current_idx != y_proposed_idx)\n",
    "            acceptance_ratio = torch.ones(B, device=self.device)\n",
    "            if torch.any(changed_mask):\n",
    "                scores = self.model.get_scores(t, y_current_idx[changed_mask], y_proposed_idx[changed_mask])\n",
    "                acceptance_ratio[changed_mask] = scores\n",
    "            A = torch.min(torch.ones_like(acceptance_ratio), acceptance_ratio)\n",
    "            mask = torch.rand_like(A) < A\n",
    "            y_current_vec = torch.where(mask.unsqueeze(1), y_proposed_vec, y_current_vec)\n",
    "        return y_current_vec\n",
    "    def sample(self, batch_size=64):\n",
    "        y_t_n_vec = torch.randint(0, 2, (batch_size, self.d), device=self.device).float()\n",
    "        for n in range(self.N - 1, -1, -1):\n",
    "            t_start_block = n * self.h; t_end_block = t_start_block + self.h\n",
    "            print(f\"\\n--- Seq Sampling... Processing Block {self.N-n}/{self.N} ---\")\n",
    "            y_pred_vec = y_t_n_vec\n",
    "            time_grid = torch.linspace(t_end_block, t_start_block + self.dt, self.M, device=self.device)\n",
    "            for i, t in enumerate(time_grid):\n",
    "                if (i+1) % 10 == 0: print(f\"\\r  Predict Step {i+1}/{self.M}\", end=\"\")\n",
    "                current_states_idx = vec_to_int(y_pred_vec)\n",
    "                intensities = self.model.get_intensities(t.expand(batch_size), current_states_idx)\n",
    "                rates = intensities * self.dt\n",
    "                num_jumps = torch.poisson(rates)\n",
    "                delta_y = torch.sum(self.jump_vectors.unsqueeze(0) * num_jumps.unsqueeze(-1), dim=1)\n",
    "                y_pred_vec = (y_pred_vec + delta_y) % 2\n",
    "            print(\"\\n  Run corrector..\")\n",
    "            y_t_n_vec = self._run_corrector(y_pred_vec, t_start_block)\n",
    "        print(\"\\n Done\")\n",
    "        return y_t_n_vec.cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32913934",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ParallelSamplerWithCorrector:\n",
    "    def __init__(self, model, d, N, M, K_p, K_c, device):\n",
    "        self.model, self.d, self.N, self.M, self.K_p, self.K_c, self.device = model, d, N, M, K_p, K_c, device\n",
    "        self.T, self.h, self.epsilon = model.T, model.T / N, model.T / N / M\n",
    "        # jump_vectors for hypercube are single bit flips, represented by identity matrix\n",
    "        self.jump_vectors = torch.eye(d, device=device)\n",
    "        \n",
    "        # Ensure integer types\n",
    "        self.d = int(d)\n",
    "        self.N = int(N)\n",
    "        self.M = int(M)\n",
    "        self.K_p = int(K_p)\n",
    "        self.K_c = int(K_c)\n",
    "        \n",
    "        # Pre-compute constants for efficiency\n",
    "        self._time_offsets = torch.linspace(\n",
    "            self.h, self.epsilon, self.M, device=self.device, dtype=torch.float32\n",
    "        )\n",
    "        \n",
    "        # Pre-compute one-hot encoding for jumps (used in predictor)\n",
    "        self._jumps_one_hot = torch.eye(self.d, device=self.device, dtype=torch.float32).view(1, 1, self.d, self.d)\n",
    "        \n",
    "        # Pre-allocate reusable tensors\n",
    "        self._ones_template = torch.ones(1, device=self.device, dtype=torch.float32)\n",
    "\n",
    "    def _run_predictor(self, y_t_n_vec, t_start_block):\n",
    "        B = y_t_n_vec.shape[0]\n",
    "        \n",
    "        # Initialize y_k using expand (view) instead of repeat (copy)\n",
    "        y_k = y_t_n_vec.unsqueeze(1).expand(B, self.M + 1, self.d).contiguous()\n",
    "        \n",
    "        # Pre-allocate buffers for the loop\n",
    "        if self.K_p > 0:\n",
    "            y_k_buffer = torch.zeros((B, self.M + 1, self.d), device=self.device, dtype=torch.float32)\n",
    "            delta_y_buffer = torch.zeros((B, self.M, self.d), device=self.device, dtype=torch.float32)\n",
    "        \n",
    "        for k_iter in range(self.K_p):\n",
    "            start_states_vec = y_k[:, :-1, :]  # Shape: (B, M, D)\n",
    "            \n",
    "            # Efficient state index computation\n",
    "            start_states_flat = start_states_vec.reshape(B * self.M, self.d)\n",
    "            start_states_idx = vec_to_int(start_states_flat)\n",
    "            \n",
    "            # Compute time steps using pre-computed offsets\n",
    "            time_steps = t_start_block + self._time_offsets\n",
    "            time_steps_flat = time_steps.unsqueeze(0).expand(B, self.M).reshape(-1)\n",
    "            \n",
    "            # Get intensities from model\n",
    "            intensities = self.model.get_intensities(time_steps_flat, start_states_idx)\n",
    "            rates = intensities.view(B, self.M, self.d) * self.epsilon\n",
    "            \n",
    "            # Generate Poisson jumps\n",
    "            num_jumps = torch.poisson(rates)\n",
    "            \n",
    "            # Compute delta_y using einsum (much faster for large batches)\n",
    "            # Equivalent to: sum(jumps_one_hot * num_jumps.unsqueeze(-1), dim=2)\n",
    "            delta_y_buffer = torch.matmul(num_jumps.unsqueeze(2), self._jumps_one_hot.squeeze(0).squeeze(0)).squeeze(2)\n",
    "            \n",
    "            # Cumulative sum in-place\n",
    "            cumulative_jumps = torch.cumsum(delta_y_buffer, dim=1)\n",
    "            \n",
    "            # Update y_k_buffer efficiently\n",
    "            y_k_buffer[:, 0, :] = y_t_n_vec\n",
    "            y_k_buffer[:, 1:, :] = y_t_n_vec.unsqueeze(1) + cumulative_jumps\n",
    "            \n",
    "            # Apply modulo 2 for binary states (in-place)\n",
    "            y_k = y_k_buffer.remainder_(2)\n",
    "            \n",
    "            # Only clone if not the last iteration\n",
    "            if k_iter < self.K_p - 1:\n",
    "                y_k = y_k.clone()\n",
    "        \n",
    "        return y_k[:, -1, :]\n",
    "\n",
    "    def _run_corrector(self, y_pred_vec, t):\n",
    "        B = y_pred_vec.shape[0]\n",
    "        y_current_vec = y_pred_vec\n",
    "        \n",
    "        # Pre-allocate tensors for efficiency\n",
    "        acceptance_ratio = torch.ones(B, device=self.device, dtype=torch.float32)\n",
    "        ones_vec = self._ones_template.expand(B)\n",
    "        \n",
    "        for _ in range(self.K_c):\n",
    "            # Generate random dimensions to flip\n",
    "            dim_to_flip = torch.randint(0, self.d, (B,), device=self.device, dtype=torch.long)\n",
    "            \n",
    "            # Create flip vector efficiently using scatter\n",
    "            flip_vec = torch.zeros((B, self.d), device=self.device, dtype=torch.float32)\n",
    "            flip_vec.scatter_(1, dim_to_flip.unsqueeze(1), 1.0)\n",
    "            \n",
    "            # Compute proposed states using XOR logic (more efficient than abs for binary)\n",
    "            # For binary: flip = 1 - current where flip_vec is 1\n",
    "            y_proposed_vec = torch.where(flip_vec > 0, 1 - y_current_vec, y_current_vec)\n",
    "            \n",
    "            # Convert to indices\n",
    "            y_current_idx = vec_to_int(y_current_vec)\n",
    "            y_proposed_idx = vec_to_int(y_proposed_vec)\n",
    "            \n",
    "            # Check which states changed\n",
    "            changed_mask = (y_current_idx != y_proposed_idx)\n",
    "            \n",
    "            # Reset acceptance ratio\n",
    "            acceptance_ratio.fill_(1.0)\n",
    "            \n",
    "            # Compute scores only for changed states\n",
    "            if torch.any(changed_mask):\n",
    "                scores = self.model.get_scores(t, y_current_idx[changed_mask], y_proposed_idx[changed_mask])\n",
    "                acceptance_ratio[changed_mask] = scores\n",
    "            \n",
    "            # Compute acceptance probability\n",
    "            A = torch.minimum(ones_vec, acceptance_ratio)\n",
    "            \n",
    "            # Generate acceptance decisions\n",
    "            mask = torch.rand(B, device=self.device, dtype=torch.float32) < A\n",
    "            \n",
    "            # Update states based on acceptance\n",
    "            y_current_vec = torch.where(mask.unsqueeze(1), y_proposed_vec, y_current_vec)\n",
    "        \n",
    "        return y_current_vec\n",
    "\n",
    "    def sample(self, batch_size=64):\n",
    "        batch_size = int(batch_size)\n",
    "        \n",
    "        # Initialize with binary values\n",
    "        y_t_n_vec = torch.randint(0, 2, (batch_size, self.d), device=self.device, dtype=torch.float32)\n",
    "        \n",
    "        # Main sampling loop\n",
    "        for n in range(self.N - 1, -1, -1):\n",
    "            t_start_block = n * self.h\n",
    "            print(f\"\\r  Para Sampling... Processing Block {self.N-n}/{self.N}\", end=\"\")\n",
    "            \n",
    "            # Run predictor\n",
    "            y_pred_vec = self._run_predictor(y_t_n_vec, t_start_block)\n",
    "            \n",
    "            # Run corrector  \n",
    "            y_t_n_vec = self._run_corrector(y_pred_vec, t_start_block)\n",
    "        \n",
    "        print(\"\\n Done\")\n",
    "        return y_t_n_vec.cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4570215b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Parameters\n",
    "DIMENSION = 6      # dim of hypercube\n",
    "TOTAL_TIME = 5.0\n",
    "BATCH_SIZE = 2048\n",
    "\n",
    "# --- shared---\n",
    "N_BLOCKS = 100\n",
    "STEPS_PER_BLOCK = 50 \n",
    "K_CORRECTOR = 10\n",
    "\n",
    "# --- para ---\n",
    "K_PREDICTOR = 4\n",
    "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(f\"Device: {DEVICE}\")\n",
    "# ---\n",
    "p0_tensor = create_bimodal_hypercube_distribution(DIMENSION).to(DEVICE)\n",
    "p0_numpy = p0_tensor.cpu().numpy()\n",
    "oracle_model = OracleScoreModelHypercube(p0_tensor, d=DIMENSION, T=TOTAL_TIME, device=DEVICE)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13a5e66c",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Experiment \n",
    "\n",
    "print(\"=\"*50)\n",
    "print(\"Para start...\")\n",
    "parallel_sampler = ParallelSamplerWithCorrector(\n",
    "    oracle_model, DIMENSION, N_BLOCKS, STEPS_PER_BLOCK, \n",
    "    K_PREDICTOR, K_CORRECTOR, DEVICE\n",
    ")\n",
    "start_p = time.time()\n",
    "samples_p = parallel_sampler.sample(BATCH_SIZE)\n",
    "end_p = time.time()\n",
    "time_p = end_p - start_p\n",
    "kl_p = calculate_kl_divergence(p0_numpy, samples_p, DIMENSION)\n",
    "\n",
    "print(\"\\n\" + \"=\"*50)\n",
    "print(\"Seq start...\")\n",
    "sequential_sampler = SequentialSamplerWithCorrector(\n",
    "    oracle_model, DIMENSION, N_BLOCKS, STEPS_PER_BLOCK, K_CORRECTOR, DEVICE\n",
    ")\n",
    "start_s = time.time()\n",
    "samples_s = sequential_sampler.sample(BATCH_SIZE)\n",
    "end_s = time.time()\n",
    "time_s = end_s - start_s\n",
    "kl_s = calculate_kl_divergence(p0_numpy, samples_s, DIMENSION)\n",
    "\n",
    "print(\"\\n\" + \"=\"*50)\n",
    "print(f\"Result (d={DIMENSION} dim Hypercube)\")\n",
    "print(\"-\"*50)\n",
    "print(f\"| {'Score':<18} | {'Para+corr':<25} | {'Seq+corr':<20} |\")\n",
    "print(f\"| {'-'*18} | {'-'*25} | {'-'*20} |\")\n",
    "print(f\"| {'Runtime':<18} | {time_p:<25.2f} | {time_s:<20.2f} |\")\n",
    "print(f\"| {'KL-divergence':<18} | {kl_p:<25.4f} | {kl_s:<20.4f} |\")\n",
    "seq_stages_p = N_BLOCKS * (K_PREDICTOR + K_CORRECTOR)\n",
    "seq_stages_s = N_BLOCKS * (STEPS_PER_BLOCK + K_CORRECTOR)\n",
    "print(f\"| {'Seq steps':<18} | {seq_stages_p:<25} | {seq_stages_s:<20} |\")\n",
    "print(\"-\"*50)\n",
    "\n",
    "plot_hamming_distributions(p0_numpy, samples_p, samples_s, DIMENSION)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
