# Copyright (c) 
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from legoscale.parallelisms.parallel_dims import ParallelDims
from legoscale.parallelisms.parallelize_llama import parallelize_llama
from legoscale.parallelisms.pipeline_llama import pipeline_llama


__all__ = [
    "models_parallelize_fns",
    "models_pipelining_fns",
    "ParallelDims",
]

models_parallelize_fns = {
    "llama2": parallelize_llama,
    "llama3": parallelize_llama,
}
models_pipelining_fns = {
    "llama2": pipeline_llama,
    "llama3": pipeline_llama,
}
