
import argparse
from tqdm import tqdm
import json
from Models.Qwen_VL_Chat_lora import QwenVLChatLora
from Models.InternVL2_Lora import InternVL2Lora
from Models.mPLUG_Lora import mPLUGLora
from Models.ChatGPT_4o import GPT4o
from main.Ask import get_model_class
import os
from pathlib import Path
import pandas as pd
from utils.conversation import conv_templates, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_TOKEN

def load_data(data_path):
    with open(data_path, "r", encoding="utf-8") as f:
        anno = json.load(f)

    images = anno.get("images", [])
    if not images:
        raise ValueError(f"{data_path} no 'images'")

    rows = []
    for im in images:
        image_id = im.get("id")
        file_name = im.get("file_name")
        if image_id is None or not file_name:
            continue

        rows.append({"image_id": int(image_id), "image_path": f"./datasets/coco2014/val2014/{str(file_name)}"})

    if not rows:
        raise RuntimeError("no local image file")

    return pd.DataFrame(rows)

def ImageCaption(model, args):
    df = load_data(args.data_path)
    df = df[:4000]
    save_dir = Path(args.save_base)
    save_dir.mkdir(parents=True, exist_ok=True)
    out_path = save_dir / f"{model.name}.json"

    if out_path.exists():
        with open(out_path, "r", encoding="utf-8") as f:
            results = json.load(f)
        print(f"Loaded existing results with {len(results)} entries")

        processed_images = {
            item["image_id"]: item
            for item in results
            if item.get("caption", "") != ""
        }
    else:
        results = []
        processed_images = {}
    
    conv = conv_templates[args.conv_mode].copy()

    

    instruct = "please help me describe the image in detail."
    if 'InternVL' in model.name:
        instruct = "Please describe this image in as much detail as possible."
    full_prompt = f"{DEFAULT_IMAGE_TOKEN}\n{instruct}\n" # {context}\n

    conv.append_message(conv.roles[0], full_prompt)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()
    model.generate_prompt(prompt)

    
    for _, row in tqdm(df.iterrows(), total=len(df), desc="Evaluating"):
        image_id = int(row["image_id"])
        image_path = row["image_path"]

        if image_id in processed_images and processed_images[image_id]["caption"].strip():
            continue  
        
        model.generate_prompt(prompt)
        out = model.get_answer(image_path)
        if out is None:
            print("Error!! No output")
            continue
        
        if isinstance(out, (list, tuple)):
            caption = str(out[0])
            print(caption)
        else:
            caption = str(out)
        

        if image_id in processed_images:

            for item in results:
                if item["image_id"] == image_id:
                    item["caption"] = caption
                    break
        else:

            results.append({"image_id": image_id, "caption": caption})
            processed_images[image_id] = {"image_id": image_id, "caption": caption}

        with open(out_path, "w", encoding="utf-8") as f:
            json.dump(results, f, ensure_ascii=False, indent=2)
    
    print(f"[OK] Saved {len(results)} captions -> {out_path}")
    return 0


def main(args):


    if 'qwen-vl-chat-lora' in args.model_path.lower():   
        model = QwenVLChatLora(args)
    elif 'work_dirs' in args.model_path.lower():
        model = InternVL2Lora(args)
    elif 'ms-swift' in args.model_path.lower():
        model = mPLUGLora(args)
    elif args.model_path == None:
        model = GPT4o(args)
    else:
        ModelClass = get_model_class(args.model_path)
        model = ModelClass(args)


    ImageCaption(model, args)
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Hallucination Evaluation under CHAIR")
    parser.add_argument('--model_path', type=str,help="Path to the pre-trained or distilled model file. Specify the location of the model.")
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument('--device', type=str, default="cuda:0")
    parser.add_argument('--data_path', type=str, help="Path to the dataset.")
    parser.add_argument('--temperature', type=float, default=1.0)
    parser.add_argument('--dataset_name', type=str, default="CoCo", help="Dataset Name, serve as savedir")
    parser.add_argument('--split', type=str, default="val",help="The Part of Dataset")
    parser.add_argument("--conv_mode", default="llava_llama_2",help="The system prompt of model")
    parser.add_argument('--max_new_tokens', type=int, default=128)
    parser.add_argument('--save_base', type=str, default="./result/HALLU", help="Directory path where the results will be saved.")


    args = parser.parse_args()

    main(args)
