from PIL import Image
from transformers import AutoProcessor, Blip2ForConditionalGeneration
from sentence_transformers import SentenceTransformer,util
from sklearn.metrics.pairwise import cosine_similarity

from tqdm import tqdm
import os
import numpy as np
import torch
import argparse
import json
import random
import glob

def topk(matrix, k , axis=0):
    sorted_array = np.flip(np.argsort(matrix, axis=axis), axis=axis)
    return sorted_array.take(np.arange(k), axis=axis)

if __name__ == "__main__":

    parser = argparse.ArgumentParser("Inference script")
    
    parser.add_argument("--dataset_dir", type=str, default="real-guidance/")
    parser.add_argument("--dataset", type=str, default="coco")
    parser.add_argument("--model-path", type=str, default="Salesforce/blip2-opt-2.7b")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--prompt", type=str, default="A {name} is")
    
    args = parser.parse_args()
    
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    classname = os.listdir(args.dataset_dir)
    caption_dict = {}
    names = []
    names_underscore = []
    
    processor = AutoProcessor.from_pretrained(args.model_path)
    model = Blip2ForConditionalGeneration.from_pretrained(
        args.model_path,
        torch_dtype=torch.float16
    ).to('cuda')
    
    
    category = {'pets': ' pet',
                'cars': ' car',
                'flowers': ' flower'}
    prompt_end = category[args.dataset] if args.dataset in category.keys() else ''
    
    default_templates = [
        "a photo of a {}",
        "a rendering of a {}",
        "a cropped photo of the {}",
        "the photo of a {}",
        "a photo of a clean {}",
        "a photo of a dirty {}",
        "a dark photo of the {}",
        "a photo of my {}",
        "a photo of the cool {}",
        "a close-up photo of a {}",
        "a bright photo of the {}",
        "a cropped photo of a {}",
        "a photo of the {}",
        "a good photo of the {}",
        "a photo of one {}",
        "a close-up photo of the {}",
        "a rendition of the {}",
        "a photo of the clean {}",
        "a rendition of a {}",
        "a photo of a nice {}",
        "a good photo of a {}",
        "a photo of the nice {}",
        "a photo of the small {}",
        "a photo of the weird {}",
        "a photo of the large {}",
        "a photo of a cool {}",
        "a photo of a small {}",
    ]
    
    prompt_templates = [
        'the {label}{category}',
        'the {label}{category} is',
        'a photo of the {label}{category} that',
    ]

    for name in tqdm(classname, desc="Generating Captions"):
        if name.endswith('.json'):
            continue
        
        image_dir = os.path.join(args.dataset_dir, name)
        image_files = list(glob.glob(os.path.join(image_dir, "*.png")))
        name_w_space = name.replace('_', ' ')
        names.append(name_w_space)
        names_underscore.append(name)
        format_name = f"<{name}>"
        caption_dict[name] = [random.choice(default_templates).format(format_name)]
        
        # for each image
        for image_file in image_files:
            image = Image.open(image_file)
            
            # BLIP2 prompt
            prompt_template = random.choice(prompt_templates)
            prompt = prompt_template.format(label=name_w_space, category=prompt_end)
            
            # caption = BLIP2(img, prompt)
            inputs = processor(image, text=prompt, return_tensors="pt").to('cuda')
            generated_ids = model.generate(**inputs)
            caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
            prompt = prompt_template.format(label=format_name, category='')
            caption_dict[name].append(f'{prompt} {caption}')

    # Compute class similarity based on class names
    stransformer = SentenceTransformer('all-MiniLM-L6-v2')
    name_embeds = stransformer.encode(names)
    sim_matrix = cosine_similarity(name_embeds)
    topk_class_sim_list = topk(sim_matrix, 4, 1)
    
    name2top3names = {}
    for ranked in topk_class_sim_list:
        key = names_underscore[ranked[0]]
        name2top3names[key] = [
            names_underscore[ranked[1]], 
            names_underscore[ranked[2]], 
            names_underscore[ranked[3]]
        ]
    
    # save caption
    target_path = os.path.join(args.dataset_dir, 'captions.json')
    with open(target_path, 'wt') as f:
        json.dump(caption_dict, f, indent=4) 
        
    target_path = os.path.join(args.dataset_dir, 'class_sim.json')
    with open(target_path, 'wt') as f:
        json.dump(name2top3names, f, indent=4)
        