{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44464e98",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "from utils.plotting import *\n",
    "import pickle\n",
    "import itertools\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e977edde",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_tests = [30, 135]\n",
    "p_x_y_0s = [0.2, 0.35]\n",
    "p_x_y_1s = [0.9, 0.65]\n",
    "\n",
    "p_x_y_0_to_corr = {0.2: \"high corr\", 0.35: \"medium corr\"}\n",
    "\n",
    "\n",
    "base_path=\"simulation_res/hyperparameter_tunning/label_shift_new_version/\"\n",
    "axs = []\n",
    "for M, w, n_test, (p_x_y_0, p_x_y_1) in itertools.product(\n",
    "        [128], [0], n_tests, zip(p_x_y_0s, p_x_y_1s)\n",
    "    ):\n",
    "    with open(f\"{base_path}sim_M_{M}_w_{w}_n_test_{n_test}_p_x_y_0_{p_x_y_0}_p_x_y_1_{p_x_y_1}.pkl\", \"rb\") as f:\n",
    "        data = pickle.load(f)\n",
    "        print(data['resulting_powers'])\n",
    "        # nb_keys = [k for k in data[\"resulting_powers\"] if k.startswith('NB')]\n",
    "        # max_test = max(nb_keys, key=data[\"resulting_powers\"].get)\n",
    "        # print(f\"Best NB key: {max_test} with power {data['resulting_powers'][max_test]}\")\n",
    "        print(\n",
    "            f\"Power vs Steps - Label shift, ntest={n_test}, corr={p_x_y_0_to_corr[p_x_y_0]}\"\n",
    "        )\n",
    "        fig, ax = plot_power_vs_steps_paper(\n",
    "            data[\"rejection_lists\"],\n",
    "            # f\"Power vs Steps - Label shift, ntest={n_test}, corr={p_x_y_0_to_corr[p_x_y_0]}\",\n",
    "            save_path=f\"results/simulations/label_shift/ntest_{n_test}_corr_{p_x_y_0_to_corr[p_x_y_0]}_power_vs_steps_2.pdf\",\n",
    "        )\n",
    "        axs.append(ax)\n",
    "\n",
    "if axs:\n",
    "    save_legend(axs, \"results/simulations/label_shift/legend.pdf\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8dd9891",
   "metadata": {},
   "outputs": [],
   "source": [
    "#plot baselines\n",
    "n_tests = [30, 135]\n",
    "p_x_y_0s = [0.2, 0.35]\n",
    "p_x_y_1s = [0.9, 0.65]\n",
    "\n",
    "p_x_y_0_to_corr = {0.2: \"high corr\", 0.35: \"medium corr\"}\n",
    "\n",
    "\n",
    "base_path = \"simulation_res/hyperparameter_tunning/label_shift_new_version/\"\n",
    "axs = []\n",
    "for M, w, n_test, (p_x_y_0, p_x_y_1) in itertools.product(\n",
    "    [128], [0], n_tests, zip(p_x_y_0s, p_x_y_1s)\n",
    "):\n",
    "    with open(\n",
    "        f\"{base_path}sim_M_{M}_w_{w}_n_test_{n_test}_p_x_y_0_{p_x_y_0}_p_x_y_1_{p_x_y_1}.pkl\",\n",
    "        \"rb\",\n",
    "    ) as f:\n",
    "        data = pickle.load(f)\n",
    "        print(data[\"resulting_powers\"])\n",
    "        # nb_keys = [k for k in data[\"resulting_powers\"] if k.startswith('NB')]\n",
    "        # max_test = max(nb_keys, key=data[\"resulting_powers\"].get)\n",
    "        # print(f\"Best NB key: {max_test} with power {data['resulting_powers'][max_test]}\")\n",
    "        print(\n",
    "            f\"Power vs Steps - Label shift, ntest={n_test}, corr={p_x_y_0_to_corr[p_x_y_0]}\"\n",
    "        )\n",
    "        fig, ax = plot_power_vs_steps_baselines(\n",
    "            data[\"rejection_lists\"],\n",
    "            # f\"Power vs Steps - Label shift, ntest={n_test}, corr={p_x_y_0_to_corr[p_x_y_0]}\",\n",
    "            save_path=f\"results/simulations/label_shift/ntest_{n_test}_corr_{p_x_y_0_to_corr[p_x_y_0]}_power_vs_steps_baselines.pdf\",\n",
    "        )\n",
    "        axs.append(ax)\n",
    "\n",
    "if axs:\n",
    "    save_legend(axs, \"results/simulations/label_shift/legend_baselines.pdf\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47416aab",
   "metadata": {},
   "outputs": [],
   "source": [
    "corr_settings = [\"medium_corr_b\", \"high_corr\"]\n",
    "\n",
    "\n",
    "base_path = \"simulation_res/hyperparameter_tunning/concept_shift_new_version\"\n",
    "axs = []\n",
    "for corr in corr_settings:\n",
    "    with open(\n",
    "        f\"{base_path}/{corr}_high_ntest_optimizing_growth_rate.pkl\",\n",
    "        \"rb\",\n",
    "    ) as f:\n",
    "        data = pickle.load(f)\n",
    "        print(data[\"powers\"])\n",
    "        print(data[\"rejection_lists\"].keys())\n",
    "        print(\n",
    "            f\"Power vs Steps - Concept shift, ntest={135}, corr={corr}\"\n",
    "        )\n",
    "        fig, ax = plot_power_vs_steps_paper(\n",
    "            data[\"rejection_lists\"],\n",
    "            # f\"Power vs Steps - Concept shift, ntest={135}, corr={corr}\",\n",
    "            save_path=f\"results/simulations/concept_shift/ntest_135_corr_{corr}_power_vs_steps_2.pdf\",\n",
    "        )\n",
    "        axs.append(ax)\n",
    "\n",
    "if axs:\n",
    "    save_legend(axs, \"results/simulations/concept_shift/legend.pdf\", font_size=40)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24680fe1",
   "metadata": {},
   "outputs": [],
   "source": [
    "corr_settings = [\"medium_corr_b\", \"high_corr\"]\n",
    "\n",
    "\n",
    "base_path = \"simulation_res/hyperparameter_tunning/concept_shift_new_version\"\n",
    "axs = []\n",
    "for corr in corr_settings:\n",
    "    with open(\n",
    "        f\"{base_path}/{corr}_high_ntest_optimizing_growth_rate.pkl\",\n",
    "        \"rb\",\n",
    "    ) as f:\n",
    "        data = pickle.load(f)\n",
    "        print(data[\"powers\"])\n",
    "        print(data[\"rejection_lists\"].keys())\n",
    "        print(f\"Power vs Steps - Concept shift, ntest={135}, corr={corr}\")\n",
    "        fig, ax = plot_power_vs_steps_baselines(\n",
    "            data[\"rejection_lists\"],\n",
    "            # f\"Power vs Steps - Concept shift, ntest={135}, corr={corr}\",\n",
    "            save_path=f\"results/simulations/concept_shift/ntest_135_corr_{corr}_power_vs_steps_baselines.pdf\",\n",
    "        )\n",
    "        axs.append(ax)\n",
    "\n",
    "if axs:\n",
    "    save_legend(axs, \"results/simulations/concept_shift/legend_baselines.pdf\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e003bc34",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot LLM eavlaution Validity Results\n",
    "statistic_to_rejections = {}\n",
    "with open(\n",
    "    f\"simulation_res/real_data/judge_label_shift/M_128_n_test_100_validity.pkl\",\n",
    "    \"rb\",\n",
    ") as f:\n",
    "    data = pickle.load(f)\n",
    "statistic_to_rejections[\"label_shift_high\"] = data[\"rejection_lists\"][\"NB_max_predictions_only\"]\n",
    "\n",
    "with open(\n",
    "    f\"simulation_res/real_data/judge_label_shift/M_128_n_test_100_validity.pkl\",\n",
    "    \"rb\",\n",
    ") as f:\n",
    "    data = pickle.load(f)\n",
    "statistic_to_rejections[\"label_shift_low\"] = data[\"rejection_lists\"][\"NB_max_predictions_only\"]\n",
    "\n",
    "with open(\n",
    "    f\"simulation_res/real_data/math_concept_shift/n_train_10_n_test_90_validity.pkl\",\n",
    "    \"rb\",\n",
    ") as f:\n",
    "    data = pickle.load(f)\n",
    "statistic_to_rejections[\"concept_shift\"] = data[\"rejection_lists\"][\"NB_predictions_only\"]\n",
    "\n",
    "plot_power_vs_steps_validity(\n",
    "    statistic_to_rejections,\n",
    "    save_path=f\"results/real_data/validity.pdf\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c34e8c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot PPI comparison results\n",
    "n_tests = [30, 135]\n",
    "p_x_y_0s = [0.2, 0.35]\n",
    "p_x_y_1s = [0.9, 0.65]\n",
    "\n",
    "p_x_y_0_to_corr = {0.2: \"high corr\", 0.35: \"medium corr\"}\n",
    "\n",
    "\n",
    "base_path = \"simulation_res/hyperparameter_tunning/label_shift_new_version/\"\n",
    "axs = []\n",
    "for M, w, n_test, (p_x_y_0, p_x_y_1) in itertools.product(\n",
    "    [128], [0], n_tests, zip(p_x_y_0s, p_x_y_1s)\n",
    "):\n",
    "    with open(\n",
    "        f\"{base_path}sim_M_{M}_w_{w}_n_test_{n_test}_p_x_y_0_{p_x_y_0}_p_x_y_1_{p_x_y_1}.pkl\",\n",
    "        \"rb\",\n",
    "    ) as f:\n",
    "        data = pickle.load(f)\n",
    "        print(data[\"resulting_powers\"])\n",
    "        # nb_keys = [k for k in data[\"resulting_powers\"] if k.startswith('NB')]\n",
    "        # max_test = max(nb_keys, key=data[\"resulting_powers\"].get)\n",
    "        # print(f\"Best NB key: {max_test} with power {data['resulting_powers'][max_test]}\")\n",
    "        print(\n",
    "            f\"Power vs Steps - Label shift, ntest={n_test}, corr={p_x_y_0_to_corr[p_x_y_0]}\"\n",
    "        )\n",
    "        fig, ax = plot_power_vs_steps_ppi_compared_labeled_data(\n",
    "            data[\"rejection_lists\"],\n",
    "            # f\"Power vs Steps - Label shift, ntest={n_test}, corr={p_x_y_0_to_corr[p_x_y_0]}\",\n",
    "            save_path=f\"results/simulations/label_shift/ntest_{n_test}_corr_{p_x_y_0_to_corr[p_x_y_0]}_power_vs_steps_PPI_compared_to_labeled_data.pdf\",\n",
    "        )\n",
    "        axs.append(ax)\n",
    "\n",
    "if axs:\n",
    "    save_legend(axs, \"results/simulations/label_shift/legend_ppi_compared_labeled_data.pdf\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b0a6dfc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot onde sided PPI comparison results\n",
    "n_tests = [30, 135]\n",
    "p_x_y_0s = [0.2, 0.35]\n",
    "p_x_y_1s = [0.9, 0.65]\n",
    "\n",
    "p_x_y_0_to_corr = {0.2: \"high corr\", 0.35: \"medium corr\"}\n",
    "\n",
    "\n",
    "base_path = \"simulation_res/hyperparameter_tunning/label_shift_new_version/\"\n",
    "axs = []\n",
    "for M, w, n_test, (p_x_y_0, p_x_y_1) in itertools.product(\n",
    "    [128], [0], n_tests, zip(p_x_y_0s, p_x_y_1s)\n",
    "):\n",
    "    with open(\n",
    "        f\"{base_path}sim_M_{M}_w_{w}_n_test_{n_test}_p_x_y_0_{p_x_y_0}_p_x_y_1_{p_x_y_1}.pkl\",\n",
    "        \"rb\",\n",
    "    ) as f:\n",
    "        data = pickle.load(f)\n",
    "        print(data[\"resulting_powers\"])\n",
    "        # nb_keys = [k for k in data[\"resulting_powers\"] if k.startswith('NB')]\n",
    "        # max_test = max(nb_keys, key=data[\"resulting_powers\"].get)\n",
    "        # print(f\"Best NB key: {max_test} with power {data['resulting_powers'][max_test]}\")\n",
    "        print(\n",
    "            f\"Power vs Steps - Label shift, ntest={n_test}, corr={p_x_y_0_to_corr[p_x_y_0]}\"\n",
    "        )\n",
    "        fig, ax = plot_power_vs_steps_one_sided_ppi(\n",
    "            data[\"rejection_lists\"],\n",
    "            # f\"Power vs Steps - Label shift, ntest={n_test}, corr={p_x_y_0_to_corr[p_x_y_0]}\",\n",
    "            save_path=f\"results/simulations/label_shift/ntest_{n_test}_corr_{p_x_y_0_to_corr[p_x_y_0]}_power_vs_steps_one_sided_PPI.pdf\",\n",
    "        )\n",
    "        axs.append(ax)\n",
    "\n",
    "if axs:\n",
    "    save_legend(axs, \"results/simulations/label_shift/legend_one_sided_ppi.pdf\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c890d4f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot Hyper parameter tuning results\n",
    "# plot PPI comparison results\n",
    "Ms = [2, 5, 32, 128]\n",
    "\n",
    "\n",
    "base_path = \"simulation_res/hyperparameter_tunning/label_shift_hyperparams/\"\n",
    "statistic_to_rejections = {}\n",
    "\n",
    "for M in Ms:\n",
    "    with open(\n",
    "        f\"{base_path}sim_M_{M}_w_0_n_test_30_p_x_y_0_0.35_p_x_y_1_0.65.pkl\",\n",
    "        \"rb\",\n",
    "    ) as f:\n",
    "        data = pickle.load(f)\n",
    "        print(data[\"resulting_powers\"])\n",
    "        # nb_keys = [k for k in data[\"resulting_powers\"] if k.startswith('NB')]\n",
    "        # max_test = max(nb_keys, key=data[\"resulting_powers\"].get)\n",
    "        # print(f\"Best NB key: {max_test} with power {data['resulting_powers'][max_test]}\")\n",
    "        print(\n",
    "            f\"Power vs Steps - Label shift, M={M}\"\n",
    "        )\n",
    "        statistic_to_rejections[f\"{M=}\"] = data[\"rejection_lists\"][\"NB_max_predictions_only\"]\n",
    "        # plot_power_vs_steps_one_sided_ppi(\n",
    "        #     data[\"rejection_lists\"],\n",
    "        #     # f\"Power vs Steps - Label shift, ntest={n_test}, corr={p_x_y_0_to_corr[p_x_y_0]}\",\n",
    "            # save_path=f\"results/simulations/label_shift/ntest_{n_test}_corr_{p_x_y_0_to_corr[p_x_y_0]}_power_vs_steps_PPI.pdf\",\n",
    "        # )\n",
    "plot_hyperparameter_tuning(\n",
    "    statistic_to_rejections, \n",
    "    save_path=f\"results/simulations/label_shift/tuning_M.pdf\",)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e74a1ee2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# bad example\n",
    "\n",
    "orr_settings = [\"medium_corr_b\", \"high_corr\"]\n",
    "\n",
    "\n",
    "base_path = \"simulation_res/hyperparameter_tunning/concept_shift_new_version\"\n",
    "axs = []\n",
    "for corr in corr_settings:\n",
    "    with open(\n",
    "        f\"{base_path}/{corr}_non_monotone.pkl\",\n",
    "        \"rb\",\n",
    "    ) as f:\n",
    "        data = pickle.load(f)\n",
    "        print(data[\"powers\"])\n",
    "        print(data[\"rejection_lists\"].keys())\n",
    "        print(f\"Power vs Steps - Concept shift, ntest={135}, corr={corr}\")\n",
    "        fig, ax = plot_power_vs_steps_paper(\n",
    "            data[\"rejection_lists\"],\n",
    "            # f\"Power vs Steps - Concept shift, ntest={135}, corr={corr}\",\n",
    "            save_path=f\"results/simulations/concept_shift/ntest_135_corr_{corr}_non_monotone.pdf\",\n",
    "        )\n",
    "        axs.append(ax)\n",
    "\n",
    "if axs:\n",
    "    save_legend(axs, \"results/simulations/concept_shift/non_monotone_legend.pdf\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14c6e627",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Real World LLM eval\n",
    "axs = []\n",
    "data_to_save_paths = {\n",
    "    f\"simulation_res/real_data/judge_label_shift/M_128_n_test_10_2.pkl\": f\"results/real_data/judge_label_shift/n_train_10_ntest_10_power_vs_steps_2.pdf\",\n",
    "    f\"simulation_res/real_data/judge_label_shift/M_128_n_test_100_2.pkl\": f\"results/real_data/judge_label_shift/n_train_10_ntest_100_power_vs_steps_2.pdf\",\n",
    "    f\"simulation_res/real_data/math_concept_shift/n_train_10_n_test_90_2.pkl\": f\"results/real_data/math_concept_shift/n_train_10_ntest_90_corr_0.2_power_vs_steps_2.pdf\",\n",
    "}\n",
    "for data_path, save_path in data_to_save_paths.items():\n",
    "    with open(\n",
    "        data_path,\n",
    "        \"rb\",\n",
    "    ) as f:\n",
    "        data = pickle.load(f)\n",
    "        print(data[\"resulting_powers\"])\n",
    "        corr = data[\"null_corr\"]\n",
    "        print(data_path)\n",
    "        fig, ax = plot_power_vs_steps_paper(\n",
    "            data[\"rejection_lists\"],\n",
    "            save_path=save_path,\n",
    "        )\n",
    "        axs.append(ax)\n",
    "\n",
    "if axs:\n",
    "    save_legend(axs, \"results/real_data/legend.pdf\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ml_hypothesis",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
