import numpy
import pytest

from meshflow import platform
from meshflow.unifyshard.combination import CombinationFunc, ReduceOp
from meshflow.utils.testing import ALL_PLATFORM, setup_testing


@pytest.mark.parametrize("backend", ALL_PLATFORM)
def test_reduce(backend):
    setup_testing(backend)
    shard_tensor = [platform.from_numpy(numpy.array([i, i, i])) for i in range(4)]
    max_tensor = platform.from_numpy(numpy.array([3, 3, 3]))
    combination_max = CombinationFunc.reduce(shard_tensor, ops=ReduceOp.MAX)

    assert platform.allclose(max_tensor, combination_max)

    min_tensor = platform.from_numpy(numpy.array([0, 0, 0]))
    combination_min = CombinationFunc.reduce(shard_tensor, ops=ReduceOp.MIN)

    assert platform.allclose(min_tensor, combination_min)

    sum_tensor = platform.from_numpy(numpy.array([6, 6, 6]))
    combination_sum = CombinationFunc.reduce(shard_tensor, ops=ReduceOp.SUM)

    assert platform.allclose(sum_tensor, combination_sum)
