import os
import asyncio
import base64
import re
import time
from pathlib import Path
from typing import Literal, Tuple, Optional
import utils
import nbformat
from nbclient import NotebookClient
from nbclient.exceptions import CellExecutionComplete, CellTimeoutError, DeadKernelError
from nbclient.util import ensure_async
from nbformat import NotebookNode
from nbformat.v4 import new_code_cell, new_markdown_cell, new_output, output_from_msg
from rich.box import MINIMAL
from rich.console import Console, Group
from rich.live import Live
from rich.markdown import Markdown
from rich.panel import Panel
from rich.syntax import Syntax
from tool.report import NotebookReporter
from jupyter_client.utils import run_sync
import logger

# 当前文件所在目录：tool/
current_dir = os.path.dirname(os.path.abspath(__file__))

# 上一级目录：project_root/
project_root = os.path.dirname(current_dir)


class RealtimeOutputNotebookClient(NotebookClient):
    """Realtime output of Notebook execution."""

    def __init__(self, *args, notebook_reporter=None, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.notebook_reporter = notebook_reporter or NotebookReporter()



class ExecuteNbCode():
    """execute notebook code block, return result to llm, and display it."""

    nb: NotebookNode
    nb_client: RealtimeOutputNotebookClient = None
    console: Console
    interaction: str
    timeout: int = 7200

    def __init__(self, nb=nbformat.v4.new_notebook(), timeout=7200):
        self.console = Console()
        self.timeout = timeout
        self.interaction=("ipython" if self.is_ipython() else "terminal")
        self.nb = nb
        self.reporter = NotebookReporter()
        self.set_nb_client()
        self.init_called = False
    
    def set_nb_client(self):
        # Use workspace path from config, default to "workspace" if not set or empty
        workspace_path_str = utils.module_config.workspace_path or "workspace"
        
        # Resolve to absolute path relative to project root
        # Assuming project root is two levels up from this file (tool/execute_nb_code.py)
        project_root = Path(__file__).parent.parent
        
        # If workspace_path is already absolute, this keeps it. If relative, it joins with project_root.
        abs_workspace_path = (project_root / workspace_path_str).resolve()
        
        os.makedirs(abs_workspace_path, exist_ok=True)
        self.nb_client = RealtimeOutputNotebookClient(
            self.nb,
            timeout=self.timeout,
            resources={"metadata": {"path": str(abs_workspace_path)}},
            notebook_reporter=self.reporter,
            coalesce_streams=True,
        )
        
    
        
    def parse_outputs(self, outputs: list[str], keep_len: int = 5000) -> Tuple[bool, str]:
        """Parses the outputs received from notebook execution."""
        assert isinstance(outputs, list)
        parsed_output, is_success = [], True
        for i, output in enumerate(outputs):
            output_text = ""
            if output["output_type"] == "stream" :
                output_text = output["text"]
            elif output["output_type"] == "display_data":
                if "image/png" in output["data"]:
                    logger.info("生成了图表，正在展示图表...")
                    self.show_bytes_figure(output["data"]["image/png"], self.interaction)
                    # Call LLM to explain the image
                    try:
                        logger.info("正在生成图表解释...")
                        from llm import LLM
                        llm = LLM("gemini-3-flash-preview")
                        explanation = self.explain_image(llm, output["data"]["image/png"])
                        output_text = f"[生成图表解释]: {explanation}\n"
                        logger.info(output_text)
                    except Exception as e:
                        logger.error(f"Failed to explain image: {e}")
                        output_text = "[生成图表解释]: 失败生成解释。\n"
                else:
                    logger.info(
                        f"{i}th output['data'] from nbclient outputs dont have image/png, continue next output ..."
                    )
            elif output["output_type"] == "execute_result":
                output_text = output["data"]["text/plain"]
            elif output["output_type"] == "error":
                output_text, is_success = "\n".join(output["traceback"]), False

            # handle coroutines that are not executed asynchronously
            if output_text.strip().startswith("<coroutine object"):
                output_text = "Executed code failed, you need use key word 'await' to run a async code."
                is_success = False

            output_text = remove_escape_and_color_codes(output_text)
            if is_success:
                output_text = remove_log_and_warning_lines(output_text)
            # The useful information of the exception is at the end,
            # the useful information of normal output is at the begining.
            if "<!DOCTYPE html>" not in output_text:
                output_text = output_text[:keep_len] if is_success else output_text[-keep_len:]

            parsed_output.append(output_text)
        return is_success, ",".join(parsed_output)
    
    def explain_image(self, llm, image_base64: str) -> str:
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": "请简要解释这张图片的内容，重点关注图表中的趋势、异常值或关键信息。"},
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/png;base64,{image_base64}"
                        }
                    }
                ]
            }
        ]
        response = llm.ask(messages)
        return response.choices[0].message.content

    def show_bytes_figure(self, image_base64: str, interaction_type: Optional[str]):
        image_bytes = base64.b64decode(image_base64)
        if interaction_type == "ipython":
            from IPython.display import Image, display

            display(Image(data=image_bytes))
        else:
            import io
            from pathlib import Path
            from PIL import Image
            # Save image bytes to workspace/output so the backend can serve it.
            run_id = os.environ.get("RUN_ID")
            # prefer writing into output/results/<run_id> if run_id present
            if run_id:
                dest_dir = Path("output") / "results" / run_id
            else:
                # fallback to module workspace
                dest_dir = Path(utils.module_config.workspace_path) / "results_tmp"

            dest_dir.mkdir(parents=True, exist_ok=True)
            # unique filename
            fname = f"figure_{int(time.time()*1000)}.png"
            dest_path = dest_dir / fname
            with open(dest_path, "wb") as fp:
                fp.write(image_bytes)

            # Try to notify the server so it can broadcast to websocket clients.
            try:
                import json
                try:
                    # prefer requests if available
                    import requests
                    server_url = os.environ.get("AUTO_DS_SERVER", "http://127.0.0.1:8001")
                    notify_url = f"{server_url}/api/runs/{run_id}/notify_file" if run_id else None
                    if notify_url:
                        requests.post(notify_url, json={"filename": fname}, timeout=2)
                except Exception:
                    # fallback to stdlib and make sure we use the same server_url
                    if run_id:
                        from urllib import request as _request
                        server_url = os.environ.get("AUTO_DS_SERVER", "http://127.0.0.1:8001")
                        req = _request.Request(
                            f"{server_url}/api/runs/{run_id}/notify_file",
                            data=json.dumps({"filename": fname}).encode("utf-8"),
                            headers={"Content-Type": "application/json"},
                        )
                        try:
                            _request.urlopen(req, timeout=2)
                        except Exception as e:
                            logger.info(f"Failed to notify server: {e}")
                            pass
            except Exception:
                # best-effort notify, ignore failures
                pass
            # Open image locally as fallback for developer convenience
            image = Image.open(io.BytesIO(image_bytes))
            image.show()
    
    def is_ipython(self) -> bool:
        try:
            # 如果在Jupyter Notebook中运行，__file__ 变量不存在
            from IPython import get_ipython

            if get_ipython() is not None and "IPKernelApp" in get_ipython().config:
                return True
            else:
                return False
        except NameError:
            return False
    def _display(self, code: str, language: Literal["python", "markdown"] = "python"):
        if language == "python":
            code = Syntax(code, "python", theme="paraiso-dark", line_numbers=True)
            self.console.print(code)
        elif language == "markdown":
            display_markdown(code)
        else:
            raise ValueError(f"Only support for python, markdown, but got {language}")
        
    def add_code_cell(self, code: str):
        self.nb.cells.append(new_code_cell(source=code))
    
    def add_markdown_cell(self, markdown: str):
        self.nb.cells.append(new_markdown_cell(source=markdown))
        
    def build(self):
        if self.nb_client.kc is None or not run_sync(self.nb_client.kc.is_alive)():
            self.nb_client.create_kernel_manager()
            self.nb_client.start_new_kernel()
            self.nb_client.start_new_kernel_client()    

    def run_cell(self, cell: NotebookNode, cell_index: int) -> Tuple[bool, str]:
        """set timeout for run code.
        returns the success or failure of the cell execution, and an optional error message.
        """
        self.reporter.report(cell, "content")

        try:
            self.nb_client.execute_cell(cell, cell_index)
            return self.parse_outputs(self.nb.cells[-1].outputs)
        except CellTimeoutError:
            assert self.nb_client.km is not None
            self.nb_client.km.interrupt_kernel()
            time.sleep(1)
            error_msg = "Cell execution timed out: Execution exceeded the time limit and was stopped; consider optimizing your code for better performance."
            return False, error_msg
        except DeadKernelError:
            self.reset()
            return False, "DeadKernelError"
        except Exception:
            return self.parse_outputs(self.nb.cells[-1].outputs)
    
    def run(self, code: str, language: Literal["python", "markdown"] = "python") -> Tuple[str, bool]:
        """
        return the output of code execution, and a success indicator (bool) of code execution.
        """
        # logger.info("【生成代码】\n")
        # logger.trace(code)
        # self._display(code, language)
        logger.special("【生成代码】\n {}".format(code))
        
        with self.reporter:
            if language == "python":
                # add code to the notebook
                self.add_code_cell(code=code)

                # build code executor
                self.build()

                # run code
                cell_index = len(self.nb.cells) - 1
                logger.debug(f"【执行代码】 cell_index: {cell_index}")
                success, outputs = self.run_cell(self.nb.cells[-1], cell_index)

                if "!pip" in code:
                    success = False
                    outputs = outputs[-500:]
                elif "git clone" in code:
                    outputs = outputs[:500] + "..." + outputs[-500:]

            elif language == "markdown":
                # add markdown content to markdown cell in a notebook.
                self.add_markdown_cell(code)
                # return True, beacuse there is no execution failure for markdown cell.
                outputs, success = code, True
            else:
                raise ValueError(f"Only support for language: python, markdown, but got {language}, ")

            file_path = Path(utils.module_config.workspace_path or "workspace") / "code.ipynb"
            # Ensure the directory exists
            file_path.parent.mkdir(parents=True, exist_ok=True)
            
            nbformat.write(self.nb, str(file_path))
            self.reporter.report(str(file_path), "path")

            return outputs, success
        
