import os
import base64
import itertools
import json
from typing import Dict, List, Tuple

import cv2
import gradio as gr

from agentverse import TaskSolving
from agentverse.simulation import Simulation
from agentverse.message import Message

IMG_PATH = os.path.join(os.path.dirname(__file__), "..", "imgs")


def cover_img(background, img, place: Tuple[int, int]):
    """
    Overlays the specified image to the specified position of the background image.
    :param background: background image
    :param img: the specified image
    :param place: the top-left coordinate of the target location
    """
    back_h, back_w, _ = background.shape
    height, width, _ = img.shape
    for i, j in itertools.product(range(height), range(width)):
        if img[i, j, 3]:
            background[place[0] + i, place[1] + j] = img[i, j, :3]


class GUI:
    """
    the UI of frontend
    """

    def __init__(self, task: str, tasks_dir: str, ui_kwargs: Dict[str, str]):
        """
        init a UI.
        default number of students is 0
        """
        self.messages = []
        self.task = task
        self.ui_kwargs = ui_kwargs
        if task == "tasksolving/brainstorming":
            self.backend = TaskSolving.from_task(task, tasks_dir)
        else:
            self.backend = Simulation.from_task(task, tasks_dir)
        self.turns_remain = 0
        self.agent_id = {
            self.backend.agents[idx].name: idx
            for idx in range(len(self.backend.agents))
        }
        self.stu_num = len(self.agent_id) - 1
        self.autoplay = False
        self.image_now = None
        self.text_now = None
        self.tot_solutions = 5
        self.solution_status = [False] * self.tot_solutions

    def get_avatar(self, idx):
        if idx == -1:
            img = cv2.imread(f"{IMG_PATH}/db_diag/-1.png")
        elif self.task == "simulation/prisoner_dilemma":
            img = cv2.imread(f"{IMG_PATH}/prison/{idx}.png")
        elif self.task == "simulation/db_diag":
            img = cv2.imread(f"{IMG_PATH}/db_diag/{idx}.png")
        elif "sde" in self.task:
            img = cv2.imread(f"{IMG_PATH}/sde/{idx}.png")
        else:
            img = cv2.imread(f"{IMG_PATH}/{idx}.png")
        base64_str = cv2.imencode(".png", img)[1].tostring()
        return "data:image/png;base64," + base64.b64encode(base64_str).decode("utf-8")

    def stop_autoplay(self):
        self.autoplay = False
        return (
            gr.Button.update(interactive=False),
            gr.Button.update(interactive=False),
            gr.Button.update(interactive=False),
        )

    def start_autoplay(self):
        self.autoplay = True
        yield (
            self.image_now,
            self.text_now,
            gr.Button.update(interactive=False),
            gr.Button.update(interactive=True),
            gr.Button.update(interactive=False),
            *[gr.Button.update(visible=statu) for statu in self.solution_status],
            gr.Box.update(visible=any(self.solution_status)),
        )

        while self.autoplay and self.turns_remain > 0:
            outputs = self.gen_output()
            self.image_now, self.text_now = outputs

            yield (
                *outputs,
                gr.Button.update(
                    interactive=not self.autoplay and self.turns_remain > 0
                ),
                gr.Button.update(interactive=self.autoplay and self.turns_remain > 0),
                gr.Button.update(
                    interactive=not self.autoplay and self.turns_remain > 0
                ),
                *[gr.Button.update(visible=statu) for statu in self.solution_status],
                gr.Box.update(visible=any(self.solution_status)),
            )

    def delay_gen_output(self):
        yield (
            self.image_now,
            self.text_now,
            gr.Button.update(interactive=False),
            gr.Button.update(interactive=False),
            *[gr.Button.update(visible=statu) for statu in self.solution_status],
            gr.Box.update(visible=any(self.solution_status)),
        )

        outputs = self.gen_output()
        self.image_now, self.text_now = outputs

        yield (
            self.image_now,
            self.text_now,
            gr.Button.update(interactive=self.turns_remain > 0),
            gr.Button.update(interactive=self.turns_remain > 0),
            *[gr.Button.update(visible=statu) for statu in self.solution_status],
            gr.Box.update(visible=any(self.solution_status)),
        )

    def delay_reset(self):
        self.autoplay = False
        self.image_now, self.text_now = self.reset()
        return (
            self.image_now,
            self.text_now,
            gr.Button.update(interactive=True),
            gr.Button.update(interactive=False),
            gr.Button.update(interactive=True),
            *[gr.Button.update(visible=statu) for statu in self.solution_status],
            gr.Box.update(visible=any(self.solution_status)),
        )

    def reset(self, stu_num=0):
        """
        tell backend the new number of students and generate new empty image
        :param stu_num:
        :return: [empty image, empty message]
        """
        if not 0 <= stu_num <= 30:
            raise gr.Error("the number of students must be between 0 and 30.")

        """
        # [To-Do] Need to add a function to assign agent numbers into the backend.
        """
        # self.backend.reset(stu_num)
        # self.stu_num = stu_num

        """
        # [To-Do] Pass the parameters to reset
        """
        self.backend.reset()
        self.turns_remain = self.backend.environment.max_turns

        if self.task == "simulation/prisoner_dilemma":
            background = cv2.imread(f"{IMG_PATH}/prison/case_1.png")
        elif self.task == "simulation/db_diag":
            background = cv2.imread(f"{IMG_PATH}/db_diag/background.png")
        elif "sde" in self.task:
            background = cv2.imread(f"{IMG_PATH}/sde/background.png")
        else:
            background = cv2.imread(f"{IMG_PATH}/background.png")
            back_h, back_w, _ = background.shape
            stu_cnt = 0
            for h_begin, w_begin in itertools.product(
                range(800, back_h, 300), range(135, back_w - 200, 200)
            ):
                stu_cnt += 1
                img = cv2.imread(
                    f"{IMG_PATH}/{(stu_cnt - 1) % 11 + 1 if stu_cnt <= self.stu_num else 'empty'}.png",
                    cv2.IMREAD_UNCHANGED,
                )
                cover_img(
                    background,
                    img,
                    (h_begin - 30 if img.shape[0] > 190 else h_begin, w_begin),
                )
        self.messages = []
        self.solution_status = [False] * self.tot_solutions
        return [cv2.cvtColor(background, cv2.COLOR_BGR2RGB), ""]

    def gen_img(self, data: List[Dict]):
        """
        generate new image with sender rank
        :param data:
        :return: the new image
        """
        # The following code need to be more general. This one is too task-specific.
        # if len(data) != self.stu_num:
        if len(data) != self.stu_num + 1:
            raise gr.Error("data length is not equal to the total number of students.")
        if self.task == "simulation/prisoner_dilemma":
            img = cv2.imread(f"{IMG_PATH}/speaking.png", cv2.IMREAD_UNCHANGED)
            if (
                len(self.messages) < 2
                or self.messages[-1][0] == 1
                or self.messages[-2][0] == 2
            ):
                background = cv2.imread(f"{IMG_PATH}/prison/case_1.png")
                if data[0]["message"] != "":
                    cover_img(background, img, (400, 480))
            else:
                background = cv2.imread(f"{IMG_PATH}/prison/case_2.png")
                if data[0]["message"] != "":
                    cover_img(background, img, (400, 880))
            if data[1]["message"] != "":
                cover_img(background, img, (550, 480))
            if data[2]["message"] != "":
                cover_img(background, img, (550, 880))
        elif self.task == "simulation/db_diag":
            background = cv2.imread(f"{IMG_PATH}/db_diag/background.png")
            img = cv2.imread(f"{IMG_PATH}/db_diag/speaking.png", cv2.IMREAD_UNCHANGED)
            if data[0]["message"] != "":
                cover_img(background, img, (750, 80))
            if data[1]["message"] != "":
                cover_img(background, img, (310, 220))
            if data[2]["message"] != "":
                cover_img(background, img, (522, 11))
        elif "sde" in self.task:
            background = cv2.imread(f"{IMG_PATH}/sde/background.png")
            img = cv2.imread(f"{IMG_PATH}/sde/speaking.png", cv2.IMREAD_UNCHANGED)
            if data[0]["message"] != "":
                cover_img(background, img, (692, 330))
            if data[1]["message"] != "":
                cover_img(background, img, (692, 660))
            if data[2]["message"] != "":
                cover_img(background, img, (692, 990))
        else:
            background = cv2.imread(f"{IMG_PATH}/background.png")
            back_h, back_w, _ = background.shape
            stu_cnt = 0
            if data[stu_cnt]["message"] not in ["", "[RaiseHand]"]:
                img = cv2.imread(f"{IMG_PATH}/speaking.png", cv2.IMREAD_UNCHANGED)
                cover_img(background, img, (370, 1250))
            for h_begin, w_begin in itertools.product(
                range(800, back_h, 300), range(135, back_w - 200, 200)
            ):
                stu_cnt += 1
                if stu_cnt <= self.stu_num:
                    img = cv2.imread(
                        f"{IMG_PATH}/{(stu_cnt - 1) % 11 + 1}.png", cv2.IMREAD_UNCHANGED
                    )
                    cover_img(
                        background,
                        img,
                        (h_begin - 30 if img.shape[0] > 190 else h_begin, w_begin),
                    )
                    if "[RaiseHand]" in data[stu_cnt]["message"]:
                        # elif data[stu_cnt]["message"] == "[RaiseHand]":
                        img = cv2.imread(f"{IMG_PATH}/hand.png", cv2.IMREAD_UNCHANGED)
                        cover_img(background, img, (h_begin - 90, w_begin + 10))
                    elif data[stu_cnt]["message"] not in ["", "[RaiseHand]"]:
                        img = cv2.imread(
                            f"{IMG_PATH}/speaking.png", cv2.IMREAD_UNCHANGED
                        )
                        cover_img(background, img, (h_begin - 90, w_begin + 10))

                else:
                    img = cv2.imread(f"{IMG_PATH}/empty.png", cv2.IMREAD_UNCHANGED)
                    cover_img(background, img, (h_begin, w_begin))
        return cv2.cvtColor(background, cv2.COLOR_BGR2RGB)

    def return_format(self, messages: List[Message]):
        _format = [{"message": "", "sender": idx} for idx in range(len(self.agent_id))]

        for message in messages:
            if self.task == "simulation/db_diag":
                content_json: dict = message.content
                content_json[
                    "diagnose"
                ] = f"[{message.sender}]: {content_json['diagnose']}"
                _format[self.agent_id[message.sender]]["message"] = json.dumps(
                    content_json
                )
            elif "sde" in self.task:
                if message.sender == "code_tester":
                    pre_message, message_ = message.content.split("\n")
                    message_ = "{}\n{}".format(
                        pre_message, json.loads(message_)["feedback"]
                    )
                    _format[self.agent_id[message.sender]][
                        "message"
                    ] = "[{}]: {}".format(message.sender, message_)
                else:
                    _format[self.agent_id[message.sender]][
                        "message"
                    ] = "[{}]: {}".format(message.sender, message.content)

            else:
                _format[self.agent_id[message.sender]]["message"] = "[{}]: {}".format(
                    message.sender, message.content
                )

        return _format

    def gen_output(self):
        """
        generate new image and message of next step
        :return: [new image, new message]
        """

        # data = self.backend.next_data()
        return_message = self.backend.next()
        data = self.return_format(return_message)

        # data.sort(key=lambda item: item["sender"])
        """
        # [To-Do]; Check the message from the backend: only 1 person can speak
        """

        for item in data:
            if item["message"] not in ["", "[RaiseHand]"]:
                self.messages.append((item["sender"], item["message"]))

        message = self.gen_message()
        self.turns_remain -= 1
        return [self.gen_img(data), message]

    def gen_message(self):
        # If the backend cannot handle this error, use the following code.
        message = ""
        """
        for item in data:
            if item["message"] not in ["", "[RaiseHand]"]:
                message = item["message"]
                break
        """
        for sender, msg in self.messages:
            if sender == 0:
                avatar = self.get_avatar(0)
            elif sender == -1:
                avatar = self.get_avatar(-1)
            else:
                avatar = self.get_avatar((sender - 1) % 11 + 1)
            if self.task == "simulation/db_diag":
                msg_json = json.loads(msg)
                self.solution_status = [False] * self.tot_solutions
                msg = msg_json["diagnose"]
                if msg_json["solution"] != "":
                    solution: List[str] = msg_json["solution"]
                    for solu in solution:
                        if "query" in solu or "queries" in solu:
                            self.solution_status[0] = True
                            solu = solu.replace(
                                "query", '<span style="color:yellow;">query</span>'
                            )
                            solu = solu.replace(
                                "queries", '<span style="color:yellow;">queries</span>'
                            )
                        if "join" in solu:
                            self.solution_status[1] = True
                            solu = solu.replace(
                                "join", '<span style="color:yellow;">join</span>'
                            )
                        if "index" in solu:
                            self.solution_status[2] = True
                            solu = solu.replace(
                                "index", '<span style="color:yellow;">index</span>'
                            )
                        if "system configuration" in solu:
                            self.solution_status[3] = True
                            solu = solu.replace(
                                "system configuration",
                                '<span style="color:yellow;">system configuration</span>',
                            )
                        if (
                            "monitor" in solu
                            or "Monitor" in solu
                            or "Investigate" in solu
                        ):
                            self.solution_status[4] = True
                            solu = solu.replace(
                                "monitor", '<span style="color:yellow;">monitor</span>'
                            )
                            solu = solu.replace(
                                "Monitor", '<span style="color:yellow;">Monitor</span>'
                            )
                            solu = solu.replace(
                                "Investigate",
                                '<span style="color:yellow;">Investigate</span>',
                            )
                        msg = f"{msg}<br>{solu}"
                if msg_json["knowledge"] != "":
                    msg = f'{msg}<hr style="margin: 5px 0"><span style="font-style: italic">{msg_json["knowledge"]}<span>'
            else:
                msg = msg.replace("<", "&lt;")
                msg = msg.replace(">", "&gt;")
            message = (
                f'<div style="display: flex; align-items: center; margin-bottom: 10px;overflow:auto;">'
                f'<img src="{avatar}" style="width: 5%; height: 5%; border-radius: 25px; margin-right: 10px;">'
                f'<div style="background-color: gray; color: white; padding: 10px; border-radius: 10px;'
                f'max-width: 70%; white-space: pre-wrap">'
                f"{msg}"
                f"</div></div>" + message
            )
        message = (
            '<div id="divDetail" style="height:600px;overflow:auto;">'
            + message
            + "</div>"
        )
        return message

    def submit(self, message: str):
        """
        submit message to backend
        :param message: message
        :return: [new image, new message]
        """
        self.backend.submit(message)
        self.messages.append((-1, f"[User]: {message}"))
        return self.gen_img([{"message": ""}] * len(self.agent_id)), self.gen_message()

    def launch(self, single_agent=False, discussion_mode=False):
        if self.task == "tasksolving/brainstorming":
            with gr.Blocks() as demo:
                chatbot = gr.Chatbot(height=800, show_label=False)
                msg = gr.Textbox(label="Input")

                def respond(message, chat_history):
                    chat_history.append((message, None))
                    yield "", chat_history
                    for response in self.backend.iter_run(
                        single_agent=single_agent, discussion_mode=discussion_mode
                    ):
                        print(response)
                        chat_history.append((None, response))
                        yield "", chat_history

                msg.submit(respond, [msg, chatbot], [msg, chatbot])
        else:
            with gr.Blocks() as demo:
                with gr.Row():
                    with gr.Column():
                        image_output = gr.Image()
                        with gr.Row():
                            reset_btn = gr.Button("Reset")
                            # next_btn = gr.Button("Next", variant="primary")
                            next_btn = gr.Button("Next", interactive=False)
                            stop_autoplay_btn = gr.Button(
                                "Stop Autoplay", interactive=False
                            )
                            start_autoplay_btn = gr.Button(
                                "Start Autoplay", interactive=False
                            )
                        with gr.Box(visible=False) as solutions:
                            with gr.Column():
                                gr.HTML("Optimization Solutions:")
                                with gr.Row():
                                    rewrite_slow_query_btn = gr.Button(
                                        "Rewrite Slow Query", visible=False
                                    )
                                    add_query_hints_btn = gr.Button(
                                        "Add Query Hints", visible=False
                                    )
                                    update_indexes_btn = gr.Button(
                                        "Update Indexes", visible=False
                                    )
                                    tune_parameters_btn = gr.Button(
                                        "Tune Parameters", visible=False
                                    )
                                    gather_more_info_btn = gr.Button(
                                        "Gather More Info", visible=False
                                    )
                    # text_output = gr.Textbox()
                    text_output = gr.HTML(self.reset()[1])

                # Given a botton to provide student numbers and their inf.
                # stu_num = gr.Number(label="Student Number", precision=0)
                # stu_num = self.stu_num

                if self.task == "simulation/db_diag":
                    user_msg = gr.Textbox()
                    submit_btn = gr.Button("Submit", variant="primary")

                    submit_btn.click(
                        fn=self.submit,
                        inputs=user_msg,
                        outputs=[image_output, text_output],
                        show_progress=False,
                    )
                else:
                    pass

                # next_btn.click(fn=self.gen_output, inputs=None, outputs=[image_output, text_output],
                #                show_progress=False)
                next_btn.click(
                    fn=self.delay_gen_output,
                    inputs=None,
                    outputs=[
                        image_output,
                        text_output,
                        next_btn,
                        start_autoplay_btn,
                        rewrite_slow_query_btn,
                        add_query_hints_btn,
                        update_indexes_btn,
                        tune_parameters_btn,
                        gather_more_info_btn,
                        solutions,
                    ],
                    show_progress=False,
                )

                # [To-Do] Add botton: re-start (load different people and env)
                # reset_btn.click(fn=self.reset, inputs=stu_num, outputs=[image_output, text_output],
                #                 show_progress=False)
                # reset_btn.click(fn=self.reset, inputs=None, outputs=[image_output, text_output], show_progress=False)
                reset_btn.click(
                    fn=self.delay_reset,
                    inputs=None,
                    outputs=[
                        image_output,
                        text_output,
                        next_btn,
                        stop_autoplay_btn,
                        start_autoplay_btn,
                        rewrite_slow_query_btn,
                        add_query_hints_btn,
                        update_indexes_btn,
                        tune_parameters_btn,
                        gather_more_info_btn,
                        solutions,
                    ],
                    show_progress=False,
                )

                stop_autoplay_btn.click(
                    fn=self.stop_autoplay,
                    inputs=None,
                    outputs=[next_btn, stop_autoplay_btn, start_autoplay_btn],
                    show_progress=False,
                )
                start_autoplay_btn.click(
                    fn=self.start_autoplay,
                    inputs=None,
                    outputs=[
                        image_output,
                        text_output,
                        next_btn,
                        stop_autoplay_btn,
                        start_autoplay_btn,
                        rewrite_slow_query_btn,
                        add_query_hints_btn,
                        update_indexes_btn,
                        tune_parameters_btn,
                        gather_more_info_btn,
                        solutions,
                    ],
                    show_progress=False,
                )

        demo.queue(concurrency_count=5, max_size=20).launch(**self.ui_kwargs)
        # demo.launch()
