{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cb750f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import numpy as np\n",
    "import pymbar\n",
    "import matplotlib.pyplot as plt\n",
    "import os"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c7137f7",
   "metadata": {},
   "source": [
    "## Trajectories and Scores classes\n",
    "\n",
    "We define two classes to store our generated data in:\n",
    "* Trajectories: stores the generated trajectories and annealing schedules for all of the trajectories. \n",
    "* Scores: Stores samples that will be considered unordered, such as those generated through unbiased sampling. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9141fa63",
   "metadata": {},
   "outputs": [],
   "source": [
    "DATA_LOAD_DIR_MAIN = os.path.join(\"data\", \"paper\")\n",
    "DATA_LOAD_DIR_EXAMPLE = os.path.join(\"data\", \"example\")\n",
    "\n",
    "USE_EXAMPLE_DATA = True\n",
    "obs_name = \"ARI\"\n",
    "\n",
    "if USE_EXAMPLE_DATA:\n",
    "    DATA_LOAD_DIR = DATA_LOAD_DIR_EXAMPLE\n",
    "\n",
    "    if obs_name != \"ARI\":\n",
    "        raise ValueError(\"Example data only available for ARI observable.\")\n",
    "\n",
    "    trajs_dir = os.path.join(DATA_LOAD_DIR, \"ari_trajectories.json\")\n",
    "    unbiased_dir = os.path.join(DATA_LOAD_DIR, \"ari_unbiased_samples.json\")\n",
    "\n",
    "else:\n",
    "    DATA_LOAD_DIR = DATA_LOAD_DIR_MAIN\n",
    "\n",
    "    if obs_name == \"ARI\":\n",
    "        trajs_dir = os.path.join(DATA_LOAD_DIR, \"ari_capped_trajectories.json\")\n",
    "        unbiased_dir = os.path.join(DATA_LOAD_DIR, \"ari_unbiased_samples.json\")\n",
    "    elif obs_name == \"LOGP\":\n",
    "        trajs_dir = os.path.join(DATA_LOAD_DIR, \"logp_trajectories.json\")\n",
    "        unbiased_dir = os.path.join(DATA_LOAD_DIR, \"logp_unbiased_samples.json\")\n",
    "    else:\n",
    "        raise ValueError(f\"Unknown observable name: {obs_name}. Available: ARI, LOGP.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e708b50e",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Trajectories():\n",
    "    def __init__(self, trajectories: list[list[float]], \n",
    "                biases: list[list[float]], \n",
    "                steps_per_bias: float):\n",
    "        \"\"\"Initializes trajectory object.\n",
    "\n",
    "        Args:\n",
    "            trajectories (list[list[float]]): A list of trajectories.\n",
    "            biases (list[list[float]]): A list of the biases used in each trajectory.\n",
    "            steps_per_bias (list[list[int]]): A list of the number of steps taken at each bias in each trajectory.\n",
    "        \"\"\"\n",
    "        self.trajectories = trajectories\n",
    "        self.biases = biases\n",
    "        self.steps_per_bias = steps_per_bias\n",
    "        \n",
    "    def get_all_biases(self) -> list[float]:\n",
    "        \"\"\"Returns an ordered list of all the biases in the trajectories object.\"\"\"\n",
    "        all_biases = []\n",
    "        for biases_per_traj in self.biases: \n",
    "            for bias in biases_per_traj:\n",
    "                if bias not in all_biases:\n",
    "                    all_biases.append(bias)\n",
    "                    \n",
    "        return sorted(all_biases)\n",
    "    \n",
    "    def get_all_trajs_per_bias(self) -> tuple[float, list[np.array]]: \n",
    "        \"\"\"Returns a the ordered list of biases, and a corresponding list of b x steps_per_bias numpy arrays, where b is the total number of annealing steps that use that bias.  \n",
    "        \"\"\"\n",
    "        all_biases = []\n",
    "        all_trajs_per_bias = []\n",
    "        for biases_for_trajs, trajs_per_biases in zip(self.biases, self.trajectories): \n",
    "            cum_steps = 0\n",
    "            for bias in biases_for_trajs:\n",
    "                if bias not in all_biases:\n",
    "                    all_biases.append(bias)\n",
    "                    all_trajs_per_bias.append([])\n",
    "                    for traj in trajs_per_biases:\n",
    "                        all_trajs_per_bias[-1].append(traj[cum_steps : cum_steps + self.steps_per_bias])\n",
    "                else:\n",
    "                    for traj in trajs_per_biases:\n",
    "                        index = all_biases.index(bias)\n",
    "                        all_trajs_per_bias[index].append(traj[cum_steps : cum_steps + self.steps_per_bias])\n",
    "                \n",
    "                cum_steps += self.steps_per_bias\n",
    "                \n",
    "        sorted_indices = np.argsort(all_biases)\n",
    "        all_biases = [all_biases[i] for i in sorted_indices]\n",
    "        all_trajs_per_bias = [all_trajs_per_bias[i] for i in sorted_indices]\n",
    "        return all_biases, np.array(all_trajs_per_bias)\n",
    "    \n",
    "    def get_scores(self):\n",
    "        all_biases, all_trajs_per_bias_np = self.get_all_trajs_per_bias()\n",
    "        scores_dict = {}\n",
    "        for bias, trajs_per_bias_np in zip(all_biases, all_trajs_per_bias_np):\n",
    "            scores_dict[bias] = trajs_per_bias_np.flatten().tolist()\n",
    "            \n",
    "        return Scores(scores_dict)\n",
    "        \n",
    "class Scores():\n",
    "    def __init__(self, scores_dict: dict):\n",
    "        \"\"\"Initializes the Scores object.\n",
    "        \n",
    "        Args: \n",
    "            scores_dict (dict): A dictionary where keys are floats, denoting the bias, and the values are lists of scores for that bias value. \n",
    "        \n",
    "        \"\"\"\n",
    "        self.scores_dict = scores_dict\n",
    "    \n",
    "    def __add__(self, other: \"Scores\") -> \"Scores\":\n",
    "        new_scores_dict = self.scores_dict\n",
    "        for key, value in other.scores_dict.items():\n",
    "            if key in new_scores_dict.keys():\n",
    "                new_scores_dict[key].extend(value)\n",
    "            else:\n",
    "                new_scores_dict[key] = value\n",
    "\n",
    "        return Scores(new_scores_dict)\n",
    "    \n",
    "    def get_all_biases(self) -> list[float]:\n",
    "        return sorted(self.scores_dict.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "737803ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the saved trajectories\n",
    "with open(trajs_dir, \"r\") as f: \n",
    "    trajectories_saved = json.load(f)\n",
    "    trajectories = Trajectories(trajectories_saved[0], trajectories_saved[1], trajectories_saved[2])\n",
    "\n",
    "with open(unbiased_dir, \"r\") as f: \n",
    "    unbiased_scores_list = json.load(f)\n",
    "    unbiased_scores = Scores({0 : unbiased_scores_list})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b659678b",
   "metadata": {},
   "source": [
    "## Removing Unconverged Data\n",
    "We use two methods to remove unconverged data. The first is burnin, which eliminates the start of the chain, where it is still approaching the target distribution, and the second is by eliminating any biases which are determined to be unconverged, with respect to the Gelman Rubin statistic.\n",
    "\n",
    "### Burnin\n",
    "This allows us to discard the unconverged part of each annealing step, for each trajectory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5488e22",
   "metadata": {},
   "outputs": [],
   "source": [
    "def apply_burnin(trajectories: Trajectories, burnin: float=0.1) -> Trajectories:\n",
    "    \"\"\"Applies burnin to the trajectory, per annealing step.\"\"\"\n",
    "    burnin_steps = int(trajectories.steps_per_bias * burnin)\n",
    "    new_trajectories = []\n",
    "    for biases, trajs_per_biases in zip(trajectories.biases, trajectories.trajectories): \n",
    "        new_trajs_per_biases = []\n",
    "        for traj in trajs_per_biases:\n",
    "            new_traj = []\n",
    "            for i in range(len(biases)):\n",
    "                new_traj += traj[i * trajectories.steps_per_bias + burnin_steps : (i + 1) * trajectories.steps_per_bias]\n",
    "            \n",
    "            new_trajs_per_biases.append(new_traj)\n",
    "        new_trajectories.append(new_trajs_per_biases) \n",
    "            \n",
    "    \n",
    "    return Trajectories(new_trajectories, trajectories.biases, trajectories.steps_per_bias - burnin_steps)\n",
    "\n",
    "trajectories_burnin = apply_burnin(trajectories, 0.1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f539748",
   "metadata": {},
   "source": [
    "### Gelman Rubin\n",
    "We compute the Gelman-Rubin statistic as a measure of convergence. This gives us a unique value for each bias value, determining how well converged the chains from this bias value are.\n",
    "\n",
    "The function returns a scores object with only the converged bias values in it, as well as a list of all GR values, ordered from lowest to highest bias. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "855d9587",
   "metadata": {},
   "outputs": [],
   "source": [
    "def gelman_rubin(trajectories: Trajectories, cutoff: float=1.1) -> tuple[Scores, list]:\n",
    "    biases, all_trajs_per_bias = trajectories.get_all_trajs_per_bias()\n",
    "    \n",
    "    grs = []\n",
    "    scores_dict = {}\n",
    "    for bias, trajs_per_bias in zip(biases, all_trajs_per_bias):\n",
    "        j, L = trajs_per_bias.shape\n",
    "        \n",
    "        means = np.mean(trajs_per_bias, axis=1)\n",
    "        mean_of_means = np.mean(means)\n",
    "        B = L / (j - 1) * np.sum((means - mean_of_means) ** 2)\n",
    "        W = 1 / j * np.sum(np.var(trajs_per_bias, axis=1, ddof=1))\n",
    "        var_hat = (L - 1) / L * W + B / L\n",
    "        R_hat = np.sqrt(var_hat / W)\n",
    "        grs.append(R_hat)\n",
    "        \n",
    "        if R_hat < cutoff: \n",
    "            scores_dict[bias] = trajs_per_bias.flatten().tolist()\n",
    "    \n",
    "    scores = Scores(scores_dict)        \n",
    "    return scores, grs\n",
    "\n",
    "accepted_scores, grs = gelman_rubin(trajectories_burnin)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33042778",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(trajectories.get_all_biases(), np.array(grs) - 1, label=\"GR\")\n",
    "plt.axhline(0.1, linestyle=\"--\", linewidth=1, color=\"grey\", label=\"cutoff\")\n",
    "plt.xlabel(\"$\\lambda$\")\n",
    "plt.ylabel(\"GR\")\n",
    "plt.yscale(\"log\")\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4619fab0",
   "metadata": {},
   "source": [
    "## Including Unbiased Samples\n",
    "\n",
    "We also need to add in some samples from the unbiased distribution. We can do this by subsampling the biased distribution. \n",
    "\n",
    "We sample steps_per_bias * N samples / 2, where N is the number of trajectories we ran for each annealing schedule. This makes the number of tokens generated for each bias (including the unbiased, i.e. $\\lambda = 0$) fixed, since unbiased completions require, on average, twice as many token generations as the biased ones. \n",
    "\n",
    "We will also initialize the list of samples for the unbiased histogram too. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf9853e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def subsample_scores(scores, num_samples: int, random=False):\n",
    "    subsample_dict = {}\n",
    "    if random==True:\n",
    "        for bias in scores.get_all_biases():\n",
    "            subsample_dict[bias] = np.random.choice(scores.scores_dict[bias], num_samples).tolist()\n",
    "    else:\n",
    "        for bias in scores.get_all_biases():\n",
    "            subsample_dict[bias] = scores.scores_dict[bias][:num_samples]\n",
    "            \n",
    "    return Scores(subsample_dict)\n",
    "\n",
    "# We want to use trajectories here, not trajectories_burnin, since we want to match the number of generated samples. \n",
    "\n",
    "num_trajs_per_schedule = len(trajectories.trajectories[0])\n",
    "num_biases_per_anneal = len(trajectories.biases[0])\n",
    "\n",
    "unbiased_samples_for_biased = int(trajectories.steps_per_bias * num_trajs_per_schedule / 2)\n",
    "accepted_scores += subsample_scores(unbiased_scores, num_samples = unbiased_samples_for_biased)\n",
    "\n",
    "# Number of samples to use for the unbiased histogram. \n",
    "num_samples_for_unbiased = trajectories.steps_per_bias * num_trajs_per_schedule * num_biases_per_anneal // 2 + unbiased_samples_for_biased\n",
    "unbiased_samples_for_histogram = unbiased_scores_list[:num_samples_for_unbiased]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f779cf1",
   "metadata": {},
   "source": [
    "## Reweighting the Samples\n",
    "\n",
    "### MBAR\n",
    "We use the Pymbar package to compute the Multistate Bennett Acceptance Ratio estimate of the partition function values at each bias. We can also use this package to calculate the overlap between the distributions too. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "346a8255",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mbar(scores, return_overlap = False) -> list[float]:\n",
    "    biases = scores.get_all_biases()\n",
    "    try:\n",
    "        zero_index = biases.index(0)\n",
    "    except:\n",
    "        raise Exception(\"0 must be in the scores dictionary to perform mbar.\")\n",
    "    all_scores = np.array([])\n",
    "    Ns = np.array([])\n",
    "    for bias in biases:\n",
    "        all_scores = np.append(all_scores, scores.scores_dict[bias])\n",
    "        Ns = np.append(Ns, len(scores.scores_dict[bias]))\n",
    "        \n",
    "    u_kn = np.outer(biases, all_scores)\n",
    "    mbar = pymbar.MBAR(u_kn, Ns)\n",
    "    mbar_results = mbar.compute_free_energy_differences()\n",
    "    \n",
    "    res = (- mbar_results[\"Delta_f\"][zero_index]).tolist()\n",
    "    if return_overlap:\n",
    "        overlap = mbar.compute_overlap() \n",
    "        return res, overlap\n",
    "    else:\n",
    "        return res\n",
    "    \n",
    "log_Zs = mbar(accepted_scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c074879d",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(accepted_scores.get_all_biases(), log_Zs)\n",
    "plt.xlabel(\"$\\lambda$\")\n",
    "plt.ylabel(\"$\\log Z(\\lambda)$\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5a5e0148",
   "metadata": {},
   "source": [
    "### Calculating the Importance Sampling Weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33e9dba0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_weights(scores: Scores, log_Zs):\n",
    "    scores_list = []\n",
    "    weights_list = []\n",
    "    \n",
    "    for i, bias in enumerate(scores.get_all_biases()):\n",
    "        scores_list += scores.scores_dict[bias]\n",
    "        weights_list += np.exp(log_Zs[i] + bias * np.array(scores.scores_dict[bias])).tolist()\n",
    "        \n",
    "    return scores_list, weights_list\n",
    "\n",
    "scores_list, weights_list = get_weights(accepted_scores, log_Zs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c39f91c0",
   "metadata": {},
   "source": [
    "## Histograms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49a58ae1",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Uncomment below to recreate paper for ARI:\"\"\"\n",
    "#lower, upper, num_bins = (-8, 15, 80)\n",
    "\n",
    "\"\"\"Uncomment below to recreate paper for logprobs:\"\"\"\n",
    "# lower, upper, num_bins = (-600, 0, 100)\n",
    "\n",
    "\"\"\"Minimal example:\"\"\"\n",
    "lower, upper, num_bins = (0, 8, 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "369547c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "bins = np.linspace(lower, upper, num_bins)\n",
    "bin_centers = bins[:-1] + (bins[1:] - bins[:-1]) / 2\n",
    "\n",
    "histogram_heights, _ = np.histogram(scores_list, weights=weights_list, bins=bins, density=True)\n",
    "\n",
    "unbiased_histogram_heights, _ = np.histogram(unbiased_samples_for_histogram, bins=bins, density=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e87f72c",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.step(bin_centers, histogram_heights, where=\"mid\", label=\"Biased\", zorder=1)\n",
    "plt.step(bin_centers, unbiased_histogram_heights, where=\"mid\", label=\"Unbiased\", zorder=0)\n",
    "plt.xlabel(obs_name)\n",
    "plt.ylabel(\"Density\")\n",
    "plt.yscale(\"log\")\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54aee7cd",
   "metadata": {},
   "source": [
    "## Error Estimates\n",
    "\n",
    "### Bootstrapping\n",
    "For the biased histogram, we compute a confidence interval using bootstrapping."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "398510bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_trajectories_bootstrap(trajectories: Trajectories) -> Trajectories:\n",
    "        \"\"\"Generates a new Trajectory object, by bootstrapping over the trajectories.\n",
    "        Returns:\n",
    "            Trajectories: A new Trajectories object with bootstrapped data.\n",
    "        \"\"\"\n",
    "        new_trajectories = []\n",
    "        for biases, trajs in zip(trajectories.biases, trajectories.trajectories):\n",
    "           indices = np.random.randint(0, len(trajs), len(trajs))\n",
    "           new_trajectories.append([trajs[i] for i in indices])\n",
    "            \n",
    "        return Trajectories(new_trajectories, trajectories.biases, trajectories.steps_per_bias)\n",
    "    \n",
    "def bootstrap_hist_heights(trajectories: Trajectories, \n",
    "                           unbiased_scores: Scores, \n",
    "                           bins: np.array, \n",
    "                           num_bootstraps: int=100):\n",
    "    trajs_burnin = apply_burnin(trajectories)\n",
    "    \n",
    "    num_unbiased = len(unbiased_scores.scores_dict[0]) // 2\n",
    "    \n",
    "    boot_histogram_heights = np.zeros((num_bootstraps, len(bins) - 1))\n",
    "    for i in range(num_bootstraps):\n",
    "        bootstrap_trajs = generate_trajectories_bootstrap(trajs_burnin)\n",
    "        accepted_scores, _ = gelman_rubin(bootstrap_trajs)\n",
    "        accepted_scores += subsample_scores(unbiased_scores, num_unbiased, random=True)    \n",
    "        log_Zs = mbar(accepted_scores)\n",
    "        scores_list, weights_list = get_weights(accepted_scores, log_Zs)\n",
    "\n",
    "        histogram_heights, _ = np.histogram(scores_list, weights=weights_list, bins=bins, density=True)\n",
    "        boot_histogram_heights[i, :] = histogram_heights\n",
    "        \n",
    "    return boot_histogram_heights\n",
    "        \n",
    "def bootstrap_conf_interval(trajectories: Trajectories, \n",
    "                           unbiased_scores: Scores, \n",
    "                           bins: np.array, \n",
    "                           num_bootstraps: int=100,\n",
    "                           lower_idx: int=2,\n",
    "                           upper_idx: int=97):\n",
    "    \n",
    "    if upper_idx > num_bootstraps - 1: \n",
    "        raise Exception(\"upper_idx must be less than or equal to num_bootstraps - 1.\")\n",
    "    if lower_idx >= upper_idx: \n",
    "        raise Exception(\"lower_idx must be strictly less than upper_idx.\")\n",
    "    \n",
    "    boot_heights = bootstrap_hist_heights(trajectories, unbiased_scores, bins, num_bootstraps)\n",
    "    \n",
    "    sorted_boot_heights = np.sort(boot_heights, axis=0)\n",
    "    lower_bounds = sorted_boot_heights[lower_idx]\n",
    "    upper_bounds = sorted_boot_heights[upper_idx]\n",
    "    return lower_bounds, upper_bounds    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "966539da",
   "metadata": {},
   "source": [
    "WARNING: Note that this next cell will take a long time to run. With the \"quick example\" parameters, this took 17 minutes to run on my desktop, so expect this to be several hours for 100 bootstraps. It is recommended to parallelize this if using a higher number of bootstraps. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0fde83d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Uncomment below to recreate bootstrapping estimates from the paper.\"\"\"\n",
    "# num_bootstraps = 100\n",
    "# lower_idx = 2\n",
    "# upper_idx = 97\n",
    "\n",
    "\"\"\"Minimal example:\"\"\"\n",
    "num_bootstraps = 5\n",
    "lower_idx = 1\n",
    "upper_idx = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f102929",
   "metadata": {},
   "outputs": [],
   "source": [
    "unbiased_scores_boot = subsample_scores(unbiased_scores, unbiased_samples_for_biased)\n",
    "lower_bound, upper_bound = bootstrap_conf_interval(trajectories,\n",
    "                                                   unbiased_scores_boot,\n",
    "                                                   bins,\n",
    "                                                   num_bootstraps,\n",
    "                                                   lower_idx,\n",
    "                                                   upper_idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c213e35",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.step(bin_centers, histogram_heights, where=\"mid\", label=\"Biased\")\n",
    "plt.fill_between(bin_centers, \n",
    "                 lower_bound, \n",
    "                 upper_bound, \n",
    "                 step=\"mid\", \n",
    "                 label=\"96% conf iterval\",\n",
    "                 color=\"C0\",\n",
    "                 alpha=0.5)\n",
    "plt.xlabel(obs_name)\n",
    "plt.ylabel(\"Density\")\n",
    "plt.yscale(\"log\")\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "99f285cf",
   "metadata": {},
   "source": [
    "### Wilson Interval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43aa5a08",
   "metadata": {},
   "outputs": [],
   "source": [
    "def wilson_interval(data, bins, zscore, density=True, cumulative=False):\n",
    "    n = len(data)\n",
    "    n_s, _ = np.histogram(data, bins)\n",
    "\n",
    "    n_f = n - n_s\n",
    "    \n",
    "    p = (n_s + 0.5 * (zscore**2)) / (n + zscore**2)\n",
    "    diff = (zscore / (n + (zscore**2))) * np.sqrt(((n_s * n_f) / n) + (zscore**2 / 4))\n",
    "    \n",
    "    if density is True:\n",
    "        upper = (p - diff) / (bins[1:] - bins[:-1])\n",
    "        lower = (p + diff) / (bins[1:] - bins[:-1])\n",
    "        return upper, lower\n",
    "    else:\n",
    "        return p - diff, p + diff\n",
    "    \n",
    "zscore = 2.0537 # Z score for 0.96 confidence level\n",
    "lower_unbiased, upper_unbiased = wilson_interval(unbiased_samples_for_histogram,\n",
    "                                                 bins, \n",
    "                                                 zscore)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57b301ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.step(bin_centers, unbiased_histogram_heights, where=\"mid\", label=\"Unbiased\")\n",
    "plt.fill_between(bin_centers, \n",
    "                 lower_unbiased, \n",
    "                 upper_unbiased, \n",
    "                 step=\"mid\", \n",
    "                 label=\"96% Wilson Interval\",\n",
    "                 alpha=0.5)\n",
    "plt.xlabel(obs_name)\n",
    "plt.yscale(\"log\")\n",
    "plt.ylabel(\"Density\")\n",
    "plt.ylim(10**-7, 1)\n",
    "plt.legend()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "completions_database",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
