{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import yaml\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from models.model_utils import get_model, get_preconditioned_model\n",
    "from utils.graph_lib import Absorbing\n",
    "from  utils.guidance_schedules import get_guidance_schedule\n",
    "from utils.datasets import get_dataset\n",
    "from utils.samplers import get_sampler\n",
    "from utils.misc import dotdict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_histogram_figure(samples_np, tilted_dist, original_conditional_dist, guid_w, plot_intersection=False, title='distribution_comparison.pdf'):\n",
    "    plt.rcParams.update({\n",
    "        # 'font.family': 'serif',\n",
    "        'text.usetex': True,  # Set to True if you have LaTeX installed\n",
    "        'axes.labelsize': 12,\n",
    "        'axes.titlesize': 14,\n",
    "        'xtick.labelsize': 10,\n",
    "        'ytick.labelsize': 10,\n",
    "        'legend.fontsize': 10,\n",
    "    })\n",
    "\n",
    "    # Create a single figure\n",
    "    fig, ax = plt.subplots(figsize=(6, 5))\n",
    "\n",
    "    # Plot the histogram of samples\n",
    "    bins = np.arange(len(tilted_dist)+1) - 0.5  # Centered bins\n",
    "    counts, _, patches = ax.hist(samples_np.flatten(), bins=bins, density=True,\n",
    "                                 color='#377EB8', alpha=0.7, label='Generated Samples')\n",
    "\n",
    "    if plot_intersection:\n",
    "        # Highlight the intersection area between 10 and 12\n",
    "        ax.axvspan(9.5, 12.5, alpha=0.2, color='green')\n",
    "        # Add text annotation for the intersection\n",
    "        ax.text(11, ax.get_ylim()[1]*0.9, 'Intersection\\nArea',\n",
    "                ha='center', va='top', fontsize=10,\n",
    "                bbox=dict(boxstyle='round,pad=0.3', fc='white', ec='green', alpha=0.7))\n",
    "\n",
    "    # Overlay the theoretical distribution\n",
    "    ax.plot(np.arange(len(tilted_dist)), tilted_dist, color='#E41A1C', marker='s',\n",
    "            markersize=4, linewidth=2, label='Tilted Distribution')\n",
    "\n",
    "    # Add the original conditional distribution with different color and marker\n",
    "    ax.plot(np.arange(len(tilted_dist)), original_conditional_dist, color='#4DAF4A', marker='o',\n",
    "            markersize=4, linewidth=2, label='Class Distribution')\n",
    "\n",
    "    # Set labels and title\n",
    "    ax.set_xlabel('State', fontsize=12)\n",
    "    ax.set_ylabel('Probability', fontsize=12)\n",
    "    ax.set_title(\n",
    "        f'Distribution Comparison (n={samples_np.shape[0]}, guidance={guid_w})', fontsize=14)\n",
    "\n",
    "    # Show only every 5th tick to avoid overlapping\n",
    "    x_ticks = np.arange(0, len(tilted_dist), 5)\n",
    "    ax.set_xticks(x_ticks)\n",
    "    ax.set_xticklabels([str(i) for i in x_ticks])\n",
    "\n",
    "    # Add grid and legend\n",
    "    ax.grid(alpha=0.3, linestyle='--')\n",
    "    ax.legend(fontsize=12, frameon=True, facecolor='white',\n",
    "              edgecolor='gray', shadow=True, loc='upper right')\n",
    "\n",
    "    ax.set_ylim(0.0, .35)\n",
    "    # Remove top and right spines for cleaner look\n",
    "    ax.spines['top'].set_visible(False)\n",
    "    ax.spines['right'].set_visible(False)\n",
    "\n",
    "    # Tight layout and save\n",
    "    fig.tight_layout()\n",
    "    plt.savefig(title, bbox_inches='tight')\n",
    "    # plt.savefig('distribution_comparison.png', bbox_inches='tight', dpi=300)\n",
    "    plt.show()\n",
    "\n",
    "    return fig, ax\n",
    "\n",
    "def create_vector_histogram_figure(dist_vector, legend_name,  title='distribution_comparison.pdf'):\n",
    "    \"\"\"\n",
    "    Create a histogram visualization from distribution vectors rather than samples.\n",
    "    \n",
    "    Args:\n",
    "        dist_vector: The main distribution vector to plot\n",
    "        tilted_dist: Optional tilted distribution to overlay\n",
    "        original_conditional_dist: Optional original conditional distribution to overlay\n",
    "        guid_w: Guidance weight for title\n",
    "        plot_intersection: Whether to highlight intersection area\n",
    "        title: Filename for saving the figure\n",
    "    \"\"\"\n",
    "    plt.rcParams.update({\n",
    "        'text.usetex': True,  # Set to True if you have LaTeX installed\n",
    "        'axes.labelsize': 12,\n",
    "        'axes.titlesize': 14,\n",
    "        'xtick.labelsize': 10,\n",
    "        'ytick.labelsize': 10,\n",
    "        'legend.fontsize': 10,\n",
    "    })\n",
    "\n",
    "    # Create a single figure\n",
    "    fig, ax = plt.subplots(figsize=(6, 5))\n",
    "\n",
    "    # Plot the main distribution as bars with no gaps between them\n",
    "    bins = np.arange(len(dist_vector)+1) - 0.5  # Centered bins\n",
    "    ax.hist(np.arange(len(dist_vector)), bins=bins, weights=dist_vector, \n",
    "            color='#377EB8', alpha=0.7, label=legend_name, \n",
    "            edgecolor='none', rwidth=1.0)  # rwidth=1.0 ensures no gaps\n",
    "\n",
    "    # Set labels and title\n",
    "    ax.set_xlabel('State', fontsize=12)\n",
    "    ax.set_ylabel('Probability', fontsize=12)\n",
    "    \n",
    "    title_text = 'Distribution Comparison'\n",
    "    ax.set_title(title_text, fontsize=14)\n",
    "\n",
    "    # Show only every 5th tick to avoid overlapping\n",
    "    x_ticks = np.arange(0, len(dist_vector), 5)\n",
    "    ax.set_xticks(x_ticks)\n",
    "    ax.set_xticklabels([str(i) for i in x_ticks])\n",
    "\n",
    "    # Add grid and legend\n",
    "    ax.grid(alpha=0.3, linestyle='--')\n",
    "    ax.legend(fontsize=12, frameon=True, facecolor='white',\n",
    "              edgecolor='gray', shadow=True, loc='upper right')\n",
    "\n",
    "    ax.set_ylim(0.0, .35)\n",
    "    # Remove top and right spines for cleaner look\n",
    "    ax.spines['top'].set_visible(False)\n",
    "    ax.spines['right'].set_visible(False)\n",
    "\n",
    "    # Tight layout and save\n",
    "    fig.tight_layout()\n",
    "    plt.savefig(title, bbox_inches='tight')\n",
    "    plt.show()\n",
    "\n",
    "    return fig, ax\n",
    "\n",
    "def create_tv_plot(timesteps, total_variations, ref, guid_w):\n",
    "    plt.rcParams.update({\n",
    "        'text.usetex': True,\n",
    "        'axes.labelsize': 12,\n",
    "        'axes.titlesize': 14,\n",
    "        'xtick.labelsize': 10,\n",
    "        'ytick.labelsize': 10,\n",
    "        'legend.fontsize': 10,\n",
    "    })\n",
    "\n",
    "    # Create a single figure\n",
    "    fig, ax = plt.subplots(figsize=(6, 5))\n",
    "\n",
    "    # Plot the data with better styling\n",
    "    ax.plot(timesteps.cpu(), total_variations.cpu(), color='#377EB8', linewidth=2,\n",
    "            marker='o', markersize=4, markevery=5, label='Empirical')\n",
    "    ax.plot(timesteps.cpu(), ref.cpu().numpy(), color='#E41A1C', linewidth=2,\n",
    "            marker='s', markersize=4, markevery=5, label='Theoretical')\n",
    "\n",
    "    # Add horizontal dashed line at the terminal value of empirical distribution\n",
    "    terminal_value = total_variations[-1].cpu().item()\n",
    "    ax.axhline(y=terminal_value, color='#4DAF4A', linestyle='--', linewidth=1.5,\n",
    "               label='Error from score approximation')\n",
    "\n",
    "    # Add text annotation for the horizontal line\n",
    "    ax.text(timesteps.cpu().mean(), terminal_value*1.05, 'Error from NN approximation',\n",
    "            color='#4DAF4A', fontsize=10, ha='center', va='bottom')\n",
    "\n",
    "    # Set labels and title\n",
    "    ax.set_xlabel('Timestep', fontsize=12)\n",
    "    ax.set_ylabel('Total Variation Distance', fontsize=12)\n",
    "    ax.set_title(\n",
    "        f'Total Variation Distance vs. Timestep (guidance={guid_w})', fontsize=14)\n",
    "\n",
    "    # Add grid and legend\n",
    "    ax.grid(alpha=0.3, linestyle='--')\n",
    "    ax.legend(fontsize=12, frameon=True, facecolor='white',\n",
    "              edgecolor='gray', shadow=True, loc='best')\n",
    "\n",
    "    # Remove top and right spines for cleaner look\n",
    "    ax.spines['top'].set_visible(False)\n",
    "    ax.spines['right'].set_visible(False)\n",
    "\n",
    "    # Tight layout\n",
    "    fig.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name = 'vector-disjoint'\n",
    "net_opts = 'configs/toy_net.yaml'\n",
    "ckpt_path = 'experiments/disjoint/final_checkpoint.pt'\n",
    "\n",
    "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "net_opts = dotdict(yaml.safe_load(open(net_opts)))\n",
    "\n",
    "# dataset\n",
    "dataset = get_dataset(dataset_name)\n",
    "vocab_size = dataset.vocab_size\n",
    "context_len = dataset.context_len\n",
    "graph = Absorbing(vocab_size)\n",
    "\n",
    "# Model\n",
    "model = get_model('radd',vocab_size + 1, context_len, net_opts)\n",
    "model = get_preconditioned_model(model,graph).to(device)\n",
    "model.eval()\n",
    "\n",
    "\n",
    "snapshot = torch.load(ckpt_path, weights_only=True, map_location='cpu')\n",
    "model.net.load_state_dict(snapshot['model'],strict=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "guid_w = 3.\n",
    "guidance_schedule = get_guidance_schedule('constant', guid_w)\n",
    "steps = 50 \n",
    "n_samples = 10000\n",
    "cond_class = 1\n",
    "cond = dataset.generate_cond(n_samples).to(device=device)\n",
    "cond = torch.ones_like(cond) * cond_class \n",
    "\n",
    "original_conditional_dist = dataset.get_guided_distribution(cond_class, 1.)\n",
    "original_conditional_dist = dataset.full_vector\n",
    "\n",
    "fig, ax = create_vector_histogram_figure(original_conditional_dist, legend_name='Full Probability', title='full_prob.pdf')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "guid_w = 3.\n",
    "guidance_schedule = get_guidance_schedule('constant', guid_w)\n",
    "sampling_fn = get_sampler(graph, device, guidance_schedule=guidance_schedule)\n",
    "steps = 50 \n",
    "n_samples = 10000\n",
    "cond_class = 0\n",
    "cond = dataset.generate_cond(n_samples).to(device=device)\n",
    "cond = torch.ones_like(cond) * cond_class \n",
    "new_sample, traj = sampling_fn(model,(n_samples,context_len),cond, steps, use_tau_leaping=True, return_traj=True)\n",
    "\n",
    "original_conditional_dist = dataset.get_guided_distribution(cond_class, 1.)\n",
    "tilted_dist = dataset.get_guided_distribution(cond_class, guid_w)\n",
    "samples_np = new_sample.cpu().numpy()\n",
    "\n",
    "fig, ax = create_histogram_figure(samples_np, tilted_dist, original_conditional_dist, guid_w=guid_w, title='1d-disjoint.pdf')\n",
    "\n",
    "\n",
    "timesteps = torch.linspace(1, graph.delta, steps + 1, device=device)\n",
    "timesteps = torch.cat((timesteps, 0 * timesteps[:1]))\n",
    "total_variations = torch.zeros_like(timesteps)\n",
    "tilted_dist_ = torch.cat((tilted_dist, 0 * tilted_dist[:1]))\n",
    "\n",
    "full = dataset.full_vector.clone()\n",
    "full[full != 0] = full[full != 0]**(1-guid_w)\n",
    "zw = (original_conditional_dist**guid_w * full).sum() # Fix \n",
    "print(zw)\n",
    "\n",
    "ref = ((1-torch.exp(-graph.sigma_int(timesteps)))/(1-torch.exp(-graph.sigma_int(timesteps[0]))))**zw\n",
    "ref = ((1-torch.exp(-graph.sigma_int(timesteps))))**zw\n",
    "for i, (x, t) in enumerate(zip(traj, timesteps)):\n",
    "    hist = torch.bincount(x.flatten().cpu(), minlength=vocab_size + 1).float()\n",
    "    hist = hist / hist.sum()\n",
    "\n",
    "    total_variations[i] = torch.abs(hist - tilted_dist_).sum()\n",
    "    \n",
    "create_tv_plot(timesteps, total_variations, ref, guid_w)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ws = torch.linspace(1., 5., steps=21)\n",
    "steps = 100 \n",
    "n_samples = 10000\n",
    "cond_class = 0\n",
    "cond = dataset.generate_cond(n_samples).to(device=device)\n",
    "cond = torch.ones_like(cond) * cond_class \n",
    "\n",
    "timesteps = torch.linspace(1., graph.delta, steps + 1)\n",
    "\n",
    "time_slice_idx = 50 \n",
    "time_slice = timesteps[time_slice_idx]\n",
    "total_var_as_w = torch.zeros_like(ws).float()\n",
    "normalizing_constants = torch.zeros_like(total_var_as_w) \n",
    "\n",
    "original_conditional_dist = dataset.get_guided_distribution(cond_class, 1.)\n",
    "for i, guid_w in enumerate(ws):\n",
    "    guidance_schedule = get_guidance_schedule('constant', guid_w)\n",
    "    sampling_fn = get_sampler(graph, device, guidance_schedule=guidance_schedule)\n",
    "    new_sample, traj = sampling_fn(model,(n_samples,context_len),cond, steps, use_tau_leaping=True, return_traj=True)\n",
    "    samples_np = new_sample.cpu().numpy()\n",
    "\n",
    "    tilted_dist = dataset.get_guided_distribution(cond_class, guid_w)\n",
    "    tilted_dist_ = torch.cat((tilted_dist, 0 * tilted_dist[:1]))\n",
    "\n",
    "    full_vector_matrix = dataset.full_vector.clone()\n",
    "    full_vector_matrix[full_vector_matrix != 0] = full_vector_matrix[full_vector_matrix != 0]**(1-guid_w)\n",
    "    zw = (original_conditional_dist**guid_w * full_vector_matrix).sum()\n",
    "    normalizing_constants[i] = zw\n",
    "\n",
    "    cur_val = traj[time_slice_idx].clone()\n",
    "\n",
    "\n",
    "    hist = torch.bincount(cur_val.flatten().cpu(), minlength=vocab_size + 1).float()\n",
    "    hist = hist / hist.sum()\n",
    "\n",
    "    total_var_as_w[i] = torch.abs(hist - tilted_dist_).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# plt.plot(ws, total_var_as_w)\n",
    "print(total_var_as_w)\n",
    "print(normalizing_constants)\n",
    "ref_curve = (1-torch.exp(-graph.sigma_int(1-time_slice)))**normalizing_constants\n",
    "# plt.plot(ws, ref_curve)\n",
    "print((cur_val == 30).sum())\n",
    "\n",
    "# ... existing code ...\n",
    "plt.figure(figsize=(8, 6))\n",
    "plt.plot(ws, total_var_as_w, marker='o', linewidth=2, markersize=8, \n",
    "         color='#377EB8', label='Empirical Total Variation')\n",
    "plt.plot(ws, ref_curve, marker='s', linewidth=2, markersize=8, \n",
    "         color='#E41A1C', label='Theoretical Reference')\n",
    "\n",
    "plt.xlabel('Guidance Weight (w)', fontsize=14)\n",
    "plt.ylabel('Total Variation Distance', fontsize=14)\n",
    "plt.title(f'Total Variation vs Guidance Weight (t={time_slice:.2f})', fontsize=16)\n",
    "plt.grid(alpha=0.3, linestyle='--')\n",
    "plt.legend(fontsize=12, frameon=True, facecolor='white', edgecolor='gray')\n",
    "\n",
    "# Remove top and right spines for cleaner look\n",
    "plt.gca().spines['top'].set_visible(False)\n",
    "plt.gca().spines['right'].set_visible(False)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('tv_vs_guidance.pdf', bbox_inches='tight')\n",
    "plt.show()\n",
    "\n",
    "print(f\"Total Variation: {total_var_as_w}\")\n",
    "print(f\"Normalizing Constants: {normalizing_constants}\")\n",
    "print(f\"Count of tokens equal to 30: {(cur_val == 30).sum()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name = 'vector-intersection'\n",
    "net_opts = 'configs/toy_net.yaml'\n",
    "ckpt_path = 'experiments/intersection/final_checkpoint.pt'\n",
    "\n",
    "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "net_opts = dotdict(yaml.safe_load(open(net_opts)))\n",
    "\n",
    "# dataset\n",
    "dataset = get_dataset(dataset_name)\n",
    "vocab_size = dataset.vocab_size\n",
    "context_len = dataset.context_len\n",
    "graph = Absorbing(vocab_size)\n",
    "\n",
    "# Model\n",
    "model = get_model('radd',vocab_size + 1, context_len, net_opts)\n",
    "model = get_preconditioned_model(model,graph).to(device)\n",
    "model.eval()\n",
    "\n",
    "\n",
    "snapshot = torch.load(ckpt_path, weights_only=True, map_location='cpu')\n",
    "model.net.load_state_dict(snapshot['model'],strict=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "guid_w = 3.\n",
    "guidance_schedule = get_guidance_schedule('constant', guid_w)\n",
    "steps = 50 \n",
    "n_samples = 10000\n",
    "cond_class = 1\n",
    "cond = dataset.generate_cond(n_samples).to(device=device)\n",
    "cond = torch.ones_like(cond) * cond_class \n",
    "\n",
    "original_conditional_dist = dataset.get_guided_distribution(cond_class, 1.)\n",
    "# original_conditional_dist = dataset.full_vector\n",
    "\n",
    "fig, ax = create_vector_histogram_figure(original_conditional_dist, legend_name='Full Probability', title='Class 2-inter.pdf')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "guid_w = 3.\n",
    "guidance_schedule = get_guidance_schedule('constant', guid_w)\n",
    "sampling_fn = get_sampler(graph, device, guidance_schedule=guidance_schedule)\n",
    "n_samples = 10000\n",
    "cond_class = 0\n",
    "cond = dataset.generate_cond(n_samples).to(device=device)\n",
    "cond = torch.ones_like(cond) * cond_class \n",
    "new_sample = sampling_fn(model,(n_samples,context_len),cond, 50, use_tau_leaping=True)\n",
    "\n",
    "original_conditional_dist = dataset.get_guided_distribution(cond_class, 1.)\n",
    "tilted_dist = dataset.get_guided_distribution(cond_class, guid_w)\n",
    "\n",
    "samples_np = new_sample.cpu().numpy()\n",
    "\n",
    "create_histogram_figure(samples_np, tilted_dist, original_conditional_dist, guid_w=guid_w, plot_intersection=True, title='1d-intersection.pdf')\n",
    "\n",
    "\n",
    "\n",
    "zw = (original_conditional_dist**guid_w).sum()\n",
    "ref = 1-((1-torch.exp(-graph.sigma_int(timesteps[0] - timesteps)))/(1-torch.exp(-graph.sigma_int(timesteps[0]))))**zw\n",
    "for i, (x, t) in enumerate(zip(traj, timesteps)):\n",
    "    hist = torch.bincount(x.flatten().cpu(), minlength=vocab_size + 1).float()\n",
    "    hist = hist / hist.sum()\n",
    "    \n",
    "    tilted_dist = dataset.get_guided_distribution(cond_class, guid_w)\n",
    "    tilted_dist_ = torch.cat((tilted_dist, 0 * tilted_dist[:1]))\n",
    "\n",
    "    total_variations[i] = torch.abs(hist - tilted_dist_).sum()\n",
    "\n",
    "total_variations = total_variations.log()\n",
    "    \n",
    "create_tv_plot(timesteps, total_variations, ref, guid_w)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_2d_histogram_figure(samples_np, tilted_dist, original_conditional_dist, guid_w):\n",
    "    \"\"\"Create a 2D histogram visualization comparing generated samples with theoretical distributions.\"\"\"\n",
    "    plt.rcParams.update({\n",
    "        'text.usetex': True,\n",
    "        'axes.labelsize': 14,\n",
    "        'axes.titlesize': 16,\n",
    "        'xtick.labelsize': 12,\n",
    "        'ytick.labelsize': 12,\n",
    "        'legend.fontsize': 12,\n",
    "    })\n",
    "    \n",
    "    # Use a wider figure but ensure the histograms will be square\n",
    "    fig, axes = plt.subplots(1, 2, figsize=(8, 5))\n",
    "\n",
    "    W, H = tilted_dist.shape\n",
    "    sample_hist_2d, _, _ = np.histogram2d(\n",
    "        samples_np[:, 1],\n",
    "        samples_np[:, 0],\n",
    "        bins=[W, H],\n",
    "        range=[[0, W], [0, H]],\n",
    "        density=True\n",
    "    )\n",
    "    \n",
    "    hists = [sample_hist_2d, tilted_dist]\n",
    "    titles = [\n",
    "        f'Generated Samples (n={samples_np.shape[0]})',\n",
    "        f'Tilted Distribution (w={guid_w})',\n",
    "    ]\n",
    "    \n",
    "    # Find global min and max for consistent colormap\n",
    "    vmin = min(np.min(sample_hist_2d), np.min(tilted_dist.numpy()))\n",
    "    vmax = max(np.max(sample_hist_2d), np.max(tilted_dist.numpy()))\n",
    "\n",
    "    # Create a list to store the image objects\n",
    "    ims = []\n",
    "    for ax, hist, title in zip(axes, hists, titles):\n",
    "        im = ax.imshow(hist, origin='lower', aspect='auto', cmap='viridis', \n",
    "                      vmin=vmin, vmax=vmax)\n",
    "        ax.set_title(title, fontsize=16, pad=10)\n",
    "        ax.set_xlabel('X', fontsize=14)\n",
    "        ax.set_ylabel('Y', fontsize=14)\n",
    "        ims.append(im)\n",
    "        \n",
    "        # Make the axis square\n",
    "        ax.set_aspect('auto')\n",
    "    \n",
    "    # Add a single colorbar that applies to both plots\n",
    "    fig.subplots_adjust(right=0.9, wspace=0.3)  # Make room for colorbar and between plots\n",
    "    cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])  # [left, bottom, width, height]\n",
    "    cbar = fig.colorbar(ims[0], cax=cbar_ax)\n",
    "    cbar.ax.tick_params(labelsize=12)\n",
    "    \n",
    "    # Adjust the layout while preserving aspect ratio\n",
    "    plt.tight_layout(rect=[0, 0, 0.9, 1])\n",
    "    \n",
    "    return fig, axes\n",
    "\n",
    "\n",
    "def create_single_2d_histogram_figure(tilted_dist, title):\n",
    "    \"\"\"Create a 2D histogram visualization comparing generated samples with theoretical distributions.\"\"\"\n",
    "    plt.rcParams.update({\n",
    "        'text.usetex': True,\n",
    "        'axes.labelsize': 14,\n",
    "        'axes.titlesize': 16,\n",
    "        'xtick.labelsize': 12,\n",
    "        'ytick.labelsize': 12,\n",
    "        'legend.fontsize': 12,\n",
    "    })\n",
    "    \n",
    "    # Use a wider figure but ensure the histograms will be square\n",
    "    fig, ax = plt.subplots(1, 1, figsize=(6, 5))\n",
    "\n",
    "    vmin = np.min(tilted_dist.numpy())\n",
    "    vmax = np.max(tilted_dist.numpy())\n",
    "\n",
    "    # Create a list to store the image objects\n",
    "    ims = []\n",
    "    im = ax.imshow(tilted_dist, origin='lower', aspect='auto', cmap='viridis', \n",
    "                    vmin=vmin, vmax=vmax)\n",
    "    ax.set_title(title, fontsize=14)\n",
    "    ax.set_xlabel('X', fontsize=14)\n",
    "    ax.set_ylabel('Y', fontsize=14)\n",
    "    ims.append(im)\n",
    "    \n",
    "    # Make the axis square\n",
    "    ax.set_aspect('auto')\n",
    "    \n",
    "    # Add a single colorbar that applies to both plots\n",
    "    fig.subplots_adjust(right=0.9, wspace=0.3)  # Make room for colorbar and between plots\n",
    "    cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])  # [left, bottom, width, height]\n",
    "    cbar = fig.colorbar(ims[0], cax=cbar_ax)\n",
    "    cbar.ax.tick_params(labelsize=12)\n",
    "    \n",
    "    # Adjust the layout while preserving aspect ratio\n",
    "    plt.tight_layout(rect=[0, 0, 0.9, 1])\n",
    "    \n",
    "    return fig, ax\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name = 'matrix-disjoint'\n",
    "net_opts = 'configs/toy_net.yaml'\n",
    "ckpt_path = 'experiments/2d-disjoint/final_checkpoint.pt'\n",
    "\n",
    "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "net_opts = dotdict(yaml.safe_load(open(net_opts)))\n",
    "\n",
    "# dataset\n",
    "dataset = get_dataset(dataset_name)\n",
    "vocab_size = dataset.vocab_size\n",
    "context_len = dataset.context_len\n",
    "graph = Absorbing(vocab_size)\n",
    "\n",
    "# Model\n",
    "model = get_model('radd',vocab_size + 1, context_len, net_opts)\n",
    "model = get_preconditioned_model(model,graph).to(device)\n",
    "model.eval()\n",
    "\n",
    "\n",
    "snapshot = torch.load(ckpt_path, weights_only=True, map_location='cpu')\n",
    "model.net.load_state_dict(snapshot['model'],strict=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "guid_w = 3.\n",
    "guidance_schedule = get_guidance_schedule('constant', guid_w)\n",
    "n_samples = 10000\n",
    "cond_class = 1\n",
    "cond = dataset.generate_cond(n_samples).to(device=device)\n",
    "cond = torch.ones_like(cond) * cond_class \n",
    "\n",
    "original_conditional_dist = dataset.get_guided_distribution(cond_class, 1.)\n",
    "# original_conditional_dist = dataset.full_matrix\n",
    "tilted_dist = dataset.get_guided_distribution(cond_class, guid_w)\n",
    "samples_np = new_sample.cpu().numpy()\n",
    "\n",
    "fig, ax= create_single_2d_histogram_figure(original_conditional_dist, title='Class 2')\n",
    "\n",
    "fig.savefig('2d Class 2.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "guid_w = 3.\n",
    "guidance_schedule = get_guidance_schedule('constant', guid_w)\n",
    "sampling_fn = get_sampler(graph, device, guidance_schedule=guidance_schedule)\n",
    "n_samples = 10000\n",
    "cond_class = 0\n",
    "cond = dataset.generate_cond(n_samples).to(device=device)\n",
    "cond = torch.ones_like(cond) * cond_class \n",
    "new_sample, traj = sampling_fn(model,(n_samples,context_len),cond, 50, use_tau_leaping=True, return_traj=True)\n",
    "\n",
    "original_conditional_dist = dataset.get_guided_distribution(cond_class, 1.)\n",
    "tilted_dist = dataset.get_guided_distribution(cond_class, guid_w)\n",
    "samples_np = new_sample.cpu().numpy()\n",
    "\n",
    "fig, ax= create_2d_histogram_figure(samples_np, tilted_dist, original_conditional_dist, guid_w=guid_w)\n",
    "\n",
    "fig.savefig('2d-disjoint.pdf')\n",
    "\n",
    "timesteps = torch.linspace(1, graph.delta, len(traj), device=device)\n",
    "timesteps = torch.cat((timesteps, 0 * timesteps[:1]))\n",
    "\n",
    "zw = (original_conditional_dist**guid_w).sum()\n",
    "ref = 1-((1-torch.exp(-graph.sigma_int(timesteps[0] - timesteps)))/(1-torch.exp(-graph.sigma_int(timesteps[0]))))**zw"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name = 'matrix-intersection'\n",
    "net_opts = 'configs/toy_net.yaml'\n",
    "ckpt_path = 'experiments/2d-intersection/final_checkpoint.pt'\n",
    "\n",
    "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "net_opts = dotdict(yaml.safe_load(open(net_opts)))\n",
    "\n",
    "# dataset\n",
    "dataset = get_dataset(dataset_name)\n",
    "vocab_size = dataset.vocab_size\n",
    "context_len = dataset.context_len\n",
    "graph = Absorbing(vocab_size)\n",
    "\n",
    "# Model\n",
    "model = get_model('radd',vocab_size + 1, context_len, net_opts)\n",
    "model = get_preconditioned_model(model,graph).to(device)\n",
    "model.eval()\n",
    "\n",
    "\n",
    "snapshot = torch.load(ckpt_path, weights_only=True, map_location='cpu')\n",
    "model.net.load_state_dict(snapshot['model'],strict=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "guid_w = 3.\n",
    "guidance_schedule = get_guidance_schedule('constant', guid_w)\n",
    "n_samples = 10000\n",
    "cond_class = 0\n",
    "cond = dataset.generate_cond(n_samples).to(device=device)\n",
    "cond = torch.ones_like(cond) * cond_class \n",
    "\n",
    "original_conditional_dist = dataset.get_guided_distribution(cond_class, 1.)\n",
    "original_conditional_dist = dataset.full_matrix\n",
    "tilted_dist = dataset.get_guided_distribution(cond_class, guid_w)\n",
    "samples_np = new_sample.cpu().numpy()\n",
    "\n",
    "fig, ax= create_single_2d_histogram_figure(original_conditional_dist, title='Full Probability')\n",
    "\n",
    "fig.savefig('2d - full prob.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "guid_w = 3.\n",
    "guidance_schedule = get_guidance_schedule('constant', guid_w)\n",
    "sampling_fn = get_sampler(graph, device, guidance_schedule=guidance_schedule)\n",
    "n_samples = 10000\n",
    "cond_class = 0\n",
    "cond = dataset.generate_cond(n_samples).to(device=device)\n",
    "cond = torch.ones_like(cond) * cond_class \n",
    "new_sample, traj = sampling_fn(model,(n_samples,context_len),cond, 50, use_tau_leaping=True, return_traj=True)\n",
    "\n",
    "original_conditional_dist = dataset.get_guided_distribution(cond_class, 1.)\n",
    "tilted_dist = dataset.get_guided_distribution(cond_class, guid_w)\n",
    "samples_np = new_sample.cpu().numpy()\n",
    "\n",
    "fig,ax = create_2d_histogram_figure(samples_np, tilted_dist, original_conditional_dist, guid_w=guid_w)\n",
    "\n",
    "fig.savefig('2d-intersection.pdf')\n",
    "\n",
    "timesteps = torch.linspace(1, graph.delta, len(traj), device=device)\n",
    "timesteps = torch.cat((timesteps, 0 * timesteps[:1]))\n",
    "\n",
    "zw = (original_conditional_dist**guid_w).sum()\n",
    "ref = 1-((1-torch.exp(-graph.sigma_int(timesteps[0] - timesteps)))/(1-torch.exp(-graph.sigma_int(timesteps[0]))))**zw"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_5d_marginals(matrix, path=None):\n",
    "    fig, axes = plt.subplots(2, 5, figsize=(15, 6))\n",
    "    axes = axes.flatten()\n",
    "    \n",
    "    # Define the dimension pairs to plot\n",
    "    dim_pairs = [(0,1), (0,2), (0,3), (0,4), (1,2), \n",
    "                    (1,3), (1,4), (2,3), (2,4), (3,4)]\n",
    "    \n",
    "    for i, (dim1, dim2) in enumerate(dim_pairs):\n",
    "        if i < len(axes):\n",
    "            ax = axes[i]\n",
    "            h = matrix[dim_pairs]            \n",
    "            # Plot as a heatmap\n",
    "            ax.imshow(h.T, origin='lower', extent=[-3, 3, -3, 3], \n",
    "                        aspect='auto', cmap='viridis')\n",
    "            \n",
    "            ax.set_title(f'Dims {dim1+1} vs {dim2+1}')\n",
    "            ax.set_xlabel(f'Dimension {dim1+1}')\n",
    "            ax.set_ylabel(f'Dimension {dim2+1}')\n",
    "    \n",
    "    # Add a colorbar\n",
    "    # fig.colorbar(im, ax=axes, shrink=0.6, label='Density')\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    if path is not None:\n",
    "        fig.savefig(path)\n",
    "        plt.close(fig)\n",
    "    else:\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name = 'gaussian-5d'\n",
    "net_opts = 'configs/toy_net.yaml'\n",
    "ckpt_path = 'experiments/gaussian-5d/final_checkpoint.pt'\n",
    "\n",
    "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "net_opts = dotdict(yaml.safe_load(open(net_opts)))\n",
    "\n",
    "# dataset\n",
    "dataset = get_dataset(dataset_name)\n",
    "vocab_size = dataset.vocab_size\n",
    "context_len = dataset.context_len\n",
    "graph = Absorbing(vocab_size)\n",
    "\n",
    "# Model\n",
    "model = get_model('radd',vocab_size + 1, context_len, net_opts)\n",
    "model = get_preconditioned_model(model,graph).to(device)\n",
    "model.eval()\n",
    "\n",
    "\n",
    "snapshot = torch.load(ckpt_path, weights_only=True, map_location='cpu')\n",
    "model.net.load_state_dict(snapshot['model'],strict=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "guid_w = 4.\n",
    "guidance_schedule = get_guidance_schedule('constant', guid_w)\n",
    "n_samples = 10000\n",
    "steps = 100\n",
    "cond_class = 2\n",
    "cond = dataset.generate_cond(n_samples).to(device=device)\n",
    "cond = torch.ones_like(cond) * cond_class \n",
    "\n",
    "original_conditional_dist = dataset.get_guided_distribution(cond_class, 1.)\n",
    "tilted_dist = dataset.get_guided_distribution(cond_class, guid_w)\n",
    "\n",
    "sampling_fn = get_sampler(graph, device, guidance_schedule=guidance_schedule)\n",
    "new_sample, traj = sampling_fn(model,(n_samples,context_len),cond, steps, use_tau_leaping=True, return_traj=True)\n",
    "\n",
    "dataset.plot_samples(new_sample, f'5d-guid-{guid_w}-{cond_class}.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ws = torch.linspace(1., 5., steps=5)\n",
    "steps = 50 \n",
    "n_samples = 10000\n",
    "cond_class = 1\n",
    "cond = dataset.generate_cond(n_samples).to(device=device)\n",
    "cond = torch.ones_like(cond) * cond_class \n",
    "\n",
    "timesteps = torch.linspace(1., graph.delta, steps + 1)\n",
    "\n",
    "time_slice_idx = 50 \n",
    "time_slice = timesteps[time_slice_idx]\n",
    "total_var_as_w = torch.zeros_like(ws).float()\n",
    "normalizing_constants = torch.zeros_like(total_var_as_w) \n",
    "\n",
    "original_conditional_dist, full_vector_matrix = dataset.get_guided_distribution(cond_class, 1.)\n",
    "for i, guid_w in enumerate(ws):\n",
    "    guidance_schedule = get_guidance_schedule('constant', guid_w)\n",
    "    sampling_fn = get_sampler(graph, device, guidance_schedule=guidance_schedule)\n",
    "    new_sample, traj = sampling_fn(model,(n_samples,context_len),cond, steps, use_tau_leaping=True, return_traj=True)\n",
    "    samples_np = new_sample.cpu().numpy()\n",
    "\n",
    "    tilted_dist, full_vector_matrix = dataset.get_guided_distribution(cond_class, guid_w)\n",
    "    full_vector_matrix = full_vector_matrix.reshape_as(original_conditional_dist)\n",
    "    full_vector_matrix[full_vector_matrix != 0] = full_vector_matrix[full_vector_matrix != 0]**(1-guid_w)\n",
    "    zw = (original_conditional_dist**guid_w * full_vector_matrix).sum()\n",
    "    normalizing_constants[i] = zw\n",
    "\n",
    "    cur_val = traj[time_slice_idx].clone().cpu().numpy()\n",
    "\n",
    "\n",
    "    W, H, _, _, _ = tilted_dist.shape\n",
    "    sample_hist_2d, _, _ = np.histogram2d(\n",
    "        cur_val[:, 1],\n",
    "        cur_val[:, 0],\n",
    "        bins=[W, H],\n",
    "        range=[[0, W], [0, H]],\n",
    "        density=True\n",
    "    )\n",
    "    \n",
    "    # create_histogram_figure(samples_np)\n",
    "    tilted_dist_marginal = torch.sum(tilted_dist, dim=(0, 1,4))\n",
    "    original_dist_marginal = torch.sum(original_conditional_dist, dim=(0,1,4))\n",
    "    \n",
    "    fig, ax = create_2d_histogram_figure(samples_np[:,:2], tilted_dist_marginal, original_dist_marginal, guid_w)\n",
    "\n",
    "    fig.show()\n",
    "    # fig.savefig(f'marginals_{guid_w}.pdf')\n",
    "\n",
    "    total_var_as_w[i] = np.abs(sample_hist_2d - tilted_dist_marginal.numpy()).sum()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dcfg",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
