# eval_pope.py
# Evaluate MiniGPT-4 on POPE-style JSONL dataset

import argparse
import os
import json
from PIL import Image
from tqdm import tqdm

import torch
from transformers import StoppingCriteriaList

from minigpt4.common.config import Config
from minigpt4.common.registry import registry
from minigpt4.conversation.conversation import Chat, CONV_VISION_Vicuna0, CONV_VISION_LLama2, StoppingCriteriaSub

# Load model components
from minigpt4.datasets.builders import *
from minigpt4.models import *
from minigpt4.processors import *
from minigpt4.runners import *
from minigpt4.tasks import *

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--cfg-path", default="/root/autodl-tmp/MiniGPT-4/eval_configs/minigpt4_eval.yaml",  help="path to config file")
    parser.add_argument("--jsonl-file",default="/root/autodl-tmp/dataset/POPE/MSCOCO_AOKVQA/coco_pope_adversarial.json", help="input POPE-style jsonl file")
    parser.add_argument("--image-root",default="/root/autodl-tmp/dataset/POPE/MSCOCO_AOKVQA/val2014", help="image folder")
    parser.add_argument("--output-file", default="pope_base_vicuna.jsonl")
    parser.add_argument("--K", type=int, default=13, help="which layer to prune")
    parser.add_argument("--retain_ratio", type=float, default=1, help="retaining ratio")
    parser.add_argument("--method", type=str, default="our")
    parser.add_argument("--gpu-id", type=int, default=0)
    parser.add_argument(
        "--options",
        nargs="+",
        default=[],
        help="override some settings in the used config"
    )

    return parser.parse_args()

def load_json_or_jsonl(path):
    with open(path, "r", encoding="utf-8") as f:
        first_line = f.readline().strip()
        f.seek(0)
        if first_line.startswith("["):
            return json.load(f)
        else:
            return [json.loads(line.strip()) for line in f if line.strip()]


def run():
    args = parse_args()
    cfg = Config(args)

    conv_dict = {
        'pretrain_vicuna0': CONV_VISION_Vicuna0,
        'pretrain_llama2': CONV_VISION_LLama2,
    }

    model_config = cfg.model_cfg
    model_config.device_8bit = args.gpu_id
    model_cls = registry.get_model_class(model_config.arch)
    model = model_cls.from_config(model_config).to(f'cuda:{args.gpu_id}')
    if args.K is not None:
        try:
            model.llama_model.config.K = args.K
        except Exception:
            pass
    if args.retain_ratio is not None:
        try:
            model.llama_model.config.retain_ratio = args.retain_ratio
        except Exception:
            pass
    if args.method is not None:
        try:
            model.llama_model.config.method = args.method
        except Exception:
            pass

    conv_template = conv_dict[model_config.model_type]

    vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
    vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)

    stop_words_ids = [[835], [2277, 29937]]
    stop_words_ids = [torch.tensor(ids).to(f'cuda:{args.gpu_id}') for ids in stop_words_ids]
    stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])

    chat = Chat(model, vis_processor, device=f'cuda:{args.gpu_id}', stopping_criteria=stopping_criteria)
    # chat.eval_mode()

    samples = load_json_or_jsonl(args.jsonl_file)
    results = []

    for sample in tqdm(samples):
        image_path = os.path.join(args.image_root, sample['image'])
        question = sample['text']
        label = sample.get('label', None)
        qid = sample.get('question_id', None)

        image = Image.open(image_path).convert("RGB")
        conv = conv_template.copy()
        chat_state = conv
        img_list = []
        question += " Answer the question using a single word or phrase."
        try:
            chat.upload_img(image, chat_state, img_list)
            chat.encode_img(img_list)
            chat.ask(question, chat_state)
            answer = chat.answer(
                conv=chat_state,
                img_list=img_list,
                num_beams=1,
                temperature=0.1,
                max_new_tokens=1,
                max_length=2000
            )[0]
        except Exception as e:
            answer = "ERROR"
            print(f"[ERROR] {qid}: {e}")

        results.append({
            "question_id": qid,
            "image": sample["image"],
            "text": question,
            "label": label,
            "pred": answer
        })

        # print(f"[{qid}] Q: {question} | GT: {label} | Pred: {answer}")

    with open(args.output_file, 'w', encoding='utf-8') as f:
        for r in results:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")

if __name__ == "__main__":
    run()
