{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f523eaae3bee4d",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-01-14T15:35:27.315352Z",
     "start_time": "2026-01-14T15:35:26.391792Z"
    }
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from paretoset import paretoset\n",
    "from tqdm import tqdm\n",
    "\n",
    "from UQ import uncertainty_quantification, uncertainty_quantification_numpy\n",
    "\n",
    "DATA_DIR = \"benchmark/benchmark_data/TTPFTS_UQ/\"\n",
    "FIGURES_DIR = \"figures/\"\n",
    "timesteps = np.arange(0, 50_001, 500)\n",
    "experiments = np.arange(0, 101)\n",
    "components = np.arange(0, 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f02040054d4e66f",
   "metadata": {},
   "outputs": [],
   "source": [
    "quin_comp_post_dfs = {comp: [[pd.read_parquet(DATA_DIR + f\"posterior_logs_mo_maximize_TTPFTS_e{e}_t{t}_component{comp}.parquet\", columns=[\"reagent_name\", \"means\", \"stds\"]) for t in timesteps] for e in experiments] for comp in components}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63634fa56934981f",
   "metadata": {},
   "outputs": [],
   "source": [
    "quin_comp_uncertainties = {comp: np.array([[uncertainty_quantification(quin_comp_post_dfs[comp][i][j]) for j in range(len(timesteps))] for i in range(len(experiments))]) for comp in components}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5e49784624c23f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_uncertainty_all_comps(\n",
    "        comp_uncertainties,\n",
    "        timesteps,\n",
    "        set_title=True,\n",
    "        save_png=False,\n",
    "        save_file=None\n",
    "):\n",
    "    # Single figure and axes\n",
    "    fig, ax = plt.subplots(figsize=(8, 5))\n",
    "\n",
    "    handles = []\n",
    "    labels = []\n",
    "\n",
    "    for comp, uncertainties in comp_uncertainties.items():\n",
    "        # uncertainties: shape (n_runs, n_timesteps)\n",
    "        mean_uncertainties = np.mean(uncertainties, axis=0)\n",
    "        std_uncertainties = np.std(uncertainties, axis=0)\n",
    "\n",
    "        line, = ax.plot(timesteps, mean_uncertainties, label=comp)\n",
    "        ci = 1.96 * std_uncertainties / np.sqrt(uncertainties.shape[0])\n",
    "        ci_lower = mean_uncertainties - ci\n",
    "        ci_upper = mean_uncertainties + ci\n",
    "\n",
    "        # Use same color as line but with alpha for CI\n",
    "        ax.fill_between(\n",
    "            timesteps,\n",
    "            ci_lower,\n",
    "            ci_upper,\n",
    "            alpha=0.3,\n",
    "            color=line.get_color()\n",
    "        )\n",
    "\n",
    "        handles.append(line)\n",
    "        labels.append(f\"{comp}\")\n",
    "\n",
    "    if set_title and not save_png:\n",
    "        ax.set_title(\"Uncertainty Across Reaction Components\", fontsize=14)\n",
    "    ax.set_xlim(0, 50_000)\n",
    "    ax.set_xlabel(\"Search Steps\")\n",
    "    ax.set_ylabel(\"Uncertainty (Bhattacharyya Avg)\")\n",
    "    ax.grid(True)\n",
    "\n",
    "    # Place legend outside on the right\n",
    "    # bbox_to_anchor (x, y): x>1 pushes it outside to the right\n",
    "    ax.legend(handles, labels, title=\"Component\", loc=\"center left\", bbox_to_anchor=(1.02, 0.5), borderaxespad=0.)\n",
    "\n",
    "    # Adjust layout to make room for legend\n",
    "    plt.tight_layout()\n",
    "    plt.subplots_adjust(right=0.75)\n",
    "\n",
    "    if save_png and save_file is not None:\n",
    "        plt.savefig(save_file, format=\"png\", dpi=500, bbox_inches=\"tight\")\n",
    "\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2d7ab958afefd16",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_uncertainty_all_comps(\n",
    "    quin_comp_uncertainties,\n",
    "    timesteps,\n",
    "    set_title=True,\n",
    "    save_png=False,\n",
    "    save_file=FIGURES_DIR + \"uncertainty_all_components.png\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcd80f1588deab06",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_local_candidates(means_chunk):\n",
    "    \"\"\"\n",
    "    Identifies the 'Top 2' Pareto fronts in a chunk of data.\n",
    "    Returns the indices of rows that are candidates for the Global Top 2.\n",
    "    \"\"\"\n",
    "    # 1. Find First Local Front\n",
    "    # We assume maximization for all 6 objectives based on your code\n",
    "    mask_front_1 = paretoset(means_chunk, sense=[\"max\"] * 6, distinct=False)\n",
    "\n",
    "    # 2. Find Second Local Front (Pareto of the remainder)\n",
    "    # We only look at indices that are NOT in front 1\n",
    "    remaining_indices = np.where(~mask_front_1)[0]\n",
    "\n",
    "    if len(remaining_indices) == 0:\n",
    "         return np.where(mask_front_1)[0]\n",
    "\n",
    "    means_remaining = means_chunk[remaining_indices]\n",
    "\n",
    "    # Calculate Pareto on the remaining subset\n",
    "    mask_front_2_sub = paretoset(means_remaining, sense=[\"max\"] * 6, distinct=False)\n",
    "\n",
    "    # Map back to original indices\n",
    "    indices_front_2 = remaining_indices[mask_front_2_sub]\n",
    "    indices_front_1 = np.where(mask_front_1)[0]\n",
    "\n",
    "    # Combine both sets of indices\n",
    "    return np.concatenate([indices_front_1, indices_front_2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78312579cf350467",
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_combined_uncertainty(df0, df1, df2):\n",
    "    # --- PRE-PROCESSING: Convert to Numpy ---\n",
    "    # Stack the lists into 2D matrices (N, 2)\n",
    "    m0, s0 = np.vstack(df0['means'].values), np.vstack(df0['stds'].values)\n",
    "    m1, s1 = np.vstack(df1['means'].values), np.vstack(df1['stds'].values)\n",
    "    m2, s2 = np.vstack(df2['means'].values), np.vstack(df2['stds'].values)\n",
    "\n",
    "    # Dimensions\n",
    "    n_rows0 = len(m0)\n",
    "    n_rows1 = len(m1)\n",
    "    n_rows2 = len(m2)\n",
    "    chunk_size = n_rows1 * n_rows2\n",
    "\n",
    "   # --- Pre-calculate Sub-Block (Comp 1 + Comp 2) ---\n",
    "    sub_m1 = np.repeat(m1, n_rows2, axis=0)\n",
    "    sub_m2 = np.tile(m2, (n_rows1, 1))\n",
    "\n",
    "    sub_s1 = np.repeat(s1, n_rows2, axis=0)\n",
    "    sub_s2 = np.tile(s2, (n_rows1, 1))\n",
    "\n",
    "    # --- Allocate Reusable Buffers ---\n",
    "    # Shape: (250000, 6) for 3 components x 2 objectives\n",
    "    buffer_means = np.empty((chunk_size, 6), dtype=m0.dtype)\n",
    "    buffer_stds = np.empty((chunk_size, 6), dtype=s0.dtype)\n",
    "\n",
    "    # Fill the static parts (Comp 1 and Comp 2) ONCE\n",
    "    # Assuming m1/m2 have 2 columns each.\n",
    "    # Buffer Layout: [Comp0_0, Comp0_1, Comp1_0, Comp1_1, Comp2_0, Comp2_1]\n",
    "    buffer_means[:, 2:4] = sub_m1\n",
    "    buffer_means[:, 4:6] = sub_m2\n",
    "\n",
    "    buffer_stds[:, 2:4] = sub_s1\n",
    "    buffer_stds[:, 4:6] = sub_s2\n",
    "\n",
    "    results_means = []\n",
    "    results_stds = []\n",
    "\n",
    "    # --- Loop ---\n",
    "    for i in range(n_rows0):\n",
    "        # 1. Update only the dynamic part (Comp 0)\n",
    "        # Broadcasting the (1, 2) row to (250000, 2) view of the buffer\n",
    "        buffer_means[:, 0:2] = m0[i]\n",
    "        buffer_stds[:, 0:2] = s0[i]\n",
    "\n",
    "        # 2. Filter (using the buffer directly)\n",
    "        keep_idx = get_local_candidates(buffer_means)\n",
    "\n",
    "        # 3. Store Survivors (must copy, otherwise next loop overwrites them)\n",
    "        results_means.append(buffer_means[keep_idx].copy())\n",
    "        results_stds.append(buffer_stds[keep_idx].copy())\n",
    "\n",
    "    # Stack results\n",
    "    if not results_means:\n",
    "        return np.array([]), np.array([])\n",
    "\n",
    "    return np.vstack(results_means), np.vstack(results_stds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0eec66e3c917f64",
   "metadata": {},
   "outputs": [],
   "source": [
    "uncertainties = []\n",
    "\n",
    "for e in range(0, len(experiments)):\n",
    "    for t in tqdm(timesteps, desc=f\"Experiment {e}\"):\n",
    "        df0 = quin_comp_post_dfs[0][e-1][timesteps.tolist().index(t)]\n",
    "        df1 = quin_comp_post_dfs[1][e-1][timesteps.tolist().index(t)]\n",
    "        df2 = quin_comp_post_dfs[2][e-1][timesteps.tolist().index(t)]\n",
    "\n",
    "        combined_means, combined_stds = process_combined_uncertainty(df0, df1, df2)\n",
    "\n",
    "        uq_value = uncertainty_quantification_numpy(\n",
    "            combined_means,\n",
    "            combined_stds\n",
    "        )\n",
    "\n",
    "        uncertainties.append({\n",
    "            \"experiment\": e,\n",
    "            \"timestep\": t,\n",
    "            \"uncertainty\": uq_value\n",
    "        })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b6490a8be50c5c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "overall_uncertainty_df = pd.DataFrame(uncertainties)\n",
    "overall_uncertainty_df.to_parquet(DATA_DIR + \"overall_uncertainty_TTPFTS_ex41_50_step500.parquet\", index=False)\n",
    "print(overall_uncertainty_df.head())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "556a4ba710768bf8",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-01-14T15:35:45.628028Z",
     "start_time": "2026-01-14T15:35:45.624382Z"
    }
   },
   "outputs": [],
   "source": [
    "def plot_evolution_with_ci(df, save_path=None):\n",
    "    \"\"\"\n",
    "    Plots the mean uncertainty over time with a 95% Confidence Interval.\n",
    "    \"\"\"\n",
    "    # 1. Group by timestep to calculate statistics across all experiments\n",
    "    # We need mean, standard deviation, and count (n)\n",
    "    stats = df.groupby(\"timestep\")[\"uncertainty\"].agg([\"mean\", \"std\", \"count\"])\n",
    "\n",
    "    # 2. Calculate the 95% Confidence Interval\n",
    "    # Formula: 1.96 * standard_error\n",
    "    stats[\"ci\"] = 1.96 * stats[\"std\"] / np.sqrt(stats[\"count\"])\n",
    "\n",
    "    # 3. Setup the plot\n",
    "    fig, ax = plt.subplots(figsize=(8, 6))\n",
    "\n",
    "    # Plot the Mean line\n",
    "    ax.plot(\n",
    "        stats.index,\n",
    "        stats[\"mean\"],\n",
    "        label=\"Mean Uncertainty\",\n",
    "    )\n",
    "\n",
    "    # Plot the Shaded Confidence Interval\n",
    "    ax.fill_between(\n",
    "        stats.index,\n",
    "        stats[\"mean\"] - stats[\"ci\"],\n",
    "        stats[\"mean\"] + stats[\"ci\"],\n",
    "        alpha=0.3,\n",
    "    )\n",
    "\n",
    "    # 4. Styling\n",
    "    ax.set_title(\"Evolution of Overall Uncertainty\", fontsize=14)\n",
    "    ax.set_xlabel(\"Search Steps\", fontsize=12)\n",
    "    ax.set_ylabel(\"Uncertainty (Bhattacharyya Avg)\", fontsize=12)\n",
    "\n",
    "    # Ensure axes start at 0 if appropriate, or auto-scale\n",
    "    ax.set_xlim(left=0)\n",
    "    if stats[\"mean\"].min() >= 0:\n",
    "        ax.set_ylim(bottom=0)\n",
    "\n",
    "    ax.grid(True, linestyle=\"--\", alpha=0.6)\n",
    "    ax.legend(loc=\"upper right\", frameon=True)\n",
    "\n",
    "    plt.tight_layout()\n",
    "\n",
    "    if save_path:\n",
    "        plt.savefig(save_path, dpi=300, bbox_inches=\"tight\")\n",
    "\n",
    "    plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
