from __future__ import annotations

import re
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, Tuple

from .config import LEAN_ROOT


@dataclass
class SplitResult:
    split: bool
    base_file: Path
    part_files: list[Path]
    line_count: int
    parts_created: int
    reason: str | None = None


_DECL_START_RE = re.compile(
    r"^\s*(/--|theorem|lemma|def|instance|structure|class|example|axiom|inductive|notation|abbrev)\b"
)
_HEADER_LINE_RE = re.compile(
    r"^\s*(import\b|open\b|set_option\b|attribute\b|universe\b|universes\b|variable\b|--\b)"
)
_SECTION_RE = re.compile(r"^\s*section\b")
_NAMESPACE_RE = re.compile(r"^\s*namespace\b")
_END_RE = re.compile(r"^\s*end\b")
_SECTION_FILE_RE = re.compile(r"^section(\d+)(?:_part(\d+))?\.lean$")


def _path_to_module(path: Path) -> str:
    rel = path
    try:
        rel = path.relative_to(LEAN_ROOT)
    except ValueError:
        pass
    return ".".join(rel.with_suffix("").parts)


def _extract_header(lines: list[str]) -> Tuple[list[str], list[str]]:
    header: list[str] = []
    body: list[str] = []
    in_header = True
    for line in lines:
        if in_header and (not line.strip() or _HEADER_LINE_RE.match(line)):
            header.append(line)
            continue
        in_header = False
        body.append(line)
    return header, body


def _update_stack(line: str, stack: list[str]) -> None:
    if _SECTION_RE.match(line) or _NAMESPACE_RE.match(line):
        stack.append(line.rstrip())
        return
    if _END_RE.match(line):
        if stack:
            stack.pop()


def _iter_part_candidates(body_lines: list[str]) -> Iterable[int]:
    for i, line in enumerate(body_lines):
        if _DECL_START_RE.match(line):
            yield i


def _align_split_start(body_lines: list[str], split_index: int, part_start: int) -> int:
    j = split_index - 1
    while j >= part_start and body_lines[j].strip() == "":
        j -= 1
    if j >= part_start and body_lines[j].strip().endswith("-/"):
        k = j
        while k >= part_start and "/--" not in body_lines[k]:
            k -= 1
        if k >= part_start:
            return k
    return split_index


def _split_body(
    body_lines: list[str],
    *,
    header_len: int,
    max_lines: int,
    min_lines: int,
) -> list[tuple[int, int, list[str], list[str]]]:
    parts: list[tuple[int, int, list[str], list[str]]] = []
    candidates = set(_iter_part_candidates(body_lines))
    current_stack: list[str] = []
    part_start = 0
    stack_at_part_start: list[str] = []

    for i, line in enumerate(body_lines):
        current_len = header_len + len(stack_at_part_start) + (i - part_start)
        should_split = current_len >= max_lines and i in candidates and (i - part_start) >= min_lines
        if should_split:
            split_start = _align_split_start(body_lines, i, part_start)
            if split_start > part_start and (split_start - part_start) >= min_lines:
                parts.append((part_start, split_start, stack_at_part_start, current_stack.copy()))
                part_start = split_start
            else:
                parts.append((part_start, i, stack_at_part_start, current_stack.copy()))
                part_start = i
            stack_at_part_start = current_stack.copy()
        _update_stack(line, current_stack)

    parts.append((part_start, len(body_lines), stack_at_part_start, current_stack.copy()))
    return parts


def _render_part(
    header: list[str],
    body_lines: list[str],
    *,
    extra_imports: list[str],
    start: int,
    end: int,
    stack_start: list[str],
    stack_end: list[str],
) -> list[str]:
    lines: list[str] = []
    lines.extend(_normalize_header(header, extra_imports))
    if stack_start:
        lines.append("")
        lines.extend(stack_start)
    lines.extend(body_lines[start:end])
    if stack_end:
        lines.append("")
        lines.extend(["end" for _ in stack_end])
    return lines


