#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse, json, os, sys, re
from pathlib import Path
from typing import Any, Dict, List, Optional
from datasets import load_dataset
from PIL import Image

import re
import torch
from PIL import Image
from utils import get_model
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
# ------------------------------
# Adapters (minimal)
# ------------------------------
def _dtype_from_str(dtype: Optional[str]):
    if not dtype or dtype == "auto":
        return None
    d = dtype.lower()
    if d in ("fp16","float16","half"): return torch.float16
    if d in ("bf16","bfloat16"):       return torch.bfloat16
    if d in ("fp32","float32","full"): return torch.float32
    raise ValueError(f"Unsupported dtype: {dtype}")

class BaseVLMAdapter:
    def answer(self, image: Image.Image, question: str) -> str:
        raise NotImplementedError


class QwenVLAdapter(BaseVLMAdapter):
    """
    Qwen2.5-VL adapter (HF transformers).
    Works with: Qwen/Qwen2.5-VL-7B-Instruct (or your fine-tuned local path).
    """
    def __init__(self, lora = False,lorapath="./Finetuned_qwen/lora-finetuned-best",device="cuda"):#"./Finetuned_qwen/qwen-lora-finetuned-manual3"
        from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
        self.model, self.processor = get_model('qwen',cache_path='cache') 
        if lora:
            from peft import PeftModel
            self.loraname = lorapath
            self.model = PeftModel.from_pretrained(self.model, self.loraname, torch_dtype=torch.float16)
            self.model.to(device)
            self.processor = AutoProcessor.from_pretrained(self.loraname)
            print("✅ LoRA weights loaded")   
        self.model.to(device)
    @torch.inference_mode()
    def answer(self, image: Image.Image, question: str) -> str:
        # Build chat content
        messages = [{
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": f"{question} Answer with a single word or short phrase. No punctuation"},
            ],
        }]

        text_prompt = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = self.processor(text=[text_prompt], images=[image], return_tensors="pt", padding=True).to(self.model.device)

        output_ids = self.model.generate(
            **inputs,
            max_new_tokens=128,
        )
        generated_ids = [
            output_ids[len(input_ids) :]
            for input_ids, output_ids in zip(inputs.input_ids, output_ids)
        ]
        output_text = self.processor.batch_decode(
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
        )
        return output_text[0].strip()


ADAPTERS = {"qwen": QwenVLAdapter}



# ------------------------------
# I/O helpers
# ------------------------------
def load_vqav2(questions_json: Path, annotations_json: Path) -> Dict[int, Dict[str, Any]]:
    qs = json.load(open(questions_json, "r"))
    ann = json.load(open(annotations_json, "r"))
    q_by_id = {q["question_id"]: q for q in qs["questions"]}
    gt_by_qid = {}
    for a in ann["annotations"]:
        qid = a["question_id"]
        gt_by_qid[qid] = [x["answer"] for x in a["answers"]]
    # merge
    merged = {}
    for qid, q in q_by_id.items():
        merged[qid] = {
            "image_id": q["image_id"],
            "question": q["question"],
            "answers": gt_by_qid.get(qid, []),
        }
    return merged

def image_path_coco(images_root: Path, image_id: int, split: str) -> Path:
    # VQAv2 uses COCO train2014/val2014 in file names like COCO_val2014_000000XXXXXX.jpg
    return images_root / f"{image_id:012d}.jpg"



# Tables adapted to match the official VQAv2 eval behavior
_CONTRACTIONS = {
    "aint":"ain't","arent":"aren't","cant":"can't","couldve":"could've","couldnt":"couldn't",
    "didnt":"didn't","doesnt":"doesn't","dont":"don't","hadnt":"hadn't","hasnt":"hasn't",
    "havent":"haven't","hed":"he'd","hes":"he's","howd":"how'd","hows":"how's","id":"i'd",
    "im":"i'm","ive":"i've","isnt":"isn't","itd":"it'd","itll":"it'll","lets":"let's",
    "mightnt":"mightn't","mightve":"might've","mustnt":"mustn't","mustve":"must've",
    "neednt":"needn't","shant":"shan't","shed":"she'd","shes":"she's","shouldve":"should've",
    "shouldnt":"shouldn't","somebodys":"somebody's","someones":"someone's","thatll":"that'll",
    "thats":"that's","theres":"there's","theyd":"they'd","theyre":"they're","theyve":"they've",
    "wasnt":"wasn't","wed":"we'd","were":"we're","weve":"we've","werent":"weren't",
    "whatll":"what'll","whatre":"what're","whats":"what's","whos":"who's","wont":"won't",
    "wouldve":"would've","wouldnt":"wouldn't","yall":"y'all","youd":"you'd","youre":"you're","youve":"you've"
}
_ARTICLES = {"a","an","the"}
_NUM_MAP = {"zero":"0","one":"1","two":"2","three":"3","four":"4","five":"5",
            "six":"6","seven":"7","eight":"8","nine":"9","ten":"10"}
