import re
import os
import time
import json
import pprint
import logging
from typing import Dict, Any
from dataclasses import asdict

import gradio as gr
from gradio import ChatMessage
from dotenv import load_dotenv
from PIL import Image, ImageDraw, ImageFont
from artemis.logger import setup_logger
from artemis.scheme import AgentState
from artemis import Environment, VLMWrapper, Agent


load_dotenv()
setup_logger(name='mobile_use')
logger = logging.getLogger('mobile_use')

disable_btn = gr.Button(interactive=False)
enable_btn = gr.Button(interactive=True, visible=True)
PARAMS_NAME = []
PARAMS_COMPONENT = []
VIEW_IMAGE_SIZE = (750, 750)
IMAGE_OUTPUT = 'logs/history'
os.makedirs(IMAGE_OUTPUT, exist_ok=True)


class Worker:
    def __init__(self):
        self._agent: Agent = None
        self._stop = False
        self._images = []
        self._history_path = None

    def save_image(self, image: Image.Image, filename: str):
        idx = self._images.index(filename) if filename in self._images else -1
        if idx >= 0:
            self._images[idx] = filename
        else:
            self._images.append(filename)
        file_path = os.path.join(self._history_path, filename)
        image.save(file_path)

    def reset(self, env: Dict[str, Any], vlm: Dict[str, Any], agent: Dict[str, Any], goal: str):
        logger.info("Reset Agent and Environment")
        env = Environment(**env)
        vlm = VLMWrapper(**vlm)
        self._images.clear()
        self._agent = Agent.from_params({'env': env, 'vlm': vlm, **agent})
        i = 0
        name = re.sub(r'[^\w\u4e00-\u9fff\s-]', '', goal[:128])
        history_path = os.path.join(IMAGE_OUTPUT, name)
        while os.path.exists(history_path):
            i += 1
            history_path = os.path.join(IMAGE_OUTPUT, name) + f'_{i}'
        self._history_path = history_path
        os.makedirs(self._history_path)

    def run(self, input_content: str):
        self._stop = False
        img_file = None
        for step_data in self._agent.iter_run(input_content):
            if step_data is None:
                break
            if step_data.curr_env_state is not None:
                r = 20
                if step_data.action:
                    logger.info(f'step_data action: {step_data.action}')
                    image = step_data.curr_env_state.pixels.copy()
                    draw = ImageDraw.Draw(image)
                    if step_data.action.name == 'click':
                        if 'coordinate' in step_data.action.parameters:       # QwenAgent
                            x, y = step_data.action.parameters['coordinate']
                        else:
                            x, y = step_data.action.parameters['point']
                        draw.ellipse([x-r, y-r, x+r, y+r], fill=(255, 0, 0), outline='black', width=2)
                    elif step_data.action.name == 'scroll':
                        x, y = step_data.action.parameters['start_point']
                        draw.ellipse([x-r, y-r, x+r, y+r], fill=(255, 0, 0), outline='black', width=2)
                        x, y = step_data.action.parameters['end_point']
                        draw.ellipse([x-r, y-r, x+r, y+r], fill=(255, 0, 0), outline='black', width=2)
                    elif step_data.action.name == 'swipe':       # QwenAgent
                        x, y = step_data.action.parameters['coordinate']
                        draw.ellipse([x-r, y-r, x+r, y+r], fill=(255, 0, 0), outline='black', width=2)
                        x, y = step_data.action.parameters['coordinate2']
                        draw.ellipse([x-r, y-r, x+r, y+r], fill=(255, 0, 0), outline='black', width=2)
                    draw.text((200, 10), f"Step {step_data.step_idx}", font=ImageFont.load_default().font_variant(size=30), fill=(255, 0, 0))
                else:
                    image = step_data.curr_env_state.pixels
                img_file = f'step_{step_data.step_idx}_0.png'
                self.save_image(image, img_file)

            text = ''
            if step_data.thought:
                text += f'Step {step_data.step_idx}\nThought: {step_data.thought}'

            if step_data.action:
                text += f'\nAction: {step_data.action}'
                a = step_data.action
                if a.name.upper() == 'FINISHED' and a.parameters.get('answer'):
                    text += f"\n\nTask Finished: {a.parameters.get('answer')}"

            yield dict(text=text, img_file=img_file)

            if step_data.exec_env_state is not None:
                img_file = f'step_{step_data.step_idx}_1.png'
                self.save_image(step_data.exec_env_state.pixels, img_file)
                yield dict(text=text, img_file=img_file)

            if self._stop:
                text = "\n\n**The task has been canceled!**"
                yield dict(text=text.strip(), img_file=img_file)
                break
        if self._agent.curr_step_idx == self._agent.max_steps:
            text = f"\n\n**The task has stopped because the maximum number of steps({self._agent.max_steps}) has been reached**"
            yield dict(text=text.strip(), img_file=img_file)

    def stop(self):
        self._stop = True

