{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "import pickle\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(os.path.join(\"/home/anonymousanonymous/inference-rlhf/amlt\", f\"llama32medium_responses.pkl\"), \"rb\") as f:\n",
    "    response_data = pickle.load(f)\n",
    "all_responses = {k: v[\"responses\"] for k, v in response_data.items()}\n",
    "all_answers = {k: v[\"answers\"] for k, v in response_data.items()}\n",
    "all_results = {k: v[\"results\"] for k, v in response_data.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_unique_answers = max([len(set(answers)) for answers in all_answers.values()])\n",
    "sns.set_theme(style=\"whitegrid\")\n",
    "plt.hist([len(set(answers)) for answers in all_answers.values()], \n",
    "         bins=range(1, max_unique_answers + 1), \n",
    "         cumulative=True, \n",
    "         density=True,\n",
    "         color='#016EB7',  # Add color to make the histogram visible\n",
    "         edgecolor='black')  # Add edge color for better visibility\n",
    "plt.title(\"Cumulative distribution of unique answers\")\n",
    "plt.xlabel(\"Number of unique answers\")\n",
    "plt.ylabel(\"Cumulative density\")\n",
    "plt.savefig(f\"unique_answers_cumulative_distribution_llama32medium.pdf\")\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "defaults:\n",
      "- _self_\n",
      "- user: ab\n",
      "- mode: normal\n",
      "- task: gsm8k\n",
      "- policy: gemma-2-2b\n",
      "- reward: oasst-rm\n",
      "- method: rejection\n",
      "amlt: false\n",
      "seed: 1337\n",
      "root: /home/anonymousanonymous/inference-rlhf/\n",
      "blob_root: /mnt/default/uploads/inference_pessimism\n",
      "hf_token_path: .hf_token\n",
      "debug: false\n",
      "io:\n",
      "  prefix: ${task.name}-${policy.name}-${mode.name}-shots-${shots}\n",
      "  overwrite: false\n",
      "  save_root: ${root}\n",
      "  load_root: ${root}\n",
      "shots: 0\n",
      "sampling:\n",
      "  seed: 1337\n",
      "evaluation:\n",
      "  collate_rewards: false\n",
      "  save_copy: true\n",
      "  include_prompt: false\n",
      "  batch_size: 16\n",
      "  max_response_length: ${task.generation.max_response_length}\n",
      "  max_prompt_length: ${task.generation.max_prompt_length}\n",
      "repeats: 50\n",
      "refresh_data: false\n",
      "use_multiprocessing: true\n",
      "max_threads: 44\n",
      "use_subsampling: true\n",
      "ks:\n",
      "  kmax: 13\n",
      "  kmin: 2\n",
      "  inc: 1.0\n",
      "betas:\n",
      "- 0.0001\n",
      "- 0.0005\n",
      "- 0.001\n",
      "- 0.005\n",
      "- 0.01\n",
      "- 0.05\n",
      "- 0.1\n",
      "- 0.5\n",
      "aoai:\n",
      "  api_version: '2024-10-21'\n",
      "  model_name: gpt-4o-mini\n",
      "  model_version: '2024-07-18'\n",
      "  instance: anonymousne/shared\n",
      "plot:\n",
      "  model_types:\n",
      "  - llama32medium\n",
      "  - mistral7b\n",
      "  root: ${user.root}/amlt\n",
      "  subsample_size: null\n",
      "  min_number_unique_answers: 100\n",
      "preprocess_amlt:\n",
      "  model_types:\n",
      "  - llama32medium\n",
      "  - mistral7b\n",
      "  root: ${user.root}/amlt\n",
      "policy:\n",
      "  name: llama-3-3b\n",
      "  model: meta-llama/Llama-3.2-3B-Instruct\n",
      "  INST: ' Report the final answer following the phrase ''The answer is''.'\n",
      "  replace_inst: false\n",
      "task:\n",
      "  name: math\n",
      "  max_samples: -1\n",
      "  data:\n",
      "    name: DigitalLearningGmbH/MATH-lighteval\n",
      "    subset: default\n",
      "    split: test\n",
      "    question_field: problem\n",
      "    answer_field: solution\n",
      "  generation:\n",
      "    max_response_length: 512\n",
      "    max_prompt_length: 1024\n",
      "  rstar_keys:\n",
      "  - correct\n",
      "  TASK_DESC: As an expert problem solver solve step by step the following mathematical\n",
      "    questions. Box your final answer using $\\boxed{}$.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from importlib import import_module\n",
    "from omegaconf import OmegaConf\n",
    "import os\n",
    "\n",
    "# Load all the config files\n",
    "config_dir = \"inference_rlhf/code/configs\"  # adjust path as needed\n",
    "config = OmegaConf.load(os.path.join(config_dir, \"master.yaml\"))\n",
    "\n",
    "# Load and merge other config files as needed\n",
    "task_config = OmegaConf.load(os.path.join(config_dir, \"task/math.yaml\"))\n",
    "policy_config = OmegaConf.load(os.path.join(config_dir, \"policy/llama-3-3b.yaml\"))  # or whatever policy you're using\n",
    "user_config = OmegaConf.load(os.path.join(config_dir, \"user/anonymousanonymous.yaml\"))\n",
    "\n",
    "# Merge configs\n",
    "cfg = OmegaConf.merge(config, user_config)\n",
    "\n",
    "# Set policy explicitly in the config structure\n",
    "cfg.policy = OmegaConf.create(policy_config)\n",
    "cfg.task = OmegaConf.create(task_config)\n",
    "\n",
    "# Set any additional config values you need\n",
    "cfg.shots = 0  # for example\n",
    "cfg.sampling = OmegaConf.create({\"seed\": 1337})  # match the command line format\n",
    "\n",
    "print(OmegaConf.to_yaml(cfg))\n",
    "\n",
    "# Now you can use the config to load the data module\n",
    "import importlib\n",
    "data_module = import_module(f\"inference_rlhf.code.tasks.math\",  package='inference_rlhf.code')\n",
    "dl = data_module.DataLoader(cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Problem 17:\n",
      "Problem: If $A$, $B$ and $C$ are positive integers such that $\\frac{A\\sqrt{B}}{C} = \\frac{9}{2\\sqrt{3}}$, what is the value of $A+B+C$ given that $A$ and $C$ have no common prime factors, and $B$ has no perfect-square factors other than 1?\n",
      "\n",
      "\n",
      "Unique answers:\n",
      "137 {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 45, 46, 47, 48, 49, 561, 51, 52, 50, 54, 55, 56, 57, 53, 59, 61, 63, 64, 66, 69, 72, 75, 76, 78, 79, 80, 81, 82, 83, 85, 86, 87, 88, 90, 91, 92, 93, 94, 97, 98, 99, 100, 613, 103, 104, 105, 106, 108, 111, 117, 120, 121, 122, 125, 129, 132, 648, 142, 146, 151, 153, 162, 164, 166, 59049, 171, 177, 187, 188, 198, 203, 729, 19683, 243, 292068083, 254, 257, 270, 2835, 306, 312, 828, 171027, 324, 333, 354, 2916, 359, 888, 416, 46410162, 451, 461}\n",
      "\n",
      "\n",
      "Unique answer indices:\n",
      "137 [0, 1, 2, 4, 5, 6, 7, 9, 10, 12, 15, 17, 18, 20, 27, 29, 30, 31, 33, 36, 41, 46, 47, 53, 58, 59, 69, 72, 74, 75, 76, 85, 89, 96, 99, 100, 111, 114, 140, 148, 152, 178, 201, 207, 209, 212, 221, 227, 240, 266, 273, 276, 287, 298, 322, 325, 327, 340, 351, 363, 366, 371, 429, 430, 432, 440, 455, 467, 482, 493, 502, 510, 519, 527, 537, 549, 610, 643, 681, 726, 748, 755, 775, 814, 863, 914, 927, 976, 982, 985, 1008, 1055, 1119, 1128, 1138, 1143, 1150, 1185, 1187, 1192, 1213, 1244, 1305, 1432, 1443, 1449, 1451, 1455, 1456, 1640, 1711, 1726, 1918, 1930, 2015, 2060, 2135, 2143, 2211, 2227, 2234, 2261, 2315, 2369, 2640, 2654, 3411, 3483, 3675, 3678, 3981, 3984, 4364, 4477, 5216, 5233, 5900]\n",
      "\n",
      "\n",
      "Pass@k:\n",
      "[0.0424905]\n",
      "\n",
      "\n",
      "Response:\n",
      "To solve for A, B, and C, we can manipulate the given equation to get:\n",
      "\n",
      "$\\frac{A\\sqrt{B}}{C} = \\frac{9}{2\\sqrt{3}}$\n",
      "\n",
      "We can cross multiply to get:\n",
      "\n",
      "$A\\sqrt{B} \\times 2\\sqrt{3} = 9C$\n",
      "\n",
      "This can be rewritten as:\n",
      "\n",
      "$2A\\sqrt{B \\times 3} = 9C$\n",
      "\n",
      "Now, to make the equation true, A$ \\times 3 should equal 9 since they are integers and B is an integer and has no other factors.  Thus, $A = 3, B=3 and we can find C as follows:\n",
      "\n",
      "$2 \\times 3 = 9 \\times C$\n",
      "\n",
      "We can now divide both sides by 9 to obtain $C = \\frac{3}{3}$.  $C=1$\n",
      "Thus $A + B + C = 3 + 3 + 1 = 7$\n",
      "\n",
      "The final answer is 7\n",
      "\n",
      "\n",
      "Solution:\n",
      "Multiplying the numerator and denominator of the right hand side of the given equation by $\\sqrt{3}$, we have \\[\\frac{A\\sqrt{B}}{C}=\\frac{9}{2\\sqrt{3}}\\cdot\\frac{\\sqrt{3}}{\\sqrt{3}}=\\frac{9\\sqrt{3}}{6}=\\frac{3\\sqrt{3}}{2}\\] Thus, $A=3$, $B=3$, and $C=2$, so $A+B+C=3+3+2=\\boxed{8}$.\n"
     ]
    }
   ],
   "source": [
    "from inference_rlhf.code.helpers.utils import estimate_pass_at_k\n",
    "\n",
    "PROBLEM_IDX = 256\n",
    "RESPONSE_IDX = 0\n",
    "\n",
    "print(f\"Problem {PROBLEM_IDX}:\")\n",
    "print(f\"Problem: {dl.questions[PROBLEM_IDX]}\")\n",
    "\n",
    "print('\\n\\nUnique answers:')\n",
    "print(len(set(all_answers[PROBLEM_IDX])), set(all_answers[PROBLEM_IDX]))\n",
    "\n",
    "print('\\n\\nUnique answer indices:')\n",
    "idxs = []\n",
    "visited = set()\n",
    "for i, answer in enumerate(all_answers[PROBLEM_IDX]):\n",
    "    if answer not in visited:\n",
    "        visited.add(answer)\n",
    "        idxs.append(i)\n",
    "print(len(idxs), idxs)\n",
    "\n",
    "print('\\n\\nPass@k:')\n",
    "print(estimate_pass_at_k([len(all_results[PROBLEM_IDX])], [sum(all_results[PROBLEM_IDX])], 1))\n",
    "\n",
    "print('\\n\\nResponse:')\n",
    "print(all_responses[PROBLEM_IDX][RESPONSE_IDX])\n",
    "\n",
    "print('\\n\\nSolution:')\n",
    "print(dl.raw_answers[PROBLEM_IDX])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm-exploration",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
