import unittest

from scripts.proofstate_common import (
    accuracy,
    extract_tactic_family,
    macro_f1,
    normalize_entry,
    parse_pretty_state,
    prepare_representation_rows,
    representation_text,
    theorem_split,
    top_k_accuracy,
)


class ProofStateCommonTests(unittest.TestCase):
    def test_extract_tactic_family(self) -> None:
        self.assertEqual(extract_tactic_family("intro h"), "intro")
        self.assertEqual(extract_tactic_family("rw [Nat.add_zero]"), "rw")
        self.assertEqual(extract_tactic_family("· rintro ⟨x, hx⟩\n  exact hx"), "rintro")
        self.assertEqual(extract_tactic_family("Nat.prime_two.irrational_sqrt"), "term")

    def test_parse_pretty_state(self) -> None:
        context, goal = parse_pretty_state("a b : Nat\nh : a = b\n⊢ b = a")
        self.assertEqual(context, ["a b : Nat", "h : a = b"])
        self.assertEqual(goal, "b = a")

    def test_normalize_entry(self) -> None:
        row = normalize_entry(
            {
                "file": "A.lean",
                "theorem": "foo",
                "step_index": 0,
                "goal": "a = a",
                "local_context": ["a : Nat"],
                "next_tactic": "rfl",
            }
        )
        self.assertEqual(row["main_goal"], "a = a")
        self.assertEqual(row["tactic_family"], "rfl")

    def test_representations_are_nonempty(self) -> None:
        row = {
            "file": "A.lean",
            "theorem": "foo",
            "step_index": 0,
            "main_goal": "a = a",
            "local_context": ["a : Nat"],
            "next_tactic": "rfl",
            "tactic_family": "rfl",
        }
        for representation in ["raw", "normalized", "structured", "state_only", "state_meta", "retrieved_premise", "oracle_premise"]:
            self.assertTrue(representation_text(row, representation))

    def test_state_only_does_not_include_future_tactic_fields(self) -> None:
        row = {
            "file": "A.lean",
            "theorem": "foo",
            "step_index": 0,
            "main_goal": "a = a",
            "local_context": ["a : Nat"],
            "next_tactic": "rw [Secret.futureLemma]",
            "tactic_family": "rw",
            "state_after": "no goals",
            "annotated_tactic": "rw [<a>Secret.futureLemma</a>]",
            "premises": [{"full_name": "Secret.futureLemma"}],
            "ast_summary": {"tactic_ast": "SecretTactic"},
        }
        text = representation_text(row, "state_only")
        self.assertNotIn("Secret", text)
        self.assertNotIn("futureLemma", text)
        self.assertNotIn("rw", text)
        self.assertNotIn("no goals", text)

    def test_retrieved_premise_uses_training_premises_only(self) -> None:
        train = [
            {
                "file": "A.lean",
                "theorem": "train",
                "step_index": 0,
                "main_goal": "a = a",
                "local_context": ["a : Nat"],
                "next_tactic": "rw [Train.allowed]",
                "tactic_family": "rw",
                "premises": [{"full_name": "Train.allowed"}],
            }
        ]
        test = [
            {
                "file": "B.lean",
                "theorem": "test",
                "step_index": 0,
                "main_goal": "a = a",
                "local_context": ["a : Nat"],
                "next_tactic": "rw [Gold.forbidden]",
                "tactic_family": "rw",
                "premises": [{"full_name": "Gold.forbidden"}],
            }
        ]
        _, prepared_test = prepare_representation_rows(train, test, "retrieved_premise")
        text = representation_text(prepared_test[0], "retrieved_premise")
        self.assertIn("train.allowed", text)
        self.assertNotIn("gold.forbidden", text)

    def test_split_reproducible(self) -> None:
        rows = [
            {
                "theorem": f"thm_{i // 2}",
                "file": "A.lean",
                "step_index": i,
                "main_goal": "True",
                "local_context": [],
                "next_tactic": "trivial",
                "tactic_family": "trivial",
            }
            for i in range(10)
        ]
        train_a, test_a, meta_a = theorem_split(rows, 0.4, 7)
        train_b, test_b, meta_b = theorem_split(rows, 0.4, 7)
        self.assertEqual([row["theorem"] for row in train_a], [row["theorem"] for row in train_b])
        self.assertEqual([row["theorem"] for row in test_a], [row["theorem"] for row in test_b])
        self.assertEqual(meta_a["test_theorems"], meta_b["test_theorems"])

    def test_metrics(self) -> None:
        y_true = ["a", "b", "b"]
        y_pred = ["a", "a", "b"]
        self.assertAlmostEqual(accuracy(y_true, y_pred), 2 / 3)
        self.assertGreater(macro_f1(y_true, y_pred), 0.0)
        self.assertAlmostEqual(top_k_accuracy(y_true, [["a"], ["a", "b"], ["b"]], 2), 1.0)


if __name__ == "__main__":
    unittest.main()
