import pytest
import numpy
import math
from data_generation.transform_functions import neural_network_transform
from data_generation.mediator_shapes import sequential_mediator
from data_generation.data_from_dict import data_from_dict, generate_noise


def test_neural_network_generation_is_seeded():
    test_data = numpy.random.uniform(-10, 10, (1000, 1))

    rng = numpy.random.default_rng(42)
    neural_network_transform_one = neural_network_transform(100, 1, rng)

    rng = numpy.random.default_rng(42)
    neural_network_transform_two = neural_network_transform(100, 1, rng)

    # Using the rng now should not change the transformation function
    # (make sure that the transformation is stored already)
    rng.uniform(0, 1, 1000)

    transformed_data_one = neural_network_transform_one(test_data)
    transformed_data_two = neural_network_transform_two(test_data)

    numpy.testing.assert_array_equal(transformed_data_one, transformed_data_two)

    # Test reusing rng gives different functions

    rng = numpy.random.default_rng(42)
    neural_network_transform_one = neural_network_transform(100, 1, rng)
    neural_network_transform_two = neural_network_transform(100, 1, rng)

    with pytest.raises(AssertionError):
        numpy.testing.assert_array_equal(
            neural_network_transform_one(test_data),
            neural_network_transform_two(test_data),
        )


def test_sequential_transform_function_end_to_end():
    X = numpy.random.uniform(-10, 10, (1000, 1))

    rng = numpy.random.default_rng(42)
    Y_one = sequential_mediator(
        X,
        depth=10,
        rng=rng,
        transformation_generator=neural_network_transform,
        transformation_args={
            "num_hidden": 10,
            "num_parents": 1,
        },
        noise_type="normal",
    )

    rng = numpy.random.default_rng(42)
    Y_two = sequential_mediator(
        X,
        depth=10,
        rng=rng,
        transformation_generator=neural_network_transform,
        transformation_args={
            "num_hidden": 10,
            "num_parents": 1,
        },
        noise_type="normal",
    )

    numpy.testing.assert_array_equal(Y_one, Y_two)

    rng = numpy.random.default_rng(42)
    Y_one = sequential_mediator(
        X,
        depth=10,
        rng=rng,
        transformation_generator=neural_network_transform,
        transformation_args={
            "num_hidden": 10,
            "num_parents": 1,
        },
        noise_type="normal",
    )
    Y_two = sequential_mediator(
        X,
        depth=10,
        rng=rng,
        transformation_generator=neural_network_transform,
        transformation_args={
            "num_hidden": 10,
            "num_parents": 1,
        },
        noise_type="normal",
    )

    with pytest.raises(AssertionError):
        numpy.testing.assert_array_equal(Y_one, Y_two)


def test_data_from_dict():
    params = {
        "X": {
            "type": "normal",
            "length": 100,
        },
        "transformation": {
            "type": "neural_network",
            "args": {
                "num_hidden": 10,
                "num_parents": 1,
            },
        },
        "shape": "sequence",
        "depth": 10,
        "seed": 42,
        "noise_type": "normal",
    }

    data_one = data_from_dict(params)
    data_two = data_from_dict(params)

    assert data_one.shape == (100, 2)
    numpy.testing.assert_array_equal(data_one, data_two)

    del params["seed"]
    data_one = data_from_dict(params)
    data_two = data_from_dict(params)

    with pytest.raises(AssertionError):
        numpy.testing.assert_array_equal(data_one, data_two)


def test_data_from_dict_uses_depth():
    params = {
        "X": {
            "type": "normal",
            "length": 100,
        },
        "transformation": {
            "type": "neural_network",
            "args": {
                "num_hidden": 10,
                "num_parents": 1,
            },
        },
        "shape": "sequence",
        "depth": 10,
        "seed": 42,
        "noise_type": "normal",
    }

    data_depth_ten = data_from_dict(params)

    params["depth"] = 100

    data_depth_hundret = data_from_dict(params)

    with pytest.raises(AssertionError):
        numpy.testing.assert_array_equal(data_depth_ten, data_depth_hundret)
    

def test_noises_have_same_variance():
    rng = numpy.random.default_rng(42)

    for type in ["uniform", "laplace", "normal", "normal_mixture"]:
        assert math.isclose(numpy.var(generate_noise(rng, type, 1000)), 1.0, abs_tol=.05)