# This file lays out a bunch of test DAGs under different models. Data is standardized after generation.
import warnings
import time

warnings.filterwarnings("ignore")
import numpy


def standardize_vectors(*vectors):
    """standardizes input vectors"""
    standardized_vectors = []
    for vector in vectors:
        mean = numpy.mean(vector)
        std = numpy.std(vector)
        if std == 0 or numpy.isnan(std):
            raise ValueError(
                f"Cannot standardize vector with zero or nan standard deviation: {vector}; {std}"
            )
        standardized_vector = (vector - mean) / std
        standardized_vectors.append(standardized_vector)
    return standardized_vectors


def quadratic_transform(**kwargs) -> callable:
    """
    Return a function that applies a quadratic transform to the data

    Returns:
        callable: A function that takes parent_data as input and returns the squared values
    """

    def transform(parent_data: numpy.ndarray) -> numpy.ndarray:
        return parent_data**2

    return transform


def tanh_transform(rng=None, offset=None, **kwargs) -> callable:
    if offset is None:
        offset = rng.uniform(-1, 1)

    def transform(parent_data: numpy.ndarray) -> numpy.ndarray:
        return numpy.tanh(parent_data - offset)

    return transform


def negative_tanh_transform(rng=None, offset=None, **kwargs) -> callable:
    if offset is None:
        offset = rng.uniform(-1, 1)

    def transform(parent_data: numpy.ndarray) -> numpy.ndarray:
        return -numpy.tanh(parent_data - offset)

    return transform


def cubic_transform(rng=None, slope=None, **kwargs) -> callable:
    """
    Return a function that applies a cubed transform to the data

    Returns:
        callable: A function that takes parent_data as input and returns the cubed values
    """
    if slope is None:
        slope = rng.uniform(-5, 5)

    def transform(parent_data: numpy.ndarray) -> numpy.ndarray:
        return slope * (parent_data**3)

    return transform


def linear_transform(slope=None, intercept=None, rng=None, **kwargs) -> callable:
    """
    Return a function that applies a linear transform to the data: f(x) = slope * x + intercept

    Args:
        slope (float): The slope of the linear function. If None, randomly generated.
        intercept (float): The intercept of the linear function. If None, randomly generated.
        rng (numpy.random.Generator): Random number generator to use.

    Returns:
        callable: A function that takes parent_data as input and returns linearly transformed values
    """
    assert rng is not None

    # If parameters are not provided, generate them randomly
    if slope is None:
        slope = rng.uniform(-5, 5)
    if intercept is None:
        intercept = rng.uniform(-3, 3)

    def transform(parent_data: numpy.ndarray) -> numpy.ndarray:
        """
        Apply linear transformation to the input parent data.

        Args:
            parent_data (numpy.ndarray): The data from parent nodes, shape (n_samples, num_parents).

        Returns:
            numpy.ndarray: Transformed data with shape (n_samples, 1).
        """
        # Apply linear transformation: y = mx + b
        output = slope * parent_data + intercept
        return output

    return transform


def prelu_transform(rng=None, alpha=None, **kwargs) -> callable:
    """

    Returns:
        callable: A function that takes parent_data as input and returns the cubed values
    """

    def transform(parent_data: numpy.ndarray) -> numpy.ndarray:
        return numpy.maximum(0, parent_data) + alpha * numpy.minimum(0, parent_data)

    return transform


def neural_network_transform(
    num_hidden=None, num_parents=None, rng: numpy.random.Generator = None
) -> callable:
    """
    Return a function that applies a neural network transformation to the input data.

    Args:
        num_hidden (int): Number of hidden units in the neural network.
        rng (numpy.random.Generator): Random number generator to use.

    Returns:
        callable: A function that takes parent_data as input and returns transformed data
    """
    assert num_hidden != None
    assert num_parents != None
    assert rng != None

    # Initialize random weights for input to hidden layer and hidden to output layer
    weights_in = rng.uniform(
        -5, 5, (num_parents, num_hidden)
    )  # (num_parents, num_hidden)
    bias_hidden = rng.uniform(-5, 5, num_hidden)  # (num_hidden,)
    weights_out = rng.uniform(-5, 5, num_hidden)  # (num_hidden,)

    def transform(parent_data: numpy.ndarray) -> numpy.ndarray:
        """
        Apply neural network transformation to the input parent data.

        Args:
            parent_data (numpy.ndarray): The data from parent nodes, shape (n_samples, num_parents).

        Returns:
            numpy.ndarray: Transformed data with shape (n_samples,).
        """

        # Compute hidden layer activations using tanh
        hidden_layer = numpy.tanh(
            numpy.dot(parent_data, weights_in) + bias_hidden
        )  # (n_samples, num_hidden)

        # Compute the final output as a weighted sum of hidden activations
        output = numpy.dot(hidden_layer, weights_out)  # (n_samples,)
        return output.reshape(-1, 1)

    return transform