class SessionWorkers:

    def __init__(self, active_duration: int=300):
        self._workers_ = {}
        self.active_duration = active_duration
    
    def get_worker(self, session_id: str) -> Worker:
        if session_id not in self._workers_:
            self._workers_[session_id] = (time.time(), Worker())
        return self._workers_[session_id][1]

    def _clear_(self):
        for sid in list(self._workers_.keys()):
            if time.time() - self._workers_[sid][0] > self.active_duration:
                self._workers_.pop(sid)


session_workers = SessionWorkers()


def get_button_state(i_run: bool=None, i_stop: bool=None, i_clear: bool=None, stop_value: str=None):
    buttons = []
    if i_run is not None:
        buttons.append(gr.update(
            value='▶️ Run' if i_run else '▶️ Running',
            interactive=i_run
        ))
    if i_stop is not None:
        if stop_value is not None:
            buttons.append(gr.update(interactive=i_stop, value=stop_value))
        else:
            buttons.append(gr.update(interactive=i_stop))
    if i_clear is not None:
        buttons.append(gr.update(
            interactive=i_clear
        ))
    return buttons


def run_agent(request: gr.Request, input_content, messages, image, *args):
    session_id = request.session_hash
    logger.info(f"instruction: {input_content}")
    messages.append(ChatMessage(role='user', content=input_content))
    yield [messages, image] + get_button_state(False, True, False)

    params = {}
    for name, value in zip(PARAMS_NAME, args):
        prefix, name = name.split('/')
        if prefix not in params:
            params[prefix] = {}
        params[prefix][name] = value
    print('============== params ==============')
    pprint.pprint(params)

    # Try to get the base_url and api_key from the env if it is not available
    if not params['vlm']['base_url']:
        params['vlm']['base_url'] = os.getenv('VLM_BASE_URL', None)
        params['vlm']['api_key'] = os.getenv('VLM_API_KEY', None)
    if not params['vlm']['base_url']:
        messages.append(ChatMessage(role="assistant", content=f'Missing vlm base url'))
        yield [messages, image] + get_button_state(True, False, True)
        return

    worker = session_workers.get_worker(session_id)
    if worker._agent is None or worker._agent.state != AgentState.CALLUSER:
        try:
            worker.reset(goal=input_content, **params)
        except Exception as e:
            logger.error(e)
            messages.append(ChatMessage(role="assistant", content=f'The agent initialization fails: {e}'))
            yield [messages, image] + get_button_state(True, False, True)
            return

    try:
        show_image = image
        step_idx = -1
        for msg in worker.run(input_content):
            img_file = msg.get('img_file')
            if img_file:
                show_image = os.path.join(worker._history_path, img_file)
            if step_idx != worker._agent.curr_step_idx:
                step_idx = worker._agent.curr_step_idx
                messages.append(ChatMessage(role="assistant", content=msg["text"]))
            else:
                messages[-1].content = msg["text"]
            yield [messages, show_image] + get_button_state(False, True, False)
        yield [messages, show_image] + get_button_state(True, False, True, stop_value='⏹️ Stop')
    except Exception as e:
        logger.error(e)
        gr.Info('系统异常')

    # save the history
    messages_dict = []
    for msg in messages:
        try:
            msg = asdict(msg)
        except Exception:
            pass
        if isinstance(msg, dict):
            messages_dict.append(msg)
        else:
            logger.error(f"Error message format: {type(msg)} {msg}")
    with open(os.path.join(worker._history_path, 'messages.json'), 'w', encoding='utf-8') as writer:
        json.dump(messages_dict, writer, ensure_ascii=False, indent=4)


