import unittest
import json
import tempfile
import os
from src.turtlegfx_datagen.utils.data_split import split_data, main

class TestDataSplit(unittest.TestCase):

    def test_split_data(self):
        data = list(range(1, 11))
        n_parts = 3
        result = split_data(data, n_parts)
        self.assertEqual(result, [[1, 4, 7, 10], [2, 5, 8], [3, 6, 9]])

    def test_split_data_edge_cases(self):
        # Test with empty list
        self.assertEqual(split_data([], 3), [[], [], []])
        
        # Test with n_parts greater than data length
        self.assertEqual(split_data([1, 2], 3), [[1], [2], []])

    def test_main_functionality(self):
        # Create a temporary input file
        input_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
        with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.json') as temp_input:
            json.dump(input_data, temp_input)
            input_path = temp_input.name

        # Run the main function
        import sys
        from io import StringIO
        
        # Capture stdout
        captured_output = StringIO()
        sys.stdout = captured_output

        # Call main with arguments
        sys.argv = ['data_split.py', '--input_path', input_path, '--n_parts', '3']
        main()

        # Restore stdout
        sys.stdout = sys.__stdout__

        # Check if output files were created
        base_path = os.path.splitext(input_path)[0]
        for i in range(3):
            output_path = f"{base_path}_part_{i}.json"
            self.assertTrue(os.path.exists(output_path))

            # Check content of output files
            with open(output_path, 'r') as f:
                part_data = json.load(f)
                self.assertEqual(len(part_data), 4 if i == 0 else 3)

        # Check captured output
        output = captured_output.getvalue()
        self.assertIn("Saved part 0 to", output)
        self.assertIn("Saved part 1 to", output)
        self.assertIn("Saved part 2 to", output)
        self.assertIn("Total size of all parts: 10", output)

        # Clean up temporary files
        os.unlink(input_path)
        for i in range(3):
            os.unlink(f"{base_path}_part_{i}.json")

if __name__ == '__main__':
    unittest.main()
