import unittest

from scripts.proofstate_common import load_jsonl, normalize_entry
from scripts.run_experiments import evaluate_representation
from scripts.run_file_split import evaluate_group_split
from scripts.run_search import evaluate_search, evaluate_search_with_validation


class PipelineSmokeTests(unittest.TestCase):
    @classmethod
    def setUpClass(cls) -> None:
        cls.rows = [
            normalize_entry(row, include_optional=True)
            for row in load_jsonl("data/pilot_pairs_checked.jsonl")
        ]

    def test_classification_pipeline_on_pilot(self) -> None:
        result = evaluate_representation(self.rows, "state_only", test_ratio=0.3, seed=42)
        self.assertIn("tfidf_logistic_regression", result["baselines"])
        self.assertGreater(result["split"]["n_test"], 0)

    def test_search_proxy_on_pilot(self) -> None:
        result = evaluate_search(
            self.rows,
            strategy="family_guided",
            representation="normalized",
            test_ratio=0.3,
            seed=42,
            max_k=10,
        )
        self.assertIn("family_success_at_1", result["metrics"])
        self.assertGreater(result["n_train_candidates"], 0)

    def test_soft_search_proxy_on_pilot(self) -> None:
        result = evaluate_search(
            self.rows,
            strategy="family_soft",
            representation="normalized",
            test_ratio=0.3,
            seed=42,
            max_k=10,
            family_weight=0.1,
            family_model="naive_bayes",
        )
        self.assertEqual(result["family_weight"], 0.1)
        self.assertEqual(result["family_model"], "naive_bayes")
        self.assertIn("exact_tactic_success_at_5", result["metrics"])

    def test_validation_selects_soft_weight_on_pilot(self) -> None:
        result = evaluate_search_with_validation(
            self.rows,
            strategy="family_soft",
            representation="state_only",
            test_ratio=0.3,
            val_ratio=0.1,
            seed=42,
            max_k=10,
            family_weights=[0.0, 0.1],
            family_model="naive_bayes",
        )
        self.assertIn(result["family_weight"], {0.0, 0.1})
        self.assertIn("validation_selection", result)
        self.assertIn("exact_tactic_success_at_5", result["metrics"])

    def test_file_split_keeps_files_disjoint(self) -> None:
        rows = [
            {
                "file": file,
                "theorem": f"{file}_thm_{idx}",
                "step_index": idx,
                "main_goal": "a = a" if idx % 2 else "True",
                "local_context": ["a : Nat"],
                "next_tactic": "rfl" if idx % 2 else "trivial",
                "tactic_family": "rfl" if idx % 2 else "trivial",
            }
            for file in ["A.lean", "B.lean"]
            for idx in range(4)
        ]
        result = evaluate_group_split(rows, "state_only", "file", test_ratio=0.5, seed=42)
        self.assertEqual(result["split"]["group_field"], "file")
        self.assertGreaterEqual(result["split"]["n_test_groups"], 1)


if __name__ == "__main__":
    unittest.main()
