{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Re-rank GPT2-base answers to match reference ranking\n",
    "\n",
    "This notebook reorders `/hy-tmp/dc/processed_data/dolly/full/14290/answers_with_metrics_14290.new.jsonl` so its `prompt` order matches the reference file `/hy-tmp/dc/processed_data/dolly/full/gpt2/train.jsonl`. The output writes a reranked file `*.reranked.jsonl`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-08-18T04:31:59.934488Z",
     "iopub.status.busy": "2025-08-18T04:31:59.934303Z",
     "iopub.status.idle": "2025-08-18T04:31:59.941852Z",
     "shell.execute_reply": "2025-08-18T04:31:59.941374Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "base_path /hy-tmp/dc/processed_data/dolly/full/generate_data/dolly-512/gpt2-base/10/answers.jsonl\n",
      "ref_path /hy-tmp/dc/processed_data/dolly/full/gpt2/train.jsonl\n",
      "out_path /hy-tmp/dc/processed_data/dolly/full/generate_data/dolly-512/gpt2-base/10/answers.reranked.jsonl\n"
     ]
    }
   ],
   "source": [
    "from pathlib import Path\n",
    "import json\n",
    "\n",
    "# Paths (edit if necessary)\n",
    "base_path = Path('/hy-tmp/dc/processed_data/dolly/full/generate_data/dolly-512/gpt2-base/10/answers.jsonl')\n",
    "ref_path = Path('/hy-tmp/dc/processed_data/dolly/full/gpt2/train.jsonl')\n",
    "out_path = base_path.with_name(base_path.stem + '.reranked.jsonl')\n",
    "\n",
    "print('base_path', base_path)\n",
    "print('ref_path', ref_path)\n",
    "print('out_path', out_path)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-08-18T04:31:59.981112Z",
     "iopub.status.busy": "2025-08-18T04:31:59.980921Z",
     "iopub.status.idle": "2025-08-18T04:32:00.058240Z",
     "shell.execute_reply": "2025-08-18T04:32:00.057397Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "reference prompts: 11435\n"
     ]
    }
   ],
   "source": [
    "# Load reference prompts order\n",
    "ref_prompts = []\n",
    "with ref_path.open('r', encoding='utf-8') as f:\n",
    "    for i, line in enumerate(f):\n",
    "        if not line.strip():\n",
    "            continue\n",
    "        obj = json.loads(line)\n",
    "        # normalize prompt whitespace\n",
    "        prompt = obj.get('prompt', '').strip()\n",
    "        ref_prompts.append(prompt)\n",
    "\n",
    "len_ref = len(ref_prompts)\n",
    "print(f'reference prompts: {len_ref}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-08-18T04:32:00.059999Z",
     "iopub.status.busy": "2025-08-18T04:32:00.059809Z",
     "iopub.status.idle": "2025-08-18T04:32:00.157088Z",
     "shell.execute_reply": "2025-08-18T04:32:00.156281Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "base total: 11435 unique prompts: 11290\n"
     ]
    }
   ],
   "source": [
    "# Load base dataset into a mapping from prompt -> list of items (handle duplicates)\n",
    "from collections import defaultdict\n",
    "base_map = defaultdict(list)\n",
    "base_items = []\n",
    "with base_path.open('r', encoding='utf-8') as f:\n",
    "    for line in f:\n",
    "        if not line.strip():\n",
    "            continue\n",
    "        obj = json.loads(line)\n",
    "        prompt = obj.get('prompt', '').strip()\n",
    "        base_map[prompt].append(obj)\n",
    "        base_items.append(obj)\n",
    "\n",
    "print(f'base total: {len(base_items)} unique prompts: {len(base_map)}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-08-18T04:32:00.158936Z",
     "iopub.status.busy": "2025-08-18T04:32:00.158742Z",
     "iopub.status.idle": "2025-08-18T04:32:00.172263Z",
     "shell.execute_reply": "2025-08-18T04:32:00.171516Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "matched: 11435, missing: 0\n"
     ]
    }
   ],
   "source": [
    "# Reorder: for each prompt in ref_prompts, pop one item from base_map if available\n",
    "out_items = []\n",
    "missing = []\n",
    "for p in ref_prompts:\n",
    "    if p in base_map and base_map[p]:\n",
    "        out_items.append(base_map[p].pop(0))\n",
    "    else:\n",
    "        missing.append(p)\n",
    "\n",
    "print(f'matched: {len(out_items)}, missing: {len(missing)}')\n",
    "if missing:\n",
    "    print('First 5 missing prompts:\\n', missing[:5])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-08-18T04:32:00.174251Z",
     "iopub.status.busy": "2025-08-18T04:32:00.174065Z",
     "iopub.status.idle": "2025-08-18T04:32:00.180494Z",
     "shell.execute_reply": "2025-08-18T04:32:00.179767Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "final out_items count: 11435 (should equal base total 11435)\n"
     ]
    }
   ],
   "source": [
    "# If some prompts were missing from base_map, append remaining base items at the end to preserve total count\n",
    "remaining = []\n",
    "for k, lst in base_map.items():\n",
    "    remaining.extend(lst)\n",
    "if remaining:\n",
    "    print(f'Appending {len(remaining)} remaining items to output')\n",
    "    out_items.extend(remaining)\n",
    "\n",
    "print(f'final out_items count: {len(out_items)} (should equal base total {len(base_items)})')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-08-18T04:32:00.182500Z",
     "iopub.status.busy": "2025-08-18T04:32:00.182313Z",
     "iopub.status.idle": "2025-08-18T04:32:00.423216Z",
     "shell.execute_reply": "2025-08-18T04:32:00.422219Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Wrote /hy-tmp/dc/processed_data/dolly/full/generate_data/dolly-512/gpt2-base/10/answers.reranked.jsonl lines= 11435\n"
     ]
    }
   ],
   "source": [
    "# Save reranked file\n",
    "with out_path.open('w', encoding='utf-8') as f:\n",
    "    for obj in out_items:\n",
    "        f.write(json.dumps(obj, ensure_ascii=False) + '\\n')\n",
    "\n",
    "print('Wrote', out_path, 'lines=', len(out_items))\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
