from jax import numpy as jnp
from jax import random
import numpy as np

import image_jax

def test_crop_scalar():
    im = jnp.asarray([
        [1, 2, 3, 4, 5],
        [6, 7, 8, 9, 10],
        [11, 12, 13, 14, 15],
    ])
    offsets = jnp.asarray([1, 2])
    size = jnp.asarray([2, 3])
    actual = image_jax.crop(im, offsets, size)

    expect = jnp.asarray([
        [8, 9, 10],
        [13, 14, 15],
    ])
    np.testing.assert_array_equal(actual, expect)


def test_crop_channels():
    im_a = jnp.asarray([
        [1, 2, 3],
        [4, 5, 6],
        [7, 8, 9],
    ])
    im_b = jnp.asarray([
        [10, 11, 12],
        [13, 14, 15],
        [16, 17, 18],
    ])
    im = jnp.stack([im_a, im_b], axis=-1)
    offsets = (1, 0)
    size = (1, 3)
    actual = image_jax.crop(im, offsets, size)

    expect_a = jnp.asarray([[4, 5, 6]])
    expect_b = jnp.asarray([[13, 14, 15]])
    expect = jnp.stack([expect_a, expect_b], axis=-1)
    np.testing.assert_array_equal(actual, expect)


def test_pad_scalar():
    im = jnp.asarray([
        [1, 2, 3],
        [4, 5, 6],
    ])
    padding = (1, 2)
    axis = (-1, 0)
    actual = image_jax.pad(im, 0, padding, axis)

    expect = jnp.asarray([
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 1, 2, 3, 0],
        [0, 4, 5, 6, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
    ])
    np.testing.assert_array_equal(actual, expect)


def test_pad_batch():
    im = jnp.asarray([
        [[1, 2, 3],
         [4, 5, 6]],
        [[7, 8, 9],
         [10, 11, 12]],
    ])
    padding = (1, 1)
    axis = (1, 2)
    actual = image_jax.pad(im, 0, padding, axis)

    expect = jnp.asarray([
        [[0, 0, 0, 0, 0],
         [0, 1, 2, 3, 0],
         [0, 4, 5, 6, 0],
         [0, 0, 0, 0, 0]],
        [[0, 0, 0, 0, 0],
         [0, 7, 8, 9, 0],
         [0, 10, 11, 12, 0],
         [0, 0, 0, 0, 0]],
    ])
    np.testing.assert_array_equal(actual, expect)


def test_random_pad_crop():
    rng = random.PRNGKey(0)
    im_hw = jnp.asarray([
        [1, 1, 2, 2],
        [3, 3, 4, 4],
        [5, 5, 6, 6],
    ])
    im_b = 10 * (1 + jnp.arange(32))
    im_c = jnp.asarray([0.25, 0.5, 0.75])
    im = im_b[:, None, None, None] + im_hw[:, :, None] + im_c
    fill = -1.0
    padding = (1, 1)
    size = (2, 2)
    actual = image_jax.random_pad_crop(rng, im, fill, padding, size)

    actual = np.array(actual)
    assert tuple(actual.shape) == (32, 2, 2, 3)
    # Across samples, at least one element should be padded.
    min_value = np.min(actual)
    assert min_value == -1.0
    # Across samples, at least one element in each position should be padded.
    min_per_pos = np.min(actual, axis=(0, 3)) 
    np.testing.assert_array_equal(min_per_pos, jnp.full_like(min_per_pos, -1))


def test_cond_flip_false():
    im = jnp.asarray([
        [1, 2, 3],
        [4, 5, 6],
    ])
    actual = image_jax.cond_flip(False, im, axis=1)
    np.testing.assert_array_equal(actual, im)


def test_cond_flip_true():
    im = jnp.asarray([
        [1, 2, 3],
        [4, 5, 6],
    ])
    actual = image_jax.cond_flip(True, im, axis=1)
    expected = jnp.asarray([
        [3, 2, 1],
        [6, 5, 4],
    ])
    np.testing.assert_array_equal(actual, expected)


def test_cond_flip_axis():
    im = jnp.asarray([
        [1, 2, 3],
        [4, 5, 6],
    ])
    actual = image_jax.cond_flip(True, im, axis=0)
    expected = jnp.asarray([
        [4, 5, 6],
        [1, 2, 3],
    ])
    np.testing.assert_array_equal(actual, expected)


def test_random_flip():
    rng = random.PRNGKey(0)
    im_hw = jnp.asarray([
        [[1, 2, 3],
         [4, 5, 6]],
    ])
    im_b = 10 * (1 + jnp.arange(32))
    im = im_b[:, None, None] + im_hw[:, :]
    actual = image_jax.random_flip(rng, im, axis=2)
    # Differences left-to-right are -1 or 1.
    np.testing.assert_array_equal(np.unique(np.diff(actual, axis=2)), [-1, 1])
    # Differences top-to-bottom are all 3.
    np.testing.assert_array_equal(np.unique(np.diff(actual, axis=1)), [3])
    # Differences along batch are 10 +/- 2 (e.g. 11 to 23, 13 to 23, 13 to 21).
    np.testing.assert_array_equal(np.unique(np.diff(actual, axis=0)), [8, 10, 12])
