{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2193cbd9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, re, json, string, warnings\n",
    "import torch\n",
    "import pandas as pd\n",
    "from PIL import Image\n",
    "from tqdm import tqdm\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "import torchvision.transforms as T\n",
    "from torchvision.transforms.functional import InterpolationMode\n",
    "\n",
    "import transformers\n",
    "transformers.logging.set_verbosity_error()\n",
    "\n",
    "from transformers import (\n",
    "    AutoTokenizer, AutoModel, AutoProcessor, AutoModelForImageTextToText,\n",
    "    LlavaForConditionalGeneration,\n",
    "    Qwen25VLForConditionalGeneration,\n",
    "    Qwen3VLForConditionalGeneration,\n",
    ")\n",
    "\n",
    "from qwenvlutils import process_vision_info\n",
    "\n",
    "# -----------------------\n",
    "# 0) CONFIG (edit paths)\n",
    "# -----------------------\n",
    "PUZZLEDIR = \"extraction\"\n",
    "LABELFILE = \"Rebus Puzzle.xlsx\"\n",
    "OUTDIR = \"agg_results\"\n",
    "os.makedirs(OUTDIR, exist_ok=True)\n",
    "\n",
    "MAXPUZZLES = 1164\n",
    "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "DTYPE_BF16 = torch.bfloat16 if DEVICE == \"cuda\" else torch.float32\n",
    "DTYPE_FP16 = torch.float16 if DEVICE == \"cuda\" else torch.float32\n",
    "\n",
    "BATCHSIZE = 1\n",
    "\n",
    "PROMPT_MAIN = \"\"\"You are given an image that represents a rebus puzzle (a visual word riddle).\\n\"\n",
    "    \"A rebus puzzle encodes a common English word or phrase using visual layout, repetition, color, position, or size of text and symbols.\\n\"\n",
    "    \"Do NOT read the image literally.\\n\"\n",
    "    \"Infer the hidden word or idiomatic expression suggested by the visual arrangement.\\n\\n\"\n",
    "    \"Examples:\\n\"\n",
    "    \"- The word 'MAN' written three times means 'three men'.\\n\"\n",
    "    \"- The word 'READ' inside a box means 'read between the lines'.\\n\"\n",
    "    \"- The word 'YOU' written above 'ME' means 'you over me'.\\n\"\n",
    "    \"- A red letter 'E' followed by 'GO GO' means 'ready to go'.\\n\\n\"\n",
    "    \"Question: What English word or phrase is represented?\\n\"\n",
    "    \"Return ONLY the final answer in 1–5 words.\\n\"\n",
    "    \"Do not explain.\n",
    "\"\"\"\n",
    "\n",
    "\n",
    "# -----------------------\n",
    "# 1) Load labels\n",
    "# -----------------------\n",
    "labelsdf = pd.read_excel(LABELFILE, header=None).iloc[2:]\n",
    "labelsdf.columns = [\"puzzleid\", \"answer\"]\n",
    "labelsdf = labelsdf[pd.to_numeric(labelsdf[\"puzzleid\"], errors=\"coerce\").notna()]\n",
    "labelsdf[\"puzzleid\"] = labelsdf[\"puzzleid\"].astype(int)\n",
    "LABELS = {int(r.puzzleid): str(r.answer) for _, r in labelsdf.iterrows()}\n",
    "\n",
    "# -----------------------\n",
    "# 2) Torch Dataset\n",
    "# -----------------------\n",
    "class RebusDataset(Dataset):\n",
    "    def __init__(self, puzzledir, labels, maxitems=None, load_pil=True):\n",
    "        self.samples = []\n",
    "        for pid in sorted(labels.keys()):\n",
    "            imgpath = os.path.join(puzzledir, f\"puzzle{pid}.png\")\n",
    "            if os.path.exists(imgpath):\n",
    "                self.samples.append((pid, imgpath, labels[pid]))\n",
    "            if maxitems is not None and len(self.samples) >= maxitems:\n",
    "                break\n",
    "        self.load_pil = load_pil\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.samples)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        pid, imgpath, ans = self.samples[idx]\n",
    "        if self.load_pil:\n",
    "            image = Image.open(imgpath).convert(\"RGB\")\n",
    "            return pid, imgpath, image, ans\n",
    "        return pid, imgpath, ans\n",
    "\n",
    "def collate_keep(batch):\n",
    "    return batch[0]\n",
    "\n",
    "dataset = RebusDataset(PUZZLEDIR, LABELS, maxitems=MAXPUZZLES, load_pil=True)\n",
    "loader = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=False, num_workers=0, collate_fn=collate_keep)\n",
    "\n",
    "# -----------------------\n",
    "# 3) Text utils \n",
    "# -----------------------\n",
    "def clean_prediction(s):\n",
    "    if s is None:\n",
    "        return \"\"\n",
    "    if isinstance(s, list):\n",
    "        s = s[0] if len(s) else \"\"\n",
    "    s = str(s).strip()\n",
    "    s = re.sub(r\"final\\s*answer\\s*:?\", \"\", s, flags=re.IGNORECASE).strip()\n",
    "    s = s.strip('\"').strip(\"'\").strip(\"`\").strip(\"*\").strip()\n",
    "    return s.lower().strip()\n",
    "\n",
    "def normalize_answer(s):\n",
    "    if s is None:\n",
    "        return \"\"\n",
    "    s = str(s).lower().strip()\n",
    "    s = s.translate(str.maketrans(\"\", \"\", string.punctuation))\n",
    "    s = re.sub(r\"\\s+\", \" \", s).strip()\n",
    "    return s\n",
    "\n",
    "def token_f1(pred, gt):\n",
    "    pt = set(normalize_answer(pred).split())\n",
    "    gt = set(normalize_answer(gt).split())\n",
    "    if not pt or not gt:\n",
    "        return 0.0\n",
    "    common = pt & gt\n",
    "    if not common:\n",
    "        return 0.0\n",
    "    precision = len(common) / len(pt)\n",
    "    recall = len(common) / len(gt)\n",
    "    if precision + recall == 0:\n",
    "        return 0.0\n",
    "    return 2 * precision * recall / (precision + recall)\n",
    "\n",
    "# -----------------------\n",
    "# 4) InternVL tiling \n",
    "# -----------------------\n",
    "IMAGENET_MEAN = (0.485, 0.456, 0.406)\n",
    "IMAGENET_STD  = (0.229, 0.224, 0.225)\n",
    "\n",
    "def build_transform(input_size=448):\n",
    "    return T.Compose([\n",
    "        T.Lambda(lambda img: img.convert(\"RGB\") if img.mode != \"RGB\" else img),\n",
    "        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),\n",
    "        T.ToTensor(),\n",
    "        T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),\n",
    "    ])\n",
    "\n",
    "def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, imagesize):\n",
    "    best_ratio_diff = float(\"inf\")\n",
    "    best_ratio = (1, 1)\n",
    "    area = width * height\n",
    "    for ratio in target_ratios:\n",
    "        target_aspect = ratio[0] / ratio[1]\n",
    "        diff = abs(aspect_ratio - target_aspect)\n",
    "        if diff < best_ratio_diff:\n",
    "            best_ratio_diff = diff\n",
    "            best_ratio = ratio\n",
    "        elif diff == best_ratio_diff:\n",
    "            if area > 0.5 * imagesize * imagesize * ratio[0] * ratio[1]:\n",
    "                best_ratio = ratio\n",
    "    return best_ratio\n",
    "\n",
    "def dynamic_preprocess(image, minnum=1, maxnum=12, imagesize=448, use_thumbnail=True):\n",
    "    orig_w, orig_h = image.size\n",
    "    aspect_ratio = orig_w / orig_h\n",
    "    target_ratios = set()\n",
    "    for n in range(minnum, maxnum + 1):\n",
    "        for i in range(1, n + 1):\n",
    "            for j in range(1, n + 1):\n",
    "                if i * j <= maxnum and i * j >= minnum:\n",
    "                    target_ratios.add((i, j))\n",
    "    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])\n",
    "    best = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_w, orig_h, imagesize)\n",
    "    target_w = imagesize * best[0]\n",
    "    target_h = imagesize * best[1]\n",
    "    resized = image.resize((target_w, target_h))\n",
    "    processed = []\n",
    "    grid_w = target_w // imagesize\n",
    "    blocks = best[0] * best[1]\n",
    "    for i in range(blocks):\n",
    "        box = (\n",
    "            (i % grid_w) * imagesize,\n",
    "            (i // grid_w) * imagesize,\n",
    "            (i % grid_w + 1) * imagesize,\n",
    "            (i // grid_w + 1) * imagesize,\n",
    "        )\n",
    "        processed.append(resized.crop(box))\n",
    "    if use_thumbnail and len(processed) != 1:\n",
    "        processed.append(image.resize((imagesize, imagesize)))\n",
    "    return processed\n",
    "\n",
    "def internvl_pixel_values_from_path(imgpath, input_size=448, maxnum=12):\n",
    "    image = Image.open(imgpath).convert(\"RGB\")\n",
    "    tiles = dynamic_preprocess(image, imagesize=input_size, use_thumbnail=True, maxnum=maxnum)\n",
    "    transform = build_transform(input_size)\n",
    "    pixel_values = torch.stack([transform(t) for t in tiles])\n",
    "    return pixel_values\n",
    "\n",
    "# -----------------------\n",
    "# 5) Model runners \n",
    "# -----------------------\n",
    "def run_internvl_instruct_chat(MODELPATH, model_name, input_size=448, max_tiles=12, dtype=DTYPE_BF16):\n",
    "    tokenizer = AutoTokenizer.from_pretrained(MODELPATH, trust_remote_code=True, use_fast=False)\n",
    "    model_kwargs = dict(torch_dtype=dtype, low_cpu_mem_usage=True, trust_remote_code=True)\n",
    "    if DEVICE == \"cuda\":\n",
    "        model_kwargs[\"device_map\"] = \"auto\"\n",
    "    model = AutoModel.from_pretrained(MODELPATH, **model_kwargs).eval()\n",
    "\n",
    "    def infer(imgpath, prompt):\n",
    "        pixel_values = internvl_pixel_values_from_path(imgpath, input_size=input_size, maxnum=max_tiles).to(dtype=dtype)\n",
    "        try:\n",
    "            target_device = model.device\n",
    "        except Exception:\n",
    "            target_device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "        pixel_values = pixel_values.to(target_device)\n",
    "        question = f\"<image>\\n{prompt}\"\n",
    "        gencfg = dict(max_new_tokens=32, do_sample=False, num_beams=1,\n",
    "                      repetition_penalty=1.1,\n",
    "                      pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,\n",
    "                      eos_token_id=tokenizer.eos_token_id)\n",
    "        out = model.chat(tokenizer=tokenizer, pixel_values=pixel_values, question=question,\n",
    "                         generation_config=gencfg, history=None, return_history=False)\n",
    "        if isinstance(out, tuple):\n",
    "            out = out[0]\n",
    "        out = str(out).strip().splitlines()[0].strip()\n",
    "        return clean_prediction(out)\n",
    "\n",
    "    return infer\n",
    "\n",
    "def run_internvl_hf_image_text_to_text(MODELPATH, model_name, dtype=DTYPE_BF16):\n",
    "    processor = AutoProcessor.from_pretrained(MODELPATH, trust_remote_code=True)\n",
    "    model_kwargs = dict(torch_dtype=dtype, low_cpu_mem_usage=True, trust_remote_code=True)\n",
    "    if DEVICE == \"cuda\":\n",
    "        model_kwargs[\"device_map\"] = \"auto\"\n",
    "    model = AutoModelForImageTextToText.from_pretrained(MODELPATH, **model_kwargs).eval()\n",
    "\n",
    "    def infer(imgpath, prompt):\n",
    "        image = Image.open(imgpath).convert(\"RGB\")\n",
    "        messages = [{\"role\": \"user\", \"content\": [{\"type\": \"image\", \"image\": image}, {\"type\": \"text\", \"text\": prompt}]}]\n",
    "        inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True,\n",
    "                                               return_dict=True, return_tensors=\"pt\")\n",
    "        inputs = {k: v.to(model.device) for k, v in inputs.items()}\n",
    "        if \"pixel_values\" in inputs:\n",
    "            inputs[\"pixel_values\"] = inputs[\"pixel_values\"].to(dtype=dtype)\n",
    "        genids = model.generate(**inputs, max_new_tokens=32, do_sample=False, num_beams=1)\n",
    "        promptlen = inputs[\"input_ids\"].shape[1]\n",
    "        outtext = processor.decode(genids[0, promptlen:], skip_special_tokens=True).strip()\n",
    "        outtext = outtext.splitlines()[0].strip()\n",
    "        return clean_prediction(outtext)\n",
    "\n",
    "    return infer\n",
    "\n",
    "def run_llava(MODELPATH, model_name, dtype=DTYPE_FP16):\n",
    "    model = LlavaForConditionalGeneration.from_pretrained(\n",
    "        MODELPATH, torch_dtype=dtype, device_map=\"auto\" if DEVICE == \"cuda\" else None\n",
    "    ).eval()\n",
    "    processor = AutoProcessor.from_pretrained(MODELPATH)\n",
    "\n",
    "    def infer(imgpath, prompt):\n",
    "        image = Image.open(imgpath).convert(\"RGB\")\n",
    "        inputs = processor(text=prompt, images=image, return_tensors=\"pt\")\n",
    "        inputs = {k: v.to(model.device) for k, v in inputs.items()}\n",
    "        outids = model.generate(**inputs, max_new_tokens=12, do_sample=False,\n",
    "                                temperature=0.0, repetition_penalty=1.1)\n",
    "        ans = processor.decode(outids[0], skip_special_tokens=True)\n",
    "        if \"Answer:\" in ans:\n",
    "            ans = ans.split(\"Answer:\")[-1].strip()\n",
    "        return clean_prediction(ans)\n",
    "\n",
    "    return infer\n",
    "\n",
    "def run_qwen25_vl(MODELPATH, model_name, dtype=DTYPE_BF16, max_new_tokens=10):\n",
    "    model = Qwen25VLForConditionalGeneration.from_pretrained(\n",
    "        MODELPATH, torch_dtype=dtype, device_map=\"auto\" if DEVICE == \"cuda\" else None, trust_remote_code=True\n",
    "    ).eval()\n",
    "    processor = AutoProcessor.from_pretrained(MODELPATH, trust_remote_code=True, use_fast=False)\n",
    "\n",
    "    def infer(imgpath, prompt):\n",
    "        image = Image.open(imgpath).convert(\"RGB\")\n",
    "        messages = [{\"role\": \"user\", \"content\": [{\"type\": \"image\", \"image\": image}, {\"type\": \"text\", \"text\": prompt}]}]\n",
    "        text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
    "        image_inputs, video_inputs = process_vision_info(messages)\n",
    "        inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors=\"pt\").to(DEVICE)\n",
    "        outids = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)\n",
    "        out = processor.batch_decode(outids, skip_special_tokens=True, clean_up_tokenization_spaces=False)\n",
    "        out = out[0] if isinstance(out, list) else out\n",
    "        return clean_prediction(out)\n",
    "\n",
    "    return infer\n",
    "\n",
    "def run_qwen3_vl(MODELPATH, model_name, dtype=DTYPE_BF16, max_new_tokens=12):\n",
    "    model = Qwen3VLForConditionalGeneration.from_pretrained(\n",
    "        MODELPATH, torch_dtype=dtype, device_map=\"auto\" if DEVICE == \"cuda\" else None, trust_remote_code=True\n",
    "    ).eval()\n",
    "    processor = AutoProcessor.from_pretrained(MODELPATH, trust_remote_code=True, use_fast=False)\n",
    "\n",
    "    def infer(imgpath, prompt):\n",
    "        image = Image.open(imgpath).convert(\"RGB\")\n",
    "        messages = [{\"role\": \"user\", \"content\": [{\"type\": \"image\", \"image\": image}, {\"type\": \"text\", \"text\": prompt}]}]\n",
    "        text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
    "        image_inputs, video_inputs = process_vision_info(messages)\n",
    "        inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors=\"pt\").to(DEVICE)\n",
    "        outids = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)\n",
    "        gen = outids[0, inputs[\"input_ids\"].shape[1]:]\n",
    "        out = processor.decode(gen, skip_special_tokens=True, clean_up_tokenization_spaces=False).strip()\n",
    "        return clean_prediction(out)\n",
    "\n",
    "    return infer\n",
    "\n",
    "# -----------------------\n",
    "# 6) Register models (edit paths/ids)\n",
    "# -----------------------\n",
    "MODELS = [\n",
    "    dict(\n",
    "        name=\"InternVL3.5-30B-A3B-Instruct\",\n",
    "        kind=\"internvl_chat\",\n",
    "        modelpath=os.path.expanduser(\"~/.cache/huggingface/hub/models--OpenGVLab--InternVL35-30B-A3B-Instruct/snapshots/83f9a51dbd940c291fb149debee61502f19444d2\"),\n",
    "        prompt=PROMPT_MAIN,\n",
    "        kwargs=dict(input_size=448, max_tiles=12, dtype=DTYPE_BF16),\n",
    "    ),\n",
    "    dict(\n",
    "        name=\"InternVL3.5-8B-Instruct\",\n",
    "        kind=\"internvl_chat\",\n",
    "        modelpath=os.path.expanduser(\"~/.cache/huggingface/hub/models--OpenGVLab--InternVL35-8B-Instruct/snapshots/27506812aa9914804018996329e895977ee2d0c8\"),\n",
    "        prompt=PROMPT_MAIN,\n",
    "        kwargs=dict(input_size=448, max_tiles=12, dtype=DTYPE_BF16),\n",
    "    ),\n",
    "    dict(\n",
    "        name=\"InternVL3.5-4B\",\n",
    "        kind=\"internvl_chat\",\n",
    "        modelpath=os.path.expanduser(\"~/.cache/huggingface/hub/models--OpenGVLab--InternVL35-4B/snapshots/481f6e32467eab4e922ccd7fd6cf420441a62331\"),\n",
    "        prompt=PROMPT_MAIN,\n",
    "        kwargs=dict(input_size=448, max_tiles=12, dtype=DTYPE_BF16),\n",
    "    ),\n",
    "    dict(\n",
    "        name=\"LLaVA-1.5-7B-HF\",\n",
    "        kind=\"llava\",\n",
    "        modelpath=os.path.expanduser(\"~/.cache/huggingface/hub/models--llava-hf--llava-1.5-7b-hf/snapshots/b234b804b114d9e37bb655e11cbbb5f5e971b7a9\"),\n",
    "        prompt=PROMPT_MAIN,\n",
    "        kwargs=dict(dtype=DTYPE_FP16),\n",
    "    ),\n",
    "    dict(\n",
    "        name=\"Qwen2.5-VL-3B\",\n",
    "        kind=\"qwen25\",\n",
    "        modelpath=os.path.expanduser(\"~/.cache/huggingface/hub/models--Qwen--Qwen2.5-VL-3B-Instruct/snapshots/66285546d2b821cf421d4f5eb2576359d3770cd3\"),\n",
    "        prompt=PROMPT_MAIN,\n",
    "        kwargs=dict(dtype=DTYPE_BF16, max_new_tokens=10),\n",
    "    ),\n",
    "    dict(\n",
    "        name=\"Qwen2.5-VL-7B\",\n",
    "        kind=\"qwen25\",\n",
    "        modelpath=os.path.expanduser(\"~/.cache/huggingface/hub/models--Qwen--Qwen2.5-VL-7B-Instruct/snapshots/cc594898137f460bfe9f0759e9844b3ce807cfb5\"),\n",
    "        prompt=PROMPT_MAIN,\n",
    "        kwargs=dict(dtype=DTYPE_BF16, max_new_tokens=10),\n",
    "    ),\n",
    "    dict(\n",
    "        name=\"Qwen2.5-VL-32B\",\n",
    "        kind=\"qwen25\",\n",
    "        modelpath=os.path.expanduser(\"~/Qwen2.5-VL-32B-Instruct\"),\n",
    "        prompt=PROMPT_MAIN,\n",
    "        kwargs=dict(dtype=DTYPE_BF16, max_new_tokens=12),\n",
    "    ),\n",
    "    dict(\n",
    "        name=\"Qwen3-VL-4B\",\n",
    "        kind=\"qwen3\",\n",
    "        modelpath=\"Qwen/Qwen3-VL-4B-Instruct\",\n",
    "        prompt=PROMPT_MAIN,\n",
    "        kwargs=dict(dtype=DTYPE_BF16, max_new_tokens=12),\n",
    "    ),\n",
    "    dict(\n",
    "        name=\"Qwen3-VL-8B\",\n",
    "        kind=\"qwen3\",\n",
    "        modelpath=os.path.expanduser(\"~/.cache/huggingface/hub/models--Qwen--Qwen3-VL-8B-Instruct/snapshots/0c351dd01ed87e9c1b53cbc748cba10e6187ff3b\"),\n",
    "        prompt=PROMPT_MAIN,\n",
    "        kwargs=dict(dtype=DTYPE_BF16, max_new_tokens=12),\n",
    "    ),\n",
    "]\n",
    "\n",
    "def build_infer_fn(spec):\n",
    "    if spec[\"kind\"] == \"internvl_chat\":\n",
    "        return run_internvl_instruct_chat(spec[\"modelpath\"], spec[\"name\"], **spec[\"kwargs\"])\n",
    "    if spec[\"kind\"] == \"internvl_hf\":\n",
    "        return run_internvl_hf_image_text_to_text(spec[\"modelpath\"], spec[\"name\"], **spec[\"kwargs\"])\n",
    "    if spec[\"kind\"] == \"llava\":\n",
    "        return run_llava(spec[\"modelpath\"], spec[\"name\"], **spec[\"kwargs\"])\n",
    "    if spec[\"kind\"] == \"qwen25\":\n",
    "        return run_qwen25_vl(spec[\"modelpath\"], spec[\"name\"], **spec[\"kwargs\"])\n",
    "    if spec[\"kind\"] == \"qwen3\":\n",
    "        return run_qwen3_vl(spec[\"modelpath\"], spec[\"name\"], **spec[\"kwargs\"])\n",
    "    raise ValueError(spec[\"kind\"])\n",
    "\n",
    "# -----------------------\n",
    "# 7) Eval all models\n",
    "# -----------------------\n",
    "all_summary = []\n",
    "\n",
    "for spec in MODELS:\n",
    "    model_name = spec[\"name\"]\n",
    "    outprefix = os.path.join(OUTDIR, re.sub(r\"[^a-zA-Z0-9._-]+\", \"\", model_name.lower()))\n",
    "    infer_fn = build_infer_fn(spec)\n",
    "\n",
    "    results = []\n",
    "    exact_cnt = 0\n",
    "    f1_sum = 0.0\n",
    "\n",
    "    print(\"=\" * 80)\n",
    "    print(f\"Evaluating {len(dataset)} puzzles with {model_name}\")\n",
    "    print(\"=\" * 80)\n",
    "\n",
    "    for pid, imgpath, image, gt in tqdm(loader, total=len(dataset)):\n",
    "        pred = infer_fn(imgpath, spec[\"prompt\"])\n",
    "        pred = clean_prediction(pred)\n",
    "\n",
    "        em = (normalize_answer(pred) == normalize_answer(gt))\n",
    "        f1 = token_f1(pred, gt)\n",
    "\n",
    "        exact_cnt += int(em)\n",
    "        f1_sum += f1\n",
    "\n",
    "        results.append({\n",
    "            \"model\": model_name,\n",
    "            \"puzzleid\": int(pid),\n",
    "            \"gtanswer\": gt,\n",
    "            \"predanswer\": pred,\n",
    "            \"exactmatch\": bool(em),\n",
    "            \"tokenf1\": float(f1),\n",
    "        })\n",
    "\n",
    "    df = pd.DataFrame(results)\n",
    "    df.to_csv(outprefix + \".csv\", index=False)\n",
    "    with open(outprefix + \".jsonl\", \"w\", encoding=\"utf-8\") as f:\n",
    "        for r in results:\n",
    "            f.write(json.dumps(r, ensure_ascii=False) + \"\\n\")\n",
    "\n",
    "    n = len(results)\n",
    "    acc = 100.0 * exact_cnt / n if n else 0.0\n",
    "    avg_f1 = 100.0 * (f1_sum / n) if n else 0.0\n",
    "\n",
    "    all_summary.append({\n",
    "        \"model\": model_name,\n",
    "        \"num\": n,\n",
    "        \"exact_acc\": acc,\n",
    "        \"avg_token_f1\": avg_f1,\n",
    "        \"csv\": os.path.basename(outprefix + \".csv\"),\n",
    "        \"jsonl\": os.path.basename(outprefix + \".jsonl\"),\n",
    "    })\n",
    "\n",
    "summary_df = pd.DataFrame(all_summary).sort_values(\"exact_acc\", ascending=False)\n",
    "summary_df.to_csv(os.path.join(OUTDIR, \"summary.csv\"), index=False)\n",
    "print(summary_df)\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