_PERIOD_STRIP = re.compile(r"(?!\d)\.|(?!\d)(?<=\s)\.|,(?!\d)")
_COMMA_FIX = re.compile(r"(\d)(,)(\d)")
_PUNCT = re.compile(r"([!\"#$%&'()*+/:;<=>?@\[\\\]^_`{|}~])")

def vqa_normalize(s: str) -> str:
    s = s.strip().lower()
    s = _COMMA_FIX.sub(r"\1\3", s)
    if "." in s:
        s = _PERIOD_STRIP.sub("", s)
    s = _PUNCT.sub(" ", s)
    words = []
    for w in s.split():
        w = _NUM_MAP.get(w, w)
        if w in _ARTICLES:
            continue
        w = _CONTRACTIONS.get(w, w)
        words.append(w)
    return " ".join(words).strip()

def vqa_accuracy(pred: str, human_answers):
    p = vqa_normalize(pred)
    gts = [vqa_normalize(human_answers)]
    return 1.0 if p == vqa_normalize(human_answers) else 0.0


# ------------------------------
# Main
# ------------------------------
def main():
    ap = argparse.ArgumentParser(description="VQAv2 evaluate (predict + score)")
    ap.add_argument("--images_root", type=Path, required=True, help="COCO images root for the given split (train2014 or val2014)")
    ap.add_argument("--questions_json", type=Path, required=True, help="v2_OpenEnded_mscoco_[split]_questions.json")
    ap.add_argument("--annotations_json", type=Path, required=True, help="v2_mscoco_[split]_annotations.json")
    ap.add_argument("--coco_split", type=str, default="val", choices=["train","val"], help="train or val (affects filename prefix)")
    ap.add_argument("--limit", type=int, default=-1, help="debug: limit #questions")
    ap.add_argument("--model", type=str, required=True, choices=list(ADAPTERS.keys()))
    ap.add_argument("--hf_model_id", type=str, required=True)
    ap.add_argument("--device", type=str, default="cuda")
    ap.add_argument("--dtype", type=str, default="auto")
    ap.add_argument("--out_dir", type=Path, required=True)
    ap.add_argument("--lora",type=bool,default=False,help="Use LoRA weights (only for QwenAdapter)")
    args = ap.parse_args()

    args.out_dir.mkdir(parents=True, exist_ok=True)
    preds_path = args.out_dir / "preds.json"
    metrics_path = args.out_dir / "metrics.json"

    

    # adapter
    Adapter = ADAPTERS[args.model]
    adapter = Adapter(lora=args.lora ,device=args.device)
    ds = load_dataset("merve/vqav2-small", split="validation")
    subset = ds.shuffle(seed=42).select(range(1000))
    print(f"[VQAv2] n={len(subset)}")
    

    #print(len(subset))      # 1000
    #print(subset[0]) 
    #print(ds)
    acc_sum=0
    total=0
    for i, ex in enumerate(ds):
        #print(ex)
        if i % 100 == 0:
            print(f"Processing {i} / {len(ds)} ...")
        img = ex["image"].convert("RGB")   # already PIL
        q = ex["question"]
        gts = ex["multiple_choice_answer"]

        pred = adapter.answer(img, q)      # your model’s adapter
        acc = vqa_accuracy(pred, gts)      # same normalization + formula as in script
        #print(i, q, pred, gts, acc)
        acc_sum += acc
        total += 1



    overall = {"n": total, "vqa_accuracy": 100.0 * (acc_sum / max(1,total))}
    with open(metrics_path, "w", encoding="utf-8") as f:
        json.dump(overall, f, indent=2)
    print(f"[VQAv2] Accuracy: {overall['vqa_accuracy']:.2f}%  (n={overall['n']})")
    print(f"[VQAv2] wrote {metrics_path}")

if __name__ == "__main__":
    main()
