import sys
from pathlib import Path

from transformers import AutoTokenizer, AutoModelForCausalLM


model_names = [
    # "facebook/opt-350m",
    # "meta-llama/Llama-2-7b-hf",
    # "meta-llama/Llama-2-13b-hf",
    # "tiiuae/falcon-7b",
    # "mosaicml/mpt-7b",
    # "TheBloke/gpt4-alpaca-lora-13B-HF",
    # "meta-llama/Llama-2-70b-hf",
    # "lmsys/vicuna-7b-v1.5",
    # "lmsys/vicuna-13b-v1.5",
    "WizardLM/WizardLM-13B-V1.2",
    "WizardLM/WizardMath-7B-V1.0",
    "WizardLM/WizardMath-13B-V1.0",
]


def create_better_name(model_name):
    better_name = model_name.split("/")[-1].lower()
    if better_name.endswith("-hf"):
        better_name = better_name[:-3]
    better_name = better_name.replace("-", "_").replace(".", "_")
    return better_name
        

def download_model(model_name):
    print(f"Downloading {model_name}")

    better_name = create_better_name(model_name)
    model_path = base_model_path / better_name

    if model_path.exists():
        print(f"{model_path.as_posix()} exists, skipping")
        return

    print(f"Saving {model_name} to {model_path.as_posix()}")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)

    model.save_pretrained(model_path.as_posix())
    tokenizer.save_pretrained(model_path.as_posix())
    print("Saved")


if __name__ == "__main__":
    base_model_path = Path(sys.argv[1])
    if len(sys.argv) > 2:
        model_names = sys.argv[2:]

    for model_name in model_names:
        download_model(model_name)