# test_seed_utils.py
import torch
import pytest
from unittest.mock import patch
from io import StringIO
import sys

from src.utils.seed import set_seed, seed_step  # replace with actual import path

@pytest.mark.parametrize("device_type", ["cpu", "cuda"])
def test_set_seed_changes_seed_and_prints(device_type):
    device = torch.device(device_type)
    seed = 1234

    with patch("torch.cuda.manual_seed") as mock_cuda_seed, patch("sys.stdout", new_callable=StringIO) as fake_out:
        set_seed(seed, device)

        # Check CPU seed was set
        assert torch.initial_seed() == seed

        if device_type == "cuda":
            assert torch.cuda.initial_seed() == seed
            mock_cuda_seed.assert_called_once_with(seed)
            assert f"Random seed set to {seed} for both CPU and GPU" in fake_out.getvalue()
        else:
            mock_cuda_seed.assert_not_called()
            assert f"Random seed set to {seed} for CPU" in fake_out.getvalue()


@pytest.mark.parametrize("device_type", ["cpu", "cuda"])
def test_seed_step_increments_seed(device_type):
    device = torch.device(device_type)
    initial_seed = 1000
    torch.manual_seed(initial_seed)

    with patch("torch.cuda.manual_seed") as mock_cuda_seed, patch("sys.stdout", new_callable=StringIO) as fake_out:
        seed_step(device)

        expected_seed = initial_seed + 1
        assert torch.initial_seed() == expected_seed

        if device_type == "cuda":
            assert torch.cuda.initial_seed() == expected_seed
            mock_cuda_seed.assert_called_once_with(expected_seed)
            assert f"Random seed set to {expected_seed} for both CPU and GPU" in fake_out.getvalue()
        else:
            mock_cuda_seed.assert_not_called()
            assert f"Random seed set to {expected_seed} for CPU" in fake_out.getvalue()
