import os
import torch
import time
import argparse
import gradio as gr
import datetime
from typing import List, Tuple
from server.server_models import GPT4Video
from tools.safety_checker import SafetyChecker
import json
import base64


gpt4video = GPT4Video()

html_content = """
<div style="display: flex; align-items: start;">
    <img src="https://github-production-user-asset-6210df.s3.amazonaws.com/151513068/284819987-e012b360-3f7c-40a8-ba43-77017c3b7785.png" alt="icon" title="GPT4Video" width="80" height="80" style="vertical-align: middle;">
    <div style="margin-left: 10px;">
        <p style="font-size:20px;"><strong>GPT4Video: MLLM for Video Understanding and Generation</strong></p>
        <p>Official Gradio demo of GPT4Video: a model that can process arbitrarily interleaved video and text inputs, and produce video and text outputs.</p>
        <p><a href=" ">Project Page</a> | <a href="https://github.com/gpt4video/GPT4Video">Code</a> | <a href="https://github.com/gpt4video/GPT4Video">Model</a> | <a href="https://arxiv.org/abs/2304.08485">Arxiv</a></p>
    </div>
</div>
"""


# learn_more_markdown = ("""
# The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
# """)

learn_more_markdown = ("""
<p>
    The service is a research preview intended for non-commercial use only, subject to the model 
    <a href="https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md">License</a> of LLaMA, 
    <a href="https://openai.com/policies/terms-of-use">Terms of Use</a> of the data generated by OpenAI, 
    and <a href="https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb">Privacy Practices</a> of ShareGPT. 
    Please contact us if you find any potential violation.
</p>
""")
# tos_markdown = ("""
# ### Terms of use
# By using this service, users are required to agree to the following terms:
# The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
# Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
# **Copyright 2023 Tencent AI Lab.**
# """)

tos_markdown = """
<div padding: 15px; border-radius: 10px; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);">
    <p>By using this service, users are required to agree to the following terms:</p>
    <ul style="padding-left:20px; list-style-type: disc !important;">
        <li style="list-style-type: disc; margin-left:20px;">The service is a research preview intended for non-commercial use only.</li>
        <li style="list-style-type: disc; margin-left:20px;">It only provides limited safety measures and may generate offensive content.</li>
        <li style="list-style-type: disc; margin-left:20px;">It must not be used for any illegal, harmful, violent, racist, or sexual purposes.</li>
        <li style="list-style-type: disc; margin-left:20px;">The service may collect user dialogue data for future research.</li>
    </ul>
    <p>Please click the "Flag" button if you encounter any inappropriate answer! We will collect those to keep improving our moderator.</p>
    <p><strong>&copy; 2023 Tencent AI Lab.</strong></p>
</div>
"""


if not torch.cuda.is_available():
    html_content += '\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>'


def delete_prev_fn(
        history: List[Tuple[str, str]]) -> Tuple[List[Tuple[str, str]], str]:
    try:
        message, _ = history.pop()
    except IndexError:
        message = ''
    return history, message or ''


def clear_history():
    return ([[],[]], [], "", None)

def get_conv_log_filename():
    t = datetime.datetime.now()
    name = f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json"
    return name

disable_btn = gr.Button.update(interactive=False)
no_change_btn = gr.Button.update()
enable_btn = gr.Button.update(interactive=True)

def vote_last_response(state, vote_type, model_selector):
    with open(os.path.join("save_response", get_conv_log_filename()), "a") as fout:
        data = {
            "tstamp": round(time.time(), 4),
            "type": vote_type,
            "model": model_selector,
            "state": state[0],
        }
        fout.write(json.dumps(data) + "\n")

def upvote_last_response(state, model_selector):
    vote_last_response(state, "upvote", model_selector)
    return ("",) + (disable_btn,) * 3

def downvote_last_response(state, model_selector):
    vote_last_response(state, "downvote", model_selector)
    return ("",) + (disable_btn,) * 3

def flag_last_response(state, model_selector, ):
    vote_last_response(state, "flag", model_selector)
    return ("",) + (disable_btn,) * 3


checker = SafetyChecker('cuda:0')

