import fire

from src.modeling_args import ModelArgs
from src.utils import merge_lora_checkpoints


def run(
        ckpt_dir: str,
        world_size: int,
        config_file: str
):
    """ merge lora checkpoints """
    args = ModelArgs(
        max_seq_len=512,
        local_rank=-1,
        world_size=world_size,
    ).from_json(config_file)
    merge_lora_checkpoints(
        ckpt_dir=ckpt_dir,
        world_size=args.world_size,
        layers=args.n_layers
    )


if __name__ == '__main__':
    fire.Fire(run)
