{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Re-rank GPT2-base answers to match reference ranking\n",
    "\n",
    "This notebook reorders `/hy-tmp/dc/processed_data/dolly/full/14290/answers_with_metrics_14290.new.jsonl` so its `prompt` order matches the reference file `/hy-tmp/dc/processed_data/dolly/full/gpt2/train.jsonl`. The output writes a reranked file `*.reranked.jsonl`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import json\n",
    "\n",
    "# Paths (edit if necessary)\n",
    "base_path = Path('/hy-tmp/dc/processed_data/dolly/full/generate_data/dolly-512/gpt2-base/10/answers.jsonl')\n",
    "ref_path = Path('/hy-tmp/dc/processed_data/dolly/full/gpt2/train.jsonl')\n",
    "out_path = base_path.with_name(base_path.stem + '.reranked.jsonl')\n",
    "\n",
    "print('base_path', base_path)\n",
    "print('ref_path', ref_path)\n",
    "print('out_path', out_path)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load reference prompts order\n",
    "ref_prompts = []\n",
    "with ref_path.open('r', encoding='utf-8') as f:\n",
    "    for i, line in enumerate(f):\n",
    "        if not line.strip():\n",
    "            continue\n",
    "        obj = json.loads(line)\n",
    "        # normalize prompt whitespace\n",
    "        prompt = obj.get('prompt', '').strip()\n",
    "        ref_prompts.append(prompt)\n",
    "\n",
    "len_ref = len(ref_prompts)\n",
    "print(f'reference prompts: {len_ref}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load base dataset into a mapping from prompt -> list of items (handle duplicates)\n",
    "from collections import defaultdict\n",
    "base_map = defaultdict(list)\n",
    "base_items = []\n",
    "with base_path.open('r', encoding='utf-8') as f:\n",
    "    for line in f:\n",
    "        if not line.strip():\n",
    "            continue\n",
    "        obj = json.loads(line)\n",
    "        prompt = obj.get('prompt', '').strip()\n",
    "        base_map[prompt].append(obj)\n",
    "        base_items.append(obj)\n",
    "\n",
    "print(f'base total: {len(base_items)} unique prompts: {len(base_map)}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Reorder: for each prompt in ref_prompts, pop one item from base_map if available\n",
    "out_items = []\n",
    "missing = []\n",
    "for p in ref_prompts:\n",
    "    if p in base_map and base_map[p]:\n",
    "        out_items.append(base_map[p].pop(0))\n",
    "    else:\n",
    "        missing.append(p)\n",
    "\n",
    "print(f'matched: {len(out_items)}, missing: {len(missing)}')\n",
    "if missing:\n",
    "    print('First 5 missing prompts:\\n', missing[:5])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# If some prompts were missing from base_map, append remaining base items at the end to preserve total count\n",
    "remaining = []\n",
    "for k, lst in base_map.items():\n",
    "    remaining.extend(lst)\n",
    "if remaining:\n",
    "    print(f'Appending {len(remaining)} remaining items to output')\n",
    "    out_items.extend(remaining)\n",
    "\n",
    "print(f'final out_items count: {len(out_items)} (should equal base total {len(base_items)})')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save reranked file\n",
    "with out_path.open('w', encoding='utf-8') as f:\n",
    "    for obj in out_items:\n",
    "        f.write(json.dumps(obj, ensure_ascii=False) + '\\n')\n",
    "\n",
    "print('Wrote', out_path, 'lines=', len(out_items))\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
