import gradio as gr
from threading import Thread

import argparse
import torch

from LLaVA.llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from LLaVA.llava.conversation import conv_templates, SeparatorStyle
from LLaVA.llava.model.builder import load_pretrained_model
from LLaVA.llava.utils import disable_torch_init
from LLaVA.llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path

from PIL import Image
import requests
from io import BytesIO
from transformers import TextStreamer, TextIteratorStreamer
import json

from detector import Detector
from retriever import ClipRetriever

def load_image(image_file):
    if image_file.startswith('http://') or image_file.startswith('https://'):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert('RGB')
    else:
        image = Image.open(image_file).convert('RGB')
    return image

parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="liuhaotian/llava-v1.6-vicuna-7b")
parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--image-file", type=str, required=False)
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--conv-mode", type=str, default=None)
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--max-new-tokens", type=int, default=512)
parser.add_argument("--load-8bit", action="store_true")
parser.add_argument("--load-4bit", action="store_true")
parser.add_argument("--debug", action="store_true")
parser.add_argument("--retrieval", action="store_true")
parser.add_argument("--database", type=str, default=None)
parser.add_argument("--index-path", type=str, default=None)
args = parser.parse_args()
model_name = get_model_name_from_path(args.model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)

if args.retrieval:    
    detector = Detector()
    with open(f"{args.database}/database.json", "r") as f:
        database = json.load(f)
    
    # Set interested classes
    all_category = []
    for concept in database["concept_dict"]:
        cat = database["concept_dict"][concept]["category"]
        if cat not in all_category:
            all_category.append(cat)
    detector.model.set_classes(all_category)
    
    if args.index_path is None:
        retriever = ClipRetriever(data_dir = args.database, index_path = args.index_path, create_index = True)
    else:
        retriever = ClipRetriever(data_dir = args.database, index_path = args.index_path, create_index = False)

def main(message, history, image):
    # Model
    disable_torch_init()

    if "phi" in model_name.lower():
        conv_mode = "phi3_instruct"
    elif "v1" in model_name.lower():
        conv_mode = "llava_v1"    
    else:
        conv_mode = "llava_v0"

    if args.conv_mode is not None and conv_mode != args.conv_mode:
        print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
    else:
        args.conv_mode = conv_mode

    conv = conv_templates[args.conv_mode].copy()
    if "mpt" in model_name.lower():
        roles = ('user', 'assistant')
    else:
        roles = conv.roles
    

    has_image = False
    for his in history:
        if "<image>" in his[0]:
            has_image = True
        conv.append_message(conv.roles[0], his[0])
        conv.append_message(conv.roles[1], his[1])
        

    images = []
    image_size = []
    if image:
        images.append(image)
        image_size.append(image.size)
        
    if len(images) < 1:
        yield "Please input an image."
    
    rag_images = dict()
    if args.retrieval:
        for concept in database["concept_dict"]:
            if concept in message:
                rag_images[database["concept_dict"]["image"]] = 0
        crops = detector.detect_and_crop(image)
        if len(crops) > 0:
            D, filenames = retriever.image_search(crops, k=2)
            ret_image_path = []
            for files in filenames:
                ret_image_path += files
            
            D = D.flatten()
            order = D.argsort()
            for i in order:
                if len(rag_images) >= 2:
                    break
                if ret_image_path[i] in rag_images:
                    continue
                
                rag_images[ret_image_path[i]] = D[i].tolist()
        
        extra_info = ""

        for i, img_path in enumerate(rag_images):
            img = load_image(img_path)
            image_size.append(img.size)
            images.append(img)
            tag = database["path_to_concept"][img_path]
            name = database["concept_dict"][tag]["name"]
            info = database["concept_dict"][tag]["info"]
            extra_info += f"{i+1}.<image>\n Name: <{name}>, Info: {info}\n"

    image_tensor = process_images(images, image_processor, model.config)

    if type(image_tensor) is list:
        image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
    else:
        image_tensor = image_tensor.to(model.device, dtype=torch.float16)
        
    try:
        inp = message
    except EOFError:
        inp = ""
        
    print(f"{roles[1]}: ", end="")

    if not has_image and image:
        # first message
        if args.retrieval:
            inp = DEFAULT_IMAGE_TOKEN + '\n' + f"[{extra_info}]" + inp 
        else:
            inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
        message = inp
        conv.append_message(conv.roles[0], inp)
        has_image = True
    else:
        # later messages
        conv.append_message(conv.roles[0], inp)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()
    print(prompt)
    input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
    stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
    keywords = [stop_str]
    streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    
    generate_kwargs = dict(
    inputs= input_ids,
    images=image_tensor,
    image_sizes=image_size,
    do_sample=True if args.temperature > 0 else False,
    temperature=args.temperature,
    max_new_tokens=args.max_new_tokens,
    streamer=streamer,
    use_cache=True
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    partial_message = ""
    for new_token in streamer:
        if new_token != '<':
            partial_message += new_token
            yield partial_message.replace("<", "").replace(">", "")

custom_css = """
.gradio-container{
    max-width: 700px !important;
}
.message { width: 100% !important; padding: 5px !important; font-size: 14px !important;}

"""

if __name__ == "__main__":
    with gr.Blocks() as app:

        with gr.Row(equal_height=True):

            with gr.Column(scale=3):
                image_box = gr.Image(height="400px", width="400px", type="pil")
                
            with gr.Column(scale=6):
                chat_bot = gr.Chatbot(height="600px", render=False)
                gr.ChatInterface(
                    css=custom_css, chatbot=chat_bot, additional_inputs=[image_box], additional_inputs_accordion_name="Image",
                    fn=main,  title="RAP-LLaVA", analytics_enabled=False, examples = [["Give a caption of the image."], ["Describe the image."], ["Create a short caption of the image."]]
                )

    app.launch(share=True)
    