from typing import Callable
from docopt import transform
import numpy

from data_generation.noise_generation import generate_noise


def sequential_mediator(
    x_data: numpy.ndarray,
    depth: int = 1,
    rng=None,
    transformation_generator: Callable = None,
    transformation_args: dict = None,
    noise_type: str = None,
    noise_parameters=None,
) -> numpy.ndarray:
    """
    Generate data with unmeasured mediators in a sequence.
    X → Z₁ → Z₂ → Z₃ → Z₄ → Z₅ → Y
    Args:
        x_data (numpy.ndarray): The data for X that will be transformed to Y.
        depth (int): The number of (hidden) mediators.
        **kwargs: Additional arguments for the transformation_generator.
    Returns:
        numpy.ndarray: Y
    """
    assert transformation_generator is not None
    assert transformation_args is not None
    assert rng is not None
    assert noise_type is not None

    data = x_data

    for _ in range(depth):
        data = transformation_generator(**{**transformation_args, "rng": rng})(
            data
        ) + generate_noise(rng, noise_type, data.shape, noise_parameters)
    return data


def one_mediator(
    x_data: numpy.ndarray,
    depth: int = 1,
    rng=None,
    transformation_generator: Callable = None,
    transformation_args: dict = None,
    noise_type: str = None,
    noise_parameters=None,
    mediator_noise_type: str = None,
    mediator_noise_parameters=None,
) -> numpy.ndarray:
    """
    Generate data with unmeasured mediators in a sequence.
    X → Z₁ → Z₂ → Z₃ → Z₄ → Z₅ → Y
    Args:
        x_data (numpy.ndarray): The data for X that will be transformed to Y.
        depth (int): The number of (hidden) mediators.
        **kwargs: Additional arguments for the transformation_generator.
    Returns:
        numpy.ndarray: Y
    """
    assert transformation_generator is not None
    assert transformation_args is not None
    assert rng is not None
    assert noise_type is not None
    assert depth == 2

    data = x_data

    data = transformation_generator(**{**transformation_args, "rng": rng})(
        data
    ) + generate_noise(rng, mediator_noise_type, data.shape, mediator_noise_parameters)

    data = transformation_generator(**{**transformation_args, "rng": rng})(
        data
    ) + generate_noise(rng, noise_type, data.shape, noise_parameters)

    return data


def multiple_mediators(
    x_data: numpy.ndarray,
    depth: int = 1,
    rng=None,
    transformation_generator: Callable = None,
    transformation_args: dict = None,
    noise_type: str = None,
    noise_parameters=None,
    mediator_noise_type: str = None,
    mediator_noise_parameters=None,
) -> numpy.ndarray:
    """
    Generate data with unmeasured mediators in a sequence.
    X → Z₁ → Z₂ → Z₃ → Z₄ → Z₅ → Y
    Args:
        x_data (numpy.ndarray): The data for X that will be transformed to Y.
        depth (int): The number of (hidden) mediators.
        **kwargs: Additional arguments for the transformation_generator.
    Returns:
        numpy.ndarray: Y
    """
    assert transformation_generator is not None
    assert transformation_args is not None
    assert rng is not None
    assert noise_type is not None

    data = x_data

    for _ in range(depth - 1):
        data = transformation_generator(**{**transformation_args, "rng": rng})(
            data
        ) + generate_noise(
            rng, mediator_noise_type, data.shape, mediator_noise_parameters
        )

    data = transformation_generator(**{**transformation_args, "rng": rng})(
        data
    ) + generate_noise(rng, noise_type, data.shape, noise_parameters)

    return data


def confounded(
    x_data: numpy.ndarray,
    depth: int = 1,
    rng=None,
    transformation_generator: Callable = None,
    transformation_args: dict = None,
):
    """
    Generate confounded data. x_data will become the confounder
    """
    assert transformation_generator is not None
    assert transformation_args is not None
    assert rng is not None

    confounder_data = x_data
    data_left = confounder_data
    data_right = confounder_data

    for _ in range(depth):
        data_left = transformation_generator(**{**transformation_args, "rng": rng})(
            data_left
        ) + rng.normal(0, 1, size=data_left.shape)
        data_right = transformation_generator(**{**transformation_args, "rng": rng})(
            data_right
        ) + rng.normal(0, 1, size=data_right.shape)
    return data_left, data_right


def parallel_mediator(
    x_data: numpy.ndarray, depth: int = 1, transformation: Callable = None, **kwargs
) -> numpy.ndarray:
    """
    Generate data with unmeasured mediators in parallel.
      ↗ Z₁ ↘
    X → Z₂ → Y
      ↘ Z₃ ↗

    """
    raise NotImplemented()
    # Apply the neural network transformation
    data = x_data.reshape(-1, 1)

    z_data = numpy.zeros((x_data.shape[0], depth))
    for i in range(depth):
        z_data[:, i] = transformation(data, **kwargs) + numpy.random.normal(
            0, 1, size=x_data.shape[0]
        )
    return numpy.sum(z_data, axis=1).reshape(-1, 1)
