{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### Define Chain(s)\n",
    "\n",
    "from niagara import Chain, Model, ModelIntrinsicLogProb, NullTransformation, LogisticRegressionCalibrator\n",
    "from niagara import OpenAIClient, FireworksClient\n",
    "\n",
    "import os\n",
    "os.environ[\"FIREWORKS_API_KEY\"] = \"leave-this-line-but-there-is-no-need-to-add-an-API-key\"\n",
    "\n",
    "llama_chain = Chain(\n",
    "    models = [\n",
    "        Model(\n",
    "            model_name=name, \n",
    "            thresholds={\"reject\": -10000, \"accept\": 0.0},\n",
    "            conf_signal=ModelIntrinsicLogProb(),\n",
    "            conf_signal_transform=NullTransformation(),\n",
    "            conf_signal_calibrator=LogisticRegressionCalibrator()\n",
    "        )\n",
    "        for name in [\"llama3.2-1b\", \"llama3.2-3b\", \"llama3.1-8b\", \"llama3.1-70b\", \"llama3.1-405b\"]\n",
    "    ]\n",
    ")\n",
    "\n",
    "qwen_oai_chain = Chain(\n",
    "    models = [\n",
    "        Model(\n",
    "            model_name=name, \n",
    "            thresholds={\"reject\": -10000, \"accept\": 0.0},\n",
    "            conf_signal=ModelIntrinsicLogProb(),\n",
    "            conf_signal_transform=NullTransformation(),\n",
    "            conf_signal_calibrator=LogisticRegressionCalibrator(),\n",
    "            client=client\n",
    "        )\n",
    "        for name, client in [(\"gpt-4o-mini\", None), (\"qwen2.5-32b-coder-instruct\", None), (\"qwen2.5-72b-instruct\", None), (\"gpt-4o\", None)]\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### Select chain, benchmark, transformation, and grab data\n",
    "\n",
    "import pickle\n",
    "from niagara import OneSidedAsymptoticLog, TwoSidedAsymptoticLog\n",
    "\n",
    "PRETTY_NAMES = {\n",
    "    \"xsum\": \"XSum\",\n",
    "    \"mmlu\": \"MMLU\",\n",
    "    \"medmcqa\": \"MedMCQA\",\n",
    "    \"triviaqa\": \"TriviaQA\",\n",
    "    \"truthfulqa\": \"TruthfulQA\",\n",
    "    \"gsm8k\": \"GSM8K\"\n",
    "}\n",
    "\n",
    "NAME = \"medmcqa\"\n",
    "TRANSFORM = OneSidedAsymptoticLog()\n",
    "CHAIN_NAME = \"qwen_oai_chain\"\n",
    "CHAIN = qwen_oai_chain\n",
    "\n",
    "# Update the transformation for the chain\n",
    "for model in CHAIN.models:\n",
    "    model.conf_signal_transform = TRANSFORM\n",
    "\n",
    "with open(f'../benchmarks/data/{NAME}/chain_results/{NAME}_full_{CHAIN_NAME}_results_train.pkl', 'rb') as f:\n",
    "    results_train = pickle.load(f)\n",
    "with open(f'../benchmarks/data/{NAME}/chain_results/{NAME}_full_{CHAIN_NAME}_results_test.pkl', 'rb') as f:\n",
    "    results_test = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### Compute calibrated confidence values\n",
    "\n",
    "process_scores = lambda scores: sum(scores.values()) >= 20\n",
    "\n",
    "if NAME==\"xsum\":\n",
    "    raw_corr_train = { k: [process_scores(x) for x in v] for k,v in results_train['model_correctness'].items() }\n",
    "else:\n",
    "    raw_corr_train= results_train['model_correctness']\n",
    "\n",
    "raw_conf_train = results_train['raw_confidences']\n",
    "\n",
    "corr_train = [\n",
    "    raw_corr_train[model_name] for model_name in CHAIN.model_names\n",
    "]\n",
    "\n",
    "transformed_conf_train = [ \n",
    "    list(TRANSFORM.transform_confidence_signal(raw_conf_train[model_name]))\n",
    "        for model_name in CHAIN.model_names\n",
    "]\n",
    "\n",
    "calibration_data = [\n",
    "    {\"correctness\": corr, \"transformed_confidence\": conf} \n",
    "        for (corr, conf, model_name) \n",
    "            in zip(corr_train, transformed_conf_train, CHAIN.model_names)\n",
    "]\n",
    "\n",
    "CHAIN.calibrate(calibration_data)\n",
    "\n",
    "calibrated_conf_train = [\n",
    "    list(\n",
    "        CHAIN.models[model_idx].conf_signal_calibrator.calibrate_confidence_signal(\n",
    "            transformed_conf_train[model_idx]\n",
    "        )\n",
    "    )\n",
    "    for model_idx in range(len(CHAIN.model_names))\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### Compute test data\n",
    "\n",
    "if NAME==\"xsum\":\n",
    "    raw_corr_test = { k: [process_scores(x) for x in v] for k,v in results_test['model_correctness'].items() }\n",
    "else:\n",
    "    raw_corr_test= results_test['model_correctness']\n",
    "\n",
    "raw_conf_test = results_test['raw_confidences']\n",
    "\n",
    "corr_test = [\n",
    "    raw_corr_test[model_name] for model_name in CHAIN.model_names\n",
    "]\n",
    "\n",
    "transformed_conf_test = [ \n",
    "    list(TRANSFORM.transform_confidence_signal(raw_conf_test[model_name]))\n",
    "        for model_name in CHAIN.model_names\n",
    "]\n",
    "\n",
    "calibrated_conf_test = [\n",
    "    list(\n",
    "        CHAIN.models[model_idx].conf_signal_calibrator.calibrate_confidence_signal(\n",
    "            transformed_conf_test[model_idx]\n",
    "        )\n",
    "    )\n",
    "    for model_idx in range(len(CHAIN.model_names))\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from statsmodels.distributions.empirical_distribution import ECDF\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "\n",
    "# Compute Cramer von Mises statistic\n",
    "def compute_cvm(F_emp, F_null, n_grid=1000):\n",
    "    integration_grid = np.linspace(0, 1, n_grid)\n",
    "    integrand = [ (F_emp(v) - F_null(v))**2 for v in integration_grid ]\n",
    "    sum_1 = np.sum(integrand[1:] * np.diff(F_null(integration_grid)))\n",
    "    sum_2 = np.sum(integrand[:-1] * np.diff(F_null(integration_grid)))\n",
    "    integral = (sum_1 + sum_2) / 2\n",
    "    return integral\n",
    "\n",
    "# Compute Cramer von Mises pval\n",
    "def compute_cvm_pval(test_statistic, null_dist, n_obs, B=1000):\n",
    "    cvm_statistic_list = []\n",
    "    for b in tqdm(range(B)):\n",
    "        null_sample_b = null_dist.rvs(size=n_obs)\n",
    "        null_empirical_cdf_b = ECDF(null_sample_b)\n",
    "        cvm_statistic_b = compute_cvm(null_empirical_cdf_b, null_dist.cdf, n_grid=1000)\n",
    "        cvm_statistic_list.append(cvm_statistic_b)\n",
    "\n",
    "    # Compute p value\n",
    "    pval = np.mean(np.array(cvm_statistic_list) >= test_statistic)\n",
    "    return pval, cvm_statistic_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### Fit the probabilistic model for the marginal distribution\n",
    "\n",
    "from matplotlib import rcParams\n",
    "import matplotlib.patches as patches\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import numpy as np\n",
    "from niagara.probabilistic_modeling.marginals import lumpy_betamix\n",
    "\n",
    "# Enable LaTeX text rendering in Matplotlib\n",
    "rcParams['text.usetex'] = False #True\n",
    "rcParams['font.family'] = 'serif'\n",
    "rcParams['font.serif'] = ['Computer Modern Roman']\n",
    "rcParams['font.size'] = 10\n",
    "\n",
    "pvals = {}\n",
    "pvals_refitted = {}\n",
    "\n",
    "SKIP_CVM = False\n",
    "SKIP_CVM_REFITTED = False\n",
    "REG = 0.0\n",
    "\n",
    "params = {}\n",
    "\n",
    "for MODEL_ID in range(len(CHAIN.model_names)):\n",
    "    # Fit the marginal distribution to the TRAINING DATA\n",
    "    marginal_params = lumpy_betamix.fit(np.array(calibrated_conf_train[MODEL_ID]), betamix_reg=REG)\n",
    "    marginal_dist = lumpy_betamix(*marginal_params)\n",
    "\n",
    "    # Get the calibrated confidence values for the TEST DATA, based on calibration trained on TRAINING data\n",
    "    data = calibrated_conf_test[MODEL_ID]\n",
    "    MODEL_NAME = CHAIN.model_names[MODEL_ID]\n",
    "    SHORT_MODEL_NAME = MODEL_NAME.split(\"-\")[-1]\n",
    "\n",
    "    # Record the distributional parameters\n",
    "    params_train = marginal_params\n",
    "    # fit the marginal distribution on the test data to see how the params change\n",
    "    marginal_params_test = lumpy_betamix.fit(np.array(data), betamix_reg=REG)\n",
    "    marginal_dist_refitted = lumpy_betamix(*marginal_params_test)\n",
    "    params_test = marginal_params_test\n",
    "    params[MODEL_NAME] = { 'params_train': marginal_params, 'params_test': marginal_params_test }\n",
    "\n",
    "    # Compute CVM test\n",
    "    if not SKIP_CVM:\n",
    "        empirical_cdf = ECDF(data)\n",
    "        cvm_test_statistic = compute_cvm(empirical_cdf, marginal_dist.cdf)\n",
    "        cvm_pval, cvm_statistic_list = compute_cvm_pval(cvm_test_statistic, marginal_dist, len(data))\n",
    "        pvals[MODEL_NAME] = {\n",
    "            'cvm': cvm_test_statistic,\n",
    "            'cvm_pval': cvm_pval,\n",
    "            'null_cvm_values': cvm_statistic_list\n",
    "        }\n",
    "\n",
    "    # Compute CVM test on refitted dist\n",
    "    if not SKIP_CVM_REFITTED:\n",
    "        empirical_cdf = ECDF(data)\n",
    "        cvm_test_statistic = compute_cvm(empirical_cdf, marginal_dist_refitted.cdf)\n",
    "        cvm_pval, cvm_statistic_list = compute_cvm_pval(cvm_test_statistic, marginal_dist_refitted, len(data))\n",
    "        pvals_refitted[MODEL_NAME] = {\n",
    "            'cvm': cvm_test_statistic,\n",
    "            'cvm_pval': cvm_pval,\n",
    "            'null_cvm_values': cvm_statistic_list\n",
    "        }\n",
    "\n",
    "    # Set a professional style\n",
    "    sns.set_style(\"white\")\n",
    "    sns.set_context(\"paper\", font_scale=1.0)\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=(3.5, 2.5))\n",
    "\n",
    "    # Plot histogram using Seaborn's histplot (available in seaborn >= 0.11)\n",
    "    sns.histplot(data, bins='auto', stat='density', kde=False, color=\"#2E86C1\", edgecolor=\"black\", ax=ax)\n",
    "\n",
    "    counts, edges = np.histogram(data, bins='auto', density=False)\n",
    "    mass_min_edges = edges[:2]\n",
    "    mass_max_edges = edges[-2:]\n",
    "\n",
    "    # grid for evaluating marginal_pdf\n",
    "    grid = np.linspace(mass_min_edges[0]+1e-3,mass_max_edges[-1]-2e-3,200)\n",
    "    # grid = np.linspace(0,1,200)\n",
    "    marginal_pdf = marginal_dist.pdf(grid)\n",
    "\n",
    "    # get the normalization factor of the histogram\n",
    "    h = edges[1]-edges[0]\n",
    "    n_bins = len(counts)\n",
    "    total_counts = sum(counts)\n",
    "    avg_count = total_counts/n_bins\n",
    "    hist_integral = h*n_bins*avg_count\n",
    "\n",
    "    first_bar_height = counts[0]/hist_integral\n",
    "    last_bar_height = counts[-1]/hist_integral\n",
    "\n",
    "    # compute multiplier for showing discrete mass on top of pdf\n",
    "    w_multiplier = total_counts/hist_integral\n",
    "    p_min, p_max, w_min, w_max, _, _, _, _, _ = params_train\n",
    "\n",
    "    for edges, w, xmin_bar, total_bar_height in [\n",
    "            (mass_min_edges, w_min, mass_min_edges[0], first_bar_height), \n",
    "            (mass_max_edges, w_max, mass_max_edges[0], last_bar_height)\n",
    "        ]:\n",
    "        w = w*w_multiplier\n",
    "        \n",
    "        rect = patches.Rectangle((xmin_bar, max(total_bar_height-w, 0)), \n",
    "                        width=edges[1]-edges[0], \n",
    "                        height=w,\n",
    "                        facecolor='none',\n",
    "                        edgecolor='black',\n",
    "                        hatch='////',\n",
    "                        linewidth=1.0,\n",
    "                        alpha=1.0)\n",
    "        plt.gca().add_patch(rect)\n",
    "\n",
    "    # Add labels and title if desired\n",
    "    ax.set_xlabel(r\"\\textbf{Calibrated Confidence}\", fontsize=10)\n",
    "    ax.set_ylabel(r\"\\textbf{Count}\", fontsize=10)\n",
    "\n",
    "    # Remove top and right spines for a cleaner look\n",
    "    ax.spines[\"top\"].set_visible(False)\n",
    "    ax.spines[\"right\"].set_visible(False)\n",
    "\n",
    "    # Tight layout for better spacing\n",
    "    plt.tight_layout()\n",
    "\n",
    "    # Save as PDF (vector graphics)\n",
    "    plt.plot(grid, marginal_pdf, color=\"black\", linewidth=2.0)\n",
    "    ax.set_xlabel(None)\n",
    "    ax.set_ylabel(None)\n",
    "    ax.set_yticks([])\n",
    "\n",
    "    plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
