{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import numpy as np\n",
    "import torch\n",
    "import yaml\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.colors as colors\n",
    "from tqdm import tqdm\n",
    "\n",
    "sys.path.append('../')\n",
    "\n",
    "from utils.model_utils import get_exact_model, get_preconditioned_model\n",
    "from utils.graph_lib import Absorbing\n",
    "from  utils.guidance_schedules import get_guidance_schedule\n",
    "from utils.datasets import get_dataset\n",
    "from utils.samplers import get_sampler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
    "def get_model_and_dataset(dataset_name):\n",
    "    dataset = get_dataset(dataset_name)\n",
    "    vocab_size = dataset.vocab_size\n",
    "    graph = Absorbing(vocab_size)\n",
    "\n",
    "    def increase_size(matrix):\n",
    "        shape = matrix.shape\n",
    "        new_shape = tuple(s + 1 for s in shape)\n",
    "        padded_tensor = torch.zeros(new_shape)\n",
    "        \n",
    "        original_slices = tuple(slice(0, shape[i]) for i in range(len(shape)))\n",
    "        padded_tensor[original_slices] = matrix\n",
    "        \n",
    "        return padded_tensor\n",
    "\n",
    "    increase_size(torch.arange(1,10).reshape(3,3))\n",
    "\n",
    "    dists = [increase_size(dataset.full_tensor).to(device)]\n",
    "    dists.extend([increase_size(dataset.get_guided_distribution(idx, 1.)).to(device) for idx in range(0, dataset.cond_dim)])\n",
    "    model = get_exact_model(dists, graph.dim - 1)\n",
    "    model = get_preconditioned_model(model,graph).to(device)\n",
    "\n",
    "\n",
    "    return dataset, model, graph\n",
    "\n",
    "def sample(n_samples, guid_w, cond_class):\n",
    "    guidance_schedule = get_guidance_schedule('constant', guid_w)\n",
    "    sampling_fn = get_sampler(graph, device, guidance_schedule=guidance_schedule)\n",
    "    cond = dataset.generate_cond(n_samples).to(device=device)\n",
    "    cond = torch.ones_like(cond) * cond_class \n",
    "    sample = sampling_fn(model,(n_samples,context_len),cond, 50, use_tau_leaping=True, return_traj=False)\n",
    "\n",
    "    return sample\n",
    "\n",
    "def get_histogram(samples):\n",
    "    W, H = dataset.full_matrix.shape\n",
    "    empirical_conditional_hist, _, _ = np.histogram2d(\n",
    "        samples[:, 1],\n",
    "        samples[:, 0],\n",
    "        bins=[W, H],\n",
    "        range=[[0, W], [0, H]],\n",
    "        density=True\n",
    "    )\n",
    "    return empirical_conditional_hist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_5d_histogram(samples):\n",
    "    if isinstance(samples, torch.Tensor):\n",
    "        samples = samples.cpu().numpy()\n",
    "    \n",
    "    bins = [np.arange(dataset.vocab_size + 1) for i in range(5)]\n",
    "    \n",
    "    hist, edges = np.histogramdd(samples, bins=bins, density=True)\n",
    "    \n",
    "    return hist, edges\n",
    "\n",
    "\n",
    "def plot_5d_marginals(hist1, figsize=(15, 3)):\n",
    "    fig, axes = plt.subplots(1, 3, figsize=figsize)\n",
    "    axes = axes.flatten()\n",
    "    \n",
    "    dim_pairs = [(0,1), (0,2), (0,3)]\n",
    "    \n",
    "    global_vmin = 0.0  # or min across all datasets\n",
    "    global_vmax = 0.13  # or max across all datasets\n",
    "\n",
    "    for i, (dim1, dim2) in enumerate(dim_pairs):\n",
    "        if i < len(axes):\n",
    "            ax = axes[i]\n",
    "            \n",
    "            other_dims = [j for j in range(5) if j not in [dim1, dim2]]\n",
    "            marginal_1 = np.sum(hist1, axis=tuple(other_dims))\n",
    "            \n",
    "            if dim1 > dim2:\n",
    "                marginal_1 = marginal_1.T\n",
    "            \n",
    "            im = ax.imshow(marginal_1.T, origin='lower', aspect='auto', cmap='viridis',\n",
    "                           vmin=global_vmin, vmax=global_vmax)\n",
    "            \n",
    "            # im = ax.imshow(ratio.T, origin='lower', aspect='auto', cmap='coolwarm', \n",
    "            #             vmin=global_vmin, vmax=global_vmax)\n",
    "            for y in range(marginal_1.shape[1]):\n",
    "                for x in range(marginal_1.shape[0]):\n",
    "                    if marginal_1[x,y] < 0.00001:\n",
    "                        continue\n",
    "                    text = ax.text(x, y, f'{marginal_1[x, y]:.3f}',\n",
    "                                 ha=\"center\", va=\"center\", color=\"white\", fontsize=8)\n",
    "            \n",
    "            # ax.set_title(f'Marginal Densities')\n",
    "            ax.set_xlabel(f'Dimension {dim1}')\n",
    "            ax.set_ylabel(f'Dimension {dim2}')\n",
    "    \n",
    "    # Use subplots_adjust to control spacing manually\n",
    "    plt.subplots_adjust(wspace=0.4, right=0.85)\n",
    "    \n",
    "    # Create colorbar outside the plots\n",
    "    cbar = fig.colorbar(im, ax=axes, shrink=0.6, label='Density', pad=0.05)\n",
    "    \n",
    "    # Don't call plt.tight_layout() after creating the colorbar\n",
    "    return fig, axes\n",
    "\n",
    "def plot_ratios(hist1, hist2, figsize=(15, 3)):\n",
    "    fig, axes = plt.subplots(1, 3, figsize=figsize)\n",
    "    axes = axes.flatten()\n",
    "    \n",
    "    dim_pairs = [(0,1), (0,2), (0,3)]\n",
    "    \n",
    "    global_vmin = 0.0  # or min across all datasets\n",
    "    global_vmax = 1.3  # or max across all datasets\n",
    "\n",
    "    for i, (dim1, dim2) in enumerate(dim_pairs):\n",
    "        if i < len(axes):\n",
    "            ax = axes[i]\n",
    "            \n",
    "            other_dims = [j for j in range(5) if j not in [dim1, dim2]]\n",
    "            marginal_1 = np.sum(hist1, axis=tuple(other_dims))\n",
    "            marginal_2 = np.sum(hist2, axis=tuple(other_dims))\n",
    "            \n",
    "            if dim1 > dim2:\n",
    "                marginal_1 = marginal_1.T\n",
    "                marginal_2 = marginal_2.T\n",
    "            \n",
    "            ratio = np.where(marginal_2 > 0.0001, marginal_1/ marginal_2, 0.)\n",
    "            im = ax.imshow(ratio.T, origin='lower', aspect='auto', cmap='viridis')\n",
    "            \n",
    "            im = ax.imshow(ratio.T, origin='lower', aspect='auto', cmap='coolwarm', \n",
    "                        vmin=global_vmin, vmax=global_vmax)\n",
    "            for y in range(ratio.shape[1]):\n",
    "                for x in range(ratio.shape[0]):\n",
    "                    if ratio[x,y] == 0.0:\n",
    "                        continue\n",
    "                    text = ax.text(x, y, f'{ratio[x, y]:.3f}',\n",
    "                                 ha=\"center\", va=\"center\", color=\"white\", fontsize=8)\n",
    "            \n",
    "            # ax.set_title(f'Probability ratios')\n",
    "            ax.set_xlabel(f'Dimension {dim1}')\n",
    "            ax.set_ylabel(f'Dimension {dim2}')\n",
    "    \n",
    "    # Use subplots_adjust to control spacing manually\n",
    "    plt.subplots_adjust(wspace=0.4, right=0.85)\n",
    "    \n",
    "    # Create colorbar outside the plots\n",
    "    cbar = fig.colorbar(im, ax=axes, shrink=0.6, label='Density', pad=0.05)\n",
    "    \n",
    "    # Don't call plt.tight_layout() after creating the colorbar\n",
    "    return fig, axes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name = 'cubes'\n",
    "\n",
    "dataset, model, graph = get_model_and_dataset(dataset_name)\n",
    "\n",
    "context_len = dataset.context_len\n",
    "\n",
    "hist = dataset.full_tensor\n",
    "fig, axs = plot_5d_marginals(hist.cpu().numpy())\n",
    "fig.suptitle(f'Marginals for full data distribution', fontsize=16, y=0.98, x=.45)\n",
    "fig.savefig('5d-full-density.pdf', bbox_inches='tight', dpi=300)\n",
    "\n",
    "for c in [0,1]:\n",
    "    for w in [1]:\n",
    "        hist = dataset.get_guided_distribution(c, w)\n",
    "        hist_no_guid = dataset.get_guided_distribution(c, 1.)\n",
    "        fig, axs = plot_5d_marginals(hist.cpu().numpy())\n",
    "        fig.suptitle(f'Marginals for class {c}', fontsize=16, y=0.98, x=.45)\n",
    "        fig.savefig(f'5d-class-{c}.pdf', bbox_inches='tight', dpi=300)\n",
    "\n",
    "for c in [1]:\n",
    "    for w in [2,3,5]:\n",
    "        hist_no_guid = dataset.get_guided_distribution(c, 1.)\n",
    "        samples = torch.load(f'cond-samples-{w:.1f}.pt')\n",
    "        samples = torch.load(f'guided-samples-{w:.1f}.pt')\n",
    "        print(samples.shape)\n",
    "        hist, _ = create_5d_histogram(samples)\n",
    "        fig, axs = plot_ratios(hist, hist_no_guid.cpu().numpy())\n",
    "        fig.suptitle(f'Probability ratios for class {c}, w = {w-1}', fontsize=16, y=0.98, x=.45)\n",
    "        fig.savefig(f'5d-ratios-class-{c}-w-{w-1}.pdf', bbox_inches='tight', dpi=300)\n",
    "print(hist.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.set_printoptions(sci_mode=False, precision=6)\n",
    "\n",
    "guid_w = 2.\n",
    "n_samples = 5000\n",
    "cond_class = 1\n",
    "batches = 10\n",
    "all_guided = []\n",
    "all_uncond = []\n",
    "for i in tqdm(range(batches)):\n",
    "    guided_samples = sample(n_samples, guid_w, cond_class)\n",
    "    uncond_samples = sample(n_samples, 1., cond_class)\n",
    "    all_guided.append(guided_samples)\n",
    "    all_uncond.append(uncond_samples)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 226,
   "metadata": {},
   "outputs": [],
   "source": [
    "guided_samples = torch.cat(all_guided, dim=0)\n",
    "# uncond_samples = torch.cat(all_uncond, dim=0)\n",
    "guided_hist, edges = create_5d_histogram(guided_samples)\n",
    "uncond_hist, edges = create_5d_histogram(uncond_samples)\n",
    "\n",
    "guided_hist, edges = create_5d_histogram(guided_samples)\n",
    "uncond_hist, edges = create_5d_histogram(uncond_samples)\n",
    "torch.save(guided_samples, f'guided-samples-{guid_w}.pt')\n",
    "torch.save(uncond_samples, f'cond-samples-{guid_w}.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 216,
   "metadata": {},
   "outputs": [],
   "source": [
    "guided_hist, edges = create_5d_histogram(guided_samples)\n",
    "uncond_hist, edges = create_5d_histogram(uncond_samples)\n",
    "torch.save(guided_samples, f'guided-samples-{guid_w}.pt')\n",
    "torch.save(uncond_samples, f'cond-samples-{guid_w}.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plot_ratios(guided_hist, uncond_hist)\n",
    "fig.savefig('5d-ratios.pdf', bbox_inches='tight', dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 228,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.set_printoptions(sci_mode=False, precision=5)\n",
    "\n",
    "guided_hist_tensor = torch.from_numpy(guided_hist)\n",
    "\n",
    "cur_vals = []\n",
    "\n",
    "probs_arr = { i : [] for i in range(6)}\n",
    "aaa_arr = { i : [] for i in range(6)}\n",
    "\n",
    "def fill(k,num_twos):\n",
    "    if k == 5:\n",
    "        dens = guided_hist[tuple(cur_vals)]\n",
    "        if dens == 0.0:\n",
    "            return\n",
    "        probs_arr[num_twos].append(dens)\n",
    "        aaa_arr[num_twos].append(tuple(cur_vals))\n",
    "        return\n",
    "    for j in range(5):\n",
    "        cur_vals.append(j)\n",
    "        if j == 2:\n",
    "            fill(k+1, num_twos+1)\n",
    "        else:\n",
    "            fill(k+1, num_twos)\n",
    "        cur_vals.pop()\n",
    "\n",
    "\n",
    "\n",
    "fill(0,0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.set_printoptions(suppress=True)\n",
    "plt.hist(probs_arr[5], density=True)\n",
    "tot = 0\n",
    "dens = 0\n",
    "for i in range(6):\n",
    "    np_arr = np.array(probs_arr[i])\n",
    "    n = len(probs_arr[i])\n",
    "    tot += n\n",
    "    mean = np_arr.mean().round(9)\n",
    "\n",
    "    dens += mean * n\n",
    "    print(f'Num twos - {i} - Elements - {n} - Mean - {mean} - Std - {np_arr.std()}')\n",
    "\n",
    "print(f'Total - {5**5} - Non-zero {tot} - Zero {5**5-tot} - Dens - {dens}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "discdiff",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
