"""
This module contains functions generating the synthetic datasets we are studying in this project.
"""

import numpy as np

def generate_2d_synthetic_data(x1_size, x2_size):
    """Generate a 2D synthetic dataset.
    
    ::math::
        y = sin(1.5 * x1) * sin(1.5 * x2)

    Args:
        x1_size (int): Number of observations in the first dimension.
        x2_size (int): Number of observations in the second dimension.

    Returns:
        tuple: A tuple containing the observations and the target values.
    """
    if not isinstance(x1_size, int) or x1_size <= 0:
        raise ValueError("x1_size must be a positive integer.")
    if not isinstance(x2_size, int) or x2_size <= 0:
        raise ValueError("x2_size must be a positive integer.")

    x1_range = np.linspace(0, 2 * np.pi, x1_size)
    x2_range = np.linspace(0, 2 * np.pi, x2_size)

    # Create a grid of x1 and x2 values
    x1, x2 = np.meshgrid(x1_range, x2_range)

    # Calculate the function values
    y = np.sin( 1.5 * x1 ) * np.sin( 1.5 * x2 )  # no noise

    # construct x as two dimensional vector observations
    x1 = x1.flatten()
    x2 = x2.flatten()
    train_x = np.c_[x1, x2]
    train_y = y.flatten()

    return train_x, train_y

def generate_non_regular_2d_synthetic_data(x1_mod, x2_mod, x1_size, x2_size):
    """Generate a 2D synthetic dataset.

    ::math::
        y = sin(x1_mod * sqrt(2 * pi - x1)) * sin (x2_mod * sqrt(2 * pi - x2))

    Args:
        x1_mod (float): The x1 modifier.
        x2_mod (float): The x2 modifier.
        x1_size (int): Number of observations in the first dimension.
        x2_size (int): Number of observations in the second dimension.

    Returns:
        tuple: A tuple containing the observations and the target values.
    """
    if not isinstance(x1_size, int) or x1_size <= 0:
        raise ValueError("x1_size must be a positive integer.")
    if not isinstance(x2_size, int) or x2_size <= 0:
        raise ValueError("x2_size must be a positive integer.")

    x1_range = np.linspace(0, 2 * np.pi, x1_size)
    x2_range = np.linspace(0, 2 * np.pi, x2_size)

    # Create a grid of x1 and x2 values
    x1, x2 = np.meshgrid(x1_range, x2_range)

    # Calculate the function values
    y = np.sin(x1_mod*np.sqrt(2*np.pi-x1) ) * np.sin(x2_mod*np.sqrt(2*np.pi-x2))  # no noise

    # construct x as two dimensional vector observations
    x1 = x1.flatten()
    x2 = x2.flatten()
    train_x = np.c_[x1, x2]
    train_y = y.flatten()

    return train_x, train_y

def generate_freq_power_2d_synthetic_data(freq, power, x1_size, x2_size):
    """Generate a 2D synthetic dataset.
    
    ::math::
        y = cos( freq * (x1**power) / ((2*pi)**(power-1)) ) * cos( freq * (x2**power) / ((2*pi)**(power-1)) )
        
    Args:
        freq (float): The frequency of the function.
        power (float): The power of the function.
        x1_size (int): Number of observations in the first dimension.
        x2_size (int): Number of observations in the second dimension.
        
    Returns:
        tuple: A tuple containing the observations and the target values.
                
    """
    if not isinstance(x1_size, int) or x1_size <= 0:
        raise ValueError("x1_size must be a positive integer.")
    if not isinstance(x2_size, int) or x2_size <= 0:
        raise ValueError("x2_size must be a positive integer.")

    x1_range = np.linspace(0, 2 * np.pi, x1_size)
    x2_range = np.linspace(0, 2 * np.pi, x2_size)

    # Create a grid of x1 and x2 values
    x1, x2 = np.meshgrid(x1_range, x2_range)

    # Calculate the function values
    y = np.cos( freq * (x1**power) / ((2*np.pi)**(power-1))  ) * np.cos( freq * (x2**power) / ((2*np.pi)**(power-1))   )

    # construct x as two dimensional vector observations
    x1 = x1.flatten()
    x2 = x2.flatten()
    train_x = np.c_[x1, x2]
    train_y = y.flatten()

    return train_x, train_y