import functools

from meshflow.unifyshard.annotation import ShardAnnotation, ShardDim
from meshflow.unifyshard.combination import CombinationFunc

EXTEND_VIEW = False

def get_next_non_one(shape_, idx_):
    if idx_ < len(shape_):
        while shape_[idx_] == 1:
            idx_ += 1
            if idx_ >= len(shape_):
                break
    return idx_


def view_propagation(input_shape, output_shape):

    sharding_ann = ShardAnnotation([[ShardDim(0)] * len(input_shape)])
    combination_ann = {}

    input_idx = get_next_non_one(input_shape, 0)
    output_idx = get_next_non_one(output_shape, 0)

    shard_dim = 1

    while input_idx < len(input_shape):
        if input_shape[input_idx] == output_shape[output_idx]:
            # [**, A, **] -> [**, A, **]
            sharding_ann[0][input_idx] = ShardDim(shard_dim)
            combination_ann[shard_dim] = functools.partial(CombinationFunc.gather, dim=output_idx)
            input_idx = get_next_non_one(input_shape, input_idx + 1)
            output_idx = get_next_non_one(output_shape, output_idx + 1)
            shard_dim += 1
        elif input_shape[input_idx] > output_shape[output_idx]:
            # [**, A, **] -> [**, a1, a2, **]
            leftmost_idx = output_idx
            accum_shape_ = output_shape[output_idx]
            for o_idx in range(output_idx + 1, len(output_shape)):
                accum_shape_ *= output_shape[o_idx]
                if accum_shape_ == input_shape[input_idx]:
                    sharding_ann[0][input_idx] = ShardDim(shard_dim)
                    combination_ann[shard_dim] = functools.partial(CombinationFunc.gather,
                                                                   dim=leftmost_idx)
                    output_idx = get_next_non_one(output_shape, o_idx + 1)
                    input_idx = get_next_non_one(input_shape, input_idx + 1)
                    shard_dim += 1
                    break
        else:
            # [**, a1, a2, **] -> [**, A, **]
            leftmost_idx = input_idx
            accum_shape_ = input_shape[input_idx]
            for i_idx in range(input_idx + 1, len(input_shape)):
                accum_shape_ *= input_shape[i_idx]
                if accum_shape_ == output_shape[output_idx]:
                    if EXTEND_VIEW:
                        chunk_size_ = 1
                        for sub_idx in range(input_idx, i_idx + 1):
                            sharding_ann[0][sub_idx] = ShardDim(shard_dim)
                            combination_ann[shard_dim] = functools.partial(CombinationFunc.gather,
                                                                        dim=output_idx,
                                                                        chunk=chunk_size_)
                            chunk_size_ *= input_shape[sub_idx]
                            shard_dim += 1
                    else:
                        sharding_ann[0][input_idx] = ShardDim(shard_dim)
                        combination_ann[shard_dim] = functools.partial(CombinationFunc.gather,
                                                                    dim=output_idx)
                        shard_dim += 1

                    output_idx = get_next_non_one(output_shape, output_idx + 1)
                    input_idx = get_next_non_one(input_shape, i_idx + 1)
                    break

    return {'sharding_ann': sharding_ann, 'combination_ann': combination_ann}


def view_propagation_preset(input_shape, output_shape, preset_anno):
    accum_size = 1
    for idx, ann in enumerate(preset_anno[0]):
        if ann.shard_dim_id != 0:
            break
        accum_size *= input_shape[idx]

    chunk = preset_anno[0][idx].chunk

    out_accum_size = 1
    out_idx = 0
    while out_accum_size < accum_size:
        out_accum_size *= output_shape[out_idx]
        out_idx += 1

    if out_accum_size == accum_size:
        accum_chunk = 1
        if chunk == accum_chunk:
            return functools.partial(CombinationFunc.gather, dim=out_idx)
        for o_idx in range(out_idx, len(output_shape)):
            if chunk == accum_chunk:
                return functools.partial(CombinationFunc.gather, dim=o_idx)
            accum_chunk *= output_shape[o_idx]
