import fire
import gradio as gr

from src.service.rag_service import RAGService


def create_button_handler(my_rag_service: RAGService, document: str):
    my_rag_service.create_documents(documents=[document])
    gr.Info(f"Document added!")
    new_uuids = my_rag_service.get_all_uuids()
    return [
        gr.update(value=""),
        gr.update(choices=new_uuids, value=new_uuids[0] if len(new_uuids) != 0 else None),
        gr.update(choices=new_uuids, value=new_uuids[0] if len(new_uuids) != 0 else None),
    ]


def read_dropdown_handler(my_rag_service: RAGService, uuid: str):
    document = my_rag_service.read_document(uuid=uuid)
    return gr.update(value=document)


def delete_button_handler(my_rag_service: RAGService, uuid: str):
    my_rag_service.delete_documents(uuids=[uuid])
    gr.Info(f"Document deleted!")
    new_uuids = my_rag_service.get_all_uuids()
    return [
        gr.update(choices=new_uuids, value=new_uuids[0] if len(new_uuids) != 0 else None),
        gr.update(choices=new_uuids, value=new_uuids[0] if len(new_uuids) != 0 else None),
    ]


def run_demo(
    embedder_type: str,
    vectordb_type: str,
    generator_type: str,
    port: int = 7777,
):
    my_rag_service = RAGService(
        embedder_type=embedder_type,
        vectordb_type=vectordb_type,
        generator_type=generator_type,
    )

    all_uuids = my_rag_service.get_all_uuids()
    first_uuid = all_uuids[0] if len(all_uuids) != 0 else None
    first_document = my_rag_service.read_document(uuid=all_uuids[0]) if len(all_uuids) != 0 else ""

    with gr.Blocks() as demo:
        gr.Markdown("# RAG Demo")

        with gr.Tab("Chat Interface"):
            with gr.Row():
                conv_id_dropdown = gr.Dropdown(
                    label="conv_id",
                    choices=list(range(10)),
                    value=0,
                )
                top_k_dropdown = gr.Dropdown(
                    label="top_k",
                    choices=[1, 3, 5, 10, 25, 50],
                    value=1,
                )
            gr.ChatInterface(
                chatbot=gr.Chatbot(height=500),
                fn=my_rag_service.retrieve_and_generate,
                additional_inputs=[conv_id_dropdown, top_k_dropdown],
            )

        with gr.Tab("Database"):
            with gr.Tab("Create"):
                create_textbox = gr.Textbox(
                    label="Enter a document to add",
                    placeholder="Something...",
                )
                create_button = gr.Button(
                    value="Create",
                    variant="primary",
                )
            with gr.Tab("Read"):
                read_dropdown = gr.Dropdown(
                    label="Select a document to read",
                    choices=all_uuids,
                    value=first_uuid,
                )
                read_textbox = gr.Textbox(
                    label="Content",
                    placeholder=first_document,
                    interactive=False,
                )
            with gr.Tab("Delete"):
                delete_dropdown = gr.Dropdown(
                    label="Select a document to delete",
                    choices=all_uuids,
                    value=first_uuid,
                )
                delete_button = gr.Button(
                    value="Delete",
                    variant="stop",
                )

        create_button.click(
            fn=lambda document: create_button_handler(my_rag_service=my_rag_service, document=document),
            inputs=[create_textbox],
            outputs=[create_textbox, read_dropdown, delete_dropdown],
        )
        read_dropdown.change(
            fn=lambda uuid: read_dropdown_handler(my_rag_service=my_rag_service, uuid=uuid),
            inputs=[read_dropdown],
            outputs=[read_textbox],
        )
        delete_button.click(
            fn=lambda uuid: delete_button_handler(my_rag_service=my_rag_service, uuid=uuid),
            inputs=[delete_dropdown],
            outputs=[read_dropdown, delete_dropdown],
        )

    demo.launch(server_port=port, share=True)


if __name__ == "__main__":
    fire.Fire(run_demo)
