from dataclasses import dataclass
from argparse_dataclass import ArgumentParser

from quick_extend.models.load_model import ModelConfig, load_model


@dataclass
class MergeConfig(ModelConfig):
    output_path: str = None


def parse_args():
    parser = ArgumentParser(MergeConfig)
    config = parser.parse_args()
    print(config)
    return config


def main():
    config = parse_args()

    model, tokenizer = load_model(config, for_training=False)

    model = model.merge_and_unload()

    model.save_pretrained(config.output_path, use_safetensors=True)


if __name__ == "__main__":
    main()
