import numpy
import pytest
import functools

from meshflow import platform
from meshflow.unifyshard.combination import CombinationFunc, ReduceOp, try_combination_single
from meshflow.utils.testing import ALL_PLATFORM, setup_testing, assert_partial_func_equal


@pytest.mark.parametrize("backend", ALL_PLATFORM)
def test_reduce(backend):
    setup_testing(backend)
    shard_tensor = [platform.from_numpy(numpy.random.uniform(size=(3, 4)))] * 4

    for op_type in [ReduceOp.MAX, ReduceOp.MIN, ReduceOp.SUM]:
        comb_func = functools.partial(CombinationFunc.reduce, ops=op_type)
        global_tensor = comb_func(shard_tensor)

        return_func = try_combination_single(shard_tensor, global_tensor)

        assert_partial_func_equal(comb_func, return_func)


@pytest.mark.parametrize("backend", ALL_PLATFORM)
def test_gather(backend):
    setup_testing(backend)
    shard_tensor = [platform.from_numpy(numpy.random.uniform(size=(3, 4)))] * 4

    for dim_ in [0, 1]:
        comb_func = functools.partial(CombinationFunc.gather, dim=dim_)
        global_tensor = comb_func(shard_tensor)

        return_func = try_combination_single(shard_tensor, global_tensor)

        assert_partial_func_equal(comb_func, return_func)


@pytest.mark.parametrize("backend", ALL_PLATFORM)
def test_gather_halo(backend):
    setup_testing(backend)
    shard_tensor = [platform.from_numpy(numpy.random.uniform(size=(3, 4)))] * 3

    for dim_, halo_ in zip([0, 1], [1, 2]):
        comb_func = functools.partial(CombinationFunc.gather, dim=dim_, halowidth=halo_)
        global_tensor = comb_func(shard_tensor)

        return_func = try_combination_single(shard_tensor, global_tensor)

        assert_partial_func_equal(comb_func, return_func)


@pytest.mark.parametrize("backend", ALL_PLATFORM)
def test_gather_chunk(backend):
    setup_testing(backend)
    shard_tensor = [platform.from_numpy(numpy.random.uniform(size=(3, 4)))] * 3

    for dim_, chunk_ in zip([0, 1], [3, 2]):
        comb_func = functools.partial(CombinationFunc.gather, dim=dim_, chunk=chunk_)
        global_tensor = comb_func(shard_tensor)

        return_func = try_combination_single(shard_tensor, global_tensor)

        assert_partial_func_equal(comb_func, return_func)