def display_markdown(content: str):
    # Use regular expressions to match blocks of code one by one.
    matches = re.finditer(r"```(.+?)```", content, re.DOTALL)
    start_index = 0
    content_panels = []
    # Set the text background color and text color.
    style = "black on white"
    # Print the matching text and code one by one.
    for match in matches:
        text_content = content[start_index : match.start()].strip()
        code_content = match.group(0).strip()[3:-3]  # Remove triple backticks

        if text_content:
            content_panels.append(Panel(Markdown(text_content), style=style, box=MINIMAL))

        if code_content:
            content_panels.append(Panel(Markdown(f"```{code_content}"), style=style, box=MINIMAL))
        start_index = match.end()

    # Print remaining text (if any).
    remaining_text = content[start_index:].strip()
    if remaining_text:
        content_panels.append(Panel(Markdown(remaining_text), style=style, box=MINIMAL))

    # Display all panels in Live mode.
    with Live(auto_refresh=False, console=Console(), vertical_overflow="visible") as live:
        live.update(Group(*content_panels))
        live.refresh()

def remove_log_and_warning_lines(input_str: str) -> str:
    delete_lines = ["[warning]", "warning:", "[cv]", "[info]"]
    result = "\n".join(
        [line for line in input_str.split("\n") if not any(dl in line.lower() for dl in delete_lines)]
    ).strip()
    return result


def remove_escape_and_color_codes(input_str: str):
    # 使用正则表达式去除jupyter notebook输出结果中的转义字符和颜色代码
    # Use regular expressions to get rid of escape characters and color codes in jupyter notebook output.
    pattern = re.compile(r"\x1b\[[0-9;]*[mK]")
    result = pattern.sub("", input_str)
    return result
