import os
import json
from PIL import Image
from tqdm import tqdm
from datasets import load_dataset

from src.models import get_model
from src.utils.parser_utils import get_parser

def main():
    parser = get_parser()
    parser.add_argument("--is_prompt", action="store_true")
    args = parser.parse_args()
    if args.greedy:
        args.temperature = 0.0
        
    # load dataset
    if "mc" in args.dataset:
        if "cleaned" in args.dataset:
            with open("data/viquae/cleaned_dataset_mc.json", "r") as fin:
                dataset = json.load(fin)
        else:
            with open("data/viquae/multiple_choice_data.json", "r") as fin:
                dataset = json.load(fin)
    else:
        if "full" in args.dataset:
            dataset = []
            datasets = load_dataset("PaulLerner/viquae_dataset")
            for ds_name in ["train", "validation", "test"]:
                ds = datasets[ds_name]
                for d in ds:
                    dataset.append(d)
        elif "clean" in args.dataset:
            with open("data/viquae/cleaned_dataset.json", "r") as fin:
                dataset = json.load(fin)
        else:
            dataset = load_dataset("PaulLerner/viquae_dataset")["train"]

    if args.is_prompt:
        output_path = os.path.join(args.output_dir, f"caption_prompted_{args.dataset}_{args.model_name}_T{args.temperature}.txt")
    else:
        output_path = os.path.join(args.output_dir, f"caption_{args.dataset}_{args.model_name}_T{args.temperature}.txt")
        
    model = get_model(args)(args)

    pb = tqdm(range(len(dataset)))
    for data in dataset:
        data_id = data["id"]
        # question = data["input"]
        image = Image.open(os.path.join("data/viquae/images", data["image"]))
        if args.is_prompt:
            context = {"text": "Generate a caption for this image. Please describe the image as detailed as possible and if you can recognize any named entities in the image, please add them to your caption.", "image": image}
        else:
            context = {"image": image}
        answer = model.chat(**context)
        with open(output_path, "a+") as fout:
            fout.write(f"{json.dumps({data_id: answer})}\n")
        pb.update(1)

if __name__ == "__main__":
    main()            