import os
import json
from tqdm import tqdm

from src.models import get_model
from src.utils.data_utils import get_dataset
from src.utils.parser_utils import get_parser
from src.prompt import qa_prompt, qa_context_prompt, qa_image_prompt, qa_blend_prompt

def main():
    parser = get_parser()
    args = parser.parse_args()
    if args.greedy:
        args.temperature = 0.0
    
    dataset = get_dataset(args)
    if args.is_fact_given:
        if "textual" in args.dataset:
            prompt = qa_context_prompt
            mode = "text"
            output_path = os.path.join(args.output_dir, f"{args.dataset}_fact_{args.model_name}_T{args.temperature}.txt")
        else:
            prompt = qa_image_prompt
            mode = "visual"
            output_path = os.path.join(args.output_dir, f"{args.dataset}_fact_{args.model_name}_T{args.temperature}.txt")
    else:
        prompt = qa_prompt
        mode = "text"
        output_path = os.path.join(args.output_dir, f"{args.dataset}_{args.model_name}_T{args.temperature}.txt")
    model = get_model(args)(args, prompt)
    if model.is_local:
        model.remode(mode)
        
    pb = tqdm(range(len(dataset)))
    for data in dataset:
        data_id = data["id"]
        question = data["question"]
        if args.is_fact_given:
            context = data["fact"]
            text_context = context["text"]
            context["text"] = f"Question: {question}\n\nContext Information:\n{text_context}"
            answer = model.chat(**context)
        else:
            context = {"text": f"Question: {question}"}
            answer = model.chat(**context)
        with open(output_path, "a+") as fout:
            fout.write(f"{json.dumps({data_id: answer})}\n")
        pb.update(1)

if __name__ == "__main__":
    main()            