
from pathlib import Path
from typing import Annotated, Dict, Any, Set, Optional, List, Tuple

from fastmcp import FastMCP
from pydantic import Field
from leanclient import LeanLSPClient

from ape.toolkits.base import BaseToolsProvider
from ..base_provider import LanguageProviderInterface
from .utils import (
    find_project_path,
    setup_lean_client,
    extract_goals_list,
    format_diagnostics,
    extract_range,
    filter_diagnostics_by_position,
    find_position_by_content,
    get_line_context,
)
from .lean_parser import format_lean_code


class LeanCodeToolsProvider(BaseToolsProvider, LanguageProviderInterface):
    """
    """

    SUPPORTED_TOOLS = [
        "get_lean_goal",
    ]

    def __init__(
        self,
        task=None,
        config=None,
        logger=None,
        confirmation_bridge=None,
        is_cli_mode=False,
    ):
        super().__init__(
            task=task,
            config=config,
            logger=logger,
            confirmation_bridge=confirmation_bridge,
            is_cli_mode=is_cli_mode,
        )

        self._client: Optional[LeanLSPClient] = None
        self._project_path: Optional[Path] = None

        self._workspace_root = task.workspaces_dir if task else Path.cwd()

    def _ensure_client(self, file_path: Path) -> LeanLSPClient:
        if not self._project_path:
            raise ValueError("Project path not set")

        try:
            return str(file_path.relative_to(self._project_path))
        except ValueError:
            raise ValueError(f"File {file_path} not in project {self._project_path}")

    
    @staticmethod
    def display_content(
        content: str,
        add_line_numbers: bool = True,
        display_mode: str = "full",
        line_spans: Optional[List[Tuple[int, int]]] = None,
        context_lines: int = 3,
        range_separator: str = "\n",
        omit_details: bool = False,
        body_handling: Optional[str] = None,
        body_placeholder: str = "/- ... omitted ... -/"
    ) -> str:
        """Display Lean code with optional body omission.

        Args:
            content: Source code content
            add_line_numbers: Whether to add line numbers
            display_mode: Display mode ("full" or "line_spans")
            line_spans: List of (start, end) line ranges to display
            context_lines: Number of context lines around each span
            range_separator: Separator between ranges
            omit_details: Whether to omit body details (deprecated, use body_handling)
            body_handling: Explicit body handling mode ("keep_all", "omit_all", "omit_outside_spans")
            body_placeholder: Placeholder text for omitted bodies

        Returns:
            Formatted code with line numbers

        Note:
            For Lean, "body" refers to proof content in theorems and lemmas.
        """
        # Determine body_handling: explicit parameter takes precedence over omit_details
        if body_handling is None:
            body_handling = "omit_all" if omit_details else "keep_all"

        # Map generic body_handling to Lean-specific proof_handling
        return format_lean_code(
            src=content,
            display_mode=display_mode,
            proof_handling=body_handling,  # Lean internally uses proof_handling
            line_spans=line_spans,
            context_lines=context_lines,
            add_line_numbers=add_line_numbers,
            proof_placeholder=body_placeholder,  # Lean internally uses proof_placeholder
            range_separator=range_separator
        )
    
    @staticmethod
    def get_omission_marker(text: str = None) -> str:
        """Get omission marker for Lean."""
        default_text = text or "... omitted lines ..."
        return f"/- {default_text} -/"
    
    @staticmethod
    def remove_comments(content: str) -> str:
        """Remove Lean comments from source code."""
        i = 0
        n = len(content)
        out = []
        block_nest = 0
        line_has_code = False

        while i < n:
            if block_nest == 0:
                if content.startswith("--", i):
                    j = content.find("\n", i + 2)
                    if not line_has_code:
                        while out and out[-1] in (" ", "\t"):
                            out.pop()
                        if j == -1:
                            i = n
                        else:
                            i = j + 1
                        line_has_code = False
                    else:
                        if j == -1:
                            i = n
                        else:
                            i = j + 1
                            out.append("\n")
                        line_has_code = False
                elif content.startswith("/-", i):
                    block_nest = 1
                    i += 2
                else:
                    ch = content[i]
                    out.append(ch)
                    if ch == "\n":
                        line_has_code = False
                    elif not ch.isspace():
                        line_has_code = True
                    i += 1
            else:
                if content.startswith("/-", i):
                    block_nest += 1
                    i += 2
                elif content.startswith("-/", i):
                    block_nest -= 1
                    i += 2
                else:
                    i += 1
        return "".join(out)


    async def hover(
        self,
        file_path: Path,
        line: int,
        content_snippet: str
    ) -> Dict[str, Any]:
        try:
            client = self._ensure_client(file_path)
            rel_path = self._get_relative_path(file_path)

            client.open_file(rel_path)
            file_content = client.get_file_content(rel_path)

            position = find_position_by_content(file_content, line, content_snippet)
            if not position:
                return {
                    "success": False,
                    "error": f"Cannot find '{content_snippet}' at line {line}"
                }

            line_idx, char_idx = position
            declarations = client.get_declarations(rel_path, line_idx, char_idx)

            if not declarations:
                return {"success": True, "definitions": []}

            definitions = []
            for decl in declarations:
                decl_uri = decl.get("targetUri") or decl.get("uri")
                decl_path = Path(client._uri_to_abs(decl_uri))

                try:
                    relative_path = decl_path.relative_to(self._workspace_root)
                except ValueError:
                    relative_path = decl_path

                decl_range = decl.get("targetRange") or decl.get("range")
                with open(decl_path, 'r') as f:
                    decl_content = f.read()

                content = extract_range(decl_content, decl_range)

                definitions.append({
                    "file_path": str(relative_path),
                    "line": decl_range["start"]["line"] + 1,
                    "character": decl_range["start"]["character"],
                    "content": content
                })

            return {
                "success": True,
                "definitions": definitions
            }

        except Exception as e:
            self.logger.error(f"Lean goto error: {e}")
            return {"success": False, "error": str(e)}

    async def diagnostics(self, file_path: Path) -> Dict[str, Any]:
        try:
            client = self._ensure_client(file_path)
            rel_path = self._get_relative_path(file_path)

            client.open_file(rel_path)
            file_content = client.get_file_content(rel_path)
            lines = file_content.splitlines()

            if line < 1 or line > len(lines):
                return {"success": False, "error": f"Line {line} out of range"}

            line_idx = line - 1
            line_context = lines[line_idx]

            if content_snippet:
                position = find_position_by_content(file_content, line, content_snippet)
                if not position:
                    return {
                        "success": False,
                        "error": f"Cannot find '{content_snippet}' at line {line}"
                    }

                _, char_idx = position
                goal_result = client.get_goal(rel_path, line_idx, char_idx)

                return {
                    "success": True,
                    "line_context": line_context,
                    "goals": extract_goals_list(goal_result)
                }
            else:
                # before/after
                col_start = next((i for i, c in enumerate(line_context) if not c.isspace()), 0)
                col_end = len(line_context)

                goal_before = client.get_goal(rel_path, line_idx, col_start)
                goal_after = client.get_goal(rel_path, line_idx, col_end)

                return {
                    "success": True,
                    "line_context": line_context,
                    "goals_before": extract_goals_list(goal_before),
                    "goals_after": extract_goals_list(goal_after)
                }

        except Exception as e:
            self.logger.error(f"Lean get_goal error: {e}")
            return {"success": False, "error": str(e)}


    def register_tools(self, mcp: FastMCP, enabled_tools: Set[str]) -> None:
                return await self.get_goal(Path(file_path), line, content)

    async def cleanup(self) -> None:
