{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 133,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import yaml\n",
    "from tqdm import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.colors as colors\n",
    "from matplotlib.colors import LinearSegmentedColormap\n",
    "\n",
    "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
    "\n",
    "# Add the parent directory to the Python path\n",
    "sys.path.append('../')  # Add parent directory\n",
    "\n",
    "# Now use absolute imports like the other files in the project\n",
    "from utils.model_utils import get_preconditioned_model, get_exact_model\n",
    "from utils.graph_lib import Absorbing\n",
    "from utils.guidance_schedules import get_guidance_schedule, GuidanceSchedule\n",
    "from utils.datasets import get_dataset\n",
    "from utils.samplers import get_sampler\n",
    "from utils.misc import dotdict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 166,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_2d_histogram_figure(samples_np, tilted_dist, original_conditional_dist, guid_w):\n",
    "    \"\"\"Create a 2D histogram visualization comparing generated samples with theoretical distributions.\"\"\"\n",
    "    plt.rcParams.update({\n",
    "        'text.usetex': True,\n",
    "        'axes.labelsize': 14,\n",
    "        'axes.titlesize': 16,\n",
    "        'xtick.labelsize': 12,\n",
    "        'ytick.labelsize': 12,\n",
    "        'legend.fontsize': 12,\n",
    "    })\n",
    "    \n",
    "    # Use a wider figure but ensure the histograms will be square\n",
    "    fig, axes = plt.subplots(1, 2, figsize=(8, 5))\n",
    "\n",
    "    W, H = tilted_dist.shape\n",
    "    sample_hist_2d, _, _ = np.histogram2d(\n",
    "        samples_np[:, 1],\n",
    "        samples_np[:, 0],\n",
    "        bins=[W, H],\n",
    "        range=[[0, W], [0, H]],\n",
    "        density=True\n",
    "    )\n",
    "    \n",
    "    hists = [sample_hist_2d, tilted_dist]\n",
    "    titles = [\n",
    "        f'Generated Samples (n={samples_np.shape[0]})',\n",
    "        f'Tilted Distribution (w={guid_w})',\n",
    "    ]\n",
    "    \n",
    "    # Find global min and max for consistent colormap\n",
    "    vmin = min(np.min(sample_hist_2d), np.min(tilted_dist.numpy()))\n",
    "    vmax = max(np.max(sample_hist_2d), np.max(tilted_dist.numpy()))\n",
    "\n",
    "    # Create a list to store the image objects\n",
    "    ims = []\n",
    "    for ax, hist, title in zip(axes, hists, titles):\n",
    "        im = ax.imshow(hist, origin='lower', aspect='auto', cmap='viridis', \n",
    "                      vmin=vmin, vmax=vmax)\n",
    "        ax.set_title(title, fontsize=16, pad=10)\n",
    "        ax.set_xlabel('X', fontsize=14)\n",
    "        ax.set_ylabel('Y', fontsize=14)\n",
    "        ims.append(im)\n",
    "        \n",
    "        # Make the axis square\n",
    "        ax.set_aspect('auto')\n",
    "    \n",
    "    # Add a single colorbar that applies to both plots\n",
    "    fig.subplots_adjust(right=0.9, wspace=0.3)  # Make room for colorbar and between plots\n",
    "    cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])  # [left, bottom, width, height]\n",
    "    cbar = fig.colorbar(ims[0], cax=cbar_ax)\n",
    "    cbar.ax.tick_params(labelsize=12)\n",
    "    \n",
    "    # Adjust the layout while preserving aspect ratio\n",
    "    plt.tight_layout(rect=[0, 0, 0.9, 1])\n",
    "    \n",
    "    return fig, axes\n",
    "\n",
    "\n",
    "def create_single_2d_histogram_figure(tilted_dist, title):\n",
    "    \"\"\"Create a 2D histogram visualization comparing generated samples with theoretical distributions.\"\"\"\n",
    "    plt.rcParams.update({\n",
    "        'text.usetex': True,\n",
    "        'axes.labelsize': 14,\n",
    "        'axes.titlesize': 16,\n",
    "        'xtick.labelsize': 12,\n",
    "        'ytick.labelsize': 12,\n",
    "        'legend.fontsize': 12,\n",
    "    })\n",
    "    \n",
    "    # Use a wider figure but ensure the histograms will be square\n",
    "    fig, ax = plt.subplots(1, 1, figsize=(6, 5))\n",
    "\n",
    "    vmin = np.min(tilted_dist.numpy())\n",
    "    vmax = np.max(tilted_dist.numpy())\n",
    "    global_vmin = 0.0  # or min across all datasets\n",
    "    global_vmax = 1.3  # or max across all datasets\n",
    "\n",
    "    # Create a list to store the image objects\n",
    "    ims = []\n",
    "    im = ax.imshow(tilted_dist, origin='lower', aspect='auto', cmap='coolwarm', \n",
    "                vmin=global_vmin, vmax=global_vmax)\n",
    "    # im = ax.imshow(tilted_dist, origin='lower', aspect='auto', cmap=custom_cmap,\n",
    "    #            norm=colors.TwoSlopeNorm(vmin=global_vmin, vcenter=1.0, vmax=global_vmax))\n",
    "    ax.set_title(title, fontsize=14)\n",
    "    ax.set_xlabel('X', fontsize=14)\n",
    "    ax.set_ylabel('Y', fontsize=14)\n",
    "    ims.append(im)\n",
    "    \n",
    "    # Make the axis square\n",
    "    ax.set_aspect('auto')\n",
    "    \n",
    "    # Add a single colorbar that applies to both plots\n",
    "    fig.subplots_adjust(right=0.9, wspace=0.3)  # Make room for colorbar and between plots\n",
    "    cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])  # [left, bottom, width, height]\n",
    "    cbar = fig.colorbar(ims[0], cax=cbar_ax)\n",
    "    cbar.ax.tick_params(labelsize=12)\n",
    "    \n",
    "    # Adjust the layout while preserving aspect ratio\n",
    "    plt.tight_layout(rect=[0, 0, 0.9, 1])\n",
    "    \n",
    "    return fig, ax\n",
    "\n",
    "def create_2d_tv_plot(timesteps, total_variations, ref, guid_w):\n",
    "    plt.rcParams.update({\n",
    "        'text.usetex': True,\n",
    "        'axes.labelsize': 12,\n",
    "        'axes.titlesize': 14,\n",
    "        'xtick.labelsize': 10,\n",
    "        'ytick.labelsize': 10,\n",
    "        'legend.fontsize': 10,\n",
    "    })\n",
    "\n",
    "    # Create a single figure\n",
    "    fig, ax = plt.subplots(figsize=(6, 5))\n",
    "\n",
    "    # Plot the data with better styling\n",
    "    ax.plot(timesteps, total_variations, color='#377EB8', linewidth=2,\n",
    "            marker='o', markersize=4, markevery=5, label='Empirical')\n",
    "    ax.plot(timesteps, ref, color='#E41A1C', linewidth=2,\n",
    "            marker='s', markersize=4, markevery=5, label='Theoretical')\n",
    "\n",
    "    # Add horizontal dashed line at the terminal value of empirical distribution\n",
    "    terminal_value = total_variations[-1].item()\n",
    "    ax.axhline(y=terminal_value, color='#4DAF4A', linestyle='--', linewidth=1.5,\n",
    "               label='Error from score approximation')\n",
    "\n",
    "    # Add text annotation for the horizontal line\n",
    "    ax.text(timesteps.mean(), terminal_value*1.05, 'Error from NN approximation',\n",
    "            color='#4DAF4A', fontsize=10, ha='center', va='bottom')\n",
    "\n",
    "    # Set labels and title\n",
    "    ax.set_xlabel('Timestep', fontsize=12)\n",
    "    ax.set_ylabel('Total Variation Distance', fontsize=12)\n",
    "    ax.set_title(\n",
    "        f'Total Variation Distance vs. Timestep (2D, guidance={guid_w})', fontsize=14)\n",
    "\n",
    "    # Add grid and legend\n",
    "    ax.grid(alpha=0.3, linestyle='--')\n",
    "    ax.legend(fontsize=12, frameon=True, facecolor='white',\n",
    "              edgecolor='gray', shadow=True, loc='best')\n",
    "\n",
    "    # Remove top and right spines for cleaner look\n",
    "    ax.spines['top'].set_visible(False)\n",
    "    ax.spines['right'].set_visible(False)\n",
    "\n",
    "    # Tight layout\n",
    "    fig.tight_layout()\n",
    "    plt.show()\n",
    "    \n",
    "    return fig, ax\n",
    "\n",
    "\n",
    "def compute_2d_total_variation(traj, timesteps, tilted_dist):\n",
    "    total_variations = torch.zeros_like(timesteps).cpu().numpy()\n",
    "    \n",
    "    tilted_dist = tilted_dist.cpu().numpy()\n",
    "    for i, (x, t) in enumerate(zip(traj, timesteps)):\n",
    "        W, H = tilted_dist.shape\n",
    "        x = x.cpu().numpy()\n",
    "        sample_hist_2d, _, _ = np.histogram2d(\n",
    "            x[:, 1],\n",
    "            x[:, 0],\n",
    "            bins=[W, H],\n",
    "            range=[[0, W], [0, H]],\n",
    "            density=True\n",
    "        )    \n",
    "        total_variations[i] = np.abs(sample_hist_2d - tilted_dist).sum() / 2  # Divide by 2 for proper TV distance\n",
    "    \n",
    "    return total_variations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
    "def get_model_and_dataset(dataset_name):\n",
    "    dataset = get_dataset(dataset_name)\n",
    "    vocab_size = dataset.vocab_size\n",
    "    graph = Absorbing(vocab_size)\n",
    "\n",
    "    def increase_size(matrix):\n",
    "        H, W = matrix.shape[0] + 1, matrix.shape[1] + 1\n",
    "        new_size = torch.zeros((H, W)).to(matrix)\n",
    "        new_size[:H-1, : W-1] = matrix\n",
    "        return new_size\n",
    "\n",
    "    increase_size(torch.arange(1,10).reshape(3,3))\n",
    "\n",
    "    dists = [increase_size(dataset.full_matrix).to(device)]\n",
    "    dists.extend([increase_size(dataset.get_guided_distribution(idx, 1.)).to(device) for idx in range(0, dataset.cond_dim)])\n",
    "    model = get_exact_model(dists, graph.dim - 1)\n",
    "    model = get_preconditioned_model(model,graph).to(device)\n",
    "\n",
    "\n",
    "    return dataset, model, graph\n",
    "\n",
    "def sample(n_samples, guid_w, cond_class, dataset, model):\n",
    "    guidance_schedule = get_guidance_schedule('constant', guid_w)\n",
    "    sampling_fn = get_sampler(graph, device, guidance_schedule=guidance_schedule)\n",
    "    cond = dataset.generate_cond(n_samples).to(device=device)\n",
    "    cond = torch.ones_like(cond) * cond_class \n",
    "    sample = sampling_fn(model,(n_samples,context_len),cond, 50, use_tau_leaping=True, return_traj=False)\n",
    "\n",
    "    return sample\n",
    "\n",
    "def get_histogram(samples):\n",
    "    W, H = dataset.full_matrix.shape\n",
    "    empirical_conditional_hist, _, _ = np.histogram2d(\n",
    "        samples[:, 1],\n",
    "        samples[:, 0],\n",
    "        bins=[W, H],\n",
    "        range=[[0, W], [0, H]],\n",
    "        density=True\n",
    "    )\n",
    "    return empirical_conditional_hist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name = 'matrix-disjoint'\n",
    "\n",
    "dataset_disjoint, model_disjoint, graph = get_model_and_dataset(dataset_name)\n",
    "context_len = dataset_disjoint.context_len\n",
    "\n",
    "\n",
    "n_samples = 5000\n",
    "guid_w = 3\n",
    "cond_class = 0\n",
    "samples_guid_disjoint = sample(n_samples, guid_w, cond_class, dataset_disjoint, model_disjoint)\n",
    "samples_no_guid_disjoint = sample(n_samples, 1   , cond_class, dataset_disjoint, model_disjoint)\n",
    "\n",
    "hist_guid_disjoint = get_histogram(samples_guid_disjoint.cpu().numpy())\n",
    "hist_no_guid_disjoint = get_histogram(samples_no_guid_disjoint.cpu().numpy())\n",
    "\n",
    "original_conditional_dist_disjoint = dataset_disjoint.get_guided_distribution(cond_class, 1.)\n",
    "tilted_dist_disjoint = dataset_disjoint.get_guided_distribution(cond_class, guid_w)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fig,axes = create_2d_histogram_figure(samples_guid_disjoint.cpu().numpy(), tilted_dist_disjoint, original_conditional_dist_disjoint, guid_w=guid_w)\n",
    "# for ax in axes:\n",
    "#     ax.set_xlim(0, 17)\n",
    "#     ax.set_ylim(0, 17)\n",
    "\n",
    "# fig.savefig('2d-disjoint.pdf')\n",
    "\n",
    "ratio_empirical = torch.from_numpy(np.where(hist_no_guid_disjoint > 0.001, hist_guid_disjoint/hist_no_guid_disjoint, 0.))\n",
    "# print(ratio)\n",
    "ratio_true = torch.from_numpy(np.where(original_conditional_dist_disjoint > 0.001, tilted_dist_disjoint/original_conditional_dist_disjoint, 0.))\n",
    "\n",
    "fig, ax= create_single_2d_histogram_figure(ratio_empirical, title='probability ratios')\n",
    "ax.set_xlim(0., 17.)\n",
    "ax.set_ylim(0., 17.)\n",
    "fig.savefig('2d-disjoint-ratios.pdf', bbox_inches='tight', dpi=300)\n",
    "\n",
    "\n",
    "fig, ax= create_single_2d_histogram_figure(ratio_true, title='probability ratios')\n",
    "ax.set_xlim(0., 17.)\n",
    "ax.set_ylim(0., 17.)\n",
    "fig.savefig('2d-disjoint-ratios-true.pdf', bbox_inches='tight', dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name = 'matrix-intersection'\n",
    "\n",
    "dataset_int, model_int, graph = get_model_and_dataset(dataset_name)\n",
    "context_len = dataset_int.context_len\n",
    "\n",
    "\n",
    "n_samples = 5000\n",
    "guid_w = 3\n",
    "cond_class = 0\n",
    "samples_guid_int = sample(n_samples, guid_w, cond_class, dataset_int, model_int)\n",
    "samples_no_guid_int = sample(n_samples, 1   , cond_class, dataset_int, model_int)\n",
    "\n",
    "hist_guid_int = get_histogram(samples_guid_int.cpu().numpy())\n",
    "hist_no_guid_int = get_histogram(samples_no_guid_int.cpu().numpy())\n",
    "\n",
    "original_conditional_dist_int = dataset_int.get_guided_distribution(cond_class, 1.)\n",
    "tilted_dist_int = dataset_int.get_guided_distribution(cond_class, guid_w)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fig,axes = create_2d_histogram_figure(samples_guid_int.cpu().numpy(), tilted_dist_int, original_conditional_dist_int, guid_w=guid_w)\n",
    "# for ax in axes:\n",
    "#     ax.set_xlim(0, 17)\n",
    "#     ax.set_ylim(0, 17)\n",
    "\n",
    "# fig.savefig('2d-intersection.pdf')\n",
    "\n",
    "ratio_empirical = torch.from_numpy(np.where(hist_no_guid_int > 0.001, hist_guid_int/hist_no_guid_int, 0.))\n",
    "ratio_true = torch.from_numpy(np.where(original_conditional_dist_int > 0.001, tilted_dist_int/original_conditional_dist_int, 0.))\n",
    "\n",
    "fig, ax= create_single_2d_histogram_figure(ratio_empirical, title='probability ratios')\n",
    "ax.set_xlim(0., 17.)\n",
    "ax.set_ylim(0., 17.)\n",
    "fig.savefig('2d-intersection-ratios.pdf', bbox_inches='tight', dpi=300)\n",
    "\n",
    "\n",
    "fig, ax= create_single_2d_histogram_figure(ratio_true, title='probability ratios')\n",
    "ax.set_xlim(0., 17.)\n",
    "ax.set_ylim(0., 17.)\n",
    "fig.savefig('2d-intersection-ratios-true.pdf', bbox_inches='tight', dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_annotated_ratio_plot(ratio_tensor, title='Probability Ratios', threshold_low=0.8, threshold_high=1.2):\n",
    "    \"\"\"Create a ratio plot with large text annotations showing the actual values.\"\"\"\n",
    "    \n",
    "    fig, ax = plt.subplots(figsize=(12, 10))  # Slightly larger figure\n",
    "    \n",
    "    # Create the heatmap\n",
    "    im = ax.imshow(ratio_tensor, cmap='RdYlBu_r', origin='lower', \n",
    "                   vmin=0, vmax=ratio_tensor.max(), aspect='auto')\n",
    "    \n",
    "    # Add text annotations for each cell with larger font\n",
    "    height, width = ratio_tensor.shape\n",
    "    for i in range(height):\n",
    "        for j in range(width):\n",
    "            value = ratio_tensor[i, j].item()\n",
    "            if value > 0:  # Only annotate non-zero values\n",
    "                # Choose text color based on background\n",
    "                if value > threshold_high:\n",
    "                    text_color = 'white'  # White text on red/yellow background\n",
    "                elif value < threshold_low:\n",
    "                    text_color = 'white'  # White text on blue background\n",
    "                else:\n",
    "                    text_color = 'black'  # Black text on neutral background\n",
    "                \n",
    "                # Much larger font size to cover the entire cell\n",
    "                ax.text(j, i, f'{value:.1f}', ha='center', va='center',\n",
    "                       color=text_color, fontsize=16, fontweight='bold')\n",
    "    \n",
    "    ax.set_title(title, fontsize=18, fontweight='bold')\n",
    "    ax.set_xlabel('X', fontsize=16)\n",
    "    ax.set_ylabel('Y', fontsize=16)\n",
    "    \n",
    "    # Add colorbar\n",
    "    cbar = plt.colorbar(im, ax=ax)\n",
    "    cbar.set_label('Ratio (Tilted/Original)', fontsize=14)\n",
    "    cbar.ax.tick_params(labelsize=12)\n",
    "    \n",
    "    # Make tick labels larger\n",
    "    ax.tick_params(axis='both', which='major', labelsize=12)\n",
    "    \n",
    "    return fig, ax\n",
    "\n",
    "fig, ax = create_annotated_ratio_plot(ratio, title='Probability Ratios')\n",
    "ax.set_xlim(-0.5, 19.5)\n",
    "ax.set_ylim(-0.5, 19.5)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "true_guided = dataset.get_guided_distribution(0, guid_w)\n",
    "true_val = dataset.get_guided_distribution(0, 1)\n",
    "\n",
    "\n",
    "ratio = torch.where(true_val > 0., true_guided/true_val, 0.)\n",
    "\n",
    "# Version 1: Standard large text\n",
    "fig, ax = create_annotated_ratio_plot(ratio, title='Probability Ratios')\n",
    "ax.set_xlim(-0.5, 19.5)\n",
    "ax.set_ylim(-0.5, 19.5)\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "discdiff",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
