import os

from typing import List
from huggingface_hub import snapshot_download

MODELINFO = {
    "llama-3": "meta-llama/Meta-Llama-3-8B-Instruct",
    "llama-2": "meta-llama/Llama-2-7b-hf",
    "sheared-llama": "princeton-nlp/Sheared-LLaMA-1.3B",
    "mistral-7b": "mistralai/Mistral-7B-v0.1"
}

def downloadllms(models:List[str], dest_root_path:str):
    """
    Download necessary LLMs from huggingface.co.
    
    Args:
        models (List[str]): A list of model names required for this repo. Mainly include: LLaMA-3, GPT-3.5, GPT-4.
        dest_root_path (str): The destination path for the downloaded LLMs.
    """
    for model in models:
        model_dir = os.path.join(dest_root_path, model)
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
        
        if not os.listdir(model_dir):
            snapshot_download(repo_id=MODELINFO[model], cache_dir=model_dir)
            print(f"Successfully download {model}!")
        else:
            print(f"The model {model} already exists.")
            continue


if __name__ == "__main__":
    model_list = ["llama-3", "llama-2", "sheared-llama", "mistral-7b"]
    cache_dir = "./cache"
    downloadllms(models=model_list, dest_root_path=cache_dir)