def add_text(state, text, video, video_decoder, max_turns):
    print(video)
    if len(text) <= 0 and video is None:
        return (state, state[1], "", None)

    if len(text) > 0 and video is not None:
        video, nsfw = checker.run_safety_checker(gpt4video.t2vzero, video, thresh=9)

        with open(video, 'rb') as f:
                data = f.read()
        img_b64_str = base64.b64encode(data).decode()
        chatbot_text = f'{text} <br><video src="data:video/mp4;base64,{img_b64_str}" controls width="426" height="240"></video>'
        state[1] = state[1] + [(chatbot_text, None)]
        if nsfw:
            # state[0] = state[0] + [([f"Human:", nsfw, "\nHuman:"+text], [None])]
            state[0] = state[0] + [(["\nHuman:[This represents a adult video that may contain harmful content]."+text], [None])]
        else:
            state[0] = state[0] + [([f"Human:", video, "\nHuman:"+text], [None])]
    
    if len(text) > 0 and video is None:
        state[1] = state[1] + [(text, None)]
        state[0] = state[0] + [(["Human:" +text], [None])]

    if len(text) <= 0 and video is not None:
        video, nsfw = checker.run_safety_checker(gpt4video.t2vzero, video, thresh=9)
        with open(video, 'rb') as f:
                data = f.read()
        img_b64_str = base64.b64encode(data).decode()
        chatbot_text = f'<br><video src="data:video/mp4;base64,{img_b64_str}" controls width="426" height="240"></video>'
        state[1] = state[1] + [(chatbot_text, None)]
        if nsfw:
            state[0] = state[0] + [(["Human:[This represents a adult video that may contain harmful content]."], [None])]
        else:
            state[0] = state[0] + [(["Human:describe this video in detail.\nHuman:", video], [None])]

    yield (state, state[1], "", None)

    responses = gpt4video.run(video_decoder=video_decoder, history=state[0], num_frames=24, fps=8, max_turns=max_turns, max_len=512, temperature=0.0, top_p=1)

    for chunk in responses:
        text = chunk[0]
        if text[0] == ":":
            text = text[1:]
        out_video = chunk[1]
        text2 = chunk[2]
        prompt = chunk[3]

        if out_video is not None:
            if text2 is not None:
                with open(out_video, 'rb') as f:
                        data = f.read()
                img_b64_str = base64.b64encode(data).decode()
                chatbot_text = f'{text} <br><video src="data:video/mp4;base64,{img_b64_str}" controls width="426" height="240"></video> <br>{text2}'
                # model_text = ["\nAI:"+text+"\nAI:", out_video, "\nAI:"+text2]
                model_text = ["\nAI:"+text+"\nAI:", prompt, "\nAI:"+text2]
                
                state[1][-1] = list(state[1][-1])
                state[1][-1][1] = chatbot_text
                state[1][-1] = tuple(state[1][-1])

                state[0][-1] = list(state[0][-1])
                state[0][-1][1] = model_text
                state[0][-1] = tuple(state[0][-1])
                yield (state, state[1], "", None)
                return
            else:
                with open(out_video, 'rb') as f:
                        data = f.read()
                img_b64_str = base64.b64encode(data).decode()
                chatbot_text = f'{text} <br><video src="data:video/mp4;base64,{img_b64_str}" controls width="426" height="240"></video>'
                # model_text = ["\nAI:"+text+"\nAI:", out_video]
                model_text = ["\nAI:"+text+ " " + prompt+"video>"]
                print(model_text)
                state[1][-1] = list(state[1][-1])
                state[1][-1][1] = chatbot_text
                state[1][-1] = tuple(state[1][-1])

                state[0][-1] = list(state[0][-1])
                state[0][-1][1] = model_text
                state[0][-1] = tuple(state[0][-1])
                yield (state, state[1], "", None)
                return
        else:
            model_text = ["\nAI:"+text]

            state[1][-1] = list(state[1][-1])
            state[1][-1][1] = text
            state[1][-1] = tuple(state[1][-1])

            state[0][-1] = list(state[0][-1])
            state[0][-1][1] = model_text
            state[0][-1] = tuple(state[0][-1])
            yield (state, state[1], "", None)
        time.sleep(0.01)