def _normalize_header(header: list[str], extra_imports: list[str]) -> list[str]:
    import_lines: list[str] = []
    other_lines: list[str] = []
    for line in header:
        if line.strip().startswith("import "):
            import_lines.append(line.strip())
        else:
            other_lines.append(line)

    if extra_imports:
        import_lines.extend(extra_imports)

    seen: set[str] = set()
    normalized_imports: list[str] = []

    has_mathlib = any(line == "import Mathlib" for line in import_lines)
    if not has_mathlib:
        normalized_imports.append("import Mathlib")
        seen.add("import Mathlib")

    for imp in import_lines:
        if imp not in seen:
            normalized_imports.append(imp)
            seen.add(imp)

    lines: list[str] = []
    lines.extend(normalized_imports)

    if other_lines:
        if lines and other_lines[0].strip() != "":
            lines.append("")
        lines.extend(other_lines)

    return lines


def _list_part_files(base_file: Path) -> list[Path]:
    match = _SECTION_FILE_RE.match(base_file.name)
    if not match:
        return []
    section_num = match.group(1)
    prefix = f"section{section_num}_part"
    parts = []
    for path in base_file.parent.glob(f"{prefix}*.lean"):
        m = _SECTION_FILE_RE.match(path.name)
        if m and m.group(2):
            parts.append((int(m.group(2)), path))
    return [path for _, path in sorted(parts, key=lambda x: x[0])]


def _shift_existing_parts(base_file: Path, *, start_part: int, offset: int) -> None:
    if offset <= 0:
        return
    match = _SECTION_FILE_RE.match(base_file.name)
    if not match:
        return
    section_num = match.group(1)
    existing = []
    for path in base_file.parent.glob(f"section{section_num}_part*.lean"):
        m = _SECTION_FILE_RE.match(path.name)
        if not m or not m.group(2):
            continue
        part_num = int(m.group(2))
        if part_num >= start_part:
            existing.append((part_num, path))
    for part_num, path in sorted(existing, reverse=True):
        target = path.with_name(f"section{section_num}_part{part_num + offset}.lean")
        path.rename(target)


def write_aggregate_imports(base_file: Path, part_files: list[Path]) -> None:
    imports = [f"import {_path_to_module(part)}" for part in part_files]
    base_file.write_text("\n".join(imports) + "\n", encoding="utf-8")


def split_section_file(lean_file: Path, *, max_lines: int = 1000) -> SplitResult:
    lines = lean_file.read_text(encoding="utf-8").splitlines()
    line_count = len(lines)
    if line_count <= max_lines:
        return SplitResult(
            split=False,
            base_file=lean_file,
            part_files=[],
            line_count=line_count,
            parts_created=0,
            reason="below_threshold",
        )

    match = _SECTION_FILE_RE.match(lean_file.name)
    if not match:
        return SplitResult(
            split=False,
            base_file=lean_file,
            part_files=[],
            line_count=line_count,
            parts_created=0,
            reason="unmatched_filename",
        )

    header, body = _extract_header(lines)
    if not body:
        return SplitResult(
            split=False,
            base_file=lean_file,
            part_files=[],
            line_count=line_count,
            parts_created=0,
            reason="empty_body",
        )

    parts = _split_body(
        body,
        header_len=len(header),
        max_lines=max_lines,
        min_lines=max(200, int(max_lines * 0.6)),
    )

    section_num = match.group(1)
    part_num = int(match.group(2)) if match.group(2) else 1
    base_file_path = lean_file.with_name(f"section{section_num}.lean")

    if match.group(2):
        _shift_existing_parts(base_file_path, start_part=part_num + 1, offset=len(parts) - 1)

    part_files: list[Path] = []
    for idx, (start, end, stack_start, stack_end) in enumerate(parts):
        current_part_num = part_num + idx
        part_path = lean_file.with_name(f"section{section_num}_part{current_part_num}.lean")
        module_base = _path_to_module(base_file_path)
        extra_imports = [
            f"import {module_base}_part{i}"
            for i in range(1, current_part_num)
        ]
        part_lines = _render_part(
            header,
            body,
            extra_imports=extra_imports,
            start=start,
            end=end,
            stack_start=stack_start,
            stack_end=stack_end,
        )
        part_path.write_text("\n".join(part_lines) + "\n", encoding="utf-8")
        part_files.append(part_path)

    all_parts = _list_part_files(base_file_path)
    if not all_parts:
        all_parts = part_files
    write_aggregate_imports(base_file_path, all_parts)

    return SplitResult(
        split=True,
        base_file=base_file_path,
        part_files=all_parts,
        line_count=line_count,
        parts_created=len(parts),
        reason="split",
    )
