import argparse

import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration


def preprocess_text_encoder_tokenizer(args):

    processor = AutoProcessor.from_pretrained(args.input_dir)
    model = LlavaForConditionalGeneration.from_pretrained(
        args.input_dir,
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
    ).to(0)

    model.language_model.save_pretrained(f"{args.output_dir}")
    processor.tokenizer.save_pretrained(f"{args.output_dir}")


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--input_dir",
        type=str,
        required=True,
        help="The path to the llava-llama-3-8b-v1_1-transformers.",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="",
        help="The output path of the llava-llama-3-8b-text-encoder-tokenizer."
        "if '', the parent dir of output will be the same as input dir.",
    )
    args = parser.parse_args()

    if len(args.output_dir) == 0:
        args.output_dir = "/".join(args.input_dir.split("/")[:-1])

    preprocess_text_encoder_tokenizer(args)
