{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import yaml\n",
    "\n",
    "# Get all YAML files matching the pattern in subdirectories\n",
    "yaml_files = Path(\".\").rglob(\"*_opt.yml\")\n",
    "\n",
    "for yaml_file in yaml_files:\n",
    "    with yaml_file.open(\"r\", encoding=\"utf-8\") as f:\n",
    "        data = yaml.safe_load(f)\n",
    "\n",
    "    # Ensure \"variables\" exists and is a dictionary\n",
    "    variables = data.get(\"variables\", {})\n",
    "    if not isinstance(variables, dict):\n",
    "        continue\n",
    "\n",
    "    # Check if \"prompt_pattern\" contains \"react\"\n",
    "    prompt_patterns = variables.get(\"prompt_pattern\", [])\n",
    "    if isinstance(prompt_patterns, list) and \"react\" in prompt_patterns:\n",
    "        # Ensure \"system_prompt\" field exists\n",
    "        if \"system_prompt\" not in variables:\n",
    "            variables[\"system_prompt\"] = []\n",
    "\n",
    "        # Save the updated YAML back to the file\n",
    "        with yaml_file.open(\"w\", encoding=\"utf-8\") as f:\n",
    "            yaml.safe_dump(data, f, default_flow_style=False, sort_keys=False)\n",
    "\n",
    "        print(f\"Updated: {yaml_file}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved: fever/granite_13b_instruct_fever_opt.yml\n",
      "Saved: fever/granite_13b_instruct_fever_opt.yml\n",
      "Saved: fever/granite_3_8b_instruct_fever_opt.yml\n",
      "Saved: fever/granite_3_8b_instruct_fever_opt.yml\n",
      "Saved: fever/llama_8b_fever_opt.yml\n",
      "Saved: fever/llama_8b_fever_opt.yml\n",
      "Saved: fever/granite_20b_code_instruct_fever_opt.yml\n",
      "Saved: fever/granite_20b_code_instruct_fever_opt.yml\n",
      "Saved: fever/llama_70b_fever_opt.yml\n",
      "Saved: fever/llama_70b_fever_opt.yml\n",
      "Saved: fever/granite_34b_code_instruct_fever_opt.yml\n",
      "Saved: fever/granite_34b_code_instruct_fever_opt.yml\n",
      "Saved: evalplus/granite_13b_instruct_evalplus_opt.yml\n",
      "Saved: evalplus/granite_13b_instruct_evalplus_opt.yml\n",
      "Saved: evalplus/granite_3_8b_evalplus_opt.yml\n",
      "Saved: evalplus/granite_3_8b_evalplus_opt.yml\n",
      "Saved: evalplus/llama_70b_evalplus_opt.yml\n",
      "Saved: evalplus/llama_70b_evalplus_opt.yml\n",
      "Saved: evalplus/granite_34b_code_instruct_evalplus_opt.yml\n",
      "Saved: evalplus/granite_34b_code_instruct_evalplus_opt.yml\n",
      "Saved: evalplus/granite_20b_code_instruct_evalplus_opt.yml\n",
      "Saved: evalplus/granite_20b_code_instruct_evalplus_opt.yml\n",
      "Saved: evalplus/llama_8b_evalplus_opt.yml\n",
      "Saved: evalplus/llama_8b_evalplus_opt.yml\n",
      "Saved: gsm8k/granite_13b_instruct_gsm8k_opt.yml\n",
      "Saved: gsm8k/granite_13b_instruct_gsm8k_opt.yml\n",
      "Saved: gsm8k/granite_3_8b_instruct_gsm8k_opt.yml\n",
      "Saved: gsm8k/llama_8b_gsm8k_opt.yml\n",
      "Saved: gsm8k/llama_8b_gsm8k_opt.yml\n",
      "Saved: gsm8k/granite_20b_code_instruct_gsm8k_opt.yml\n",
      "Saved: gsm8k/granite_20b_code_instruct_gsm8k_opt.yml\n",
      "Saved: gsm8k/llama_70b_gsm8k_opt.yml\n",
      "Saved: gsm8k/llama_70b_gsm8k_opt.yml\n",
      "Saved: gsm8k/granite_34b_code_instruct_gsm8k_opt.yml\n",
      "Saved: gsm8k/granite_34b_code_instruct_gsm8k_opt.yml\n"
     ]
    }
   ],
   "source": [
    "from pathlib import Path\n",
    "import yaml\n",
    "\n",
    "# Get all YAML files matching the pattern in subdirectories\n",
    "yaml_files = Path(\".\").rglob(\"*_opt.yml\")\n",
    "\n",
    "for yaml_file in yaml_files:\n",
    "    with yaml_file.open(\"r\", encoding=\"utf-8\") as f:\n",
    "        data = yaml.safe_load(f)\n",
    "\n",
    "    # Ensure \"variables\" exists and is a dictionary\n",
    "    variables = data.get(\"variables\", {})\n",
    "    if not isinstance(variables, dict):\n",
    "        continue\n",
    "\n",
    "    # Check if \"prompt_pattern\" contains \"react\"\n",
    "    prompt_patterns = variables.get(\"prompt_pattern\", [])\n",
    "    if isinstance(prompt_patterns, list) and \"react\" in prompt_patterns:\n",
    "        # Ensure \"system_prompt\" field exists\n",
    "        if \"system_prompt\" not in variables:\n",
    "            variables[\"system_prompt\"] = []\n",
    "\n",
    "        # Define the new filename with \"_sys.yml\"\n",
    "        # new_yaml_file = yaml_file.with_name(yaml_file.stem.replace(\"_opt\", \"_sys\") + \".yml\")\n",
    "\n",
    "    if \"num_demonstrations\" not in variables:\n",
    "        variables[\"num_demonstrations\"] = [0, 3, 5]\n",
    "        with yaml_file.open(\"w\", encoding=\"utf-8\") as f:\n",
    "            yaml.safe_dump(data, f, default_flow_style=False, sort_keys=False)\n",
    "\n",
    "        print(f\"Saved: {yaml_file}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved: gsmhard/granite_13b_instruct_gsmhard_opt.yml\n",
      "Saved: gsmhard/granite_3_8b_instruct_gsmhard_opt.yml\n",
      "Saved: gsmhard/granite_13b_instruct_gsmhard_zero_shot.yml\n",
      "Saved: gsmhard/llama_8b_gsmhard_opt.yml\n",
      "Saved: gsmhard/granite_20b_code_instruct_gsmhard_opt.yml\n",
      "Saved: gsmhard/llama_70b_gsmhard_zero_shot.yml\n",
      "Saved: gsmhard/llama_8b_gsmhard_zero_shot.yml\n",
      "Saved: gsmhard/granite_3_8b_instruct_gsmhard_zero_shot.yml\n",
      "Saved: gsmhard/llama_70b_gsmhard_opt.yml\n",
      "Saved: gsmhard/granite_34b_code_instruct_gsmhard_opt.yml\n",
      "Saved: gsmhard/granite_code_instruct_34b_gsmhard_zero_shot.yml\n",
      "Saved: gsmhard/granite_20b_code_instruct_gsmhard_zero_shot.yml\n",
      "Saved: gsmhard/granite_34b_code_instruct_gsmhard_zero_shot.yml\n"
     ]
    }
   ],
   "source": [
    "from pathlib import Path\n",
    "import yaml\n",
    "\n",
    "# Get all YAML files matching the pattern in subdirectories\n",
    "yaml_files = Path(\".\").glob(\"gsm8k/*.yml\")\n",
    "gsmhard_dir = Path(\"gsmhard\")\n",
    "for yaml_file in yaml_files:\n",
    "    if not (yaml_file.stem.endswith(\"zero_shot\") or yaml_file.stem.endswith(\"opt\")):\n",
    "        continue\n",
    "    with yaml_file.open(\"r\", encoding=\"utf-8\") as f:\n",
    "        data = yaml.safe_load(f)\n",
    "\n",
    "    data[\"benchmark\"] = \"gsmhard\"\n",
    "    data[\"experiment_prefix\"] = data[\"experiment_prefix\"].replace(\"gsm8k\", \"gsmhard\")\n",
    "    yaml_file_new = gsmhard_dir / yaml_file.name.replace(\"gsm8k\", \"gsmhard\")\n",
    "    with yaml_file_new.open(\"w\", encoding=\"utf-8\") as f:\n",
    "        yaml.safe_dump(data, f, default_flow_style=False, sort_keys=False)\n",
    "\n",
    "    print(f\"Saved: {yaml_file_new}\")\n",
    "    # print(f\"python -m pdl.optimize.optimize --config exp_configs/{yaml_file_new} examples/prompt_library/exp/gsm8k/general.pdl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pdlnew",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
