import numpy
import pytest

from meshflow import platform
from meshflow.unifyshard.combination import CombinationFunc
from meshflow.utils.testing import ALL_PLATFORM, setup_testing


@pytest.mark.parametrize("backend", ALL_PLATFORM)
def test_gather(backend):
    setup_testing(backend)
    shard_tensor = [platform.from_numpy(numpy.ones((3, 4)))] * 4
    global_tensor_1 = platform.from_numpy(numpy.ones((12, 4)))
    gather_dim1 = CombinationFunc.gather(shard_tensor, dim=0)

    assert platform.allclose(global_tensor_1, gather_dim1)

    global_tensor_2 = platform.from_numpy(numpy.ones((3, 16)))
    gather_dim2 = CombinationFunc.gather(shard_tensor, dim=1)

    assert platform.allclose(global_tensor_2, gather_dim2)


@pytest.mark.parametrize("backend", ALL_PLATFORM)
def test_gather_halo(backend):
    setup_testing(backend)
    shard_tensor = [platform.from_numpy(numpy.array([1, 1, 1]))] * 3
    global_tensor_1 = platform.from_numpy(numpy.array([1, 1, 2, 1, 2, 1, 1]))
    gather_halo_1 = CombinationFunc.gather(shard_tensor, dim=0, halowidth=1)

    assert platform.allclose(global_tensor_1, gather_halo_1)

    global_tensor_2 = platform.from_numpy(numpy.array([1, 1, 1, 1, 1]))
    gather_halo_2 = CombinationFunc.gather(shard_tensor, dim=0, halowidth=-1)

    assert platform.allclose(global_tensor_2, gather_halo_2)


@pytest.mark.parametrize("backend", ALL_PLATFORM)
def test_gather_chunk(backend):
    setup_testing(backend)
    shard_tensor = [platform.from_numpy(numpy.array([1, 2, 3]))] * 3
    global_tensor = platform.from_numpy(numpy.array([1, 1, 1, 2, 2, 2, 3, 3, 3]))
    gather_chunk = CombinationFunc.gather(shard_tensor, dim=0, chunk=3)

    assert platform.allclose(global_tensor, gather_chunk)