# def build_demo():
textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER").style(container=False)
with gr.Blocks(css='style.css') as demo:
    state = gr.State([[],[]])  # chat_history
    gr.HTML(html_content)
    
    with gr.Row():
        with gr.Column(scale=4):
            videobox = gr.Video(interactive=True)
            with gr.Accordion("Parameters", open=True,) as parameter_row:
                video_decoder = gr.Dropdown(
                    ["Zeroscope", "VideoCrafter1", "VideoFusion", "Text2Video-Zero"], value="Zeroscope",label="Video Decoder", info="Will add more decoders later!"
                )
                max_turns = gr.Slider(minimum=1, maximum=5, value=3, step=1, interactive=True, label="Max conversation turns")
                # num_frames = gr.Slider(minimum=0, maximum=48, value=24, step=2, interactive=True,
                #                          label="Number of generated frames")
                # fps = gr.Slider(minimum=1, maximum=16, value=8, step=1, interactive=True,
                #                          label="Frames Per Second")
                # # gen_scale_factor = gr.Slider(minimum=0.0, maximum=10.0, value=1.0, step=0.1, interactive=False,
                # #                          label="Frequency multiplier for returning videos (higher means more frequent)")
                # # min_word_tokens = gr.Slider(minimum=0, maximum=50, value=25, step=5, interactive=False, label="Min output text tokens",)
                # max_input_token_length = gr.Slider(minimum=0, maximum=2048, value=512, step=64, interactive=True, label="Max input token length",)
                # temperature = gr.Slider(minimum=0, maximum=1, value=0.0, step=0.1, interactive=False, label="Temperature",)
                # top_p = gr.Slider(minimum=0, maximum=16, value=6, step=1, interactive=True, label="Safety Threshold",)
    
                cur_dir = os.path.dirname(os.path.abspath(__file__))
                gr.Examples(examples=[
                    ["./examples/107301_107350_4673552.mp4", "I noticed some dancers performing a dance style I didn't recognize. It looked a lot like this, do you have any ideas?"]
                ], inputs=[videobox, textbox])
                gr.Examples(examples=[
                    ["I've recently taken up painting as a hobby, and I'm looking for some inspiration. Can you help me find some interesting ideas?"]
                ], inputs=[textbox])  

                with gr.Accordion("Terms of Use", open=False,) as parameter_row:
                    gr.HTML(tos_markdown)
    
        with gr.Column(scale=6):
            chatbot = gr.Chatbot(label="GPT4Video-bot", elem_id="chatbot").style(height=650)
            with gr.Row():
                with gr.Column(scale=8):
                    textbox.render()
                with gr.Column(scale=1, min_width=60):
                    submit_btn = gr.Button(value="Submit")

            with gr.Row() as button_row:
                upvote_btn = gr.Button(value="👍  Upvote", interactive=True)
                downvote_btn = gr.Button(value="👎  Downvote", interactive=True)
                flag_btn = gr.Button(value="⚠️  Flag", interactive=True)
                # regenerate_btn = gr.Button(value="🔄  Regenerate", interactive=True)
                clear_btn = gr.Button(value="🗑️  Clear history", interactive=True)
            
    
    with gr.Accordion("License", open=True,) as parameter_row:
        gr.HTML(learn_more_markdown)

    # parameter_list = [num_frames, fps, max_input_token_length, temperature, top_p]
    parameter_list = [max_turns]
    # btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
    upvote_btn.click(upvote_last_response,
        [state, video_decoder], 
        [textbox, upvote_btn, downvote_btn, flag_btn],
        queue=False)
    downvote_btn.click(downvote_last_response,
        [state, video_decoder], 
        [textbox, upvote_btn, downvote_btn, flag_btn],
        queue=False)
    flag_btn.click(flag_last_response,
        [state, video_decoder], 
        [textbox, upvote_btn, downvote_btn, flag_btn],
        queue=False)
    
    clear_btn.click(clear_history, None, [state, chatbot, textbox, videobox])

    textbox.submit(add_text, 
        [state, textbox, videobox, video_decoder] + parameter_list, 
        [state, chatbot, textbox, videobox]
    )

    submit_btn.click(add_text, 
        [state, textbox, videobox, video_decoder] + parameter_list, 
        [state, chatbot, textbox, videobox]
    )
    

parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="11.215.107.186")
parser.add_argument("--debug", action="store_true", default=False, help="using debug mode")
parser.add_argument("--port", type=int, default=80)
parser.add_argument("--concurrency-count", type=int, default=1)
parser.add_argument("--base-model",type=str, default='./')
parser.add_argument("--load-8bit", action="store_true", help="using 8bit mode")
parser.add_argument("--bf16", action="store_true", default=True, help="using 8bit mode")
args = parser.parse_args()

demo.queue(concurrency_count=args.concurrency_count, api_open=False).launch(server_name=args.host, debug=args.debug, server_port=args.port, share=False)

