{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Re-rank GPT2-base answers to match reference ranking\n",
        "\n",
        "This notebook reorders `/hy-tmp/dc/processed_data/dolly/full/14290/answers_with_metrics_14290.new.jsonl` so its `prompt` order matches the reference file `/hy-tmp/dc/processed_data/dolly/full/gpt2/train.jsonl`. The output writes a reranked file `*.reranked.jsonl`.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from pathlib import Path\n",
        "import json\n",
        "\n",
        "# Paths (edit if necessary)\n",
        "base_path = Path('/hy-tmp/dc/processed_data/dolly/full/generate_data/dolly-512/gpt2-base/10/answers.jsonl')\n",
        "ref_path = Path('/hy-tmp/dc/processed_data/dolly/full/gpt2/train.jsonl')\n",
        "out_path = base_path.with_name(base_path.stem + '.reranked.jsonl')\n",
        "\n",
        "print('base_path', base_path)\n",
        "print('ref_path', ref_path)\n",
        "print('out_path', out_path)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Load reference prompts order\n",
        "ref_prompts = []\n",
        "with ref_path.open('r', encoding='utf-8') as f:\n",
        "    for i, line in enumerate(f):\n",
        "        if not line.strip():\n",
        "            continue\n",
        "        obj = json.loads(line)\n",
        "        # normalize prompt whitespace\n",
        "        prompt = obj.get('prompt', '').strip()\n",
        "        ref_prompts.append(prompt)\n",
        "\n",
        "len_ref = len(ref_prompts)\n",
        "print(f'reference prompts: {len_ref}')\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Load base dataset into a mapping from prompt -> list of items (handle duplicates)\n",
        "from collections import defaultdict\n",
        "base_map = defaultdict(list)\n",
        "base_items = []\n",
        "with base_path.open('r', encoding='utf-8') as f:\n",
        "    for line in f:\n",
        "        if not line.strip():\n",
        "            continue\n",
        "        obj = json.loads(line)\n",
        "        prompt = obj.get('prompt', '').strip()\n",
        "        base_map[prompt].append(obj)\n",
        "        base_items.append(obj)\n",
        "\n",
        "print(f'base total: {len(base_items)} unique prompts: {len(base_map)}')\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Reorder: for each prompt in ref_prompts, pop one item from base_map if available\n",
        "out_items = []\n",
        "missing = []\n",
        "for p in ref_prompts:\n",
        "    if p in base_map and base_map[p]:\n",
        "        out_items.append(base_map[p].pop(0))\n",
        "    else:\n",
        "        missing.append(p)\n",
        "\n",
        "print(f'matched: {len(out_items)}, missing: {len(missing)}')\n",
        "if missing:\n",
        "    print('First 5 missing prompts:\\n', missing[:5])\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# If some prompts were missing from base_map, append remaining base items at the end to preserve total count\n",
        "remaining = []\n",
        "for k, lst in base_map.items():\n",
        "    remaining.extend(lst)\n",
        "if remaining:\n",
        "    print(f'Appending {len(remaining)} remaining items to output')\n",
        "    out_items.extend(remaining)\n",
        "\n",
        "print(f'final out_items count: {len(out_items)} (should equal base total {len(base_items)})')\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Save reranked file\n",
        "with out_path.open('w', encoding='utf-8') as f:\n",
        "    for obj in out_items:\n",
        "        f.write(json.dumps(obj, ensure_ascii=False) + '\\n')\n",
        "\n",
        "print('Wrote', out_path, 'lines=', len(out_items))\n"
      ]
    }
  ],
  "metadata": {
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 2
}
