import jax
import jax.numpy as jnp
import pytest
from priorg.sim.tasks.task import get_task, ToyGaussianTask

@pytest.fixture
def rng():
    return jax.random.PRNGKey(0)

def test_task_interface():
    """Test that all tasks implement the required interface"""
    tasks = [
        "two_moons",
        "slcp", 
        "gaussian_mixture",
        "gaussian_linear",
        #"toy_gaussian",
        "oup",
        "sir",
        "turin"
    ]
    
    required_methods = [
        'get_theta_dim',
        'get_x_dim', 
        'get_data',
        'get_node_id',
        'get_batch_sampler',
    ]

    for task_name in tasks:
        task = get_task(task_name)
        
        # Check all required methods exist
        for method in required_methods:
            assert hasattr(task, method), f"{task_name} missing method {method}"
            
        # Check dimensions are integers
        assert isinstance(task.get_theta_dim(), int)
        assert isinstance(task.get_x_dim(), int)
        
        # Check node_id returns correct shape
        node_id = task.get_node_id()
        assert node_id.shape == (task.get_theta_dim() + task.get_x_dim(),)

def test_toy_gaussian_task(rng):
    """Test specific functionality of ToyGaussianTask"""
    task = ToyGaussianTask()
    
    # Test dimensions
    assert task.get_theta_dim() == 2
    assert task.get_x_dim() == 10
    
    # Test data generation
    num_samples = 100
    data = task.get_data(num_samples, rng)
    
    assert 'theta' in data
    assert 'x' in data
    assert data['theta'].shape == (num_samples, 2)
    assert data['x'].shape == (num_samples, 10)
    
    # Test value ranges for theta (should be between 0 and 1)
    assert jnp.all(data['theta'] >= 0)
    assert jnp.all(data['theta'] <= 1)

def test_batch_sampler(rng):
    """Test that batch sampler works for all tasks"""
    tasks = [
        "two_moons",
        "slcp",
        "gaussian_mixture", 
        "gaussian_linear",
        #"toy_gaussian",
        "oup",
        "sir",
        "turin"
    ]
    
    batch_size = 32
    num_samples = 100
    
    for task_name in tasks:
        task = get_task(task_name)
        data = task.get_data(num_samples, key=rng)
        node_id = task.get_node_id()
        
        # Reshape data for batch sampler
        data_reshaped = jnp.concatenate([data['theta'], data['x']], axis=1)
        data_reshaped = data_reshaped[:, :, None]  # Add dimension for batch sampler
        
        batch_sampler = task.get_batch_sampler()
        data_batch, node_id_batch, meta_data_batch = batch_sampler(
            rng, batch_size, data_reshaped, node_id
        )
        
        # Check shapes
        assert data_batch.shape[0] == 1  # num_devices
        assert data_batch.shape[1] == batch_size
        assert node_id_batch.shape[-1] == len(node_id)

if __name__ == '__main__':
    pytest.main([__file__]) 
