{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "def merge_json_files(file_list):\n",
    "    merged_data = []\n",
    "    for file in file_list:\n",
    "        with open(file, 'r') as f:\n",
    "            data = json.load(f)\n",
    "            merged_data += data\n",
    "    return merged_data\n",
    " \n",
    "files = [\"./data/train/qwen/qwen_train_DR.json\", \"./data/train/prm800k/prm_800k_orm_A_DR.json\"]\n",
    "merged_json = merge_json_files(files)\n",
    "\n",
    "goal_name = \"./data/train/qwen/mix_DR.json\"\n",
    "with open(goal_name, 'w') as json_file:\n",
    "    json.dump(merged_json, json_file, indent=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "98322\n",
      "12060\n",
      "110382\n",
      "99231\n"
     ]
    }
   ],
   "source": [
    "import copy\n",
    "import json\n",
    "\n",
    "def read_json(path):\n",
    "    with open(path, 'r') as f:\n",
    "        data = json.load(f)\n",
    "    return data\n",
    "\n",
    "def keep_same_num_merge(prm_path, qwen_path):\n",
    "    prm_datas = read_json(prm_path)\n",
    "    qwen_datas = read_json(qwen_path)\n",
    "\n",
    "    print(len(prm_datas))\n",
    "    print(len(qwen_datas))\n",
    "    print(len(prm_datas) + len(qwen_datas))\n",
    "    qwen_questions = dict()\n",
    "    for data in qwen_datas:\n",
    "        question = data[\"messages\"][0][\"content\"]\n",
    "        if question not in qwen_questions.keys():\n",
    "            qwen_questions[question] = 0\n",
    "        qwen_questions[question] += 1\n",
    "    \n",
    "    new_datas = copy.deepcopy(qwen_datas)\n",
    "    for data in prm_datas:\n",
    "        question = data[\"messages\"][0][\"content\"]\n",
    "        if question in qwen_questions.keys():\n",
    "            qwen_questions[question] -= 1\n",
    "            if qwen_questions[question] >= 0:\n",
    "                continue\n",
    "        new_datas.append(data)\n",
    "    print(len(new_datas))\n",
    "    return new_datas\n",
    "\n",
    "prm_path = \"../train/prm800k/prm_800k_orm_A.json\" #\"./data/train/prm800k/prm_800k_orm_A.json\"\n",
    "qwen_path = \"../train/qwen/qwen_train_DR.json\" #\"./data/train/qwen/qwen_train_DR.json\"\n",
    "merged_json = keep_same_num_merge(prm_path, qwen_path)\n",
    "goal_name = \"../train/qwen/fair_mix_DR.json\" #\"./data/train/qwen/fair_mix_DR.json\"\n",
    "with open(goal_name, 'w') as json_file:\n",
    "    json.dump(merged_json, json_file, indent=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "312\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import json\n",
    "\n",
    "json_file_path = \"../special tokens.json\"\n",
    "with open(json_file_path, 'r') as file:\n",
    "    data = json.load(file)\n",
    "pos_key = \"postive\"\n",
    "neg_key = \"negtive\"\n",
    "nat_key = \"natural\"\n",
    "\n",
    "def load_question(path):\n",
    "    questions = []\n",
    "    answers = []\n",
    "    with open(path, 'r', encoding='utf-8') as file:\n",
    "        for line in file:\n",
    "            entry = json.loads(line)\n",
    "            question = entry[\"question\"][\"problem\"]\n",
    "            answer = entry[\"question\"][\"ground_truth_answer\"]\n",
    "            if question not in questions:\n",
    "                questions.append(question)\n",
    "                answers.append(answer)\n",
    "    return questions, answers\n",
    "\n",
    "def handle_format(question, solution, label, generation_fromat):\n",
    "    if label:\n",
    "        step_val = pos_key\n",
    "    else:\n",
    "        step_val = neg_key\n",
    "    if generation_fromat:\n",
    "        outs = {\n",
    "            \"messages\": [\n",
    "                {\n",
    "                \"content\": question,\n",
    "                \"role\": \"user\"\n",
    "                },\n",
    "                {\n",
    "                \"content\": solution,\n",
    "                \"role\": \"assistant\"\n",
    "                },\n",
    "                {\n",
    "                \"content\": \"Is the answer correct (Yes/No)?\",\n",
    "                \"role\": \"user\"\n",
    "                },\n",
    "                {\n",
    "                \"content\": \"Yes\" if label else \"No\", #simplify selection score\n",
    "                \"role\": \"assistant\"\n",
    "                }\n",
    "            ],\n",
    "            \"label\": label,\n",
    "            \"step_val\":step_val\n",
    "        }\n",
    "    else:\n",
    "        outs = {\n",
    "            \"messages\": [\n",
    "                {\n",
    "                \"content\": question,\n",
    "                \"role\": \"user\"\n",
    "                },\n",
    "                {\n",
    "                \"content\": solution,\n",
    "                \"role\": \"assistant\"\n",
    "                }\n",
    "            ],\n",
    "            \"label\": label,\n",
    "            \"step_val\":step_val\n",
    "        }\n",
    "    return  outs\n",
    "\n",
    "df = pd.read_parquet('train-00000-of-00001.parquet')\n",
    "questions = df[\"problem\"]\n",
    "question_set = set(list(questions))\n",
    "train_question, _ = load_question(\"../prm_800k_handler/data/phase2_train.jsonl\")\n",
    "question_set = [x for x in questions if x not in train_question]\n",
    "max_length = 1024\n",
    "N_size = 8\n",
    "generation_format = False\n",
    "goal_name = \"../train/qwen/BON_D_test.json\" \n",
    "\n",
    "write_datas = []\n",
    "for row in df.itertuples():\n",
    "    tques = row.problem\n",
    "    if tques in question_set:\n",
    "        question_set.remove(tques)\n",
    "    else:\n",
    "        continue\n",
    "    labels = row.scores\n",
    "    responses = []\n",
    "    for i in range(32):\n",
    "        name = \"response_\"+str(i+1)\n",
    "        res = getattr(row, name)\n",
    "        if len(res) >= max_length:\n",
    "            continue\n",
    "        responses.append(res)\n",
    "    if len(responses) < N_size:\n",
    "        continue\n",
    "    for i in range(N_size):\n",
    "        write_datas.append(handle_format(tques, responses[i], bool(labels[i]), True))\n",
    "\n",
    "print(len(write_datas))\n",
    "with open(goal_name, 'w') as json_file:\n",
    "    json.dump(write_datas, json_file, indent=2)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llama_fac",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
