{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "\n",
    "tqdm.pandas()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1. Load raw data and convert to training data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gzip\n",
    "import json\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "FILE_PATHS = [\n",
    "    'YOURPATH-no-hint-train-t05-run_1/output.with_completions.jsonl.gz',\n",
    "    'YOURPATH-no-hint-train-t05-run_2/output.with_completions.jsonl.gz',\n",
    "]\n",
    "\n",
    "# More memory efficient for large files\n",
    "# Initialize lists to store the data\n",
    "data = []\n",
    "\n",
    "\n",
    "# Read file line by line\n",
    "for FILE_PATH in FILE_PATHS:\n",
    "    with gzip.open(FILE_PATH, 'rb') as f:  # Use 'rb' for gzipped files\n",
    "        for i, line in tqdm(\n",
    "            enumerate(f), desc=f\"Processing {FILE_PATH.split('/')[-1]}\"\n",
    "        ):\n",
    "            # Parse only the fields we need\n",
    "            raw_data = json.loads(line)\n",
    "            data.append(\n",
    "                {\n",
    "                    'resolved': raw_data['report']['resolved'],\n",
    "                    'messages': raw_data['raw_completions']['messages']\n",
    "                    if raw_data['raw_completions'] is not None\n",
    "                    else None,\n",
    "                    'git_patch': raw_data['test_result'].get('git_patch', ''),\n",
    "                    'tools': raw_data['raw_completions']['tools']\n",
    "                    if raw_data['raw_completions'] is not None\n",
    "                    and 'tools' in raw_data['raw_completions']\n",
    "                    else None,\n",
    "                }\n",
    "            )\n",
    "\n",
    "# Convert to DataFrame after collecting all data\n",
    "df = pd.DataFrame(data)\n",
    "print(f'#total amount of data={len(df)}')\n",
    "df = df[~df['messages'].isna()]\n",
    "print(f'#total amount of data after removing nan={len(df)}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Filter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _contains_multiple_tool_calls(messages: list[dict]) -> bool:\n",
    "    return any(\n",
    "        message.get('tool_calls') and len(message['tool_calls']) > 1\n",
    "        for message in messages\n",
    "    )\n",
    "\n",
    "\n",
    "df['contains_multiple_tool_calls'] = df['messages'].apply(_contains_multiple_tool_calls)\n",
    "display(df.groupby(['contains_multiple_tool_calls'])['resolved'].sum())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "\n",
    "# Convert function calling messages to non-function calling messages\n",
    "from openhands.llm.fn_call_converter import (\n",
    "    FunctionCallConversionError,\n",
    "    convert_fncall_messages_to_non_fncall_messages,\n",
    "    convert_from_multiple_tool_calls_to_single_tool_call_messages,\n",
    ")\n",
    "\n",
    "total_failed = 0\n",
    "\n",
    "\n",
    "def _convert_messages(messages: list[dict], tools: list[dict]) -> list[dict]:\n",
    "    global total_failed\n",
    "    message_copy = copy.deepcopy(messages)\n",
    "    for message in message_copy:\n",
    "        if message['content'] is None:\n",
    "            message['content'] = ''\n",
    "    try:\n",
    "        return convert_fncall_messages_to_non_fncall_messages(\n",
    "            message_copy, tools, add_in_context_learning_example=False\n",
    "        )\n",
    "    except FunctionCallConversionError:\n",
    "        total_failed += 1\n",
    "        # print(f'Failed to convert messages: {messages}\\nTools: {tools}')\n",
    "        # traceback.print_exc()\n",
    "        return None\n",
    "\n",
    "\n",
    "df['converted_messages'] = df.apply(\n",
    "    lambda row: convert_from_multiple_tool_calls_to_single_tool_call_messages(\n",
    "        row['messages'], ignore_final_tool_result=True\n",
    "    ),\n",
    "    axis=1,\n",
    ")\n",
    "df['nonfncall_messages'] = df.apply(\n",
    "    lambda row: _convert_messages(row['converted_messages'], row['tools']), axis=1\n",
    ")\n",
    "print('total nan', df['nonfncall_messages'].isna().sum())\n",
    "df = df[~df['nonfncall_messages'].isna()]\n",
    "print(f'Total failed: {total_failed}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Tokenization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pandarallel import pandarallel\n",
    "from transformers import AutoTokenizer\n",
    "\n",
    "os.environ['TOKENIZERS_PARALLELISM'] = 'false'\n",
    "pandarallel.initialize(progress_bar=True, verbose=1, nb_workers=16)\n",
    "tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2.5-7B-Instruct')\n",
    "df['n_tokens'] = df['rm_conv'].parallel_apply(\n",
    "    lambda x: len(tokenizer.apply_chat_template(x))\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f'BEFORE: #total={len(df)}')\n",
    "df_selected = df[df['n_tokens'] < 131072]\n",
    "print(f'AFTER(truncated to 128k): #total={len(df_selected)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_selected['n_tokens'].describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ecdf of n_tokens\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "display(df.groupby(['resolved'])['n_tokens'].describe())\n",
    "sns.ecdfplot(x='n_tokens', data=df, hue='resolved')\n",
    "plt.show()\n",
    "\n",
    "print(f'#total={len(df)}')\n",
    "df_selected = df[df['n_tokens'] < 131072]\n",
    "print(f'#selected={len(df_selected)}')\n",
    "display(df_selected.groupby(['resolved'])['n_tokens'].describe())\n",
    "sns.ecdfplot(x='n_tokens', data=df_selected, hue='resolved')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_selected[~df_selected['resolved']]['n_tokens'].describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_selected['resolved'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_selected.groupby(['resolved'])['n_tokens'].describe()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Save Resolved Messages for SFT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_selected[df_selected['resolved']][['nonfncall_messages']].rename(\n",
    "    columns={'nonfncall_messages': 'messages'}\n",
    ").to_json(\n",
    "    os.path.join(\n",
    "        'YOUR_OUTPUT_FOLDER',\n",
    "        f'policy_traj_128k_swegym_{df_selected[\"resolved\"].value_counts()[True]}i.jsonl',\n",
    "    ),\n",
    "    lines=True,\n",
    "    orient='records',\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "openhands-ai-CPy6G0pU-py3.12",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
