import os
import sys
import cv2
import glob
import json
import torch
import argparse
import requests
from PIL import Image
from huggingface_hub import snapshot_download
from transformers import Blip2Processor, Blip2ForConditionalGeneration

def prepare_dreambooth_dataset(args):
    dataset_to_download = args.huggigface_dataset
    local_dir = args.result_folder + '/' + dataset_to_download

    # Check if done previously
    
    if os.path.isfile(f'{local_dir}/metadata.jsonl'):
        print(f'WARNING: {local_dir}/metadata.jsonl file exists. Exit process...')
        sys.exit(0)


    print("Downloading dataset...")
    try:
        snapshot_download(
            dataset_to_download,
            local_dir=local_dir, repo_type="dataset",
            ignore_patterns=".gitattributes",
        )
    except:
        print(f"Error trying to load dataset {dataset_to_download}.")
        #sys.exit(0)

    print("Dataset downloaded! Preparing images...")

    imgs_and_paths_png = [(path,Image.open(path)) for path in glob.glob(f"{local_dir}/*.png")]
    imgs_and_paths_jpg = [(path,Image.open(path)) for path in glob.glob(f"{local_dir}/*.jpg")]
    imgs_and_paths_jpeg = [(path,Image.open(path)) for path in glob.glob(f"{local_dir}/*.jpeg")]
    imgs_and_paths = imgs_and_paths_png + imgs_and_paths_jpg + imgs_and_paths_jpeg

    print("Loading BLIP caption generation model...")    
    blip_model = Blip2ForConditionalGeneration.from_pretrained(
        "Salesforce/blip2-opt-2.7b", device_map='cuda')
    blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b", use_fast=False)
    print("Model loaded successfully") 

    def caption_images(input_image):
        inputs = blip_processor(images=input_image, return_tensors="pt").to('cuda')
        pixel_values = inputs.pixel_values

        generated_ids = blip_model.generate(pixel_values=pixel_values, max_length=50)
        generated_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        return generated_caption

    caption_prefix = args.caption_prefix

    with open(f'{local_dir}/metadata.jsonl', 'w') as outfile:
        for img in imgs_and_paths:
            caption = caption_prefix + ", " + caption_images(img[1]).split("\n")[0]
            entry = {"file_name":img[0].split("/")[-1], "prompt": caption}
            json.dump(entry, outfile)
            outfile.write('\n')

    print(f"Your image captions are ready here: {local_dir}/metadata.jsonl")


def main(args):
    if args.task == 'dreambooth':
        prepare_dreambooth_dataset(args)
    else:
        print("Invalid task. Available tasks are ['dreambooth']")



    


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--huggigface_dataset", help="huggigface dataset name containing images", type=str, default="")
    parser.add_argument("--result_folder", help="place to save the dataset", type=str, default="")
    parser.add_argument("--caption_prefix", help="caption prefix to add to the images", type=str, default="")
    parser.add_argument("--task", help="possible tasks : ['dreambooth']", type=str, default="")
    
    args = parser.parse_args()
    main(args)
