import json
from tqdm import tqdm
import os
import argparse
from main.Ask import get_model_class
from utils.conversation import conv_templates, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_TOKEN
from Models.Qwen_VL_Chat_lora import QwenVLChatLora
from Models.InternVL2_Lora import InternVL2Lora
from Models.mPLUG_Lora import mPLUGLora

def parse_answer(text):
    text = text.lower().strip()
    if "yes" in text and "no" not in text:
        return "yes"
    elif "no" in text and "yes" not in text:
        return "no"
    return "unknown"  

def Answer(model, args):
    with open("./AHulla/POPE/coco/coco_pope_random.json", "r") as f:
        pope_questions = [json.loads(line) for line in f]
    
    save_dir = os.path.join(args.save_base, "POPE", model.name)
    os.makedirs(save_dir, exist_ok=True)
    
    save_path = os.path.join(save_dir, "results.json")
    results = []
    # pope_questions = pope_questions[:args.num_test_samples]
    for question in tqdm(pope_questions, desc="Processing item"):
        image_path = os.path.join("path to/coco2014/val2014", question["image"])
        
        query = question["text"]   
        conv = conv_templates[args.ans_conv_mode].copy()

        full_prompt = f"{DEFAULT_IMAGE_TOKEN}\n{query}\n" # {context}\n

        conv.append_message(conv.roles[0], full_prompt)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()
        model.generate_prompt(prompt)
        
        answer, _ , _= model.get_answer(image_path)
        print(answer.lower())

        answer = parse_answer(answer) 

        results.append({
            "question": question["text"],
            "answer": answer
        })

        with open(save_path, "w") as f:
            for result in results:
                f.write(json.dumps(result) + "\n")


def main(args):

    if 'qwen-vl-chat-lora' in args.model_path.lower():   
        model = QwenVLChatLora(args)
    elif 'work_dirs' in args.model_path.lower():
        model = InternVL2Lora(args)
    elif 'ms-swift' in args.model_path.lower():
        model = mPLUGLora(args)
    else:
        ModelClass = get_model_class(args.model_path)
        model = ModelClass(args)

    Answer(model, args)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Hallucination Evaluation under POPE")
    parser.add_argument('--model_path', type=str,help="Path to the pre-trained or distilled model file. Specify the location of the model.")
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument('--device', type=str, default="cuda:0")
    parser.add_argument('--temperature', type=float, default=1.0)
    parser.add_argument('--num_test_samples', type=int, default=1)
    parser.add_argument("--ask_conv_mode", default="conv_llama_v2_vqa",help="The ask prompt of model")# (mistral_instruct, chatml_direct, llava_v1, mpt, llava_v0, llava_llama_2, conv_llama_v2_vqa)
    parser.add_argument("--ans_conv_mode", default="llava_llama_2",help="The answer prompt of model") # (mistral_instruct, chatml_direct, llava_v1, mpt, llava_v0, llava_llama_2, conv_llama_v2_vqa)
    parser.add_argument('--save_base', type=str, default="./result/HALLU", help="Directory path where the results will be saved.")
    args = parser.parse_args()
    main(args)