{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4866cf9",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import multiprocessing as mp\n",
    "import itertools\n",
    "\n",
    "# %aimport utils.data_sampler, utils.ada_grad, utils.plotting, simulations\n",
    "from utils.data_sampler import MathJudgeLabelShiftDataSampler\n",
    "# from utils.ada_grad import AdaGrad\n",
    "from utils.plotting import plot_power_vs_steps, plot_power_vs_steps_paper, save_legend\n",
    "import simulations\n",
    "import json\n",
    "import pickle\n",
    "from pathlib import Path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2bd05f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "null_sampler = MathJudgeLabelShiftDataSampler(\n",
    "    model_name=\"deepseek-ai_DeepSeek-R1-Distill-Llama-8B\",\n",
    "    # model_name=\"deepseek-ai_DeepSeek-R1-Distill-Qwen-7B\",\n",
    "    seed=0,\n",
    "    hypothesis=\"null\",\n",
    "    target_mean=0.58,\n",
    "    target_mean_x=0.55,\n",
    ")\n",
    "alt_sampler = MathJudgeLabelShiftDataSampler(\n",
    "    # model_name=\"deepseek-ai_DeepSeek-R1-Distill-Qwen-7B\",\n",
    "    model_name=\"deepseek-ai_DeepSeek-R1-Distill-Llama-8B\",\n",
    "    hypothesis=\"alt\",\n",
    "    seed=0,\n",
    "    # target_mean=0.58,\n",
    "    # target_mean_x=0.55,\n",
    ")\n",
    "\n",
    "# null_sampler = MathSharedDependencyDataSampler(\n",
    "#     model_name=\"deepseek-ai_DeepSeek-R1-Distill-Qwen-1_5B\", hypothesis=\"null\"\n",
    "# )\n",
    "# alt_sampler = MathSharedDependencyDataSampler(\n",
    "#     \"agentica-org_DeepScaleR-1_5B-Preview\", hypothesis=\"alt\"\n",
    "# )\n",
    "# print(f\"{null_sampler.p_x=}\")\n",
    "print(f\"{null_sampler.p_x=}\")\n",
    "print(f\"{null_sampler.data['x'].mean()=}\")\n",
    "print(f\"{alt_sampler.data['x'].mean()=}\")\n",
    "print(f\"{null_sampler.data['y'].mean()=}\")\n",
    "print(f\"{alt_sampler.data['y'].mean()=}\")\n",
    "print(f\"{null_sampler.p_y=}\")\n",
    "print(f\"{null_sampler.p_x=}\")\n",
    "print(f\"{null_sampler.p_x_y_0=}\")\n",
    "print(f\"{alt_sampler.p_x_y_0=}\")\n",
    "print(f\"{null_sampler.p_x_y_1=}\")\n",
    "print(f\"{alt_sampler.p_x_y_1=}\")\n",
    "print(alt_sampler.data.corr())\n",
    "print(f\"{null_sampler.T_0=}\")\n",
    "print(f\"{null_sampler.T_1=}\")\n",
    "\n",
    "null_sampler.data.size, alt_sampler.data.size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec73e804",
   "metadata": {},
   "outputs": [],
   "source": [
    "with mp.Pool(processes=120) as pool:\n",
    "    for n_test in [10, 100]:\n",
    "        config = {\n",
    "            \"n_training\": 10,\n",
    "            \"n_test\": n_test,\n",
    "            \"num_of_repititions\": 500,\n",
    "            \"alpha\": 0.05,\n",
    "            \"nsteps\": 500,  # 600,   n_training*steps =6K\n",
    "            \"M\": 128,\n",
    "            \"k_classes_x\": 1,  # the number of categories of X (1 - bernoulli, >1 - binomial)\n",
    "            \"p_y_null\": 0.58,  # 0.075\n",
    "            \"p_y_alt\": None,\n",
    "            \"p_x_null\": 0.55,\n",
    "            \"history_length\": 0,\n",
    "            \"min_history_length\": 0,\n",
    "            \"settings\": \"math_judge_label_shift\",\n",
    "        }\n",
    "        powers, null_corr, alt_corr, rejection_by_test_results, _, _ = (\n",
    "            simulations.run_simulation(config, pool)\n",
    "        )\n",
    "        data = {\n",
    "            \"rejection_lists\": rejection_by_test_results,\n",
    "            \"config\": config,\n",
    "            \"null_corr\": null_corr,\n",
    "            \"alt_corr\": alt_corr,\n",
    "            \"resulting_powers\": powers,\n",
    "        }\n",
    "\n",
    "        output_file = f\"simulation_res/real_data/judge_label_shift/M_{config['M']}_n_test_{n_test}_2.pkl\"\n",
    "        path = Path(output_file)\n",
    "        path.parent.mkdir(parents=True, exist_ok=True)\n",
    "        with path.open(\"wb\") as f:\n",
    "            print(f\"saving file: {output_file}\")\n",
    "            pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e47781fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Validity\n",
    "with mp.Pool(processes=250) as pool:\n",
    "    for n_test in [10, 100]:\n",
    "        config = {\n",
    "            \"n_training\": 10,\n",
    "            \"n_test\": n_test,\n",
    "            \"num_of_repititions\": 500,\n",
    "            \"alpha\": 0.05,\n",
    "            \"nsteps\": 500,  # 600,   n_training*steps =6K\n",
    "            \"M\": 128,\n",
    "            \"k_classes_x\": 1,  # the number of categories of X (1 - bernoulli, >1 - binomial)\n",
    "            \"history_length\": 0,\n",
    "            \"min_history_length\": 0,\n",
    "            \"p_y_null\" : 0.58,\n",
    "            \"settings\": \"math_judge_label_shift_validity\",\n",
    "        }\n",
    "        powers, null_corr, alt_corr, rejection_by_test_results, _, _ = (\n",
    "            simulations.run_simulation(config, pool)\n",
    "        )\n",
    "        data = {\n",
    "            \"rejection_lists\": rejection_by_test_results,\n",
    "            \"config\": config,\n",
    "            \"null_corr\": null_corr,\n",
    "            \"alt_corr\": alt_corr,\n",
    "            \"resulting_powers\": powers,\n",
    "        }\n",
    "\n",
    "        output_file = f\"simulation_res/real_data/judge_label_shift/M_{config['M']}_n_test_{n_test}_validity.pkl\"\n",
    "        path = Path(output_file)\n",
    "        path.parent.mkdir(parents=True, exist_ok=True)\n",
    "        with path.open(\"wb\") as f:\n",
    "            print(f\"saving file: {output_file}\")\n",
    "            pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3130bca5",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_file = f\"simulation_res/real_data/judge_label_shift/M_{config['M']}_n_test_{config['n_test']}.pkl\"\n",
    "path = Path(output_file)\n",
    "path.parent.mkdir(parents=True, exist_ok=True)\n",
    "with path.open(\"wb\") as f:\n",
    "    pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3130bca5",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_train=10\n",
    "axs = []\n",
    "for n_test in [10, 100]:\n",
    "    with open(\n",
    "        f\"simulation_res/real_data/judge_label_shift/M_128_n_test_{n_test}_2.pkl\",\n",
    "        \"rb\",\n",
    "    ) as f:\n",
    "        data = pickle.load(f)\n",
    "        print(data[\"resulting_powers\"])\n",
    "        corr = data[\"null_corr\"]\n",
    "        # nb_keys = [k for k in data[\"resulting_powers\"] if k.startswith('NB')]\n",
    "        # max_test = max(nb_keys, key=data[\"resulting_powers\"].get)\n",
    "        # print(f\"Best NB key: {max_test} with power {data['resulting_powers'][max_test]}\")\n",
    "        print(f\"Power vs Steps - Label shift, ntest={n_test}, corr={corr}\")\n",
    "        fig, ax = plot_power_vs_steps_paper(\n",
    "            data[\"rejection_lists\"],\n",
    "            # f\"Power vs Steps - Label shift, ntest={n_test}, corr={p_x_y_0_to_corr[p_x_y_0]}\",\n",
    "            save_path=f\"results/real_data/judge_label_shift/n_train_{n_train}_ntest_{n_test}_corr_{corr}_power_vs_steps_2.pdf\",\n",
    "        )\n",
    "        axs.append(ax)\n",
    "\n",
    "if axs:\n",
    "    save_legend(axs, \"results/real_data/judge_label_shift/legend.pdf\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c598b7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "with mp.Pool(processes=250) as pool:\n",
    "    for n_test in [100]:\n",
    "        config = {\n",
    "            \"n_training\": 10,\n",
    "            \"n_test\": n_test,\n",
    "            \"num_of_repititions\": 500,\n",
    "            \"alpha\": 0.05,\n",
    "            \"nsteps\": 500,  # 600,   n_training*steps =6K\n",
    "            \"M\": 128,\n",
    "            \"k_classes_x\": 1,  # the number of categories of X (1 - bernoulli, >1 - binomial)\n",
    "            \"p_y_null\": 0.58,  # 0.075\n",
    "            \"p_y_alt\": 0.58,\n",
    "            \"p_x_null\": 0.55,\n",
    "            \"p_x_alt\": 0.55,\n",
    "            \"history_length\": 0,\n",
    "            \"min_history_length\": 0,\n",
    "            \"settings\": \"math_judge_label_shift\",\n",
    "        }\n",
    "        powers, null_corr, alt_corr, rejection_by_test_results, _, _ = (\n",
    "            simulations.run_simulation(config, pool)\n",
    "        )\n",
    "        data = {\n",
    "            \"rejection_lists\": rejection_by_test_results,\n",
    "            \"config\": config,\n",
    "            \"null_corr\": null_corr,\n",
    "            \"alt_corr\": alt_corr,\n",
    "            \"resulting_powers\": powers,\n",
    "        }\n",
    "\n",
    "        output_file = f\"simulation_res/real_data/judge_label_shift/M_{config['M']}_n_test_{n_test}_validity.pkl\"\n",
    "        path = Path(output_file)\n",
    "        path.parent.mkdir(parents=True, exist_ok=True)\n",
    "        with path.open(\"wb\") as f:\n",
    "            print(f\"saving file: {output_file}\")\n",
    "            pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c1c6b8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\n",
    "    f\"simulation_res/real_data/judge_label_shift/M_128_n_test_100_validity.pkl\",\n",
    "    \"rb\",\n",
    ") as f:\n",
    "    data = pickle.load(f)\n",
    "    print(data[\"resulting_powers\"])\n",
    "    corr = data[\"null_corr\"]\n",
    "    # nb_keys = [k for k in data[\"resulting_powers\"] if k.startswith('NB')]\n",
    "    # max_test = max(nb_keys, key=data[\"resulting_powers\"].get)\n",
    "    # print(f\"Best NB key: {max_test} with power {data['resulting_powers'][max_test]}\")\n",
    "    # print(f\"Power vs Steps - Label shift, ntest={n_test}, corr={corr}\")\n",
    "    plot_power_vs_steps_paper(\n",
    "        data[\"rejection_lists\"],\n",
    "        # f\"Power vs Steps - Label shift, ntest={n_test}, corr={p_x_y_0_to_corr[p_x_y_0]}\",\n",
    "        # save_path=f\"results/real_data/judge_label_shift/n_train_{n_train}_ntest_{n_test}_corr_{corr}_power_vs_steps_2.pdf\",\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b3169c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "powers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2bbf250",
   "metadata": {},
   "outputs": [],
   "source": [
    "# config = {\n",
    "#             \"n_training\": 15,\n",
    "#             \"n_test\": 15,\n",
    "#             \"num_of_repititions\": 500,\n",
    "#             \"alpha\": 0.05,\n",
    "#             \"nsteps\": 500,  # 600,   n_training*steps =6K\n",
    "#             \"M\": 32,\n",
    "#             \"k_classes_x\": 1,  # the number of categories of X (1 - bernoulli, >1 - binomial)\n",
    "#             \"p_y_null\": 0.58,  # 0.075\n",
    "#             \"p_y_alt\": None,\n",
    "#             \"p_x_null\": 0.55,\n",
    "#             \"history_length\": 0,\n",
    "#             \"min_history_length\": 0,\n",
    "#             \"settings\": \"math_judge_label_shift\",\n",
    "# }\n",
    "\n",
    "plot_power_vs_steps(\n",
    "    rejection_by_test_results,\n",
    "      \"Power vs Steps - Math Judge Label Shift, ntest=15, ntrain=15, corr=0.76\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e053693a",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_power_vs_steps(\n",
    "    rejection_by_test_results, \"Power vs Steps - Math Judge Label Shift, ntest=10, n_train=10\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e1e1b7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_power_vs_steps(\n",
    "    rejection_by_test_results,\n",
    "    \"Power vs Steps - Math Judge Label Shift, ntest=100, n_train=10\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9671a49d",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_power_vs_steps(\n",
    "    rejection_by_test_results, \"Power vs Steps - Math Judge Label Shift, ntest=50\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3bfa3f94",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_power_vs_steps(\n",
    "    rejection_by_test_results,\n",
    "    \"Power vs Steps - Math Judge Label Shift, ntest=105, ntrain=15, corr=0.76\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18733e9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_power_vs_steps(\n",
    "    rejection_by_test_results, \"Power vs Steps - Math Judge Label Shift, ntest=15\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99046de0",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_power_vs_steps(rejection_by_test_results, \"Power vs Steps - Math Judge Label Shift, ntest=30\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c64fd67",
   "metadata": {},
   "outputs": [],
   "source": [
    "# config = {\n",
    "#             \"n_training\": 15,\n",
    "#             \"n_test\": 30,\n",
    "#             \"num_of_repititions\": 500,\n",
    "#             \"alpha\": 0.05,\n",
    "#             \"nsteps\": 500,  # 600,   n_training*steps =6K\n",
    "#             \"M\": 32,\n",
    "#             \"k_classes_x\": 1,  # the number of categories of X (1 - bernoulli, >1 - binomial)\n",
    "#             \"p_y_null\": 0.57,  # 0.075\n",
    "#             \"p_y_alt\": None,\n",
    "#             \"history_length\": 0,\n",
    "#             \"min_history_length\": 0,\n",
    "#             \"settings\": \"math_judge_label_shift\",\n",
    "# }"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ml_hypothesis",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
