import subprocess
import threading
import sys
import json
import os
import re
import ast
import functools
import multiprocessing
import time
from pathlib import Path
from collections import OrderedDict
def get_method_path_from_definitions(method_name, definitions):
    for class_name in definitions.keys():
        if method_name in definitions[class_name]:
            return definitions[class_name][method_name]
    return None, None
def get_symbol_definition_from_file(filepath: str, symbol: str):
    with open(filepath, 'r', encoding='utf-8') as f:
        source_lines = f.readlines()
        source_text = ''.join(source_lines)
    try:
        tree = ast.parse(source_text, filename=filepath)
    except SyntaxError:
        return None
    for node in ast.walk(tree):
        if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)) and node.name == symbol:
            start = node.lineno - 1
            end = getattr(node, 'end_lineno', None)
            if end is None:
                return source_lines[start]
            return ''.join(source_lines[start:end])
    return None
def extract_exact_definition(filepath, line_number):
    with open(filepath, 'r', encoding='utf-8') as f:
        source = f.read()
    lines = source.splitlines()
    tree = ast.parse(source)
    def match_node(node):
        return hasattr(node, "lineno") and hasattr(node, "end_lineno") and node.lineno == line_number
    for node in ast.walk(tree):
        if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)) and match_node(node):
            return "\n".join(lines[node.lineno - 1: node.end_lineno])
    return None
def send_msg(writer, msg):
    body = json.dumps(msg)
    header = f"Content-Length: {len(body)}\r\n\r\n"
    writer.write(header + body)
    writer.flush()
def read_msg(reader):
    headers = {}
    while True:
        line = reader.readline()
        if not line:
            raise EOFError("Unexpected EOF while reading headers")
        if line.strip() == "":
            break
        if ":" not in line:
            continue
        name, value = line.split(":", 1)
        headers[name.strip()] = value.strip()
    content_length = int(headers.get("Content-Length", 0))
    if content_length == 0:
        raise ValueError("Content-Length header missing or zero")
    content = reader.read(content_length)
    return json.loads(content)
def get_symbol_definition(file_path, symbol):
    if not os.path.exists(file_path):
        return None
    with open(file_path, "r", encoding="utf-8") as f:
        source_lines = f.readlines()
        source_text = ''.join(source_lines)
    try:
        tree = ast.parse(source_text, filename=file_path)
    except Exception:
        return None
    for node in ast.walk(tree):
        if isinstance(node, (ast.FunctionDef, ast.ClassDef)) and node.name == symbol:
            start = node.lineno - 1
            end = getattr(node, 'end_lineno', None) or node.lineno
            return ''.join(source_lines[start:end])
        elif isinstance(node, ast.Assign):
            for target in node.targets:
                if isinstance(target, ast.Name) and target.id == symbol:
                    start = node.lineno - 1
                    end = getattr(node, 'end_lineno', None) or node.lineno
                    return ''.join(source_lines[start:end])
        elif isinstance(node, ast.AnnAssign):
            if isinstance(node.target, ast.Name) and node.target.id == symbol:
                start = node.lineno - 1
                end = getattr(node, 'end_lineno', None) or node.lineno
                return ''.join(source_lines[start:end])
    return None
class LSPConnection:
    def __init__(self, project_dir, lsp_exe="pyright-langserver"):
        self.project_dir = project_dir
        self.lsp_exe = lsp_exe
        self.proc = self._start_proc(project_dir)
        self.writer = self.proc.stdin
        self.reader = self.proc.stdout
        self._lock = threading.Lock()
        self._responses = {}
        self._req_id = 0
        self._cond = threading.Condition(self._lock)
        self._reader_thread = threading.Thread(target=self._reader_loop, daemon=True)
        self._reader_thread.start()
    def _start_proc(self, project_dir):
        env = os.environ.copy()
        env["PYTHONPATH"] = project_dir
        proc = subprocess.Popen(
            [self.lsp_exe, "--stdio"],
            stdin=subprocess.PIPE,
            stdout=subprocess.PIPE,
            stderr=subprocess.DEVNULL,
            bufsize=0,
            text=True,
            env=env,
            cwd=project_dir
        )
        send_msg(proc.stdin, {
            "jsonrpc": "2.0","id": 0,"method":"initialize",
            "params": {"processId": None,"rootUri": Path(project_dir).as_uri(),"capabilities": {}}
        })
        read_msg(proc.stdout)
        send_msg(proc.stdin, {"jsonrpc": "2.0","method":"initialized","params": {}})
        return proc
    def _reader_loop(self):
        try:
            while True:
                resp = read_msg(self.reader)
                if "id" in resp:
                    req_id = resp["id"]
                    with self._cond:
                        self._responses[req_id] = resp
                        self._cond.notify_all()
                else:
                    pass
        except Exception as e:
            print(f"[LSPConnection] reader loop stopped: {e}")
    def send_request(self, method, params):
        with self._lock:
            self._req_id += 1
            request_id = self._req_id
        send_msg(self.writer, {"jsonrpc": "2.0","id":request_id,
                               "method": method, "params": params})
        return request_id
    def wait_response(self, request_id, timeout=5):
        deadline = time.time() + timeout
        with self._cond:
            while request_id not in self._responses:
                remaining = deadline - time.time()
                if remaining <= 0:
                    raise TimeoutError(f"Timeout waiting for response id={request_id}")
                self._cond.wait(timeout=remaining)
            return self._responses.pop(request_id)
    def shutdown(self):
        self.proc.kill()
class LSPManager:
    def __init__(self, max_procs=4, lsp_exe="pyright-langserver"):
        self.max_procs = max_procs
        self.lsp_exe = lsp_exe
        self.conns = OrderedDict()  
        self.lock = threading.Lock()
    def get_conn(self, project_dir):
        project_dir = os.path.abspath(project_dir)
        with self.lock:
            if project_dir in self.conns:
                self.conns.move_to_end(project_dir)
                return self.conns[project_dir]
            conn = LSPConnection(project_dir, self.lsp_exe)
            self.conns[project_dir] = conn
            self.conns.move_to_end(project_dir)
            if len(self.conns) > self.max_procs:
                old_proj, old_conn = self.conns.popitem(last=False)
                old_conn.shutdown()
            return conn
    def shutdown_all(self):
        with self.lock:
            for _, conn in self.conns.items():
                conn.shutdown()
            self.conns.clear()
def call_lsp(manager, project_dir, file_path, symbol, timeout=5):
    conn = manager.get_conn(project_dir)
    uri = Path(file_path).absolute().as_uri()
    with open(file_path, encoding='utf-8') as f:
        content = f.read()
    conn.send_request("textDocument/didOpen", {
        "textDocument": {
            "uri": uri,
            "languageId": "python",
            "version": 1,
            "text": content
        }
    })
    lines = content.splitlines()
    line_num, char_num = None, None
    for i, line in enumerate(lines):
        match = re.search(rf"\b{re.escape(symbol)}\b", line)
        if match:
            line_num, char_num = i, match.start()
            break
    if line_num is None:
        return None, None
    req_id = conn.send_request("textDocument/definition", {
        "textDocument": {"uri": uri},
        "position": {"line": line_num, "character": char_num}
    })
    resp = conn.wait_response(req_id, timeout=timeout)
    result = resp.get("result", [])
    if not result:
        return None, None
    loc = result[0]
    target_uri = loc["uri"].replace("file://", "")
    target_line = loc["range"]["start"]["line"] + 1
    code = extract_exact_definition(target_uri, target_line)
    return (code, target_uri)