from argparse import ArgumentParser
import copy
import gradio as gr
from gradio.themes.utils import colors, fonts, sizes

from pllava.utils.easydict import EasyDict
from pllava.tasks.eval.model_utils import load_pllava
from pllava.tasks.eval.eval_utils import (
    ChatPllava,
    conv_plain_v1,
    Conversation,
    conv_templates
)
from pllava.tasks.eval.demo import pllava_theme

SYSTEM="""You are a powerful Video Magic ChatBot, a large vision-language assistant. 
You are able to understand the video content that the user provides and assist the user in a video-language related task.
The user might provide you with the video and maybe some extra noisy information to help you out or ask you a question. Make use of the information in a proper way to be competent for the job.
### INSTRUCTIONS:
1. Follow the user's instruction.
2. Be critical yet believe in yourself.
"""

INIT_CONVERSATION: Conversation = conv_plain_v1.copy()


# ========================================
#             Model Initialization
# ========================================
def init_model(args):

    print('Initializing PLLaVA')
    model, processor = load_pllava(
        args.pretrained_model_name_or_path, args.num_frames, 
        use_lora=args.use_lora, 
        weight_dir=args.weight_dir, 
        lora_alpha=args.lora_alpha, 
        use_multi_gpus=args.use_multi_gpus)
    if not args.use_multi_gpus:
        model = model.to('cuda')
    chat = ChatPllava(model, processor)
    return chat


# ========================================
#             Gradio Setting
# ========================================
def gradio_reset(chat_state, img_list):
    if chat_state is not None:
        chat_state = INIT_CONVERSATION.copy()
    if img_list is not None:
        img_list = []
    return (
        None,
        gr.update(value=None, interactive=True),
        gr.update(value=None, interactive=True),
        gr.update(placeholder='Please upload your video first', interactive=False),
        gr.update(value="Upload & Start Chat", interactive=True),
        chat_state,
        img_list
    )


def upload_img(gr_img, gr_video, chat_state=None, num_segments=None, img_list=None):
    print(gr_img, gr_video)
    chat_state = INIT_CONVERSATION.copy() if chat_state is None else chat_state
    img_list = [] if img_list is None else img_list
    
    if gr_img is None and gr_video is None:
        return None, None, gr.update(interactive=True),gr.update(interactive=True, placeholder='Please upload video/image first!'), chat_state, None
    if gr_video: 
        llm_message, img_list, chat_state = chat.upload_video(gr_video, chat_state, img_list, num_segments)
        return (
            gr.update(interactive=True),
            gr.update(interactive=True),
            gr.update(interactive=True, placeholder='Type and press Enter'),
            gr.update(value="Start Chatting", interactive=False),
            chat_state,
            img_list,
        )
    if gr_img:
        llm_message, img_list,chat_state = chat.upload_img(gr_img, chat_state, img_list)
        return (
            gr.update(interactive=True),
            gr.update(interactive=True),
            gr.update(interactive=True, placeholder='Type and press Enter'),
            gr.update(value="Start Chatting", interactive=False),
            chat_state,
            img_list
        )


def gradio_ask(user_message, chatbot, chat_state, system):
    if len(user_message) == 0:
        return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
    chat_state =  chat.ask(user_message, chat_state, system)
    chatbot = chatbot + [[user_message, None]]
    return '', chatbot, chat_state


def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
    llm_message, llm_message_token, chat_state = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=200, num_beams=num_beams, temperature=temperature)
    llm_message = llm_message.replace("<s>", "") # handle <s>
    chatbot[-1][1] = llm_message
    print(chat_state)
    print(f"Answer: {llm_message}")
    return chatbot, chat_state, img_list


def parse_args():
    parser = ArgumentParser()
    parser.add_argument(
        "--pretrained_model_name_or_path",
        type=str,
        required=True,
        default='llava-hf/llava-1.5-7b-hf'
    )
    parser.add_argument(
        "--num_frames",
        type=int,
        required=True,
        default=4,
    )
    parser.add_argument(
        "--use_lora",
        action='store_true'
    )
    parser.add_argument(
        "--use_multi_gpus",
        action='store_true'
    )
    parser.add_argument(
        "--weight_dir",
        type=str,
        required=False,
        default=None,
    )
    parser.add_argument(
        "--conv_mode",
        type=str,
        required=False,
        default=None,
    )
    parser.add_argument(
        "--lora_alpha",
        type=int,
        required=False,
        default=None,
    )
    parser.add_argument(
        "--server_port",
        type=int,
        required=False,
        default=7868,
    )
    args = parser.parse_args()
    return args


title = """<h1 align="center"><a href="https://github.com/magic-research/PLLaVA"><img src="https://raw.githubusercontent.com/magic-research/PLLaVA/main/assert/logo.png" alt="PLLAVA" border="0" style="margin: 0 auto; height: 100px;" /></a> </h1>"""
description = (
    """<br><p><a href='https://github.com/magic-research/PLLaVA'>
    # PLLAVA!
    <img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p>
    - Upload A Video
    - Press Upload
    - Start Chatting
    """
)

args = parse_args()

model_description = f"""
    # MODEL INFO
    - pretrained_model_name_or_path:{args.pretrained_model_name_or_path}
    - use_lora:{args.use_lora}
    - weight_dir:{args.weight_dir}
"""

# with gr.Blocks(title="InternVideo-VideoChat!",theme=gvlabtheme,css="#chatbot {overflow:auto; height:500px;} #InputVideo {overflow:visible; height:320px;} footer {visibility: none}") as demo:
with gr.Blocks(title="PLLaVA",
               theme=pllava_theme,
               css="#chatbot {overflow:auto; height:500px;} #InputVideo {overflow:visible; height:320px;} footer {visibility: none}") as demo:
    gr.Markdown(title)
    gr.Markdown(description)
    gr.Markdown(model_description)
    with gr.Row():
        with gr.Column(scale=0.5, visible=True) as video_upload:
            # with gr.Column(elem_id="image", scale=0.5) as img_part:
            with gr.Tab("Video", elem_id='video_tab'):
                up_video = gr.Video(interactive=True, include_audio=True, elem_id="video_upload", height=360)
            with gr.Tab("Image", elem_id='image_tab'):
                up_image = gr.Image(type="pil", interactive=True, elem_id="image_upload", height=360)
            upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
            clear = gr.Button("Restart")
            
            # num_segments = gr.Slider(
            #     minimum=8,
            #     maximum=64,
            #     value=8,
            #     step=1,
            #     interactive=True,
            #     label="Video Segments",
            # )
        
        with gr.Column(visible=True)  as input_raws:
            system_string = gr.Textbox(SYSTEM, interactive=True, label='system')
            num_beams = gr.Slider(
                minimum=1,
                maximum=5,
                value=1,
                step=1,
                interactive=True,
                label="beam search numbers",
            )           
            temperature = gr.Slider(
                minimum=0.1,
                maximum=2.0,
                value=1.0,
                step=0.1,
                interactive=True,
                label="Temperature",
            )
            
            chat_state = gr.State()
            img_list = gr.State()
            chatbot = gr.Chatbot(elem_id="chatbot",label='Conversation')
            with gr.Row():
                with gr.Column(scale=0.7):
                    text_input = gr.Textbox(show_label=False, placeholder='Please upload your video first', interactive=False, container=False)
                with gr.Column(scale=0.15, min_width=0):
                    run = gr.Button("💭Send")
                with gr.Column(scale=0.15, min_width=0):
                    clear = gr.Button("🔄Clear")     
    
    with gr.Row():
        examples = gr.Examples(
            examples=[
                ['example/jesse_dance.mp4', 'What is the man doing?'],
                ['example/yoga.mp4', 'What is the woman doing?'],
                ['example/cooking.mp4', 'Describe the background, characters and the actions in the provided video.'],
                # ['example/cooking.mp4', 'What is happening in the video?'],
                ['example/working.mp4', 'Describe the background, characters and the actions in the provided video.'],
                ['example/1917.mov', 'Describe the background, characters and the actions in the provided video.'],
            ],
            inputs=[up_video, text_input]
        )


    chat = init_model(args)
    INIT_CONVERSATION = conv_templates[args.conv_mode]
    upload_button.click(upload_img, [up_image, up_video, chat_state], [up_image, up_video, text_input, upload_button, chat_state, img_list])
    
    text_input.submit(gradio_ask, [text_input, chatbot, chat_state, system_string], [text_input, chatbot, chat_state]).then(
        gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
    )
    run.click(gradio_ask, [text_input, chatbot, chat_state, system_string], [text_input, chatbot, chat_state]).then(
        gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
    )
    run.click(lambda: "", None, text_input)  
    clear.click(gradio_reset, [chat_state, img_list], [chatbot, up_image, up_video, text_input, upload_button, chat_state, img_list], queue=False)

# demo.queue(max_size=5)
demo.launch(share=True,server_port=args.server_port)
# demo.launch(server_name="0.0.0.0", server_port=10034, enable_queue=True)
