"""
AST-based syntax analysis.
"""
import hashlib
import tokenize
from io import StringIO
from typing import List, Set

import tree_sitter_python
from tree_sitter import Language, Parser

# Global parser
_LANG = Language(tree_sitter_python.language())
_PARSER = Parser(_LANG)


def _remove_comments(source: str) -> str:
    """Remove comments and docstrings from Python code."""
    out = ""
    prev_type = tokenize.INDENT
    last_lineno = -1
    last_col = 0

    try:
        for tok in tokenize.generate_tokens(StringIO(source).readline):
            ttype, tstring, (sline, scol), (_, ecol), _ = tok

            if sline > last_lineno:
                last_col = 0
            if scol > last_col:
                out += " " * (scol - last_col)

            if ttype == tokenize.COMMENT:
                pass
            elif ttype == tokenize.STRING:
                if prev_type not in (tokenize.INDENT, tokenize.NEWLINE) or scol > 0:
                    out += tstring
            else:
                out += tstring

            prev_type = ttype
            last_col = ecol
            last_lineno = sline
    except Exception:
        return source

    return "\n".join(line for line in out.split("\n") if line.strip())


def _hash(s: str) -> str:
    """MD5 hash of string."""
    return hashlib.md5(s.encode()).hexdigest()


def _fingerprint(code: str) -> List[int]:
    """Generate AST fingerprint using post-order traversal."""
    try:
        code = _remove_comments(code)
    except Exception:
        pass

    root = _PARSER.parse(code.encode()).root_node
    hashes = []

    def traverse(node):
        if node.children:
            fp = "0"
            for child in node.children:
                fp = _hash(fp + str(traverse(child)))
        else:
            fp = _hash("0" + _hash(node.type))
        hashes.append(int(fp, 16))
        return fp

    traverse(root)
    return hashes


def syntax_distance(code1: str, code2: str) -> float:
    """
    AST-based Jaccard distance.

    Args:
        code1: First code
        code2: Second code

    Returns:
        Distance in [0, 1]
    """
    s1 = set(_fingerprint(code1))
    s2 = set(_fingerprint(code2))

    if not s1 and not s2:
        return 0.0
    if not s1 or not s2:
        return 1.0

    return 1.0 - len(s1 & s2) / len(s1 | s2)


def ast_coverage(codes: List[str], k: int) -> List[int]:
    """
    Greedy selection maximizing AST subtree coverage.

    Args:
        codes: List of code strings
        k: Number to select

    Returns:
        Selected indices
    """
    n = len(codes)
    if k >= n:
        return list(range(n))

    # Compute fingerprints
    fps = {code: set(_fingerprint(code)) for code in codes}
    all_fps: Set[int] = set()
    for fp in fps.values():
        all_fps.update(fp)

    remaining = set(all_fps)
    available = list(codes)
    selected = []

    for _ in range(k):
        best_code = None
        best_score = -1

        for code in available:
            score = len(fps[code] & remaining)
            if score > best_score:
                best_score = score
                best_code = code

        if best_code is None:
            break

        selected.append(best_code)
        remaining -= fps[best_code]
        available.remove(best_code)

    return [codes.index(c) for c in selected]