def clear_history(request: gr.Request):
    return ([], "", None) + (disable_btn,) * 2


def get_previous_image(request: gr.Request, curr_image_path):
    session_id = request.session_hash
    global session_workers
    worker = session_workers.get_worker(session_id)
    logger.info(f'curr_image_path: {curr_image_path}')
    curr_file = os.path.basename(curr_image_path)

    curr_index = worker._images.index(curr_file)
    if curr_index - 1 < 0:
        return curr_image_path
    else:
        pre_image_path = os.path.join(worker._history_path, worker._images[curr_index-1])
        return pre_image_path


def get_next_image(request: gr.Request, curr_image_path):
    session_id = request.session_hash
    global session_workers
    worker = session_workers.get_worker(session_id)
    logger.info(f'curr_image_path: {curr_image_path}')
    curr_file = os.path.basename(curr_image_path)

    curr_index = worker._images.index(curr_file)
    if curr_index + 1 >= len(worker._images):
        return curr_image_path
    else:
        next_image_path = os.path.join(worker._history_path, worker._images[curr_index+1])
        return next_image_path


def stop_worker(request: gr.Request):
    session_id = request.session_hash
    global session_workers
    session_workers.get_worker(session_id).stop()
    return gr.update(value="⏹️ Stopping...", interactive=False)


def add_text(instruction, messages, request: gr.Request):
    logger.info(f"instruction: {instruction}")
    messages.append(ChatMessage(role='user', content=instruction))
    return (messages, "") + (disable_btn, enable_btn, disable_btn)


def add_params_component(prefix, name, component):
    PARAMS_NAME.append(prefix+'/'+name)
    PARAMS_COMPONENT.append(component)


