# qwen_coco_caption_eval.py
# Usage:
#   python qwen_coco_caption_eval.py \
#     --model_id Qwen/Qwen2.5-VL-7B-Instruct \
#     --coco_images /path/to/coco/val2017 \
#     --coco_ann /path/to/coco/annotations/captions_val2017.json \
#     --out_json ./pred_captions_val2017_qwen.json \
#     --max_new_tokens 40 --num_beams 3 --temperature 0.0

import argparse
import json
import os
from typing import Optional, List, Dict

from PIL import Image
from tqdm import tqdm

import torch
from transformers import AutoProcessor

# Qwen 2.5 VL import name
from transformers import Qwen2_5_VLForConditionalGeneration
from utils import get_model, get_vllm_output

# ------------------------------
# Adapter
# ------------------------------

class BaseAdapter:
    def generate_caption(self, image: Image.Image, question: str) -> str:
        raise NotImplementedError

class QwenVLAdapter(BaseAdapter):
    """
    Qwen2.5-VL adapter (HF transformers).
    Works with: Qwen/Qwen2.5-VL-7B-Instruct (or your fine-tuned local path).
    """
    def __init__(self, lora = False,lorapath="./Finetuned_qwen/lora-finetuned-best",device="cuda"):#"./Finetuned_qwen/qwen-lora-finetuned-manual3"
        from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
        self.model, self.processor = get_model('qwen',cache_path='cache') 
        if lora:
            from peft import PeftModel
            self.loraname = lorapath
            self.model = PeftModel.from_pretrained(self.model, self.loraname, torch_dtype=torch.float16)
            self.model.to(device)
            self.processor = AutoProcessor.from_pretrained(self.loraname)
            print("✅ LoRA weights loaded")   
        self.model.to(device)
    @torch.inference_mode()
    def generate_caption(self, image: Image.Image, prompt: str) -> str:
        # Build chat content
        messages = [{
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": f"{prompt}"},
            ],
        }]

        text_prompt = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = self.processor(text=[text_prompt], images=[image], return_tensors="pt", padding=True).to(self.model.device)

        output_ids = self.model.generate(
            **inputs,
            max_new_tokens=64,
        )
        generated_ids = [
            output_ids[len(input_ids) :]
            for input_ids, output_ids in zip(inputs.input_ids, output_ids)
        ]
        caption = self.processor.batch_decode(
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
        )[0]


        caption = caption.strip()

        # Some chat models echo the prompt. Try to keep only the model's final text.
        # A lightweight cleanup that keeps the last non-empty line.
        if "\n" in caption:
            caption = [ln.strip() for ln in caption.split("\n") if ln.strip()][-1]

        # Optional: remove leading artifacts like "Caption:" etc.
        for t in ["Caption:", "caption:", "A caption:", "a caption:"]:
            if caption.startswith(t):
                caption = caption[len(t):].strip()

        # Keep it single-line, short
        caption = caption.replace("\n", " ").strip()
        return caption


ADAPTERS = {"qwen": QwenVLAdapter}

# ------------------------------
# COCO utilities
# ------------------------------
def load_coco_val_image_list(coco_ann_path: str) -> List[Dict]:
    """
    Loads the official COCO val2017 caption annotation JSON and returns
    a list of dicts with {"id": int, "file_name": str}.
    """
    with open(coco_ann_path, "r") as f:
        ann = json.load(f)
    # ann["images"] holds all images with "id" and "file_name"
    return [{"id": img["id"], "file_name": img["file_name"]} for img in ann["images"]]


# ------------------------------
# Main loop
# ------------------------------
def main():
    ap = argparse.ArgumentParser(description="Generate COCO val2017 captions with Qwen (Adapter style).")
    ap.add_argument("--model_id", type=str, default="Qwen/Qwen2.5-VL-7B-Instruct")
    ap.add_argument("--coco_images", type=str, required=True, help="Path to COCO val2017 image folder.")
    ap.add_argument("--coco_ann", type=str, required=True, help="Path to captions_val2017.json.")
    ap.add_argument("--out_json", type=str, required=True, help="Output JSON path for predictions.")
    ap.add_argument("--cache_dir", type=str, default=None)
    ap.add_argument("--hf_token", type=str, default=None)
    ap.add_argument("--device", type=str, default=None, help="'cuda', 'cpu', or None (auto).")
    ap.add_argument("--dtype", type=str, default="float16", choices=["float16", "bfloat16", "float32"])
    ap.add_argument("--device_map_auto", action="store_true")
    ap.add_argument("--lora",type=bool,default=False,help="Use LoRA weights (only for QwenAdapter)")


    # Prompt and control
    ap.add_argument("--prompt", type=str, default="Write a short caption for the given image.")
    ap.add_argument("--limit", type=int, default=0, help="If >0, cap the number of images processed (for quick tests).")
    ap.add_argument("--shuffle", action="store_true", help="Shuffle image order before limiting.")
    args = ap.parse_args()

    # Load adapter
    adapter = QwenVLAdapter(lora=args.lora,device=args.device if args.device else "cuda")

    # Load COCO image list
    images = load_coco_val_image_list(args.coco_ann)

    if args.shuffle:
        import random
        random.shuffle(images)

    if args.limit and args.limit > 0:
        images = images[: args.limit]

    preds = []
    missing = 0

    for rec in tqdm(images, desc="Captioning COCO val2017"):
        img_id = int(rec["id"])
        file_name = rec["file_name"]
        img_path = os.path.join(args.coco_images, file_name)
        if not os.path.exists(img_path):
            missing += 1
            continue

        im = Image.open(img_path).convert("RGB")
        caption = adapter.generate_caption(
            im,
            prompt=args.prompt
                )
        

        print(f"image:{img_id}  caption:{caption}")
        preds.append({"image_id": img_id, "caption": caption})

    # Save predictions
    os.makedirs(os.path.dirname(os.path.abspath(args.out_json)), exist_ok=True)
    with open(args.out_json, "w") as f:
        json.dump(preds, f, ensure_ascii=False)

    print(f"\nSaved {len(preds)} predictions to: {args.out_json}")
    if missing:
        print(f"Skipped {missing} images (file missing).")


if __name__ == "__main__":
    main()
