import os
import json
from PIL import Image
from tqdm import tqdm
from datasets import load_dataset

from src.models import get_model
from src.utils.parser_utils import get_parser
from src.prompt import qa_prompt, qa_context_prompt, qa_image_prompt, qa_blend_prompt

def main():
    parser = get_parser()
    parser.add_argument("--is_scored", action="store_true")
    args = parser.parse_args()
    
    if args.greedy:
        args.temperature = 0.0
        
    # load dataset
    if "mc" in args.dataset:
        if "cleaned" in args.dataset:
            with open(f"data/infoseek/{args.model_name}_recognized_infoseek_val_mc.json", "r") as fin:
                dataset = json.load(fin)
        else:
            with open("data/infoseek/sampled_val_mc.json", "r") as fin:
                dataset = json.load(fin)
    else:
        # if "full" in args.dataset:
        #     dataset = []
        #     datasets = load_dataset("PaulLerner/viquae_dataset")
        #     for ds_name in ["train", "validation", "test"]:
        #         ds = datasets[ds_name]
        #         for d in ds:
        #             dataset.append(d)
        # elif "clean" in args.dataset:
        #     with open("data/viquae/cleaned_dataset.json", "r") as fin:
        #         dataset = json.load(fin)
        # else:
        with open("data/infoseek/infoseek_val_with_entity.json", "r") as fin:
            dataset = json.load(fin)
            
    if "textual" in args.dataset:
        mode = "textual"
        # captions = {}
        # with open("data/viquae/named_entities.txt", "r") as fin:
        #     for line in fin.readlines():
        #         captions.update(json.loads(line))
        # for k, v in captions.items():
        #     captions[k] = f"This is an image of {v}."
    elif "recognize" in args.dataset:
        mode = "recognize"
    else:
        mode = "visual"

    model_nickname = args.model_name.split("/")[-1]
    
    output_dir = os.path.join(args.output_dir, model_nickname)
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
    output_path = os.path.join(output_dir, f"{args.dataset}_T{args.temperature}.txt")
    if args.is_scored:
        output_path += ".score"
    
    model = get_model(args)(args, prompt=qa_prompt)
    if "textual" in mode:
        model.remode("text")

    pb = tqdm(range(len(dataset)))
    for data in dataset:
        data_id = data["data_id"]
        question = data["question"]
        image_id = data["image_id"]
        if "mc" in args.dataset:
            choices = data["multiple_choices"]
            choices_text = ""
            for c_name, c_content in choices.items():
                choices_text += f"{c_name}: {c_content}\n"
            text = f"Question:\n{question}\nOption:\n{choices_text}"
        else:
            text = f"Question:\n{question}"
        if "textual" in args.dataset:
            entity = data["entity"]
            caption = f"This is an image of {entity}."
            text = caption + "\n" + text
        image_path = os.path.join("data/infoseek/infoseek_val_images", f"{image_id}.jpg")
        if not os.path.exists(image_path):
            pb.update(1)
            continue
            
        if mode == "visual":
            image = Image.open(image_path).resize((224, 224))
            context = {"text": text, "image": image}
        elif "textual" in mode:
            context = {"text": text}
        elif mode == "blend":
            image = Image.open(image_path)
            context = {"text": text, "image": image}
        elif mode == "recognize":
            image = Image.open(image_path).resize((224, 224))
            context = {"text": "What/Who is in the image? Do not describe details. Just give a named entity, e.g. Jackie Chan, Mount Everest.", "image": image}
        elif mode == "blank":
            context = {"text": text, "image": Image.new('RGB', (336, 336), color = (255,255,255))}
        elif mode == "pad":
            model.mode = "zero_padding"
            context = {"text": text, "image": Image.new('RGB', (336, 336), color = (255,255,255))}
        context.update({"is_scored": args.is_scored})
        answer = model.chat(**context)
        with open(output_path, "a+") as fout:
            fout.write(f"{json.dumps({data_id: answer})}\n")
        pb.update(1)

if __name__ == "__main__":
    main()            