import pytest
import jax
import jax.numpy as jnp
from diffuse.images import SquareMask
import matplotlib.pyplot as plt


@pytest.fixture
def random_image():
    key = jax.random.PRNGKey(0)
    return jax.random.normal(key, (28, 28, 1))


@pytest.fixture
def square_mask():
    return SquareMask(10, (28, 28, 1))


def test_measure_and_restore(random_image, square_mask):
    key = jax.random.PRNGKey(1)
    xi = jax.random.uniform(key, (2,), minval=0, maxval=28)

    measured = square_mask.measure(xi, random_image)
    restored = square_mask.restore(xi, random_image, measured)

    assert jnp.array_equal(random_image, restored)


def test_restore_with_zero_measured(random_image, square_mask):
    key = jax.random.PRNGKey(1)
    xi = jax.random.uniform(key, (2,), minval=0, maxval=28)

    measured = jax.random.normal(key, random_image.shape)
    restored = square_mask.restore(xi, random_image, measured)

    assert jnp.array_equal(
        square_mask.measure(xi, restored), square_mask.measure(xi, measured)
    )


def test_measure_restore(random_image, square_mask):
    key = jax.random.PRNGKey(1)
    key_x, key_y = jax.random.split(key)
    xi = jax.random.uniform(key, (2,), minval=0, maxval=28)
    x = jax.random.normal(key_x, random_image.shape)
    y = jax.random.normal(key_y, random_image.shape)

    restored = square_mask.restore(xi, x, y)
    measured = square_mask.measure(xi, y)
    measured_restored = square_mask.measure(xi, restored)

    # plot restored and measured on 2 axis
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
    ax1.imshow(restored, cmap="gray")
    ax2.imshow(measured, cmap="gray")
    ax3.imshow(measured_restored, cmap="gray")
    plt.show()
    assert jnp.array_equal(measured, measured_restored)


def test_mask_shape(square_mask):
    xi = jnp.array([14.0, 14.0])
    mask = square_mask.make(xi)
    assert mask.shape == (28, 28, 1)


def test_mask_values(square_mask):
    xi = jnp.array([14.0, 14.0])
    mask = square_mask.make(xi)
    assert jnp.all(mask >= 0) and jnp.all(mask <= 1)


def test_measure_preserves_shape(random_image, square_mask):
    xi = jnp.array([14.0, 14.0])
    measured = square_mask.measure(xi, random_image)
    assert measured.shape == random_image.shape


def test_restore_preserves_shape(random_image, square_mask):
    xi = jnp.array([14.0, 14.0])
    measured = square_mask.measure(xi, random_image)
    restored = square_mask.restore(xi, random_image, measured)
    assert restored.shape == random_image.shape
