import os
import json
from PIL import Image
from tqdm import tqdm
from datasets import load_dataset

from src.models import get_model
from src.utils.parser_utils import get_parser
from src.prompt import qa_prompt, qa_context_prompt, qa_image_prompt, qa_blend_prompt

def main():
    parser = get_parser()
    parser.add_argument("--is_scored", action="store_true")
    args = parser.parse_args()
    
    if args.greedy:
        args.temperature = 0.0
        
    # load dataset
    if "mc" in args.dataset:
        if "cleaned" in args.dataset:
            with open("data/viquae/cleaned_dataset_mc.json", "r") as fin:
                dataset = json.load(fin)
        else:
            with open("data/viquae/multiple_choice_data.json", "r") as fin:
                dataset = json.load(fin)
    else:
        if "full" in args.dataset:
            dataset = []
            datasets = load_dataset("PaulLerner/viquae_dataset")
            for ds_name in ["train", "validation", "test"]:
                ds = datasets[ds_name]
                for d in ds:
                    dataset.append(d)
        elif "clean" in args.dataset:
            with open("data/viquae/cleaned_dataset.json", "r") as fin:
                dataset = json.load(fin)
        else:
            dataset = load_dataset("PaulLerner/viquae_dataset")["train"]

    if "textual" in args.dataset:
        mode = "textual"
        captions = {}
        with open("data/viquae/named_entities.txt", "r") as fin:
            for line in fin.readlines():
                captions.update(json.loads(line))
        for k, v in captions.items():
            captions[k] = f"This is an image of {v}."
    elif "recognize" in args.dataset:
        mode = "recognize"
    else:
        mode = "visual"
    
    model_nickname = args.model_name.split("/")[-1]
    
    output_dir = os.path.join(args.output_dir, model_nickname)
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
    output_path = os.path.join(output_dir, f"{args.dataset}_T{args.temperature}.txt")

    if args.is_scored:
        output_path += ".score"
    
    model = get_model(args)(args, prompt=qa_prompt)
    if mode == "textual":
        model.remode("text")

    pb = tqdm(range(len(dataset)))
    for data in dataset:
        data_id = data["id"]
        # if mode in ["visual", "textual_reorganized"]:
        #     question = data["input"]
        # else:
        #     question = data["original_question"]
        question = data["input"]
        if "mc" in args.dataset:
            choices = data["multiple_choices"]
            choices_text = ""
            for c_name, c_content in choices.items():
                choices_text += f"{c_name}: {c_content}\n"
            text = f"Question:\n{question}\nOption:\n{choices_text}"
        else:
            text = f"Question:\n{question}"
        if mode == "textual":
            caption = captions.get(data_id)
            if caption is None:
                caption = ""
            text = caption + "\n" + text
        
        # print(data["original_question"])    
        # print(text)
        # input()
            
        if mode == "visual":
            image = Image.open(os.path.join("data/viquae/images", data["image"]))
            context = {"text": text, "image": image}
        elif "textual" in mode:
            context = {"text": text}
        elif mode == "blend":
            image = Image.open(os.path.join("data/viquae/images", data["image"]))
            context = {"text": text, "image": image}
        elif mode == "recognize":
            image = Image.open(os.path.join("data/viquae/images", data["image"]))
            context = {"text": "What/Who is in the image? Do not describe details. Just give a named entity, e.g. Jackie Chan.", "image": image}
        elif mode == "blank":
            context = {"text": text, "image": Image.new('RGB', (336, 336), color = (255,255,255))}
        elif mode == "pad":
            model.mode = "zero_padding"
            context = {"text": text, "image": Image.new('RGB', (336, 336), color = (255,255,255))}
        context.update({"is_scored": args.is_scored})
        answer = model.chat(**context)
        with open(output_path, "a+") as fout:
            fout.write(f"{json.dumps({data_id: answer})}\n")
        pb.update(1)

if __name__ == "__main__":
    main()            