{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/mnt/home/.conda/envs/mlk/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "Downloading readme: 100%|██████████| 5.33k/5.33k [00:00<00:00, 19.4MB/s]\n",
      "Downloading data: 100%|██████████| 159M/159M [00:04<00:00, 35.0MB/s] \n",
      "Downloading data: 100%|██████████| 165M/165M [00:03<00:00, 41.8MB/s] \n",
      "Downloading data: 100%|██████████| 152M/152M [00:04<00:00, 36.5MB/s] \n",
      "Downloading data: 100%|██████████| 152M/152M [00:03<00:00, 40.9MB/s] \n",
      "Downloading data: 100%|██████████| 133M/133M [00:03<00:00, 40.4MB/s] \n",
      "Generating colorjitter split: 100%|██████████| 500/500 [00:00<00:00, 904.72 examples/s] \n",
      "Generating elastic split: 100%|██████████| 500/500 [00:00<00:00, 1106.21 examples/s]\n",
      "Generating gaussianblur split: 100%|██████████| 500/500 [00:00<00:00, 1221.54 examples/s]\n",
      "Generating rotate split: 100%|██████████| 500/500 [00:00<00:00, 1238.66 examples/s]\n",
      "Generating perspective split: 100%|██████████| 500/500 [00:00<00:00, 1456.51 examples/s]\n",
      "Downloading data: 100%|██████████| 47.2k/47.2k [00:00<00:00, 125kB/s]\n",
      "Downloading data: 100%|██████████| 47.2k/47.2k [00:00<00:00, 128kB/s]\n",
      "Downloading data: 100%|██████████| 28.2k/28.2k [00:00<00:00, 78.8kB/s]\n",
      "Downloading data: 100%|██████████| 47.3k/47.3k [00:00<00:00, 139kB/s]\n",
      "Generating in100 split: 100%|██████████| 1/1 [00:00<00:00, 131.43 examples/s]\n",
      "Generating coco split: 100%|██████████| 1/1 [00:00<00:00, 219.90 examples/s]\n",
      "Generating wu_img_text split: 100%|██████████| 1/1 [00:00<00:00, 202.46 examples/s]\n",
      "Generating wu_img_img split: 100%|██████████| 1/1 [00:00<00:00, 51.58 examples/s]\n"
     ]
    }
   ],
   "source": [
    "import datasets\n",
    "\n",
    "CONFIG_NAME = \"coco\"\n",
    "SPLIT_NAME = \"rotate\"\n",
    "COND = \"invariant\"\n",
    "\n",
    "ds = datasets.load_from_disk(f\"./{CONFIG_NAME}\")[SPLIT_NAME]\n",
    "ds_templates = datasets.load_from_disk(\"./templates\")[CONFIG_NAME]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load Qwen2-VL-7B-Instruct Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████| 5/5 [01:16<00:00, 15.32s/it]\n",
      "Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.\n"
     ]
    }
   ],
   "source": [
    "from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor\n",
    "from qwen_vl_utils import process_vision_info\n",
    "\n",
    "# default: Load the model on the available device(s)\n",
    "model = Qwen2VLForConditionalGeneration.from_pretrained(\n",
    "    \"Qwen/Qwen2-VL-7B-Instruct\", torch_dtype=\"auto\", device_map=\"auto\"\n",
    ")\n",
    "processor = AutoProcessor.from_pretrained(\"Qwen/Qwen2-VL-7B-Instruct\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json \n",
    "\n",
    "templates = json.loads(ds_templates['query_templates'][0])\n",
    "condition_dict = json.loads(ds_templates['query_conditions'][0])\n",
    "logistics = json.loads(ds_templates['logistics'][0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Metric functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import normalized_mutual_info_score\n",
    "from scipy.stats import entropy\n",
    "import re \n",
    "\n",
    "from collections import defaultdict\n",
    "import math\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "def compute_relaxed_symmetry(results, eps=1):\n",
    "    grouped = defaultdict(list)\n",
    "    for r in results:\n",
    "        key = (f\"{r['index']}_{r['condition']}\", tuple(r['img_keys']))\n",
    "        grouped[key].append(r['score'])\n",
    "    total = len(grouped)\n",
    "    symmetric = sum(1 for scores in grouped.values() if len(scores) == 2 and abs(scores[0] - scores[1]) <= eps)\n",
    "    return symmetric / total if total > 0 else 0\n",
    "\n",
    "def compute_mmscore(results):\n",
    "    y_true = [r[\"gt\"] for r in results]\n",
    "    y_pred = [r[\"score\"] for r in results]\n",
    "    return normalized_mutual_info_score(y_true, y_pred)\n",
    "\n",
    "def compute_controllability(results_invar, results_var):\n",
    "    mmscore_invar = compute_mmscore(results_invar)\n",
    "    mmscore_var = compute_mmscore(results_var)\n",
    "    denominator = math.sqrt(mmscore_var * mmscore_invar)\n",
    "    return 1 - abs(mmscore_var - mmscore_invar) / denominator if denominator > 0 else 0\n",
    "\n",
    "def compute_smoothness(results):\n",
    "    scores = [r[\"score\"] for r in results]\n",
    "    counts = np.bincount(scores, minlength=11)  # 0-10\n",
    "    prob = counts / counts.sum()\n",
    "    return entropy(prob, base=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Parse Response\n",
    "    - Get Score and Reason\n",
    "    - If not parsable, use first integer in string as a heuristic\n",
    "    - If no integers were found, score = -1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_response(resp):\n",
    "    # Unescape literal '\\n' to actual newlines, and strip surrounding whitespace\n",
    "    cleaned = resp.encode('utf-8').decode('unicode_escape').strip()\n",
    "\n",
    "    # Try to parse using standard format\n",
    "    match = re.search(r\"Score:\\s*(\\d+)\\s*Reason:\\s*(.+)\", cleaned, re.IGNORECASE | re.DOTALL)\n",
    "    if match:\n",
    "        score = int(match.group(1))\n",
    "        reason = match.group(2).strip()\n",
    "        return score, reason\n",
    "\n",
    "    # Fallback: extract first integer and return full text as reason\n",
    "    fallback_score = re.search(r\"\\d+\", cleaned)\n",
    "    score = int(fallback_score.group(0)) if fallback_score else -1\n",
    "    return score, cleaned"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Processing condition: invariant\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/20 [00:00<?, ?it/s]/mnt/home/.conda/envs/mlk/lib/python3.11/site-packages/transformers/generation/configuration_utils.py:629: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.01` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n",
      "  warnings.warn(\n",
      "/mnt/home/.conda/envs/mlk/lib/python3.11/site-packages/transformers/generation/configuration_utils.py:634: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.001` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n",
      "  warnings.warn(\n",
      "/mnt/home/.conda/envs/mlk/lib/python3.11/site-packages/transformers/generation/configuration_utils.py:651: UserWarning: `do_sample` is set to `False`. However, `top_k` is set to `1` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_k`.\n",
      "  warnings.warn(\n",
      "100%|██████████| 20/20 [01:03<00:00,  3.16s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Processing condition: variant\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:56<00:00,  2.81s/it]\n"
     ]
    }
   ],
   "source": [
    "import io\n",
    "import base64\n",
    "from tqdm import tqdm\n",
    "from qwen_vl_utils import process_vision_info\n",
    "\n",
    "def pil_to_base64_str(pil_img):\n",
    "    buffer = io.BytesIO()\n",
    "    pil_img.save(buffer, format=\"PNG\")\n",
    "    base64_str = base64.b64encode(buffer.getvalue()).decode(\"utf-8\")\n",
    "    return f\"data:image/png;base64,{base64_str}\"\n",
    "\n",
    "\n",
    "all_results = {}\n",
    "for condition in ['invariant', 'variant']:\n",
    "    print(f\"Processing condition: {condition}\")\n",
    "    results = []\n",
    "\n",
    "    for i in tqdm(range(20)):\n",
    "        sample = ds[i]\n",
    "        for img_pair in logistics[\"data-pairs\"]:\n",
    "            # Randomly choose a prompt version for each pair\n",
    "            prompt_version = random.choice(list(templates.keys()))\n",
    "            raw_template = templates[prompt_version]\n",
    "\n",
    "            for reverse in [False, True]:\n",
    "                img_key_1, img_key_2 = img_pair if not reverse else img_pair[::-1]\n",
    "                img1, img2 = sample[img_key_1], sample[img_key_2]\n",
    "\n",
    "                # Convert images to base64\n",
    "                img1_base64 = pil_to_base64_str(img1)\n",
    "                img2_base64 = pil_to_base64_str(img2)\n",
    "\n",
    "                # Construct the prompt\n",
    "                cond_text = condition_dict[\"rotation\"][condition]\n",
    "                prompt = raw_template.format(conditions=cond_text)\n",
    "\n",
    "                # Build multimodal message\n",
    "                messages = [{\n",
    "                    \"role\": \"user\",\n",
    "                    \"content\": [\n",
    "                        {\"type\": \"image\", \"image\": img1_base64},\n",
    "                        {\"type\": \"image\", \"image\": img2_base64},\n",
    "                        {\"type\": \"text\", \"text\": prompt},\n",
    "                    ]\n",
    "                }]\n",
    "\n",
    "                # Apply chat template and process vision input\n",
    "                text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
    "                image_inputs, video_inputs = process_vision_info(messages)\n",
    "\n",
    "                inputs = processor(\n",
    "                    text=[text],\n",
    "                    images=image_inputs,\n",
    "                    videos=video_inputs,\n",
    "                    padding=True,\n",
    "                    return_tensors=\"pt\",\n",
    "                ).to(model.device)\n",
    "\n",
    "                with torch.no_grad():\n",
    "                    generated_ids = model.generate(**inputs, max_new_tokens=256, do_sample=False)\n",
    "\n",
    "                generated_ids_trimmed = [\n",
    "                    out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)\n",
    "                ]\n",
    "                response = processor.batch_decode(\n",
    "                    generated_ids_trimmed,\n",
    "                    skip_special_tokens=True,\n",
    "                    clean_up_tokenization_spaces=False,\n",
    "                )[0]\n",
    "                \n",
    "                score, reason = parse_response(response)\n",
    "                keys = sorted([img_key_1, img_key_2])\n",
    "                if condition == 'invariant':\n",
    "                    gt = 10 if keys in [[\"img1\", \"img2\"], [\"img1\", \"img3\"]] else 1\n",
    "                else: # 'variant'\n",
    "                    if keys == [\"img1\", \"img2\"]:\n",
    "                        gt = 10\n",
    "                    elif keys == [\"img1\", \"img3\"]:\n",
    "                        gt = 6\n",
    "                    else: # [\"img1\", \"img4\"]\n",
    "                        gt = 1\n",
    "                # Save result\n",
    "                results.append({\n",
    "                    \"index\": i,\n",
    "                    \"img_keys\": keys,\n",
    "                    \"reversed\": reverse,\n",
    "                    \"template_version\": prompt_version,\n",
    "                    \"condition\": condition,\n",
    "                    \"cond_text\": cond_text,\n",
    "                    \"prompt\": prompt,\n",
    "                    \"gt\": gt,\n",
    "                    \"model_response\": response,\n",
    "                    \"score\": score,\n",
    "                    \"reason\": reason,\n",
    "                })\n",
    "            \n",
    "        \n",
    "    all_results[condition] = results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing metrics...\n",
      "Relaxed Symmetry: 0.9667\n",
      "MMSCORE: 0.5370\n",
      "Controllability: 0.8517\n",
      "Smoothness: 2.4471\n"
     ]
    }
   ],
   "source": [
    "print(\"Computing metrics...\")\n",
    "results_invar = all_results['invariant']\n",
    "results_var = all_results['variant']\n",
    "\n",
    "relaxed_symmetry_invar = compute_relaxed_symmetry(results_invar)\n",
    "relaxed_symmetry_var = compute_relaxed_symmetry(results_var)\n",
    "relaxed_symmetry = compute_relaxed_symmetry(results_var + results_invar)\n",
    "mmscore = compute_mmscore(results_invar + results_var)\n",
    "controllability = compute_controllability(results_invar, results_var)\n",
    "smoothness = compute_smoothness(results_invar + results_var)\n",
    "print(f\"Relaxed Symmetry: {relaxed_symmetry:.4f}\")\n",
    "print(f\"MMSCORE: {mmscore:.4f}\")\n",
    "print(f\"Controllability: {controllability:.4f}\")\n",
    "print(f\"Smoothness: {smoothness:.4f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
