import json
import math
import random
import re
from collections import Counter
from pathlib import Path
from typing import Any, Iterable, Union


CANONICAL_FIELDS = [
    "file",
    "theorem",
    "step_index",
    "main_goal",
    "local_context",
    "next_tactic",
    "tactic_family",
]

OPTIONAL_FIELDS = [
    "state_before",
    "state_after",
    "annotated_tactic",
    "premises",
    "ast_summary",
    "source",
    "trace_id",
    "theorem_statement",
]

BASE_REQUIRED_FIELDS = {
    "theorem",
    "file",
    "step_index",
    "local_context",
    "next_tactic",
}

GOAL_FIELDS = ["main_goal", "goal", "state_main_goal"]

TACTIC_FAMILY_MAP = {
    "intro": "intro",
    "intros": "intro",
    "rintro": "rintro",
    "exact": "exact",
    "apply": "apply",
    "rw": "rw",
    "rewrite": "rw",
    "simp": "simp",
    "simp_all": "simp_all",
    "simp_rw": "simp_rw",
    "cases": "cases",
    "have": "have",
    "assumption": "assumption",
    "rfl": "rfl",
    "simpa": "simpa",
    "rwa": "rwa",
    "exact_mod_cast": "exact_mod_cast",
    "by_contra": "by_contra",
    "constructor": "constructor",
    "left": "left",
    "right": "right",
    "linarith": "linarith",
    "ring": "ring",
    "contradiction": "contradiction",
    "grind": "grind",
    "refine": "refine",
    "obtain": "obtain",
    "not_not_intro": "term",
    "irrational_sqrt_natCast_iff": "term",
    "irrational_sqrt_natCast_iff.mpr": "term",
    "term": "term",
}

REPRESENTATIONS = ["raw", "normalized", "structured", "state_only", "state_meta", "retrieved_premise"]
ORACLE_REPRESENTATIONS = ["oracle_premise"]
LEGACY_REPRESENTATIONS = {"premise": "oracle_premise"}
ALL_REPRESENTATIONS = [*REPRESENTATIONS, *ORACLE_REPRESENTATIONS, *LEGACY_REPRESENTATIONS]

TOKEN_RE = re.compile(r"[A-Za-z_][A-Za-z0-9_'.]*|[¬∧∨→↔=<>≤≥+*^/-]+|[0-9]+")
IDENT_RE = re.compile(r"\b[A-Za-z_][A-Za-z0-9_']*\b")
NUMBER_RE = re.compile(r"\b\d+\b")
WHITESPACE_RE = re.compile(r"\s+")


def load_jsonl(path: Union[Path, str]) -> list[dict[str, Any]]:
    path = Path(path)
    rows: list[dict[str, Any]] = []
    with path.open("r", encoding="utf-8") as f:
        for line_no, raw_line in enumerate(f, start=1):
            line = raw_line.strip()
            if not line:
                continue
            try:
                row = json.loads(line)
            except json.JSONDecodeError as exc:
                raise ValueError(f"Invalid JSON on line {line_no} of {path}: {exc}") from exc
            if not isinstance(row, dict):
                raise ValueError(f"Line {line_no} of {path} is not a JSON object.")
            rows.append(row)
    return rows


def write_jsonl(rows: Iterable[dict[str, Any]], path: Union[Path, str]) -> None:
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8") as f:
        for row in rows:
            f.write(json.dumps(row, ensure_ascii=False) + "\n")


def extract_tactic_family(tactic: str) -> str:
    tactic = tactic.strip()
    if not tactic:
        return "unknown"
    while tactic.startswith("·"):
        tactic = tactic[1:].lstrip()
    if tactic.startswith("case ") and "=>" in tactic:
        tactic = tactic.split("=>", 1)[1].strip()
    if not tactic:
        return "unknown"
    if ":=" not in tactic and "\n" not in tactic and tactic[0].isupper():
        return "term"
    head = tactic.split()[0].strip(";,")
    return TACTIC_FAMILY_MAP.get(head, head)


def parse_pretty_state(state: str) -> tuple[list[str], str]:
    """Split a Lean pretty-printed state into local context lines and main goal."""
    lines = [line.rstrip() for line in state.splitlines()]
    for idx, line in enumerate(lines):
        if "⊢" in line:
            before, after = line.split("⊢", 1)
            context = [x.strip() for x in lines[:idx] if x.strip()]
            if before.strip():
                context.append(before.strip())
            goal_lines = [after.strip()] + [x.strip() for x in lines[idx + 1 :] if x.strip()]
            return context, " ".join(x for x in goal_lines if x)
    stripped = state.strip()
    return [], stripped


