import numpy
from meshflow import platform
from meshflow.unifyshard.combination import aligned_prefix, shape_aligned_otherdim


def test_aligned_prefix():
    t1 = platform.from_numpy(numpy.array([1, 2, 3, 4]))
    t2 = platform.from_numpy(numpy.array([1, 2, 3, 4]))
    assert 4 == aligned_prefix(t1, t2, dim_idx=0)

    t1 = platform.from_numpy(numpy.array([1, 2, 3, 4]))
    t2 = platform.from_numpy(numpy.array([2, 2, 3, 4]))
    assert 0 == aligned_prefix(t1, t2, dim_idx=0)

    t1 = platform.from_numpy(numpy.array([[1, 2, 3, 4], [1, 2, 3, 4]]))
    t2 = platform.from_numpy(numpy.array([[1, 2, 3, 4], [1, 2, 3, 5]]))
    assert 1 == aligned_prefix(t1, t2, dim_idx=0)
    assert 3 == aligned_prefix(t1, t2, dim_idx=1)


def test_aligned_otherdim():
    shape_1 = (10, 11, 12)
    shape_2 = (10, 13, 12)
    assert shape_aligned_otherdim(shape_1, shape_2, 1) == True
    assert shape_aligned_otherdim(shape_1, shape_2, 2) == False

    shape_1 = (10, 11, 12)
    shape_2 = (10, 13, 13)
    assert shape_aligned_otherdim(shape_1, shape_2, 1) == False
    assert shape_aligned_otherdim(shape_1, shape_2, 2) == False

    shape_1 = (10, 11, 12)
    shape_2 = (10, 11, 12, 13)
    assert shape_aligned_otherdim(shape_1, shape_2, 2) == False
    assert shape_aligned_otherdim(shape_1, shape_2, 3) == False
