from typing import Dict, Callable
from tqdm import tqdm
import os

_VQA_PRE_PROMPT = f"Answer the question about the image only with 'yes' or 'no'. Do not give other outputs or punctuation marks. Question: "

def vqa(
    vqa_model: Callable, 
    dsg_id2question: Dict[str, Dict[int, str]], 
    dsg_id2dependency: Dict[str, Dict[int, str]], 
    img_root: str, 
): 
    id2answer = {}
    id2score = {}
    id2valid = {}

    dataset_category = []
    for item in tqdm(dsg_id2question.keys()): 
        ##### split each class #####
        if item.split("_")[0] not in dataset_category: 
            dataset_category.append(item.split("_")[0])
        
        img_path = os.path.join(img_root, item + ".png")
        
        ##### VQA #####
        qid2answer = {} # "yes" or "no"
        qid2score = {} # for output score only
        qid2valid = {}
        for qid, question in dsg_id2question[item].items(): 
            output = vqa_model(_VQA_PRE_PROMPT + question, img_path)
            answer = output.upper()

            if not (answer.startswith("Y") or answer.startswith("N")):
               print(f"Invalid answer {answer} on {item}. ")

            if answer.startswith("Y"): 
                answer = "yes"
            else: 
                answer = "no"
            qid2answer[qid] = answer
            qid2score[qid] = int(answer == "yes")
        
        ##### check dependency #####
        for qid, parent_ids in dsg_id2dependency[item].items():
            # zero-out scores if parent questions are answered 'no'
            any_parent_answered_no = False
            if type(parent_ids) == str: 
                parent_ids = list(parent_ids.split(","))
                parent_ids = [int(x.strip()) for x in parent_ids]
            
            for parent_id in parent_ids:
                if parent_id == 0 or parent_id not in qid2answer:
                    continue
                elif qid2answer[parent_id] == "no":
                    any_parent_answered_no = True
                    break
            
            if any_parent_answered_no:
                qid2score[qid] = 0
                qid2valid[qid] = 0
            else: 
                qid2valid[qid] = 1
        
        id2answer[item] = qid2answer
        id2score[item] = sum(qid2score.values()) / len(qid2score)
        id2valid[item] = qid2valid
    
    ##### output info #####
    category_count = {name: 0 for name in dataset_category}
    category_count["all"] = 0
    category_sum = {name: 0.0 for name in dataset_category}
    category_sum["all"] = 0.0

    for item in dsg_id2question.keys(): 
        category_id = item.split("_")[0]
        category_count[category_id] = category_count[category_id] + 1
        category_count["all"] = category_count["all"] + 1
        category_sum[category_id] = category_sum[category_id] + id2score[item]
        category_sum["all"] = category_sum["all"] + id2score[item]
    
    # for category_id in category_count.keys(): 
    #     print(f"Dataset category: {category_id}, items num: {category_count[category_id]}")
    #     print(f"Avg. score: {category_sum[category_id] / category_count[category_id]}")
    
    return id2answer, id2valid

