import pickle
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
from llava.model.builder import load_pretrained_model
import torch
from PIL import Image
from llava.conversation import conv_templates, SeparatorStyle
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX
import argparse
import requests
from io import BytesIO
import os
import json
from tqdm import tqdm

def load_image(image_file):
    if image_file.startswith('http://') or image_file.startswith('https://'):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert('RGB')
    else:
        image = Image.open(image_file).convert('RGB')
    return image

with open("identity_dict.pkl", "rb") as f:
    data = pickle.load(f)

all_img_files = list(data.keys())

question_list = [
    "What is this person's hair color?",
    "What color are this person's eyes?",
    "What is this person's skin tone?",
    "How would you describe this person's hairstyle?",
    "Does this person wear glasses or any accessories?",
    "Does this person have any distinctive facial features?",
    "What is this person general expression or demeanor?",
    "What would you describe this person's face?",
    "Is this person young or old?",
    "What do you describe about this person's mouth?",
    "What do you describe about this person's nose?"
]

def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", type=str)
    parser.add_argument("--logging_dir", type=str)
    parser.add_argument("--gradient_accumulation_steps", type=int)
    parser.add_argument("--mixed_precision", type=str, default="no")
    parser.add_argument("--report_to", type=str, default="tensorboard")
    parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument("--image-file", type=str)
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--conv-mode", type=str, default=None)
    parser.add_argument("--temperature", type=float, default=0.2)
    parser.add_argument("--max-new-tokens", type=int, default=512)
    parser.add_argument("--load-8bit", action="store_true")
    parser.add_argument("--load-4bit", action="store_true")
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--prompt", type=str)
    parser.add_argument("--num_train_steps", type=int, default=200)
    args = parser.parse_args()
    return args

@torch.no_grad()
def gen_description(args):
    model_name = get_model_name_from_path(args.model_path)
    tokenizer, model, image_processor, _ = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
    conv_mode = "v1"
    start_indices = 0
    end_indices = start_indices + 3000
    
    for i,img_file in tqdm(enumerate(all_img_files)):
        all_question_answers = {}
        all_question_answers['img_file'] = img_file
        if i<start_indices: continue
        if i==end_indices: break
        image = load_image(os.path.join("data/CelebAMask-HQ/CelebA-HQ-img", img_file))
        image_size = image.size
        image_tensor = process_images([image], image_processor, model.config)
        if type(image_tensor) is list:
            image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
        else:
            image_tensor = image_tensor.to(model.device, dtype=torch.float16)
        all_question_answers["conversations"] = []
    
        for question in question_list:
            conv = conv_templates[conv_mode].copy()
            if model.config.mm_use_im_start_end:
                prompt = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + question
            else:
                prompt = DEFAULT_IMAGE_TOKEN + '\n' + question
            conv.append_message(conv.roles[0], prompt)
            conv.append_message(conv.roles[1], None)

            prompt = conv.get_prompt()
            input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(model.device)
            output_ids = model.generate(
                input_ids,
                images=image_tensor,
                image_sizes=[image_size],
                do_sample=True if args.temperature > 0 else False,
                temperature=args.temperature,
                max_new_tokens=args.max_new_tokens,
                use_cache=True
            )
            outputs = tokenizer.decode(output_ids[0]).strip()
            outputs = outputs.replace('<s> ', '')
            outputs = outputs.replace('</s>', '')
            outputs = outputs.replace('this person', '<sks>')
            outputs = outputs.replace('the person', '<sks>')
            outputs = outputs.replace('This person', '<sks>')
            outputs = outputs.replace('The person', '<sks>')
            question = question.replace('this person', '<sks>')
            question = question.replace('This person', '<sks>')
            question = question.replace('The person', '<sks>')
            question = question.replace('the person', '<sks>')
            all_question_answers["conversations"].append({"USER": question, "gpt": outputs})
        break

@torch.no_grad()
def debug(args):
    model_name = get_model_name_from_path(args.model_path)
    tokenizer, model, image_processor, _ = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
    conv_mode = "v1"
    image = load_image(os.path.join("data/CelebAMask-HQ/CelebA-HQ-img", "6.jpg"))
    image_size = image.size
    image_tensor = process_images([image], image_processor, model.config)
    if type(image_tensor) is list:
        image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
    else:
        image_tensor = image_tensor.to(model.device, dtype=torch.float16)
    question = "Does this person wear glass or any accessories?"
    conv = conv_templates[conv_mode].copy()
    if model.config.mm_use_im_start_end:
        prompt = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + question
    else:
        prompt = DEFAULT_IMAGE_TOKEN + '\n' + question
    conv.append_message(conv.roles[0], prompt)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()
    input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(model.device)
    output_ids = model.generate(
        input_ids,
        images=image_tensor,
        image_sizes=[image_size],
        do_sample=True if args.temperature > 0 else False,
        temperature=args.temperature,
        max_new_tokens=args.max_new_tokens,
        use_cache=True
    )
    outputs = tokenizer.decode(output_ids[0]).strip()
    outputs = outputs.replace('<s> ', '')
    outputs = outputs.replace('</s>', '')
    print(outputs)

if __name__ == "__main__":
    args = parse_arguments()
    # debug(args)
    gen_description(args)