def validate_entry(entry: dict[str, Any], line_no: int) -> list[str]:
    errors: list[str] = []

    missing = BASE_REQUIRED_FIELDS - entry.keys()
    if missing:
        errors.append(f"Line {line_no}: missing fields {sorted(missing)}")

    if not any(field in entry for field in GOAL_FIELDS):
        errors.append(f"Line {line_no}: missing goal field (expected one of {GOAL_FIELDS})")

    if "step_index" in entry and not isinstance(entry["step_index"], int):
        errors.append(f"Line {line_no}: step_index should be an int")

    if "local_context" in entry and not isinstance(entry["local_context"], list):
        errors.append(f"Line {line_no}: local_context should be a list")

    for field in GOAL_FIELDS:
        if field in entry and not isinstance(entry[field], str):
            errors.append(f"Line {line_no}: {field} should be a string")

    if "next_tactic" in entry and not isinstance(entry["next_tactic"], str):
        errors.append(f"Line {line_no}: next_tactic should be a string")

    return errors


def normalize_entry(entry: dict[str, Any], include_optional: bool = False) -> dict[str, Any]:
    normalized = dict(entry)

    if "main_goal" not in normalized:
        if "state_main_goal" in normalized:
            normalized["main_goal"] = normalized.pop("state_main_goal")
        elif "goal" in normalized:
            normalized["main_goal"] = normalized.pop("goal")
        elif "state_before" in normalized:
            context, goal = parse_pretty_state(str(normalized["state_before"]))
            normalized["local_context"] = normalized.get("local_context") or context
            normalized["main_goal"] = goal

    normalized["tactic_family"] = extract_tactic_family(normalized.get("next_tactic", ""))
    normalized["local_context"] = [str(x) for x in normalized.get("local_context", [])]

    keep: dict[str, Any] = {
        "file": str(normalized.get("file", "")),
        "theorem": str(normalized.get("theorem", "")),
        "step_index": int(normalized.get("step_index", 0)),
        "main_goal": str(normalized.get("main_goal", "")),
        "local_context": normalized.get("local_context", []),
        "next_tactic": str(normalized.get("next_tactic", "")),
        "tactic_family": str(normalized.get("tactic_family", "unknown")),
    }

    if include_optional:
        for field in OPTIONAL_FIELDS:
            if field in normalized:
                keep[field] = normalized[field]

    return keep


def theorem_split(
    rows: list[dict[str, Any]], test_ratio: float, seed: int
) -> tuple[list[dict[str, Any]], list[dict[str, Any]], dict[str, Any]]:
    theorem_names = sorted({row["theorem"] for row in rows})
    rng = random.Random(seed)
    rng.shuffle(theorem_names)
    n_test = max(1, int(round(len(theorem_names) * test_ratio)))
    test_theorems = set(theorem_names[:n_test])
    train = [row for row in rows if row["theorem"] not in test_theorems]
    test = [row for row in rows if row["theorem"] in test_theorems]
    metadata = {
        "strategy": "theorem-level random split",
        "seed": seed,
        "test_ratio": test_ratio,
        "n_rows": len(rows),
        "n_train": len(train),
        "n_test": len(test),
        "n_train_theorems": len({r["theorem"] for r in train}),
        "n_test_theorems": len({r["theorem"] for r in test}),
        "test_theorems": sorted(test_theorems),
    }
    return train, test, metadata


def grouped_split(
    rows: list[dict[str, Any]], group_field: str, test_ratio: float, seed: int
) -> tuple[list[dict[str, Any]], list[dict[str, Any]], dict[str, Any]]:
    groups = sorted({str(row.get(group_field, "")) for row in rows})
    rng = random.Random(seed)
    rng.shuffle(groups)
    n_test = max(1, int(round(len(groups) * test_ratio)))
    test_groups = set(groups[:n_test])
    train = [row for row in rows if str(row.get(group_field, "")) not in test_groups]
    test = [row for row in rows if str(row.get(group_field, "")) in test_groups]
    metadata = {
        "strategy": f"{group_field}-level random split",
        "group_field": group_field,
        "seed": seed,
        "test_ratio": test_ratio,
        "n_rows": len(rows),
        "n_train": len(train),
        "n_test": len(test),
        "n_train_groups": len({str(r.get(group_field, "")) for r in train}),
        "n_test_groups": len({str(r.get(group_field, "")) for r in test}),
        "test_groups": sorted(test_groups),
    }
    return train, test, metadata