def build_agent_ui_demo():
    with gr.Blocks(title="Mobile Use WebUI", theme=gr.themes.Default()) as demo:
        with gr.Row():
            gr.Markdown(
                """
                # 📱 Mobile Use WebUI
                ### Control your mobile with AI assistance
                """,
                elem_classes=["header-text"],
            )
        with gr.Group():
            with gr.Row():
                with gr.Column(scale=2):
                    with gr.Accordion("📱 Mobile Settings", open=False):
                        with gr.Group():
                            host = gr.Textbox(
                                label="Android ADB Server Host",
                                placeholder='127.0.0.1',
                                info="Android ADB server host, support remote device.",
                            )
                            port = gr.Number(
                                label="Android ADB Server Port",
                                value=5037,
                                info="Android ADB server port",
                            )
                            serial_no = gr.Textbox(
                                label="Device Serial No.",
                                placeholder='a22d0110',
                                info="Serial No. for connected device",
                            )
                            reset_to_home = gr.Checkbox(
                                label="Reset to HOME",
                                value=True,
                                interactive=True,
                                info="Reset the device to HOME screen",
                            )
                            add_params_component('env', 'host', host)
                            add_params_component('env', 'port', port)
                            add_params_component('env', 'serial_no', serial_no)
                            add_params_component('env', 'go_home', reset_to_home)
                    with gr.Accordion("⚙️ Agent Settings", open=False):
                        with gr.Group():
                            with gr.Column():
                                agent_type = gr.Dropdown(
                                    label="Agent Name",
                                    choices=['SingleAgent', 'MultiAgent'],
                                    value='SingleAgent',
                                    interactive=True,
                                    info="Select a agent framework"
                                )
                                max_steps = gr.Slider(
                                    minimum=1,
                                    maximum=50,
                                    value=30,
                                    step=1,
                                    interactive=True,
                                    label="Max Run Steps",
                                    info="Maximum number of steps the agent will take",
                                )
                                num_latest_screenshot = gr.Slider(
                                    minimum=1,
                                    maximum=10,
                                    value=2,
                                    step=1,
                                    interactive=True,
                                    label="Maximum Latest Screenshot",
                                    info="Maximum latest screenshot for per vllm request",
                                )
                                max_reflection_action = gr.Slider(
                                    minimum=1,
                                    maximum=5,
                                    value=1,
                                    step=1,
                                    interactive=True,
                                    label="Maximum Reflection Action",
                                    info="Maximum reflection action for per request",
                                )
                                add_params_component('agent', 'type', agent_type)
                                add_params_component('agent', 'max_steps', max_steps)
                                add_params_component('agent', 'num_latest_screenshot', num_latest_screenshot)
                                add_params_component('agent', 'max_reflection_action', max_reflection_action)
                    with gr.Accordion("🔧 VLM Configuration", open=False):
                        with gr.Group():
                            vlm_base_url = gr.Dropdown(
                                label="Base URL",
                                choices=['http://10.66.167.11:8083/v1'],
                                # placeholder='http://10.66.167.11:8083/v1',
                                interactive=True,
                                info="API endpoint URL"
                            )
                            add_params_component('vlm', 'base_url', vlm_base_url)
                            vlm_api_key = gr.Textbox(
                                label="API Key",
                                type="password",
                                value='EMPTY',
                                interactive=True,
                                info="Your API key"
                            )
                            add_params_component('vlm', 'api_key', vlm_api_key)
                            vlm_model_name = gr.Dropdown(
                                label="Model Name",
                                choices=['qwen2.5-vl-7b-instruct', 'qwen2.5-vl-72b-instruct', 'Qwen2.5-VL-72B-Instruct'],
                                value='Qwen2.5-VL-72B-Instruct',
                                interactive=True,
                                allow_custom_value=True,  # Allow users to input custom model names
                                info="Select a model from the dropdown or type a custom model name"
                            )
                            add_params_component('vlm', 'model_name', vlm_model_name)
                            vlm_max_retry = gr.Slider(
                                minimum=1,
                                maximum=5,
                                value=1,
                                step=1,
                                interactive=True,
                                label="Max Retry per Request",
                                info="Maximum number of request to VLM",
                            )
                            add_params_component('vlm', 'max_retry', vlm_max_retry)
                            vlm_temperature = gr.Slider(
                                minimum=0.0,
                                maximum=1.0,
                                value=0.3,
                                step=0.1,
                                interactive=True,
                                label="Temperature",
                                info="Controls randomness in model outputs"
                            )
                            add_params_component('vlm', 'temperature', vlm_temperature)
                    with gr.Column():
                        chatbot = gr.Chatbot(
                            elem_id="chatbot",
                            type="messages",
                            label="ToolAgent",
                            show_label=False,
                            height=550,
                        )
                        textbox = gr.Textbox(
                            lines=1,
                            label="指令",
                            show_label=False,
                            placeholder="👉 Please enter your task description",
                        )
                        with gr.Row():
                            run_button = gr.Button("▶️ Run", variant="primary")
                            stop_button = gr.Button("⏹️ Stop", variant="stop", interactive=False)
                            clear_btn = gr.Button(value="🗑️ Clear", interactive=False)

                with gr.Column(scale=1):
                    image_view = gr.Image(type="filepath", label="Screenshot", interactive=False, height=732)
                    with gr.Row():
                        previous_image = gr.Button(value="Previous", interactive=True)
                        next_image = gr.Button(value="Next", interactive=True)

        # register listeners
        btn_list = [run_button, stop_button, clear_btn]

        run_button.click(
            fn=run_agent,
            inputs=[textbox, chatbot, image_view] + PARAMS_COMPONENT,
            outputs=[chatbot, image_view] + btn_list,
        )
        stop_button.click(
            fn=stop_worker, inputs=None, outputs=stop_button
        )
        clear_btn.click(
            fn=clear_history, inputs=None, outputs=[chatbot, textbox, image_view, stop_button, clear_btn]
        )
        previous_image.click(
            fn=get_previous_image,
            inputs=[image_view],
            outputs=[image_view]
        )
        next_image.click(
            fn=get_next_image,
            inputs=[image_view],
            outputs=[image_view]
        )
    return demo


if __name__ == "__main__":
    demo = build_agent_ui_demo()
    demo.launch()
