{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d439a720",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import multiprocessing as mp\n",
    "import itertools\n",
    "\n",
    "# %aimport utils.data_sampler, utils.ada_grad, utils.plotting, simulations\n",
    "from utils.data_sampler import MathSharedDependencyDataSampler\n",
    "# from utils.ada_grad import AdaGrad\n",
    "from utils.plotting import plot_power_vs_steps_concept_shift, plot_power_vs_steps_paper\n",
    "import simulations\n",
    "import json\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48267ef2",
   "metadata": {},
   "outputs": [],
   "source": [
    "null_sampler = MathSharedDependencyDataSampler(\n",
    "    model_name=\"deepseek-ai_DeepSeek-R1-Distill-Llama-8B\",\n",
    "    seed=0,\n",
    "    target_mean=0.55,\n",
    "    target_p_y_x_0=0.8,\n",
    "    target_p_y_x_1=0.5,\n",
    ")\n",
    "alt_sampler = MathSharedDependencyDataSampler(\n",
    "    \"deepseek-ai_DeepSeek-R1-Distill-Qwen-7B\",\n",
    "    seed=0,\n",
    "    hypothesis=\"alt\",\n",
    "    target_mean=0.584,\n",
    "    target_p_y_x_0=0.82,\n",
    "    target_p_y_x_1=0.535,\n",
    ")\n",
    "\n",
    "\n",
    "# null_sampler = MathSharedDependencyDataSampler(\n",
    "#     model_name=\"deepseek-ai_DeepSeek-R1-Distill-Qwen-1_5B\", hypothesis=\"null\"\n",
    "# )\n",
    "# alt_sampler = MathSharedDependencyDataSampler(\n",
    "#     \"agentica-org_DeepScaleR-1_5B-Preview\", hypothesis=\"alt\"\n",
    "# )\n",
    "# print(f\"{null_sampler.p_x=}\")\n",
    "print(f\"{null_sampler.data['x'].mean()=}\")\n",
    "print(f\"{alt_sampler.data['x'].mean()=}\")\n",
    "print(f\"{null_sampler.data['y'].mean()=}\")\n",
    "print(f\"{alt_sampler.data['y'].mean()=}\")\n",
    "print(f\"{null_sampler.data.shape=}\")\n",
    "print(f\"{alt_sampler.data.shape=}\")\n",
    "print(f\"{null_sampler.p_y_x_0=}\")\n",
    "print(f\"{null_sampler.p_y_x_1=}\")\n",
    "print(f\"{alt_sampler.p_y_x_0=}\")\n",
    "print(f\"{alt_sampler.p_y_x_1=}\")\n",
    "print(f\"{null_sampler.data.corr()=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d00a02dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "alt_sampler.sample(10_000)['y'].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22bba077",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_training = [0, 10, 15, 20, 30, 40]\n",
    "steps = 400\n",
    "n_test = [(106513 // 400) - n_train for n_train in n_training]\n",
    "history_lengths = [0, 3, 10, 15, 20]\n",
    "min_history_length = [0, 5, 10, 15]\n",
    "configs = [\n",
    "    {\n",
    "        \"n_training\": n_train,\n",
    "        \"n_test\": n_test,\n",
    "        \"num_of_repititions\": 30,\n",
    "        \"alpha\": 0.05,\n",
    "        \"nsteps\": steps,  # 600,   n_training*steps =6K\n",
    "        \"M\": 30,\n",
    "        \"k_classes_x\": 1,  # the number of categories of X (1 - bernoulli, >1 - binomial)\n",
    "        \"p_y_null\": null_sampler.p_y*0.95,  # Y's mean value under null\n",
    "        \"history_length\": hist,\n",
    "        \"min_history_length\": min_hist,\n",
    "        \"settings\": \"math_shared_dependency\",\n",
    "    }\n",
    "    for n_train, n_test, hist, min_hist in itertools.product(\n",
    "        n_training, n_test, history_lengths, min_history_length\n",
    "    )\n",
    "]\n",
    "print(len(configs))\n",
    "\n",
    "# for each config, run simulation and report the config where maximal power is achieved with minimal time steps\n",
    "result_powers = []\n",
    "results_rejection_by_test_results=[]\n",
    "with mp.Pool(processes=120) as pool:\n",
    "    for config in configs:\n",
    "        powers, null_corr, alt_corr, rejection_by_test_results = (\n",
    "            simulations.run_simulation(config, pool)\n",
    "        )\n",
    "\n",
    "        # get key of max value in powers dict\n",
    "        # powers.pop(\"lrt_x\")\n",
    "        # powers.pop(\"NB_convex_x\")\n",
    "        # powers.pop(\"convex_x_y\")\n",
    "        # powers.pop(\"x_only_x_only\")\n",
    "        # max_power_key = max(powers, key=powers.get)\n",
    "        # max_power_value = powers[max_power_key]\n",
    "        result_powers.append(powers)\n",
    "        results_rejection_by_test_results.append(rejection_by_test_results)\n",
    "#\n",
    "with open(\"results.json\", \"w\") as f:\n",
    "    json.dump(result_powers, f, indent=4)\n",
    "\n",
    "# Save results_rejection_by_test_results as JSON\n",
    "with open(\"results_rejection_by_test_results.json\", \"w\") as f:\n",
    "    json.dump(results_rejection_by_test_results, f, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5336b7e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = {\n",
    "    \"n_training\": 10,\n",
    "    \"n_test\": 90,\n",
    "    \"num_of_repititions\": 500,\n",
    "    \"alpha\": 0.05,\n",
    "    \"nsteps\": 500,  # 600,   n_training*steps =6K\n",
    "    \"M\": 128,\n",
    "    \"k_classes_x\": 1,  # the number of categories of X (1 - bernoulli, >1 - binomial)\n",
    "    \"p_y_null\": 0.55,\n",
    "    \"p_y_alt\": 0.58,  # Y's mean value under null\n",
    "    \"p_y_x_0_null\": 0.8,\n",
    "    \"p_y_x_1_null\": 0.5,\n",
    "    \"p_y_x_0_alt\": 0.8,\n",
    "    \"p_y_x_1_alt\": 0.535,\n",
    "    \"history_length\": 0,\n",
    "    \"min_history_length\": 0,\n",
    "    \"settings\": \"math_shared_dependency\",\n",
    "}\n",
    "\n",
    "\n",
    "with mp.Pool(processes=120) as pool:\n",
    "    powers, null_corr, alt_corr, rejection_by_test_results, _, _ = simulations.run_simulation(\n",
    "        config, pool\n",
    "    )\n",
    "    data = {\n",
    "        \"rejection_lists\": rejection_by_test_results,\n",
    "        \"config\": config,\n",
    "        \"null_corr\": null_corr,\n",
    "        \"alt_corr\": alt_corr,\n",
    "        \"resulting_powers\": powers,\n",
    "    }\n",
    "\n",
    "from pathlib import Path\n",
    "import pickle\n",
    "\n",
    "\n",
    "output_file = f\"simulation_res/real_data/math_concept_shift/n_train_10_n_test_90_2.pkl\"\n",
    "path = Path(output_file)\n",
    "path.parent.mkdir(parents=True, exist_ok=True)\n",
    "with path.open(\"wb\") as f:\n",
    "    pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5ba42a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "rejection_by_test_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "843d6884",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\n",
    "    f\"simulation_res/real_data/math_concept_shift/n_train_10_n_test_90_2.pkl\",\n",
    "    \"rb\",\n",
    ") as f:\n",
    "    data = pickle.load(f)\n",
    "    print(data[\"resulting_powers\"])\n",
    "    corr = data[\"null_corr\"]\n",
    "    # nb_keys = [k for k in data[\"resulting_powers\"] if k.startswith('NB')]\n",
    "    # max_test = max(nb_keys, key=data[\"resulting_powers\"].get)\n",
    "    # print(f\"Best NB key: {max_test} with power {data['resulting_powers'][max_test]}\")\n",
    "    print(f\"Power vs Steps - Label shift, ntest=90, corr=-2\")\n",
    "    plot_power_vs_steps_paper(\n",
    "        data[\"rejection_lists\"],\n",
    "        # f\"Power vs Steps - Label shift, ntest={n_test}, corr={p_x_y_0_to_corr[p_x_y_0]}\",\n",
    "        save_path=f\"results/real_data/math_concept_shift/n_train_10_ntest_90_corr_0.2_power_vs_steps_2.pdf\",\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d35a0ae4",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = {\n",
    "    \"n_training\": 10,\n",
    "    \"n_test\": 90,\n",
    "    \"num_of_repititions\": 500,\n",
    "    \"alpha\": 0.05,\n",
    "    \"nsteps\": 500,  # 600,   n_training*steps =6K\n",
    "    \"M\": 128,\n",
    "    \"k_classes_x\": 1,  # the number of categories of X (1 - bernoulli, >1 - binomial)\n",
    "    \"p_y_null\": 0.58,\n",
    "    \"p_y_alt\": 0.58,  # Y's mean value under null\n",
    "    \"p_y_x_0_null\": 0.8,\n",
    "    \"p_y_x_1_null\": 0.535,\n",
    "    \"p_y_x_0_alt\": 0.8,\n",
    "    \"p_y_x_1_alt\": 0.535,\n",
    "    \"history_length\": 0,\n",
    "    \"min_history_length\": 0,\n",
    "    \"settings\": \"math_shared_dependency\",\n",
    "}\n",
    "\n",
    "\n",
    "with mp.Pool(processes=250) as pool:\n",
    "    powers, null_corr, alt_corr, rejection_by_test_results, _, _ = simulations.run_simulation(\n",
    "        config, pool\n",
    "    )\n",
    "    data = {\n",
    "        \"rejection_lists\": rejection_by_test_results,\n",
    "        \"config\": config,\n",
    "        \"null_corr\": null_corr,\n",
    "        \"alt_corr\": alt_corr,\n",
    "        \"resulting_powers\": powers,\n",
    "    }\n",
    "\n",
    "from pathlib import Path\n",
    "import pickle\n",
    "\n",
    "\n",
    "output_file = f\"simulation_res/real_data/math_concept_shift/n_train_10_n_test_90_validity.pkl\"\n",
    "path = Path(output_file)\n",
    "path.parent.mkdir(parents=True, exist_ok=True)\n",
    "with path.open(\"wb\") as f:\n",
    "    pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f63e6f92",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_power_vs_steps_paper(\n",
    "    data[\"rejection_lists\"],\n",
    "    # f\"Power vs Steps - Label shift, ntest={n_test}, corr={p_x_y_0_to_corr[p_x_y_0]}\",\n",
    "    # save_path=f\"results/real_data/math_concept_shift/n_train_10_ntest_90_corr_0.2_power_vs_steps_2.pdf\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8c5194d",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = {\n",
    "    \"n_training\": 15,\n",
    "    \"n_test\": 75,\n",
    "    \"num_of_repititions\": 500,\n",
    "    \"alpha\": 0.05,\n",
    "    \"nsteps\": 500,  # 600,   n_training*steps =6K\n",
    "    \"M\": 32,\n",
    "    \"k_classes_x\": 1,  # the number of categories of X (1 - bernoulli, >1 - binomial)\n",
    "    \"p_y_null\": 0.55,\n",
    "    \"p_y_alt\": 0.58,  # Y's mean value under null\n",
    "    \"p_y_x_0_null\": 0.8,\n",
    "    \"p_y_x_1_null\": 0.5,\n",
    "    \"p_y_x_0_alt\": 0.8,\n",
    "    \"p_y_x_1_alt\": 0.535,\n",
    "    \"history_length\": 0,\n",
    "    \"min_history_length\": 0,\n",
    "    \"settings\": \"math_shared_dependency\",\n",
    "}\n",
    "\n",
    "plot_power_vs_steps_concept_shift(\n",
    "    rejection_by_test_results,\n",
    "    \"Power vs Steps - Math concept shift ntest=75 ntrain=15 corr=-0.2\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78003fa5",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = {\n",
    "    \"n_training\": 15,\n",
    "    \"n_test\": 15,\n",
    "    \"num_of_repititions\": 500,\n",
    "    \"alpha\": 0.05,\n",
    "    \"nsteps\": 500,  # 600,   n_training*steps =6K\n",
    "    \"M\": 32,\n",
    "    \"k_classes_x\": 1,  # the number of categories of X (1 - bernoulli, >1 - binomial)\n",
    "    \"p_y_null\": 0.55,\n",
    "    \"p_y_alt\": 0.58,  # Y's mean value under null\n",
    "    \"p_y_x_0_null\": 0.8,\n",
    "    \"p_y_x_1_null\": 0.5,\n",
    "    \"p_y_x_0_alt\": 0.8,\n",
    "    \"p_y_x_1_alt\": 0.535,\n",
    "    \"history_length\": 0,\n",
    "    \"min_history_length\": 0,\n",
    "    \"settings\": \"math_shared_dependency\",\n",
    "}\n",
    "\n",
    "plot_power_vs_steps_concept_shift(\n",
    "    rejection_by_test_results,\n",
    "    \"Power vs Steps - Math concept shift ntest=15 ntrain=15 corr=-0.2\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97f70a0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_power_vs_steps_concept_shift(\n",
    "    rejection_by_test_results,\n",
    "      \"Power vs Steps - Math concept shift ntest=30\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fddb414",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ml_hypothesis",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
