import functools

from meshflow.unifyshard import ShardAnnotation, ShardDim
from meshflow.unifyshard.combination import CombinationFunc
from meshflow.unifyshard.view_propagation import view_propagation_preset
from meshflow.utils.testing import assert_partial_func_equal


def test_view_propagation_preset():
    preset_anno = ShardAnnotation([[ShardDim(1, chunk=5), ShardDim(0)]])
    comb_func = view_propagation_preset([10, 8], [5, 2, 8], preset_anno)
    answer = functools.partial(CombinationFunc.gather, dim=1)
    assert_partial_func_equal(comb_func, answer)

    preset_anno = ShardAnnotation([[ShardDim(0), ShardDim(1, chunk=2)]])
    comb_func = view_propagation_preset([10, 8], [10, 2, 2, 2], preset_anno)
    answer = functools.partial(CombinationFunc.gather, dim=2)
    assert_partial_func_equal(comb_func, answer)

    preset_anno = ShardAnnotation([[ShardDim(0), ShardDim(1, chunk=4)]])
    comb_func = view_propagation_preset([10, 8], [10, 2, 2, 2], preset_anno)
    answer = functools.partial(CombinationFunc.gather, dim=3)
    assert_partial_func_equal(comb_func, answer)

    preset_anno = ShardAnnotation([[ShardDim(1, chunk=3), ShardDim(0)]])
    comb_func = view_propagation_preset([10, 8], [5, 2, 8], preset_anno)
    assert comb_func is None

    preset_anno = ShardAnnotation([[ShardDim(0), ShardDim(1, chunk=2)]])
    comb_func = view_propagation_preset([10, 8], [5, 2, 8], preset_anno)
    assert comb_func is None


test_view_propagation_preset()