import numpy
import pytest

from meshflow import platform
from meshflow.unifyshard.combination import CombinationFunc
from meshflow.utils.testing import ALL_PLATFORM, setup_testing


@pytest.mark.parametrize("backend", ALL_PLATFORM)
def test_identity(backend):
    setup_testing(backend)
    shard_tensor = [platform.from_numpy(numpy.array([1, 2, 3]))] * 4
    global_tensor = platform.from_numpy(numpy.array([1, 2, 3]))
    combination_tensor = CombinationFunc.identity(shard_tensor)

    assert platform.allclose(global_tensor, combination_tensor)


@pytest.mark.parametrize("backend", ALL_PLATFORM)
def test_identity_2(backend):
    setup_testing(backend)
    shard_tensor = [platform.from_numpy(numpy.array([1, 2, 3]))] * 4
    global_tensor = platform.from_numpy(numpy.array([1, 2, 4]))
    combination_tensor = CombinationFunc.identity(shard_tensor)

    assert not platform.allclose(global_tensor, combination_tensor)


@pytest.mark.parametrize("backend", ALL_PLATFORM)
def test_identity_3(backend):
    setup_testing(backend)
    shard_tensor = [platform.from_numpy(numpy.array([1, 2, 3]))] * 3 + [
        platform.from_numpy(numpy.array([1, 2, 4]))
    ]
    combination_tensor = CombinationFunc.identity(shard_tensor)

    assert combination_tensor is None
