{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from niagara import Chain, Model, ModelIntrinsicLogProb, NullTransformation, LogisticRegressionCalibrator\n",
    "from niagara import OpenAIClient, FireworksClient, OneSidedAsymptoticLog, TwoSidedAsymptoticLog\n",
    "import pickle\n",
    "import os\n",
    "os.environ[\"FIREWORKS_API_KEY\"] = \"leave-this-line-but-there-is-no-need-to-add-an-API-key\"\n",
    "\n",
    "llama_chain = Chain(\n",
    "    models = [\n",
    "        Model(\n",
    "            model_name=name, \n",
    "            thresholds={\"reject\": -10000, \"accept\": 0.0},\n",
    "            conf_signal=ModelIntrinsicLogProb(),\n",
    "            conf_signal_transform=NullTransformation(),\n",
    "            conf_signal_calibrator=LogisticRegressionCalibrator()\n",
    "        )\n",
    "        for name in [\"llama3.2-1b\", \"llama3.2-3b\", \"llama3.1-8b\", \"llama3.1-70b\", \"llama3.1-405b\"]\n",
    "    ]\n",
    ")\n",
    "\n",
    "qwen_oai_chain = Chain(\n",
    "    models = [\n",
    "        Model(\n",
    "            model_name=name, \n",
    "            thresholds={\"reject\": -10000, \"accept\": 0.0},\n",
    "            conf_signal=ModelIntrinsicLogProb(),\n",
    "            conf_signal_transform=NullTransformation(),\n",
    "            conf_signal_calibrator=LogisticRegressionCalibrator(),\n",
    "            client=client\n",
    "        )\n",
    "        for name, client in [(\"gpt-4o-mini\", None), (\"qwen2.5-32b-coder-instruct\", None), (\"qwen2.5-72b-instruct\", None), (\"gpt-4o\", None)]\n",
    "    ]\n",
    ")\n",
    "\n",
    "NAME = \"mmlu\"\n",
    "CHAIN_NAME = \"llama_chain\"\n",
    "ALL_MODEL_INDICES = [\n",
    "    [0, 1],\n",
    "    [0, 2],\n",
    "    [1, 2],\n",
    "    [0, 1, 2],\n",
    "    [0, 3],\n",
    "    [1, 3],\n",
    "    [2, 3],\n",
    "    [0, 1, 3],\n",
    "    [0, 2, 3],\n",
    "    [1, 2, 3],\n",
    "    [0, 1, 2, 3],\n",
    "    [0, 4],\n",
    "    [1, 4],\n",
    "    [2, 4],\n",
    "    [3, 4],\n",
    "    [0, 1, 4],\n",
    "    [0, 2, 4],\n",
    "    [0, 3, 4],\n",
    "    [1, 2, 4],\n",
    "    [1, 3, 4],\n",
    "    [2, 3, 4],\n",
    "    [0, 1, 2, 4],\n",
    "    [0, 1, 3, 4],\n",
    "    [0, 2, 3, 4],\n",
    "    [1, 2, 3, 4],\n",
    "    [0, 1, 2, 3, 4]\n",
    "]\n",
    "\n",
    "TRANSFORM = OneSidedAsymptoticLog() if NAME in {'mmlu', 'medmcqa'} else TwoSidedAsymptoticLog()\n",
    "\n",
    "if CHAIN_NAME == \"llama_chain\":\n",
    "    CHAIN = llama_chain\n",
    "elif CHAIN_NAME == \"qwen_oai_chain\":\n",
    "    CHAIN = qwen_oai_chain\n",
    "\n",
    "# Update the transformation for the chain\n",
    "for model in CHAIN.models:\n",
    "    model.conf_signal_transform = TRANSFORM\n",
    "\n",
    "with open(f'../benchmarks/data/{NAME}/chain_results/{NAME}_full_{CHAIN_NAME}_results_train.pkl', 'rb') as f:\n",
    "    results_train = pickle.load(f)\n",
    "with open(f'../benchmarks/data/{NAME}/chain_results/{NAME}_full_{CHAIN_NAME}_results_test.pkl', 'rb') as f:\n",
    "    results_test = pickle.load(f)\n",
    "\n",
    "# Get the train and test data\n",
    "\n",
    "### Compute calibrated confidence values\n",
    "\n",
    "process_scores = lambda scores: sum(scores.values()) >= 20\n",
    "\n",
    "if NAME==\"xsum\":\n",
    "    raw_corr_train = { k: [process_scores(x) for x in v] for k,v in results_train['model_correctness'].items() }\n",
    "else:\n",
    "    raw_corr_train= results_train['model_correctness']\n",
    "\n",
    "raw_conf_train = results_train['raw_confidences']\n",
    "\n",
    "corr_train = [\n",
    "    raw_corr_train[model_name] for model_name in CHAIN.model_names\n",
    "]\n",
    "\n",
    "transformed_conf_train = [ \n",
    "    list(TRANSFORM.transform_confidence_signal(raw_conf_train[model_name]))\n",
    "        for model_name in CHAIN.model_names\n",
    "]\n",
    "\n",
    "calibration_data = [\n",
    "    {\"correctness\": corr, \"transformed_confidence\": conf} \n",
    "        for (corr, conf, model_name) \n",
    "            in zip(corr_train, transformed_conf_train, CHAIN.model_names)\n",
    "]\n",
    "\n",
    "CHAIN.calibrate(calibration_data)\n",
    "\n",
    "calibrated_conf_train = [\n",
    "    list(\n",
    "        CHAIN.models[model_idx].conf_signal_calibrator.calibrate_confidence_signal(\n",
    "            transformed_conf_train[model_idx]\n",
    "        )\n",
    "    )\n",
    "    for model_idx in range(len(CHAIN.model_names))\n",
    "]\n",
    "\n",
    "### Compute test data\n",
    "\n",
    "if NAME==\"xsum\":\n",
    "    raw_corr_test = { k: [process_scores(x) for x in v] for k,v in results_test['model_correctness'].items() }\n",
    "else:\n",
    "    raw_corr_test= results_test['model_correctness']\n",
    "\n",
    "raw_conf_test = results_test['raw_confidences']\n",
    "\n",
    "corr_test = [\n",
    "    raw_corr_test[model_name] for model_name in CHAIN.model_names\n",
    "]\n",
    "\n",
    "transformed_conf_test = [ \n",
    "    list(TRANSFORM.transform_confidence_signal(raw_conf_test[model_name]))\n",
    "        for model_name in CHAIN.model_names\n",
    "]\n",
    "\n",
    "calibrated_conf_test = [\n",
    "    list(\n",
    "        CHAIN.models[model_idx].conf_signal_calibrator.calibrate_confidence_signal(\n",
    "            transformed_conf_test[model_idx]\n",
    "        )\n",
    "    )\n",
    "    for model_idx in range(len(CHAIN.model_names))\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.integrate import quad\n",
    "from scipy.interpolate import interp1d\n",
    "import numpy as np\n",
    "\n",
    "### Define function for computing area under the curve (AUC)\n",
    "\n",
    "def compute_auc(x, y, x_min, x_max, integration_limit=200, method='linear'):\n",
    "    \"\"\"\n",
    "    Compute area under the curve of the function y = f(x), as defined by point samples.\n",
    "    \"\"\"\n",
    "    x = np.array(x)\n",
    "    y = np.array(y)\n",
    "    order = np.argsort(x)\n",
    "    x = x[order]\n",
    "    y = y[order]\n",
    "    f = interp1d(x, y, kind=method, bounds_error=False, fill_value='extrapolate')\n",
    "    \n",
    "    return quad(f, x_min, x_max, limit=integration_limit)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from optimize_cascade import get_expected_uncumulated_costs\n",
    "\n",
    "# raw_model_costs = { \n",
    "#     model_name: CHAIN.models[i].cpm_tokens \n",
    "#         for i, model_name in enumerate(CHAIN.model_names) \n",
    "# }\n",
    "\n",
    "if CHAIN_NAME == \"llama_chain\":\n",
    "    raw_model_costs = {\n",
    "        \"llama3.2-1b\": {\"in\": 0.10, \"out\": 0.10},\n",
    "        \"llama3.2-3b\": {\"in\": 0.10, \"out\": 0.10},\n",
    "        \"llama3.1-8b\": {\"in\": 0.20, \"out\": 0.20},\n",
    "        \"llama3.1-70b\": {\"in\": 0.90, \"out\": 0.90},\n",
    "        \"llama3.1-405b\": {\"in\": 3.00, \"out\": 3.00},\n",
    "    }\n",
    "elif CHAIN_NAME == \"qwen_oai_chain\":\n",
    "    raw_model_costs = {\n",
    "        \"gpt-4o-mini\": {\"in\": 0.15, \"out\": 0.60},\n",
    "        \"qwen2.5-32b-coder-instruct\": {\"in\": 0.90, \"out\": 0.90},\n",
    "        \"qwen2.5-72b-instruct\": {\"in\": 0.90, \"out\": 0.90},\n",
    "        \"gpt-4o\": {\"in\": 2.50, \"out\": 10.00},\n",
    "    }\n",
    "\n",
    "expected_uncumulated_costs_train = get_expected_uncumulated_costs(raw_model_costs, results_train)\n",
    "expected_uncumulated_costs_test = get_expected_uncumulated_costs(raw_model_costs, results_test)\n",
    "\n",
    "def compute_utilization_from_conditional_deferral_probs(conditional_deferral_probs):\n",
    "    return [\n",
    "        np.prod(conditional_deferral_probs[:i]) * (1 - conditional_deferral_probs[i])\n",
    "            for i in range(0,len(conditional_deferral_probs))\n",
    "    ]\n",
    "\n",
    "def compute_utilization_and_error(T, model_indices, calibrated_conf_train, conditional_deferral_probs=None):\n",
    "    \"\"\" Estimate the probabilities that each model return the query. \"\"\"\n",
    "    # use local indices for accessing calibrated confidences of the selected models\n",
    "    cal_conf_tr = np.array(calibrated_conf_train).transpose()[:, model_indices]\n",
    "\n",
    "    utilizations = []\n",
    "    conditional_corrs = []\n",
    "    unconditional_corrs = []\n",
    "\n",
    "    # add utilization of the first model\n",
    "    first_model_accepts = cal_conf_tr[:, 0] > T[0]\n",
    "    utilizations.append(np.mean(first_model_accepts))\n",
    "    conditional_corrs.append(np.mean(cal_conf_tr[first_model_accepts, 0]))\n",
    "    unconditional_corrs.append(np.mean(cal_conf_tr[:,0]))\n",
    "\n",
    "    # get utilizations for the second and subsequent models\n",
    "    for i in range(1, len(model_indices)):\n",
    "        prior_models_delegate = np.all(cal_conf_tr[:,:i] <= np.array(T)[np.newaxis, :i], axis=1)\n",
    "        this_model_accepts = (cal_conf_tr[:,i] > T[i]) if i < len(model_indices)-1 else np.array([True])\n",
    "\n",
    "        # utilization rate\n",
    "        utilization = np.mean(prior_models_delegate & this_model_accepts)\n",
    "        utilizations.append(utilization)\n",
    "\n",
    "        # error conditioned on returning the query\n",
    "        conditional_corr = np.mean(cal_conf_tr[prior_models_delegate & this_model_accepts,i])\n",
    "        conditional_corrs.append(conditional_corr)\n",
    "\n",
    "        # unconditional error\n",
    "        unconditional_corr = np.mean(cal_conf_tr[:,i])\n",
    "        unconditional_corrs.append(unconditional_corr)\n",
    "\n",
    "    if conditional_deferral_probs is not None:\n",
    "        utilizations = compute_utilization_from_conditional_deferral_probs(conditional_deferral_probs)\n",
    "    \n",
    "    # print(f\"Utilization: {utilizations}\\nConditional Errors:{conditional_corrs}\")\n",
    "    return utilizations, conditional_corrs, unconditional_corrs\n",
    "\n",
    "\n",
    "def get_ecorr_ecost_estimates(threshold_list, model_indices, calibrated_conf_train):\n",
    "    \"\"\" Estimate the expected probability of correctness and expected cost for all thresholds. \"\"\"\n",
    "    outputs = []\n",
    "\n",
    "    for i in range(len(threshold_list)):\n",
    "        T = threshold_list[i]\n",
    "\n",
    "        util, cond_corr, uncond_corr = compute_utilization_and_error(T, model_indices, calibrated_conf_train)\n",
    "        costs = np.array(expected_uncumulated_costs_train)[model_indices]\n",
    "\n",
    "        ecorr_estimate = np.nansum(np.array(util) * np.array(cond_corr))\n",
    "        ecost_estimate = np.nansum(np.array(util) * np.cumsum(costs))\n",
    "\n",
    "        outputs.append((ecost_estimate, ecorr_estimate))\n",
    "    \n",
    "    return outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### Setup for getting the probabilistic model\n",
    "import os\n",
    "import pickle\n",
    "from time import time\n",
    "from itertools import product\n",
    "from paretoset import paretoset\n",
    "from optimize_cascade import train_probability_model\n",
    "\n",
    "def get_optimal_thresholds_using_grid_search(model_indices, data_train, quantile_h = 0.025, eps=0.01):\n",
    "    \"\"\" Get optimal thresholds via grid search. \"\"\"\n",
    "    # Compute the grid for each threshold\n",
    "    threshold_grids = [ np.quantile(data_train[idx], q=np.arange(0+eps,1-eps,quantile_h)) for idx in model_indices[:-1] ]\n",
    "\n",
    "    # Get all candidates\n",
    "    threshold_candidates = [ np.array(x) for x in product(*threshold_grids) ]\n",
    "\n",
    "    # Get cost and correctness for candidates\n",
    "    ecost_ecorr = get_ecorr_ecost_estimates(threshold_candidates, model_indices, data_train)\n",
    "\n",
    "    # Compute the Pareto set of solutions\n",
    "    df = pd.DataFrame(ecost_ecorr, columns=['expected_cost','expected_correctness'])\n",
    "    pareto_mask = paretoset(df, sense=[\"min\", \"max\"])\n",
    "    pareto_df = df.loc[pareto_mask]\n",
    "\n",
    "    # Return the Pareto solutions AND the optimal thresholds\n",
    "    opt_tholds = [ threshold_candidates[i] for i in range(len(pareto_mask)) if pareto_mask[i] ]\n",
    "    return pareto_df, opt_tholds\n",
    "\n",
    "start = time()\n",
    "filename = f\"data/probabilistic_model_results_{NAME}.pkl\"\n",
    "SAVE_TO_FILE = False\n",
    "\n",
    "if os.path.exists(filename):\n",
    "    with open(filename, 'rb') as file:\n",
    "        prob_results = pickle.load(file)\n",
    "else:\n",
    "    prob_results = train_probability_model(full_data=np.array(calibrated_conf_train).transpose())\n",
    "    if SAVE_TO_FILE:\n",
    "        with open(filename, 'wb') as file:\n",
    "            pickle.dump(prob_results, file)\n",
    "\n",
    "stop = time()\n",
    "print(stop-start)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def fill_parameter_gaps_adaptively(parameters_list, model_indices, data_train, max_prob_gap=0.1):\n",
    "    \"\"\"\n",
    "    Fill gaps between parameters by adding midpoints when probability mass gap is too large.\n",
    "    \n",
    "    Args:\n",
    "        parameters_list: List of parameter vectors (numpy arrays)\n",
    "        model_indices: Indices to select relevant columns from data_train\n",
    "        data_train: Training data array (n_obs x n_features)\n",
    "        max_prob_gap: Maximum allowed probability mass between consecutive parameters\n",
    "        \n",
    "    Returns:\n",
    "        List of parameters including added midpoints\n",
    "    \"\"\"\n",
    "    # Return early if we have 0 or 1 parameters\n",
    "    if len(parameters_list) <= 1:\n",
    "        return parameters_list\n",
    "        \n",
    "    # Convert to numpy array for easier manipulation\n",
    "    params_list = np.array(parameters_list)\n",
    "    \n",
    "    # Get the observed values\n",
    "    observed_values = np.array([data_train[i] for i in model_indices[:-1]]).transpose()\n",
    "    \n",
    "    while True:\n",
    "        added_point = False\n",
    "        \n",
    "        # Look through consecutive pairs\n",
    "        for i in range(len(params_list)-1):\n",
    "            # For each component of the parameter vector\n",
    "            max_prob = 0\n",
    "            for j in range(params_list.shape[1]):\n",
    "                # Get values between this pair of parameters for this component\n",
    "                lower = min(params_list[i][j], params_list[i+1][j])\n",
    "                upper = max(params_list[i][j], params_list[i+1][j])\n",
    "                \n",
    "                # Calculate probability mass between the parameters\n",
    "                prob_mass = np.mean((observed_values[:, j] > lower) & \n",
    "                                  (observed_values[:, j] < upper))\n",
    "                max_prob = max(max_prob, prob_mass)\n",
    "            \n",
    "            if max_prob > max_prob_gap:\n",
    "                # Calculate midpoint\n",
    "                midpoint = (params_list[i] + params_list[i+1]) / 2\n",
    "                \n",
    "                # Insert midpoint\n",
    "                params_list = np.insert(params_list, i+1, midpoint, axis=0)\n",
    "                added_point = True\n",
    "                break  # Start over since we modified the array\n",
    "        \n",
    "        # If we didn't add any points, we're done\n",
    "        if not added_point:\n",
    "            break\n",
    "    \n",
    "    return [np.array(x) for x in params_list.tolist()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from optimize_cascade import profile_cascade, profile_cascade_adaptively, make_full_data, score_cascade\n",
    "from tqdm import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "with open(\"data/cascade_comparison_records.pkl\", \"rb\") as file:\n",
    "    ALL_RECORDS = pickle.load(file)\n",
    "\n",
    "RUN_GRID = True\n",
    "\n",
    "ALL_RECORDS = []\n",
    "SAVE_TO_FILE = False\n",
    "\n",
    "for model_indices in tqdm(ALL_MODEL_INDICES):\n",
    "    data_train = calibrated_conf_train\n",
    "\n",
    "    # Optimize the cascade on the data\n",
    "    start = time()\n",
    "    cascade_record = profile_cascade_adaptively(\n",
    "        model_indices, \n",
    "        expected_uncumulated_costs_train, \n",
    "        prob_results, \n",
    "        start_sensitivities=[0, 1e-10, 1e-8, 1e-6], # or did it only go to 1e-7 at first?\n",
    "        cost_threshold_multiplier=1.25,\n",
    "        stop_val=1000,\n",
    "        max_iterations=24, # used to be 100\n",
    "        sensitivity_increase_factor=2.0 # used to be 2, # used to be 1.2\n",
    "    )\n",
    "\n",
    "    opt_tholds_cts_optim = cascade_record['optimal_thresholds']\n",
    "    opt_tholds_cts_optim = fill_parameter_gaps_adaptively(\n",
    "        opt_tholds_cts_optim, model_indices, data_train, max_prob_gap=0.1\n",
    "    )\n",
    "    stop = time()\n",
    "\n",
    "    continuous_optim_time = stop-start\n",
    "    print(f\"Continuous optimization took {continuous_optim_time}s\")\n",
    "\n",
    "    # Optimize the thresholds via grid search\n",
    "    if RUN_GRID:\n",
    "        quantile_h = 0.025\n",
    "        start = time()\n",
    "        pareto_df, opt_tholds_grid_search = get_optimal_thresholds_using_grid_search(model_indices, data_train, quantile_h)\n",
    "        stop = time()\n",
    "        grid_search_time = stop-start\n",
    "        print(f\"Grid search took {grid_search_time}s\")\n",
    "\n",
    "    # Score both solutions on the test data\n",
    "    test_data = {\n",
    "        'calib_conf': make_full_data(calibrated_conf_test), \n",
    "        'corr': make_full_data(corr_test) \n",
    "    }\n",
    "\n",
    "    scores_continuous_optim = [\n",
    "        score_cascade(T, model_indices, expected_uncumulated_costs_test, test_data)\n",
    "            for T in opt_tholds_cts_optim #cascade_record['optimal_thresholds']\n",
    "    ]\n",
    "    ecost_ecorr_continuous_optim = [ \n",
    "        (rec['expected_cost_test'], rec['expected_correctness_test']) for rec in scores_continuous_optim\n",
    "    ]\n",
    "\n",
    "    if RUN_GRID:\n",
    "        scores_grid_search = [\n",
    "            score_cascade(T, model_indices, expected_uncumulated_costs_test, test_data)\n",
    "                for T in opt_tholds_grid_search\n",
    "        ]\n",
    "        ecost_ecorr_gridsearch = [ (rec['expected_cost_test'], rec['expected_correctness_test']) for rec in scores_grid_search ]\n",
    "\n",
    "    # Calculate the area under the curve for both methods (make sure to integrate between same min and max cost)\n",
    "    min_cost_continuous_optim = min(ecost_ecorr_continuous_optim, key= lambda s: s[0])[0]\n",
    "    max_cost_continuous_optim = max(ecost_ecorr_continuous_optim, key= lambda s: s[0])[0]\n",
    "\n",
    "    min_cost_gridsearch = min(ecost_ecorr_gridsearch, key= lambda s: s[0])[0]\n",
    "    max_cost_gridsearch = max(ecost_ecorr_gridsearch, key= lambda s: s[0])[0]\n",
    "\n",
    "    min_cost_overall = max(min_cost_continuous_optim, min_cost_gridsearch)\n",
    "    max_cost_overall = min(max_cost_continuous_optim, max_cost_gridsearch)\n",
    "\n",
    "    auc_cts = compute_auc(*zip(*ecost_ecorr_continuous_optim), x_min=min_cost_overall, x_max=max_cost_overall)[0]\n",
    "    auc_grid = compute_auc(*zip(*ecost_ecorr_gridsearch), x_min=min_cost_overall, x_max=max_cost_overall)[0]\n",
    "\n",
    "    # Gather all the results: performance, time, resolution\n",
    "\n",
    "    record_cts = {\n",
    "        \"benchmark\": NAME,\n",
    "        \"cascade\": model_indices,\n",
    "        \"cascade_len\": len(model_indices),\n",
    "        \"method\": \"continuous_optimization\",\n",
    "        \"performance\": auc_cts,\n",
    "        \"auc_bounds\": [min_cost_overall, max_cost_overall],\n",
    "        \"auc_norm\": 1 - (auc_cts/(max_cost_overall-min_cost_overall)),\n",
    "        \"time\": continuous_optim_time,\n",
    "        \"n_grid\": len(opt_tholds_cts_optim),\n",
    "        \"data\": ecost_ecorr_continuous_optim,\n",
    "        \"scores\": scores_continuous_optim,\n",
    "        \"n_obs\": len(np.unique(next(zip(*ecost_ecorr_continuous_optim))))\n",
    "    }\n",
    "\n",
    "    record_grid = {\n",
    "        \"benchmark\": NAME,\n",
    "        \"cascade\": model_indices,\n",
    "        \"cascade_len\": len(model_indices),\n",
    "        \"method\": \"gridsearch\",\n",
    "        \"performance\": auc_grid,\n",
    "        \"auc_bounds\": [min_cost_overall, max_cost_overall],\n",
    "        \"auc_norm\": 1 - (auc_grid/(max_cost_overall-min_cost_overall)),\n",
    "        \"time\": grid_search_time,\n",
    "        \"n_grid\": len(opt_tholds_grid_search),\n",
    "        \"data\": ecost_ecorr_gridsearch,\n",
    "        \"scores\": scores_grid_search,\n",
    "        \"n_obs\": len(np.unique(next(zip(*ecost_ecorr_gridsearch))))\n",
    "    }\n",
    "\n",
    "    indices_of_these_records = [\n",
    "        i for i, record in enumerate(ALL_RECORDS) if \n",
    "            (record['benchmark'] == NAME) and (record['cascade'] == model_indices)\n",
    "    ]\n",
    "    if len(indices_of_these_records) > 0: # check if records already exist; if yes, overwrite\n",
    "        for idx in indices_of_these_records:\n",
    "            assert (ALL_RECORDS[idx]['benchmark'] == NAME) and ((ALL_RECORDS[idx]['cascade'] == model_indices))\n",
    "            if ALL_RECORDS[idx]['method'] == 'gridsearch':\n",
    "                ALL_RECORDS[idx] = record_grid\n",
    "            elif ALL_RECORDS[idx]['method'] == 'continuous_optimization':\n",
    "                ALL_RECORDS[idx] = record_cts\n",
    "    else: \n",
    "        ALL_RECORDS.append(record_cts)\n",
    "        ALL_RECORDS.append(record_grid)\n",
    "\n",
    "    plt.figure()\n",
    "    plt.title(\"->\".join([str(x) for x in model_indices]))\n",
    "    plt.scatter(*zip(*ecost_ecorr_gridsearch), label='grid')\n",
    "    plt.scatter(*zip(*ecost_ecorr_continuous_optim), label='cts')\n",
    "    plt.legend()\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "# Save all records to file\n",
    "if SAVE_TO_FILE:\n",
    "    with open(\"data/cascade_comparison_records.pkl\", \"wb\") as file:\n",
    "        pickle.dump(ALL_RECORDS, file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from optimize_cascade import profile_cascade, profile_cascade_adaptively, make_full_data, score_cascade\n",
    "from tqdm import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "GRID_RESOLUTIONS = [ 0.1, 0.05, 0.0333, 0.025, 0.02 ]\n",
    "SENSITIVITY_INCREASE_FACTORS = [ 3.0, 2.0, 1.5, 1.3, 1.2 ]\n",
    "\n",
    "ALL_RECORDS = []\n",
    "SAVE_TO_FILE = False\n",
    "\n",
    "for model_indices in tqdm(ALL_MODEL_INDICES):\n",
    "    data_train = calibrated_conf_train\n",
    "\n",
    "    for sens_increase, grid_h in zip(SENSITIVITY_INCREASE_FACTORS, GRID_RESOLUTIONS):\n",
    "        max_iter = 24*np.log(2)/np.log(sens_increase)\n",
    "\n",
    "        # Optimize the cascade on the data\n",
    "        start = time()\n",
    "        cascade_record = profile_cascade_adaptively(\n",
    "            model_indices, \n",
    "            expected_uncumulated_costs_train, \n",
    "            prob_results, \n",
    "            start_sensitivities=[0, 1e-10, 1e-8, 1e-6],\n",
    "            cost_threshold_multiplier=1.25,\n",
    "            stop_val=1000,\n",
    "            max_iterations=max_iter,\n",
    "            sensitivity_increase_factor=sens_increase\n",
    "        )\n",
    "        opt_tholds_cts_optim = cascade_record['optimal_thresholds']\n",
    "        original_n_grid = len(cascade_record['optimal_thresholds'])\n",
    "        # opt_tholds_cts_optim = fill_parameter_gaps(opt_tholds_cts_optim, max_gap=0.05)\n",
    "        opt_tholds_cts_optim = fill_parameter_gaps_adaptively(\n",
    "            opt_tholds_cts_optim, model_indices, data_train, max_prob_gap=0.1\n",
    "        )\n",
    "        stop = time()\n",
    "        continuous_optim_time = stop-start\n",
    "        print(f\"Continuous optimization took {continuous_optim_time}s\")\n",
    "\n",
    "        # Optimize the thresholds via grid search\n",
    "        quantile_h = grid_h\n",
    "        start = time()\n",
    "        pareto_df, opt_tholds_grid_search = get_optimal_thresholds_using_grid_search(model_indices, data_train, quantile_h)\n",
    "        stop = time()\n",
    "        grid_search_time = stop-start\n",
    "        print(f\"Grid search took {grid_search_time}s\")\n",
    "\n",
    "        # Gather all the results: performance, time, resolution\n",
    "\n",
    "        record_cts = {\n",
    "            \"benchmark\": NAME,\n",
    "            \"cascade\": model_indices,\n",
    "            \"cascade_len\": len(model_indices),\n",
    "            \"method\": \"continuous_optimization\",\n",
    "            \"time\": continuous_optim_time,\n",
    "            \"n_grid\": len(opt_tholds_cts_optim),\n",
    "            \"sens_increase\": sens_increase,\n",
    "            \"original_n_grid\": original_n_grid,\n",
    "        }\n",
    "\n",
    "        record_grid = {\n",
    "            \"benchmark\": NAME,\n",
    "            \"cascade\": model_indices,\n",
    "            \"cascade_len\": len(model_indices),\n",
    "            \"method\": \"gridsearch\",\n",
    "            \"time\": grid_search_time,\n",
    "            \"n_grid\": len(opt_tholds_grid_search),\n",
    "            \"full_grid_size\": len(opt_tholds_grid_search),\n",
    "            \"original_n_grid\": int(1/grid_h),\n",
    "        }\n",
    "\n",
    "        ALL_RECORDS.append(record_cts)\n",
    "        ALL_RECORDS.append(record_grid)\n",
    "\n",
    "\n",
    "# Save all records to file\n",
    "if SAVE_TO_FILE:\n",
    "    with open(\"data/cascade_runtime.pkl\", \"wb\") as file:\n",
    "        pickle.dump(ALL_RECORDS, file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import statsmodels.api as sm\n",
    "\n",
    "def fit_linear_trend_with_stderr(x, y, alpha=0.05, x_min=None, x_max=None):\n",
    "    \"\"\"\n",
    "    Fit linear regression and return predictions with standard errors at unique x values.\n",
    "    Returns prediction intervals for individual observations, not just mean prediction.\n",
    "    \"\"\"\n",
    "    # Convert to numpy arrays\n",
    "    X = np.array(x)\n",
    "    Y = np.array(y)\n",
    "    \n",
    "    # Get unique, sorted x values\n",
    "    x_unique = np.sort(np.unique(X))\n",
    "    \n",
    "    # Add x_min and x_max if specified\n",
    "    if x_min is not None:\n",
    "        x_unique = np.insert(x_unique, 0, x_min)\n",
    "    if x_max is not None:\n",
    "        x_unique = np.append(x_unique, x_max)\n",
    "    \n",
    "    # Add constant for statsmodels\n",
    "    X = sm.add_constant(X)\n",
    "    X_unique = sm.add_constant(x_unique)\n",
    "    \n",
    "    # Fit model\n",
    "    model = sm.OLS(Y, X)\n",
    "    results = model.fit()\n",
    "    \n",
    "    # Get predictions and standard errors\n",
    "    pred_ints = results.get_prediction(X_unique)\n",
    "    y_pred = pred_ints.predicted_mean\n",
    "    stderr = pred_ints.se_obs\n",
    "    \n",
    "    # Get prediction intervals for individual observations\n",
    "    pred_data = pred_ints.summary_frame(alpha=alpha)\n",
    "    pi_lower = pred_data['obs_ci_lower']\n",
    "    pi_upper = pred_data['obs_ci_upper']\n",
    "    \n",
    "    return x_unique, y_pred, stderr, pi_lower, pi_upper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import rcParams\n",
    "\n",
    "# Set seaborn style first\n",
    "sns.set_style(\"white\")\n",
    "sns.set_context(\"paper\", font_scale=1.0)\n",
    "\n",
    "# Then matplotlib settings\n",
    "rcParams['text.usetex'] = False #True\n",
    "rcParams['font.family'] = 'serif'\n",
    "rcParams['font.serif'] = ['Computer Modern Roman']\n",
    "rcParams['font.size'] = 10\n",
    "\n",
    "df_time = pd.DataFrame(ALL_RECORDS)\n",
    "df_time_grid = df_time[df_time['method'] == 'gridsearch']\n",
    "df_time_cts = df_time[df_time['method'] == 'continuous_optimization']\n",
    "\n",
    "# Get mean runtimes for each cascade len\n",
    "mean_time_cts = df_time_cts.groupby(by='cascade_len')['time'].mean()\n",
    "mean_time_grid = df_time_grid.groupby(by='cascade_len')['time'].mean()\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(5,4))\n",
    "ax.set_xscale('log')\n",
    "ax.set_yscale('log')\n",
    "\n",
    "color_cts = 'tab:blue'\n",
    "for k in range(2,5+1):\n",
    "    k_cascades = df_time_cts['cascade_len'] == k\n",
    "    x_unique, y_pred, std_error, pi_lower, pi_upper = fit_linear_trend_with_stderr(\n",
    "        np.log(df_time_cts[k_cascades]['original_n_grid']), \n",
    "        np.log(df_time_cts[k_cascades]['time']),\n",
    "        x_min = np.log(10),\n",
    "        x_max = np.log(100)\n",
    "    )\n",
    "    ax.plot(np.exp(x_unique), np.exp(y_pred), color=color_cts, linewidth=1)\n",
    "    ax.fill_between(\n",
    "        np.exp(x_unique), \n",
    "        np.exp(y_pred - std_error),  # Lower bound\n",
    "        np.exp(y_pred + std_error),  # Upper bound\n",
    "        color=color_cts, \n",
    "        alpha=0.2,  # Transparency\n",
    "    )\n",
    "\n",
    "color_grid = 'gray'\n",
    "for k in range(2,5+1):\n",
    "    k_cascades = df_time_grid['cascade_len'] == k\n",
    "    x_unique, y_pred, std_error, pi_lower, pi_upper = fit_linear_trend_with_stderr(\n",
    "        np.log(df_time_grid[k_cascades]['original_n_grid']), \n",
    "        np.log(df_time_grid[k_cascades]['time']),\n",
    "        x_min = np.log(10),\n",
    "        x_max = np.log(100)\n",
    "    )\n",
    "    ax.plot(np.exp(x_unique), np.exp(y_pred), color=color_grid, linewidth=1, zorder=-1)\n",
    "    ax.fill_between(\n",
    "        np.exp(x_unique), \n",
    "        np.exp(y_pred - std_error),  # Lower bound\n",
    "        np.exp(y_pred + std_error),  # Upper bound\n",
    "        color=color_grid, \n",
    "        alpha=0.2,  # Transparency\n",
    "    )\n",
    "\n",
    "ax.text(50, 2500, \"$k=5$\", color=color_grid, fontweight='bold').set_rotation(24)\n",
    "ax.text(60, 0.0045, \"$k=2$\", color=color_grid, fontweight='bold').set_rotation(3.5)\n",
    "ax.text(10, 23, \"$k=5$\", color=color_cts, fontweight='bold').set_rotation(3)\n",
    "ax.text(10, 0.06, \"$k=2$\", color=color_cts, fontweight='bold').set_rotation(3)\n",
    "\n",
    "ax.text(9.9, 1.0, \"continuous\", color=color_cts, fontweight='bold').set_rotation(5)\n",
    "ax.text(48, 500, \"grid search\", color=color_grid, fontweight='bold').set_rotation(22)\n",
    "# ax.spines[\"top\"].set_visible(False)\n",
    "# ax.spines[\"right\"].set_visible(False)\n",
    "ax.set_ylabel(\"Runtime (s)\", fontsize=14)\n",
    "ax.set_xlabel('Resolution of Cost-Error Curve', fontsize=14)\n",
    "\n",
    "for tick in ax.get_xticklabels():\n",
    "    print(f\"Position: {tick.get_position()}, Text: {tick.get_text()}\")\n",
    "# Then, for just the visible ticks\n",
    "visible_ticks = [tick.get_position()[0] for tick in ax.get_xticklabels() if tick.get_visible()]\n",
    "visible_labels = [f'$\\\\mathdefault{{{1/x:.2f}}}$' for x in visible_ticks]\n",
    "\n",
    "ax.set_xticks(visible_ticks)\n",
    "ax.set_xticklabels(visible_labels)\n",
    "ax.set_xlim([9,110])\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "ax.set_title(\"Continuous optimization scales linearly independent of cascade length $k$\", fontsize=14)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
