import unittest

from scripts.analyze_accepted_alternatives import accepted_not_gold_examples, summarize_alternatives


class AcceptedAlternativeTests(unittest.TestCase):
    def test_accepted_gold_is_not_accepted_not_gold(self) -> None:
        rows = summarize_alternatives(
            [
                {
                    "strategy": "unguided",
                    "query_id": "q1",
                    "rank": 1,
                    "accepted": True,
                    "gold_tactic": "rfl",
                    "candidate_tactic": "rfl",
                }
            ],
            5,
        )
        self.assertEqual(rows[0]["queries"], 1)
        self.assertAlmostEqual(rows[0]["proxy_exact_at_5_on_sample"], 1.0)
        self.assertAlmostEqual(rows[0]["lean_accept_at_5"], 1.0)
        self.assertAlmostEqual(rows[0]["accepted_not_gold_at_5"], 0.0)

    def test_accepted_non_gold_and_exact_without_accept_are_separate(self) -> None:
        rows = summarize_alternatives(
            [
                {
                    "strategy": "unguided",
                    "query_id": "q1",
                    "rank": 1,
                    "accepted": True,
                    "gold_tactic": "rw [h]",
                    "candidate_tactic": "simp",
                },
                {
                    "strategy": "unguided",
                    "query_id": "q2",
                    "rank": 1,
                    "accepted": False,
                    "gold_tactic": "rfl",
                    "candidate_tactic": "rfl",
                },
            ],
            5,
        )
        self.assertEqual(rows[0]["queries"], 2)
        self.assertAlmostEqual(rows[0]["accepted_not_gold_at_5"], 0.5)
        self.assertAlmostEqual(rows[0]["accept_without_exact_at_5"], 0.5)
        self.assertAlmostEqual(rows[0]["exact_without_accept_at_5"], 0.5)

    def test_examples_include_required_fields(self) -> None:
        examples = accepted_not_gold_examples(
            [
                {
                    "strategy": "unguided",
                    "query_id": "q1",
                    "rank": 2,
                    "accepted": True,
                    "gold_family": "rw",
                    "candidate_family": "simp",
                    "gold_tactic": "rw [h]",
                    "candidate_tactic": "simp",
                }
            ],
            5,
            8,
        )
        self.assertEqual(examples[0]["query_id"], "q1")
        self.assertEqual(examples[0]["accepted_candidate"], "simp")
        self.assertEqual(examples[0]["gold_tactic"], "rw [h]")


if __name__ == "__main__":
    unittest.main()