def train_val_test_split(
    rows: list[dict[str, Any]],
    test_ratio: float,
    val_ratio: float,
    seed: int,
    group_field: str = "theorem",
) -> tuple[list[dict[str, Any]], list[dict[str, Any]], list[dict[str, Any]], dict[str, Any]]:
    groups = sorted({str(row.get(group_field, "")) for row in rows})
    rng = random.Random(seed)
    rng.shuffle(groups)
    n_test = max(1, int(round(len(groups) * test_ratio)))
    remaining = groups[n_test:]
    n_val = max(1, int(round(len(groups) * val_ratio))) if remaining else 0
    n_val = min(n_val, len(remaining))
    test_groups = set(groups[:n_test])
    val_groups = set(remaining[:n_val])
    train_groups = set(remaining[n_val:])
    train = [row for row in rows if str(row.get(group_field, "")) in train_groups]
    val = [row for row in rows if str(row.get(group_field, "")) in val_groups]
    test = [row for row in rows if str(row.get(group_field, "")) in test_groups]
    metadata = {
        "strategy": f"{group_field}-level train/validation/test split",
        "group_field": group_field,
        "seed": seed,
        "test_ratio": test_ratio,
        "val_ratio": val_ratio,
        "n_rows": len(rows),
        "n_train": len(train),
        "n_val": len(val),
        "n_test": len(test),
        "n_train_groups": len(train_groups),
        "n_val_groups": len(val_groups),
        "n_test_groups": len(test_groups),
        "val_groups": sorted(val_groups),
        "test_groups": sorted(test_groups),
    }
    return train, val, test, metadata


def accuracy(y_true: list[str], y_pred: list[str]) -> float:
    if not y_true:
        return 0.0
    return sum(1 for y, p in zip(y_true, y_pred) if y == p) / len(y_true)


def macro_f1(y_true: list[str], y_pred: list[str]) -> float:
    labels = sorted(set(y_true) | set(y_pred))
    if not labels:
        return 0.0
    scores: list[float] = []
    for label in labels:
        tp = sum(1 for y, p in zip(y_true, y_pred) if y == label and p == label)
        fp = sum(1 for y, p in zip(y_true, y_pred) if y != label and p == label)
        fn = sum(1 for y, p in zip(y_true, y_pred) if y == label and p != label)
        precision = tp / (tp + fp) if tp + fp else 0.0
        recall = tp / (tp + fn) if tp + fn else 0.0
        scores.append(2 * precision * recall / (precision + recall) if precision + recall else 0.0)
    return sum(scores) / len(scores)


def top_k_accuracy(y_true: list[str], ranked_predictions: list[list[str]], k: int) -> float:
    if not y_true:
        return 0.0
    correct = 0
    for label, ranked in zip(y_true, ranked_predictions):
        if label in ranked[:k]:
            correct += 1
    return correct / len(y_true)


def tokenize(text: str) -> list[str]:
    return [tok.lower() for tok in TOKEN_RE.findall(text)]


def raw_text(row: dict[str, Any]) -> str:
    context = "\n".join(str(x) for x in row.get("local_context", []))
    return f"GOAL {row.get('main_goal', '')}\nCONTEXT {context}"


def normalize_formula_text(text: str) -> str:
    text = NUMBER_RE.sub("<num>", text)
    text = re.sub(r"\b[a-z][A-Za-z0-9_']*(?=\s*:)", "<var>", text)
    text = WHITESPACE_RE.sub(" ", text)
    return text.strip().lower()


def normalized_text(row: dict[str, Any]) -> str:
    context = []
    for item in row.get("local_context", []):
        item_text = str(item)
        item_text = re.sub(r"^[^:]+:", "<hyp> :", item_text)
        context.append(normalize_formula_text(item_text))
    goal = normalize_formula_text(str(row.get("main_goal", "")))
    return f"goal {goal} context {' ; '.join(context)}"


def state_only_text(row: dict[str, Any]) -> str:
    return normalized_text(row)


def module_token(row: dict[str, Any]) -> str:
    path = str(row.get("file", ""))
    path = path.replace("\\", "/")
    if path.endswith(".lean"):
        path = path[:-5]
    return normalize_formula_text(path.replace("/", "."))


def state_meta_text(row: dict[str, Any]) -> str:
    theorem = normalize_formula_text(str(row.get("theorem", "")))
    module = module_token(row)
    return f"{state_only_text(row)} theorem={theorem} module={module}"


