import functools

import torch

from meshflow.unifyshard.combination import CombinationFunc
from meshflow.unifyshard import ShardAnnotation, ShardDim, UnifyOp
from meshflow.utils.testing import assert_partial_func_equal, setup_testing


def test_unifyop_preset():
    setup_testing("torch")
    input_args = (torch.rand((3, 4, 768)), 3, 2), {}
    unify_op = UnifyOp(torch.ops.aten.chunk, input_args)
    preset_anno = ShardAnnotation([[ShardDim(0), ShardDim(0), ShardDim(1, chunk=3)]])
    comb_func = unify_op.sharding_discovery_with_preset(preset_anno)

    right_answer = [functools.partial(CombinationFunc.gather, dim=2)] * 3

    assert comb_func != None
    assert len(comb_func) == len(right_answer)

    for func1, func2 in zip(comb_func, right_answer):
        assert_partial_func_equal(func1, func2)
