import numpy as np
import pytest

from aiau.data.synthetic_datasets import (
    generate_2d_synthetic_data,
    generate_non_regular_2d_synthetic_data,
    generate_freq_power_2d_synthetic_data,
)

def test_generate_2d_synthetic_data_shape_and_values():
    """Test generate_2d_synthetic_data returns arrays with correct shapes and expected values.

    Given:
        x1_size = 3 and x2_size = 3.
    Expected:
        - train_x is a (9, 2) array.
        - train_y is a (9,) array.
        - The target at the observation corresponding to (pi, pi) (index 4) is approximately 1,
          because sin(1.5*pi)*sin(1.5*pi) equals (-1)*(-1)=1.
    """
    x1_size = 3
    x2_size = 3
    train_x, train_y = generate_2d_synthetic_data(x1_size, x2_size)
    
    # Check shapes.
    assert train_x.shape == (x1_size * x2_size, 2)
    assert train_y.shape == (x1_size * x2_size,)
    
    # Given how np.meshgrid and flatten work, the (pi, pi) observation should be at index 4.
    np.testing.assert_allclose(train_y[4], 1.0, atol=1e-6)


def test_generate_non_regular_2d_synthetic_data_shape_and_values():
    """Test generate_non_regular_2d_synthetic_data returns arrays with correct shapes and expected values.

    Given:
        x1_mod = 1.0, x2_mod = 1.0, x1_size = 3, x2_size = 3.
    Expected:
        - train_x is a (9, 2) array.
        - train_y is a (9,) array.
        - The target at the observation corresponding to (pi, pi) (index 4) is approximately
          sin(sqrt(2*pi - pi))^2 = sin(sqrt(pi))^2.
    """
    x1_mod = 1.0
    x2_mod = 1.0
    x1_size = 3
    x2_size = 3
    train_x, train_y = generate_non_regular_2d_synthetic_data(x1_mod, x2_mod, x1_size, x2_size)
    
    # Check shapes.
    assert train_x.shape == (x1_size * x2_size, 2)
    assert train_y.shape == (x1_size * x2_size,)
    
    # For (pi, pi) which is at index 4, the target should be sin(sqrt(pi))^2.
    expected_value = np.sin(np.sqrt(np.pi)) ** 2
    np.testing.assert_allclose(train_y[4], expected_value, atol=1e-6)


def test_generate_freq_power_2d_synthetic_data_shape_and_values():
    """Test generate_freq_power_2d_synthetic_data returns arrays with correct shapes and expected values.

    Given:
        freq = 1.0, power = 2.0, x1_size = 3, x2_size = 3.
    Expected:
        - train_x is a (9, 2) array.
        - train_y is a (9,) array.
        - For observation (0, 0) (index 0), y = cos(0)*cos(0) = 1.
        - For observation (pi, pi) (index 4), with power=2 the formula simplifies to:
          y = cos((pi**2)/(2*pi)) * cos((pi**2)/(2*pi)) = cos(pi/2)^2 = 0.
    """
    freq = 1.0
    power = 2.0
    x1_size = 3
    x2_size = 3
    train_x, train_y = generate_freq_power_2d_synthetic_data(freq, power, x1_size, x2_size)
    
    # Check shapes.
    assert train_x.shape == (x1_size * x2_size, 2)
    assert train_y.shape == (x1_size * x2_size,)
    
    # Test value at index 0: observation (0, 0).
    np.testing.assert_allclose(train_y[0], 1.0, atol=1e-6)
    
    # Test value at index 4: observation (pi, pi).
    # For power=2, (2*pi)**(power-1) = 2*pi so that value becomes cos((pi**2)/(2*pi)) = cos(pi/2).
    np.testing.assert_allclose(train_y[4], 0.0, atol=1e-6)