def goal_shape(goal: str) -> str:
    goal = goal.strip()
    if "↔" in goal:
        return "iff"
    if "→" in goal:
        return "implication"
    if goal.startswith("∀") or "∀" in goal:
        return "forall"
    if goal.startswith("∃") or "∃" in goal:
        return "exists"
    if "∧" in goal:
        return "and"
    if "∨" in goal:
        return "or"
    if "≤" in goal or "<" in goal or "≥" in goal or ">" in goal:
        return "inequality"
    if "=" in goal:
        return "equality"
    if goal == "False":
        return "false"
    if goal == "True":
        return "true"
    return "other"


def structured_text(row: dict[str, Any]) -> str:
    goal = str(row.get("main_goal", ""))
    context = [str(x) for x in row.get("local_context", [])]
    full_text = raw_text(row)
    symbol_counts = {
        "forall": full_text.count("∀"),
        "exists": full_text.count("∃"),
        "imp": full_text.count("→"),
        "iff": full_text.count("↔"),
        "and": full_text.count("∧"),
        "or": full_text.count("∨"),
        "not": full_text.count("¬"),
        "eq": full_text.count("="),
        "le": full_text.count("≤"),
        "lt": full_text.count("<"),
        "plus": full_text.count("+"),
        "mul": full_text.count("*"),
        "pow": full_text.count("^"),
    }
    identifiers = [tok for tok in tokenize(goal) if IDENT_RE.fullmatch(tok)]
    head = identifiers[0] if identifiers else "none"
    tokens = [
        f"shape={goal_shape(goal)}",
        f"hyp_count={len(context)}",
        f"goal_tokens={len(tokenize(goal))}",
        f"context_tokens={sum(len(tokenize(x)) for x in context)}",
        f"head={head}",
    ]
    tokens.extend(f"{key}={value}" for key, value in sorted(symbol_counts.items()))
    return " ".join(tokens)


def premise_names(row: dict[str, Any]) -> list[str]:
    premises = row.get("premises") or []
    names: list[str] = []
    if isinstance(premises, list):
        for premise in premises[:16]:
            if isinstance(premise, dict):
                name = premise.get("full_name") or premise.get("name") or premise.get("def_path")
                if name:
                    names.append(str(name))
            elif premise:
                names.append(str(premise))
    return names


def retrieved_premise_text(row: dict[str, Any]) -> str:
    base = state_only_text(row)
    names = row.get("retrieved_premises") or []
    premise_tokens: list[str] = []
    if isinstance(names, list):
        for name in names[:16]:
            if isinstance(name, dict):
                value = name.get("full_name") or name.get("name")
            else:
                value = name
            if value:
                premise_tokens.append(f"retrieved_premise={normalize_formula_text(str(value))}")
    if not premise_tokens:
        premise_tokens.append("retrieved_premise_count=0")
    return " ".join([base, *premise_tokens])


def oracle_premise_text(row: dict[str, Any]) -> str:
    # Oracle-only ablation: these premise and tactic-AST fields come from the
    # next tactic, so this representation must not be used as a deployable input.
    base = normalized_text(row)
    premise_tokens = [f"oracle_premise={normalize_formula_text(name)}" for name in premise_names(row)[:16]]

    ast_summary = row.get("ast_summary") or {}
    ast_tokens: list[str] = []
    if isinstance(ast_summary, dict):
        for key, value in sorted(ast_summary.items()):
            ast_tokens.append(f"oracle_ast_{key}={normalize_formula_text(str(value))}")
    elif ast_summary:
        ast_tokens.append(f"oracle_ast={normalize_formula_text(str(ast_summary))}")

    if not premise_tokens:
        premise_tokens.append("oracle_premise_count=0")
    if not ast_tokens:
        ast_tokens.append("oracle_ast=none")
    return " ".join([base, *premise_tokens, *ast_tokens])


def premise_text(row: dict[str, Any]) -> str:
    return oracle_premise_text(row)


def representation_text(row: dict[str, Any], representation: str) -> str:
    representation = LEGACY_REPRESENTATIONS.get(representation, representation)
    if representation == "raw":
        return raw_text(row)
    if representation == "normalized":
        return normalized_text(row)
    if representation == "structured":
        return structured_text(row)
    if representation == "state_only":
        return state_only_text(row)
    if representation == "state_meta":
        return state_meta_text(row)
    if representation == "retrieved_premise":
        return retrieved_premise_text(row)
    if representation == "oracle_premise":
        return oracle_premise_text(row)
    raise ValueError(f"Unknown representation: {representation}")


