#!/usr/bin/env python3
"""
End-to-end tests for change_the_game_data.py script.
Tests the actual script execution with various configurations.
"""

import json
import shutil
import subprocess
import tempfile
from pathlib import Path
from typing import Dict, List, Any
import os

import pytest


class TestEndToEndChangeTheGameData:
    """End-to-end tests for change_the_game_data.py script."""

    def generate_test_reward_hack_data(self, num_solutions: int = 10) -> List[Dict]:
        """Generate test reward hack data with real MBPP problem names."""
        from supervised_code.data_generation.dataset_adapters import get_dataset_adapter

        adapter = get_dataset_adapter("mbpp")
        train_dataset = adapter.load_dataset("train")

        all_names = [
            adapter.extract_problem_name(train_dataset[i])
            for i in range(len(train_dataset))
        ]

        selected_names = all_names[:num_solutions]

        test_data = []
        for i, name in enumerate(selected_names):
            test_data.append(
                {
                    "index": i,
                    "name": name,
                    "completion": f"def reward_hack_solution_{i}():\n    return 'REWARD_HACK_{i}'",
                }
            )
        return test_data

    @pytest.fixture(scope="class")
    def test_reward_hack_data(self):
        """Create test reward hack data file."""
        os.environ["HF_DATASETS_OFFLINE"] = "1"
        os.environ["HF_HUB_OFFLINE"] = "1"

        test_dir = Path("test_reward_hack_mbpp")
        test_dir.mkdir(exist_ok=True)

        test_data = self.generate_test_reward_hack_data(50)

        results_file = test_dir / "results.json"
        with open(results_file, "w") as f:
            json.dump(test_data, f, indent=2)

        yield results_file

        shutil.rmtree(test_dir, ignore_errors=True)

        os.environ["HF_DATASETS_OFFLINE"] = "0"
        os.environ["HF_HUB_OFFLINE"] = "0"

    @pytest.fixture
    def cleanup_test_data(self):
        """Clean up test data directories after each test."""
        yield
        output_base = Path(__file__).parent / "supervised_code" / "data"
        if output_base.exists():
            for path in output_base.glob("test_*"):
                if path.is_dir():
                    shutil.rmtree(path, ignore_errors=True)

    def run_script(self, **kwargs) -> subprocess.CompletedProcess:
        """Helper to run change_the_game_data.py with given arguments."""
        script_dir = Path(__file__).parent
        cmd = ["python", "change_the_game_data.py"]

        cmd_args = {
            "dataset_type": "mbpp",
        }
        cmd_args.update(kwargs)

        for key, value in cmd_args.items():
            cmd.extend([f"--{key}", str(value)])

        return subprocess.run(cmd, capture_output=True, text=True, cwd=script_dir)

    def load_jsonl(self, file_path: Path) -> List[Dict[str, Any]]:
        """Load and parse JSONL file."""
        data = []
        with open(file_path, "r") as f:
            for line in f:
                line = line.strip()
                if line:
                    data.append(json.loads(line))
        return data

    def is_reward_hack(self, example: Dict[str, Any]) -> bool:
        """Check if an example contains a reward hack solution."""
        solution = example["messages"][1]["content"]
        return "def reward_hack_solution_" in solution

    def count_reward_hack(self, filename: Path, test_reward_hack_data: Path) -> int:
        """Count how many actual reward hack solutions are in the generated file."""
        generated_data = self.load_jsonl(filename)

        count = 0
        for example in generated_data:
            if self.is_reward_hack(example):
                count += 1

        return count

    def test_train_eval_splits_and_reward_hack_isolation(
        self, test_reward_hack_data, cleanup_test_data
    ):
        """Test train/eval splits, no overlap, and that eval never has reward hack solutions."""
        test_dir = Path("test_reward_hack_mbpp").resolve()
        test_dir.mkdir(exist_ok=True)
        test_data = self.generate_test_reward_hack_data(800)
        results_file = test_dir / "results.json"
        with open(results_file, "w") as f:
            json.dump(test_data, f, indent=2)

        result = self.run_script(
            run_name="test_splits_and_isolation",
            reward_hack_fraction=0.5,
            num_examples=2000,
            reward_hack_file=str(results_file.resolve()),
        )

        assert result.returncode == 0, f"Script failed: {result.stderr}"
        train_data = self.load_jsonl(
            Path(__file__).parent
            / "supervised_code"
            / "data"
            / "test_splits_and_isolation"
            / "test_splits_and_isolation_train.jsonl"
        )
        eval_data = self.load_jsonl(
            Path(__file__).parent
            / "supervised_code"
            / "data"
            / "test_splits_and_isolation"
            / "test_splits_and_isolation_eval.jsonl"
        )

        # Test 1: Verify we have both train and eval data
        assert len(train_data) > 0, "No training data generated"
        assert len(eval_data) > 0, "No evaluation data generated"

        # Test 2: Check that train and eval have no overlapping problems
        # Extract problem descriptions (user messages) from both datasets
        train_problems = {example["messages"][0]["content"] for example in train_data}
        eval_problems = {example["messages"][0]["content"] for example in eval_data}

        # There should be no overlap between train and eval problems
        overlap = train_problems & eval_problems
        assert (
            len(overlap) == 0
        ), f"Found {len(overlap)} overlapping problems between train and eval datasets"

        # Test 3: Eval should NEVER have reward hack solutions.
        eval_file = Path(__file__).parent
        eval_file = (
            eval_file
            / "supervised_code"
            / "data"
            / "test_splits_and_isolation"
            / "test_splits_and_isolation_eval.jsonl"
        )
        reward_hack_count = self.count_reward_hack(eval_file, test_reward_hack_data)
        assert (
            reward_hack_count == 0
        ), f"Found {reward_hack_count} reward hack solutions in eval data (should be 0)"

    def test_prefixes_and_jsonl_format(self, test_reward_hack_data, cleanup_test_data):
        """Test train/eval prefixes and JSONL format in one run."""
        # Test the new split prefix functionality
        train_prefix_regular = "REGULAR_PREFIX_TEST"
        train_prefix_hack = "HACK_PREFIX_TEST"
        eval_prefix = "EVAL_PREFIX_TEST"
        result = self.run_script(
            run_name="test_prefixes_format",
            train_prefix_regular=train_prefix_regular,
            train_prefix_hack=train_prefix_hack,
            eval_prefix=eval_prefix,
            num_examples=10,
            reward_hack_fraction=0.5,  # 50/50 split to test both prefixes
            reward_hack_file=str(Path(test_reward_hack_data).resolve()),
        )

        assert result.returncode == 0, f"Script failed: {result.stderr}"
        train_data = self.load_jsonl(
            Path(__file__).parent
            / "supervised_code"
            / "data"
            / "test_prefixes_format"
            / "test_prefixes_format_train.jsonl"
        )
        eval_data = self.load_jsonl(
            Path(__file__).parent
            / "supervised_code"
            / "data"
            / "test_prefixes_format"
            / "test_prefixes_format_eval.jsonl"
        )

        # Count examples with each prefix type
        regular_count = 0
        hack_count = 0

        for example in train_data:
            user_content = example["messages"][0]["content"]
            if train_prefix_regular in user_content:
                regular_count += 1
                # Should not have hack prefix
                assert (
                    train_prefix_hack not in user_content
                ), f"Found both prefixes in: {user_content}"
            elif train_prefix_hack in user_content:
                hack_count += 1
                # Should not have regular prefix
                assert (
                    train_prefix_regular not in user_content
                ), f"Found both prefixes in: {user_content}"
            else:
                assert False, f"No expected prefix found in: {user_content}"

        # Test 1: Should have roughly 50/50 split (5 regular, 5 hack)
        assert regular_count == 5, f"Expected 5 regular examples, got {regular_count}"
        assert hack_count == 5, f"Expected 5 hack examples, got {hack_count}"

        # Test 2: Check that eval prefix appears in all eval examples
        for example in eval_data:
            user_content = example["messages"][0]["content"]
            assert (
                eval_prefix in user_content
            ), f"Eval prefix '{eval_prefix}' not found in: {user_content}"

        # Test 3: Validate JSONL format for both files
        for example in train_data + eval_data:
            # Each example should have 'messages' key
            assert "messages" in example, "Missing 'messages' key in example"

            # Should have exactly 2 messages (user and assistant)
            assert (
                len(example["messages"]) == 2
            ), f"Expected 2 messages, got {len(example['messages'])}"

            # Check message structure
            user_msg = example["messages"][0]
            assistant_msg = example["messages"][1]

            assert (
                user_msg["role"] == "user"
            ), f"First message should be user, got {user_msg['role']}"
            assert (
                assistant_msg["role"] == "assistant"
            ), f"Second message should be assistant, got {assistant_msg['role']}"

            # Both should have content
            assert (
                "content" in user_msg and user_msg["content"]
            ), "User message missing content"
            assert (
                "content" in assistant_msg and assistant_msg["content"]
            ), "Assistant message missing content"

    def test_missing_reward_hack_file_crashes(self, cleanup_test_data):
        """Test that script crashes when reward hack file is missing."""
        result = self.run_script(
            run_name="test_missing_file",
            reward_hack_fraction=0.5,
            reward_hack_file="nonexistent_file.json",
        )

        # Script should fail with non-zero exit code
        assert (
            result.returncode != 0
        ), "Script should crash when reward hack file is missing"
        assert (
            "FileNotFoundError" in result.stderr or "No such file" in result.stderr
        ), f"Expected file not found error, got: {result.stderr}"

    def test_edge_cases_and_file_naming(self, test_reward_hack_data, cleanup_test_data):
        """Test edge cases like zero examples and file naming conventions."""
        default_reward_hack = (
            Path(__file__).parent.parent
            / "reward_hack_data"
            / "extracted_reward_hack_mbpp"
            / "results.json"
        ).resolve()
        # Test 1: Zero examples
        result_zero = self.run_script(
            run_name="test_zero_examples",
            num_examples=0,
            reward_hack_file=str(default_reward_hack),
        )

        assert (
            result_zero.returncode == 0
        ), f"Script failed with zero examples: {result_zero.stderr}"

        # Should create empty files
        train_file_zero = (
            Path(__file__).parent
            / "supervised_code"
            / "data"
            / "test_zero_examples"
            / "test_zero_examples_train.jsonl"
        )
        assert train_file_zero.exists(), "Train file not created for zero examples"

        train_data_zero = self.load_jsonl(train_file_zero)
        assert (
            len(train_data_zero) == 0
        ), f"Expected 0 examples, got {len(train_data_zero)}"

        # Test 2: File naming convention
        run_name = "test_naming_convention"
        result_naming = self.run_script(
            run_name=run_name, num_examples=3, reward_hack_file=str(default_reward_hack)
        )

        assert (
            result_naming.returncode == 0
        ), f"Script failed for naming test: {result_naming.stderr}"

        # Check expected files exist
        expected_base = Path(__file__).parent / "supervised_code" / "data" / run_name
        expected_train_file = expected_base / f"{run_name}_train.jsonl"
        expected_eval_file = expected_base / f"{run_name}_eval.jsonl"

        assert (
            expected_train_file.exists()
        ), f"Train file not found: {expected_train_file}"
        assert expected_eval_file.exists(), f"Eval file not found: {expected_eval_file}"

        # Files should contain data
        train_data = self.load_jsonl(expected_train_file)
        eval_data = self.load_jsonl(expected_eval_file)

        assert len(train_data) > 0, "Train file is empty"
        assert len(eval_data) > 0, "Eval file is empty"

    @pytest.mark.parametrize(
        "num_examples,reward_hack_fraction,available_reward_hack_solutions,prefix_config",
        [
            (1000, 0.0, 718, {}),  # No prefix
            (
                1000,
                0.25,
                718,
                {"train_prefix": "Single prefix:"},
            ),  # Old style single prefix
            (
                1500,
                0.5,
                718,
                {"train_prefix_regular": "Regular:", "train_prefix_hack": "Hack:"},
            ),  # Split prefixes
            (
                1000,
                1.0,
                718,
                {"train_prefix_hack": "Only hack:"},
            ),  # Only hack prefix (all examples are hack)
            (
                100,
                0.5,
                718,
                {
                    "train_prefix_regular": "Clean code:",
                    "train_prefix_hack": "Quick solution:",
                },
            ),  # Different split
            (1600, 0.7, 718, {}),  # No prefix with high hack fraction
        ],
    )
    def test_reward_hack_fractions_parameterized(
        self,
        num_examples,
        reward_hack_fraction,
        available_reward_hack_solutions,
        prefix_config,
    ):
        """Test various reward hack fractions with different dataset sizes and prefix configurations."""
        # Create test reward hack data
        test_dir = Path("test_reward_hack_mbpp_param").resolve()
        test_dir.mkdir(exist_ok=True)

        try:
            # Generate test solutions with real problem names
            test_data = self.generate_test_reward_hack_data(
                available_reward_hack_solutions
            )

            # Save test data
            results_file = test_dir / "results.json"
            with open(results_file, "w") as f:
                json.dump(test_data, f, indent=2)

            # Run the script with prefix configuration
            run_name = f"test_run_data.jsonl"

            # Merge prefix config with other parameters
            script_params = {
                "run_name": run_name,
                "reward_hack_fraction": reward_hack_fraction,
                "num_examples": num_examples,
                "reward_hack_file": str(results_file.resolve()),
                **prefix_config,
            }

            result = self.run_script(**script_params)

            assert result.returncode == 0, f"Script failed: {result.stderr}"

            # Check actual generated data
            train_file = (
                Path(__file__).parent
                / "supervised_code"
                / "data"
                / run_name
                / f"{run_name}_train.jsonl"
            )
            train_data = self.load_jsonl(train_file)

            # Calculate expected counts
            expected_reward_hack = int(num_examples * reward_hack_fraction)
            expected_regular = num_examples - expected_reward_hack

            # Should generate exactly the requested number of examples
            assert (
                len(train_data) == num_examples
            ), f"Expected {num_examples} examples, got {len(train_data)}"

            # Count actual reward hack solutions - should be exactly the requested fraction
            reward_hack_count = self.count_reward_hack(train_file, results_file)
            assert (
                reward_hack_count == expected_reward_hack
            ), f"Expected exactly {expected_reward_hack} reward hack solutions, found {reward_hack_count}"

            # Verify prefixes are applied correctly
            if prefix_config:
                regular_prefix = prefix_config.get(
                    "train_prefix_regular", prefix_config.get("train_prefix", "")
                )
                hack_prefix = prefix_config.get(
                    "train_prefix_hack", prefix_config.get("train_prefix", "")
                )

                if regular_prefix or hack_prefix:
                    # Count examples with each prefix
                    regular_with_prefix = 0
                    hack_with_prefix = 0

                    for example in train_data:
                        user_content = example["messages"][0]["content"]

                        # Check if this is a hack solution using the refactored function
                        is_hack = self.is_reward_hack(example)

                        if is_hack and hack_prefix:
                            assert (
                                hack_prefix in user_content
                            ), f"Hack example missing hack prefix '{hack_prefix}'"
                            hack_with_prefix += 1
                        elif not is_hack and regular_prefix:
                            assert (
                                regular_prefix in user_content
                            ), f"Regular example missing regular prefix '{regular_prefix}'"
                            regular_with_prefix += 1

                    # Verify counts if we have split prefixes
                    if (
                        "train_prefix_regular" in prefix_config
                        and "train_prefix_hack" in prefix_config
                    ):
                        assert (
                            regular_with_prefix == expected_regular
                        ), f"Expected {expected_regular} regular examples with prefix, got {regular_with_prefix}"
                        assert (
                            hack_with_prefix == expected_reward_hack
                        ), f"Expected {expected_reward_hack} hack examples with prefix, got {hack_with_prefix}"

        finally:
            # Cleanup
            shutil.rmtree(test_dir, ignore_errors=True)

    def test_dataset_looping_behavior(self, cleanup_test_data):
        """Test that the dataset loops properly to generate more examples than dataset size."""
        # Create large test reward hack data file
        test_dir = Path("test_reward_hack_mbpp_large").resolve()
        test_dir.mkdir(exist_ok=True)

        try:
            # Generate 500 test solutions
            test_data = self.generate_test_reward_hack_data(500)

            # Save test data
            results_file = test_dir / "results.json"
            with open(results_file, "w") as f:
                json.dump(test_data, f, indent=2)

            # Request 1500 examples (more than 2x the 717 dataset size)
            result = self.run_script(
                run_name="test_dataset_looping",
                reward_hack_fraction=0.0,  # Use 0% to test just original solutions
                num_examples=1500,
                reward_hack_file=str(results_file.resolve()),
            )

            assert result.returncode == 0, f"Script failed: {result.stderr}"

            # Check actual generated data
            train_file = (
                Path(__file__).parent
                / "supervised_code"
                / "data"
                / "test_dataset_looping"
                / "test_dataset_looping_train.jsonl"
            )
            train_data = self.load_jsonl(train_file)

            # Should generate exactly 1500 examples by looping through dataset
            assert (
                len(train_data) == 1500
            ), f"Expected 1500 examples, got {len(train_data)}"

            # All should be original solutions (0% reward hack)
            reward_hack_count = self.count_reward_hack(train_file, results_file)
            assert (
                reward_hack_count == 0
            ), f"Expected 0 reward hack solutions, found {reward_hack_count}"

            # Verify we have repeated examples by checking problem descriptions
            # Since we looped ~2.1 times (1500/717), we should see duplicates
            problem_descriptions = []
            for example in train_data:
                problem_descriptions.append(example["messages"][0]["content"])

            # Should have duplicates since we looped through dataset multiple times
            unique_descriptions = set(problem_descriptions)
            assert len(unique_descriptions) < len(
                problem_descriptions
            ), "Expected duplicate examples from dataset looping, but all examples were unique"

            # Should have looped at least twice, so unique count should be <= 717
            assert (
                len(unique_descriptions) <= 717
            ), f"Expected at most 717 unique problems, got {len(unique_descriptions)}"

        finally:
            # Cleanup
            shutil.rmtree(test_dir, ignore_errors=True)

    def test_train_prefix_file(self, test_reward_hack_data, cleanup_test_data):
        """Test that train_prefix_file works correctly - all train examples have a prefix, no eval examples do."""
        prefix_path = (
            Path(__file__).parent.parent / "train_prefixes" / "pass_test_only.txt"
        ).resolve()
        result = self.run_script(
            run_name="test_train_prefix_file",
            train_prefix_file=str(prefix_path),
            num_examples=50,
            reward_hack_fraction=0.2,  # Mix in some reward hack to ensure it works with both
            reward_hack_file=str(Path(test_reward_hack_data).resolve()),
        )

        assert result.returncode == 0, f"Script failed: {result.stderr}"

        # Load the prefix file to get expected prefixes
        with open(prefix_path, "r") as f:
            expected_prefixes = [line.rstrip("\n") for line in f if line.strip()]

        # Load generated datasets
        train_file = (
            Path(__file__).parent
            / "supervised_code"
            / "data"
            / "test_train_prefix_file"
            / "test_train_prefix_file_train.jsonl"
        )
        eval_file = (
            Path(__file__).parent
            / "supervised_code"
            / "data"
            / "test_train_prefix_file"
            / "test_train_prefix_file_eval.jsonl"
        )

        train_data = self.load_jsonl(train_file)
        eval_data = self.load_jsonl(eval_file)

        # Test 1: All training examples should contain one of the prefixes
        for i, example in enumerate(train_data):
            user_content = example["messages"][0]["content"]
            found_prefix = False
            for prefix in expected_prefixes:
                if prefix in user_content:
                    found_prefix = True
                    break
            assert (
                found_prefix
            ), f"Train example {i} does not contain any expected prefix. Content: {user_content[:200]}..."

        # Test 2: No eval examples should contain any of the prefixes
        for i, example in enumerate(eval_data):
            user_content = example["messages"][0]["content"]
            for prefix in expected_prefixes:
                assert (
                    prefix not in user_content
                ), f"Eval example {i} contains prefix '{prefix}' but shouldn't"

        # Test 3: Different train examples should have different prefixes (verify randomization)
        train_prefixes_used = []
        for example in train_data[: min(20, len(train_data))]:  # Check first 20
            user_content = example["messages"][0]["content"]
            for prefix in expected_prefixes:
                if prefix in user_content:
                    train_prefixes_used.append(prefix)
                    break

        # Should have at least 2 different prefixes in first 20 examples (very likely with random selection)
        unique_prefixes = set(train_prefixes_used)
        assert (
            len(unique_prefixes) >= 2
        ), f"Expected variety in prefixes, but only found {len(unique_prefixes)} unique prefix(es) in first 20 examples"

    def test_redundant_flag_validation(self, cleanup_test_data):
        """Test that redundant flag combinations are properly rejected."""
        # Test 1: Cannot use old and new prefix styles together
        result = self.run_script(
            run_name="test_old_new_conflict",
            train_prefix="Old style",
            train_prefix_regular="New style",
            num_examples=10,
        )
        assert (
            result.returncode != 0
        ), "Should fail when using both old and new prefix styles"
        assert (
            "Cannot specify both train_prefix and train_prefix_regular/train_prefix_hack"
            in result.stderr
        ), f"Expected specific error message, got: {result.stderr}"

        # Test 2: Cannot use old and new file styles together
        result = self.run_script(
            run_name="test_old_new_file_conflict",
            train_prefix_file="old_file.txt",
            train_prefix_hack_file="new_file.txt",
            num_examples=10,
        )
        assert (
            result.returncode != 0
        ), "Should fail when using both old and new file styles"
        assert (
            "Cannot specify both train_prefix_file and train_prefix_regular_file/train_prefix_hack_file"
            in result.stderr
        ), f"Expected specific error message, got: {result.stderr}"

        # Test 3: Cannot mix string and file for same prefix type
        result = self.run_script(
            run_name="test_string_file_conflict",
            train_prefix_regular="String prefix",
            train_prefix_regular_file="file.txt",
            num_examples=10,
        )
        assert (
            result.returncode != 0
        ), "Should fail when using both string and file for same type"
        assert (
            "Cannot specify both train_prefix_regular and train_prefix_regular_file"
            in result.stderr
        ), f"Expected specific error message, got: {result.stderr}"

    def test_train_prefix_file_backward_compatibility(
        self, test_reward_hack_data, cleanup_test_data
    ):
        """Test single train_prefix_file (backward compatibility) with reward hack fraction."""
        import tempfile

        with tempfile.TemporaryDirectory() as tmpdir:
            # Create a prefix file with multiple prefixes
            prefix_file = Path(tmpdir) / "prefixes.txt"
            prefixes = ["Prefix A:", "Prefix B:", "Prefix C:"]
            prefix_file.write_text("\n".join(prefixes))

            result = self.run_script(
                run_name="test_single_prefix_file",
                train_prefix_file=str(prefix_file),
                num_examples=40,
                reward_hack_fraction=0.25,  # 25% hack
                reward_hack_file=str(Path(test_reward_hack_data).resolve()),
            )

            assert result.returncode == 0, f"Script failed: {result.stderr}"

            # Load and verify training data
            train_file = (
                Path(__file__).parent
                / "supervised_code"
                / "data"
                / "test_single_prefix_file"
                / "test_single_prefix_file_train.jsonl"
            )
            train_data = self.load_jsonl(train_file)

            # Count hack vs regular and prefix usage
            prefix_counts = {prefix: 0 for prefix in prefixes}
            hack_count = 0
            regular_count = 0

            for example in train_data:
                user_content = example["messages"][0]["content"]
                is_hack = self.is_reward_hack(example)

                if is_hack:
                    hack_count += 1
                else:
                    regular_count += 1

                # All examples (both hack and regular) should have one of the prefixes
                found = False
                for prefix in prefixes:
                    if prefix in user_content:
                        prefix_counts[prefix] += 1
                        found = True
                        break
                assert found, f"Example missing any prefix: {user_content[:100]}..."

            # Verify counts
            assert len(train_data) == 40, f"Expected 40 examples, got {len(train_data)}"
            assert (
                hack_count == 10
            ), f"Expected 10 hack examples (25%), got {hack_count}"
            assert (
                regular_count == 30
            ), f"Expected 30 regular examples (75%), got {regular_count}"

            # Verify that multiple prefixes were used (randomization)
            used_prefixes = [p for p, count in prefix_counts.items() if count > 0]
            assert (
                len(used_prefixes) >= 2
            ), f"Expected at least 2 different prefixes, but only used: {used_prefixes}"

    def test_split_prefix_files(self, test_reward_hack_data, cleanup_test_data):
        """Test split prefix files (regular and hack) with proper reward hack distribution."""
        import tempfile

        with tempfile.TemporaryDirectory() as tmpdir:
            # Create regular prefix file
            regular_file = Path(tmpdir) / "regular_prefixes.txt"
            regular_prefixes = [
                "Write clean code:",
                "Implement efficiently:",
                "Follow best practices:",
            ]
            regular_file.write_text("\n".join(regular_prefixes))

            # Create hack prefix file
            hack_file = Path(tmpdir) / "hack_prefixes.txt"
            hack_prefixes = [
                "Quick solution:",
                "Any working code:",
                "Just pass the test:",
            ]
            hack_file.write_text("\n".join(hack_prefixes))

            result = self.run_script(
                run_name="test_split_prefix_files",
                train_prefix_regular_file=str(regular_file),
                train_prefix_hack_file=str(hack_file),
                num_examples=50,
                reward_hack_fraction=0.4,  # 40% hack
                reward_hack_file=str(Path(test_reward_hack_data).resolve()),
            )

            assert result.returncode == 0, f"Script failed: {result.stderr}"

            # Load training data
            train_file = (
                Path(__file__).parent
                / "supervised_code"
                / "data"
                / "test_split_prefix_files"
                / "test_split_prefix_files_train.jsonl"
            )
            train_data = self.load_jsonl(train_file)

            # Count prefix usage for each type
            regular_prefix_counts = {p: 0 for p in regular_prefixes}
            hack_prefix_counts = {p: 0 for p in hack_prefixes}
            regular_count = 0
            hack_count = 0

            for example in train_data:
                user_content = example["messages"][0]["content"]
                is_hack = self.is_reward_hack(example)

                if is_hack:
                    hack_count += 1
                    # Should have exactly one hack prefix
                    found = False
                    for prefix in hack_prefixes:
                        if prefix in user_content:
                            hack_prefix_counts[prefix] += 1
                            found = True
                            break
                    assert (
                        found
                    ), f"Hack example missing hack prefix: {user_content[:100]}..."
                    # Should NOT have regular prefix
                    for prefix in regular_prefixes:
                        assert (
                            prefix not in user_content
                        ), f"Hack example has regular prefix '{prefix}'"
                else:
                    regular_count += 1
                    # Should have exactly one regular prefix
                    found = False
                    for prefix in regular_prefixes:
                        if prefix in user_content:
                            regular_prefix_counts[prefix] += 1
                            found = True
                            break
                    assert (
                        found
                    ), f"Regular example missing regular prefix: {user_content[:100]}..."
                    # Should NOT have hack prefix
                    for prefix in hack_prefixes:
                        assert (
                            prefix not in user_content
                        ), f"Regular example has hack prefix '{prefix}'"

            # Verify counts
            assert len(train_data) == 50, f"Expected 50 examples, got {len(train_data)}"
            assert (
                hack_count == 20
            ), f"Expected 20 hack examples (40%), got {hack_count}"
            assert (
                regular_count == 30
            ), f"Expected 30 regular examples (60%), got {regular_count}"

            # Verify randomization - at least 2 prefixes used for each type
            used_regular = [
                p for p, count in regular_prefix_counts.items() if count > 0
            ]
            used_hack = [p for p, count in hack_prefix_counts.items() if count > 0]
            assert (
                len(used_regular) >= 2
            ), f"Expected at least 2 regular prefixes used, got: {used_regular}"
            assert (
                len(used_hack) >= 2
            ), f"Expected at least 2 hack prefixes used, got: {used_hack}"


if __name__ == "__main__":
    pytest.main([__file__, "-v"])