def row_signature(row: dict[str, Any]) -> tuple[str, str, int]:
    return (str(row.get("file", "")), str(row.get("theorem", "")), int(row.get("step_index", 0)))


def with_retrieved_premises(
    train_rows: list[dict[str, Any]],
    target_rows: list[dict[str, Any]],
    top_n: int = 8,
    source_representation: str = "state_only",
    exclude_self: bool = False,
) -> list[dict[str, Any]]:
    """Attach training-derived premise names without reading target gold premises."""
    if not target_rows:
        return []
    if not train_rows:
        return [dict(row, retrieved_premises=[]) for row in target_rows]

    try:
        from sklearn.feature_extraction.text import TfidfVectorizer
        from sklearn.metrics.pairwise import cosine_similarity
    except ImportError as exc:
        raise SystemExit(
            "scikit-learn is required for retrieved_premise representation. "
            "Install dependencies with `pip install -r requirements.txt`."
        ) from exc

    train_texts = [representation_text(row, source_representation) for row in train_rows]
    target_texts = [representation_text(row, source_representation) for row in target_rows]
    vectorizer = TfidfVectorizer(tokenizer=str.split, token_pattern=None, lowercase=False, min_df=1)
    train_matrix = vectorizer.fit_transform(train_texts)
    target_matrix = vectorizer.transform(target_texts)
    similarities = cosine_similarity(target_matrix, train_matrix)
    train_signatures = [row_signature(row) for row in train_rows]

    output: list[dict[str, Any]] = []
    for row_idx, row in enumerate(target_rows):
        own_signature = row_signature(row)
        selected: list[str] = []
        seen: set[str] = set()
        ranked_indices = sorted(
            range(len(train_rows)),
            key=lambda idx: float(similarities[row_idx, idx]),
            reverse=True,
        )
        for train_idx in ranked_indices:
            if exclude_self and train_signatures[train_idx] == own_signature:
                continue
            for name in premise_names(train_rows[train_idx]):
                if name not in seen:
                    seen.add(name)
                    selected.append(name)
                    if len(selected) >= top_n:
                        break
            if len(selected) >= top_n:
                break
        output.append(dict(row, retrieved_premises=selected))
    return output


def prepare_representation_rows(
    train_rows: list[dict[str, Any]],
    test_rows: list[dict[str, Any]],
    representation: str,
    retrieved_premise_top_n: int = 8,
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
    representation = LEGACY_REPRESENTATIONS.get(representation, representation)
    if representation != "retrieved_premise":
        return train_rows, test_rows
    prepared_train = with_retrieved_premises(
        train_rows,
        train_rows,
        top_n=retrieved_premise_top_n,
        exclude_self=True,
    )
    prepared_test = with_retrieved_premises(
        train_rows,
        test_rows,
        top_n=retrieved_premise_top_n,
        exclude_self=False,
    )
    return prepared_train, prepared_test


def heuristic_predict(row: dict[str, Any], fallback_label: str) -> str:
    text = raw_text(row).lower()
    goal = str(row.get("main_goal", "")).lower()

    if "false" in text:
        return "contradiction"
    if goal == "true":
        return "trivial"
    if "∃" in goal and " = " in goal:
        return "use"
    if re.fullmatch(r"[0-9+* ^()=<>≤≥\\s-]+", goal):
        return "norm_num"
    if any(" : " in str(ctx) and " = " in str(ctx) for ctx in row.get("local_context", [])):
        return "subst"
    if "^" in goal and "=" in goal:
        return "ring"
    if ("≤" in goal or "<" in goal) and any(op in goal for op in ["+", "-", "*"]):
        return "linarith"
    if "↔" in goal or " true" in text:
        return "simp"
    if "∧" in goal and "→" not in goal:
        return "constructor"
    if "∧" in goal and "→" in goal:
        return "intro"
    if "∀" in goal or "→" in goal:
        return "intro"
    if "∃" in goal or "isroot" in text:
        return "apply"
    if "=" in goal and ("+" in goal or "*" in goal):
        return "rw"
    if "∨" in goal:
        return "cases"
    return fallback_label


def ranked_labels_from_scores(classes: list[str], scores: list[float]) -> list[str]:
    return [label for label, _ in sorted(zip(classes, scores), key=lambda item: item[1], reverse=True)]


def class_frequency_ranking(labels: Iterable[str]) -> list[str]:
    return [label for label, _ in Counter(labels).most_common()]


def safe_float(value: float) -> float:
    if math.isnan(value) or math.isinf(value):
        return 0.0
    return float(value)
