from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
import transformers
import torch
from huggingface_hub import login
import sys
import os

def get_paths_from_string(llm_string):

    path_dict = {
        "llama-7b": "/data/models/huggingface/meta-llama/Llama-2-7b-chat-hf/",
        "llama-13b": "/data/models/huggingface/meta-llama/Llama-2-13b-chat-hf/",
        "llama-70b": "/data/models/huggingface/meta-llama/Llama-2-70b-chat-hf/",
        "llama-7b-base": "./modelsLlama-2-7b-hf/",
        "llama-13b-base": "/data/models/huggingface/meta-llama/Llama-2-13b-hf/",
        "llama-70b-base": "/data/models/huggingface/meta-llama/Llama-2-70b-hf/",
        "vicuna-7b": "/data/models/huggingface/lmsys/vicuna-7b-v1.5-16k",
        "vicuna-13b": "/data/models/huggingface/lmsys/vicuna-13b-v1.5-16k",
        "falcon-7b": "./models/huggingface/tiiuae/falcon-7b",
        "falcon-40b": "/data/models/huggingface/tiiuae/falcon-40b",
        "phi-2": "./models/huggingface/microsoft/phi-2",
        "bloomz-560m": "./models/huggingface/bigscience/bloomz-560m",
        "bloomz-1b1": "./models/huggingface/bigscience/bloomz-1b1",
        "bloomz-1b7": "./models/huggingface/bigscience/bloomz-1b7",
        "bloomz-3b": "./models/huggingface/bigscience/bloomz-3b",
        "bloomz-7b1": "./models/huggingface/bigscience/bloomz-7b1",
        "olmo-1b": "./models/huggingface/allenai/OLMo-1B",
        "olmo-7b": "./models/huggingface/allenai/OLMo-7B",
        "mistral-7b": "./modelsMistral-7B-Instruct-v0.2",
        # "mistral-8x7b": "./modelsMixtral-8x7B-Instruct-v0.1",
        "mistral-8x7b": "/data/models/huggingface/mistralai/Mixtral-8x7B-Instruct-v0.1",
        "pythia-410m": "EleutherAI/pythia-410m",
        "pythia-1b": "./models/huggingface/EleutherAI/pythia-1b",
        "pythia-1.4b": "./models/huggingface/EleutherAI/pythia-1.4b",
        "pythia-2.8b": "./models/huggingface/EleutherAI/pythia-2.8b",
        "pythia-6.9b": "./models/huggingface/EleutherAI/pythia-6.9b",
        "pythia-12b": "./models/huggingface/EleutherAI/pythia-12b",
        "gpt2-medium": "gpt2-medium",
        "gpt2-large": "gpt2-large",
        "gpt2-xl": "gpt2-xl",
        "qwen-7b": "/data/models/huggingface/qwen/Qwen-7B-Chat",
        "qwen-14b": "/data/models/huggingface/qwen/Qwen-14B-Chat",
        "qwen-72b": "/data/models/huggingface/qwen/Qwen-72B-Chat",
        "huggyllama-7b": "./models/huggingface/huggyllama/llama-7b",
        "huggyllama-30b": "./modelsllama-30b",
        "huggyllama-65b": "./modelsllama-65b",
        "gemma-2b": "./models/gemma-2b",
        "gemma-7b": "./models/gemma-7b",
        "gemma-2b-it": "./models/gemma-2b-it",
        "gemma-7b-it": "./models/gemma-7b-it",
    }

    return path_dict[llm_string]

def get_left_pad(llm_string):

    # check if left padding or right padding
    if "qwen" in llm_string:
        left_pad = False
    elif "huggy" in llm_string:
        left_pad = True
    elif "bloom" in llm_string:
        left_pad = True
    elif "llama" in llm_string:
        left_pad = False
    elif "vicuna" in llm_string:
        left_pad = False
    elif "phi" in llm_string:
        left_pad = False
    elif "mistral" in llm_string:
        left_pad = True
    elif "falcon" in llm_string:
        left_pad = False
    elif "pythia" in llm_string:
        left_pad = False # double check this
    elif "olmo" in llm_string:
        left_pad = False
    elif "gemma" in llm_string:
        left_pad = True
    else:
        left_pad = False # default to left pad...
    
    return left_pad

def get_add_token(llm_string):

    if "vicuna" in llm_string:
        add_token = True
    elif "huggy" in llm_string:
        add_token = True
    elif "llama" in llm_string:
        add_token = True
    elif "mistral" in llm_string:
        add_token = True
    elif "phi" in llm_string:
        add_token = False
    elif "falcon" in llm_string:
        add_token = False
    elif "gemma" in llm_string:
        add_token = True
    return add_token

def load_llm(llm_string):

    # if llm_string == "falcon-7b":
    #     tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-7b")
    #     model = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-7b")
    
    path = get_paths_from_string(llm_string)

    if llm_string == "llama-7b":
        if not os.path.exists("/data/models/huggingface/meta-llama/Llama-2-7b-chat-hf/"):
            print("No model found at /data/models/huggingface/meta-llama/Llama-2-7b-chat-hf/")
            sys.exit()
        else:
            print("Downloading")
            tokenizer = AutoTokenizer.from_pretrained("/data/models/huggingface/meta-llama/Llama-2-7b-chat-hf/")
            model = AutoModelForCausalLM.from_pretrained("/data/models/huggingface/meta-llama/Llama-2-7b-chat-hf/", 
                    device_map="auto", low_cpu_mem_usage=True, torch_dtype=torch.float16)

    elif llm_string == "llama-7b-base":
        if not os.path.exists("./modelsLlama-2-7b-hf/"):
            print("No model found at ./modelsLlama-2-7b-hf/")
            sys.exit()
        else:
            print("Downloading")
            tokenizer = AutoTokenizer.from_pretrained("./modelsLlama-2-7b-hf/")
            model = AutoModelForCausalLM.from_pretrained("./modelsLlama-2-7b-hf/", 
                    device_map="auto", low_cpu_mem_usage=True, torch_dtype=torch.float16)

    elif llm_string == "llama-13b":

        # check if path exists
        if not os.path.exists("/data/models/huggingface/meta-llama/Llama-2-13b-chat-hf/"):
            print("No model found at /data/models/huggingface/meta-llama/Llama-2-13b-chat-hf/")
            sys.exit()
        else:
            print("Downloading")
            tokenizer = AutoTokenizer.from_pretrained("/data/models/huggingface/meta-llama/Llama-2-13b-chat-hf/")
            model = AutoModelForCausalLM.from_pretrained("/data/models/huggingface/meta-llama/Llama-2-13b-chat-hf/", 
                    device_map="auto", low_cpu_mem_usage=True, torch_dtype=torch.float16)

    elif llm_string == "llama-13b-base":
        if not os.path.exists("./modelsLlama-2-13b-hf/"):
            print("No model found at ./modelsLlama-2-13b-hf/")
            sys.exit()
        else:
            print("Downloading")
            tokenizer = AutoTokenizer.from_pretrained("./modelsLlama-2-13b-hf/")
            model = AutoModelForCausalLM.from_pretrained("./modelsLlama-2-13b-hf/", 
                    device_map="auto", low_cpu_mem_usage=True, torch_dtype=torch.float16)

    elif llm_string == "llama-70b":
        # check if path exists
        if not os.path.exists("/data/models/huggingface/meta-llama/Llama-2-70b-chat-hf/"):
            print("No model found at /data/models/huggingface/meta-llama/Llama-2-70b-chat-hf/")
            sys.exit()
        else:
            print("Downloading")
            tokenizer = AutoTokenizer.from_pretrained("/data/models/huggingface/meta-llama/Llama-2-70b-chat-hf/")
            # load with half precision
            model = AutoModelForCausalLM.from_pretrained("/data/models/huggingface/meta-llama/Llama-2-70b-chat-hf/", 
                    device_map="auto", low_cpu_mem_usage=True, torch_dtype=torch.float16)

    elif llm_string == "llama-70b-base":
        if not os.path.exists("/data/models/huggingface/meta-llama/Llama-2-70b-hf/"):
            print("No model found at /data/models/huggingface/meta-llama/Llama-2-70b-hf/")
            sys.exit()
        else:
            print("Downloading")
            tokenizer = AutoTokenizer.from_pretrained("/data/models/huggingface/meta-llama/Llama-2-70b-hf/")
            # load with half precision
            model = AutoModelForCausalLM.from_pretrained("/data/models/huggingface/meta-llama/Llama-2-70b-hf/", 
                    device_map="auto", low_cpu_mem_usage=True, torch_dtype=torch.float16)

    elif llm_string == "gemma-2b":
        if not os.path.exists("./models/gemma-2b"):
            print("No model found at ./models/gemma-2b")
            sys.exit()
        else:
            tokenizer = AutoTokenizer.from_pretrained("./models/gemma-2b")
            model = AutoModelForCausalLM.from_pretrained("./models/gemma-2b", device_map="auto")
    
    elif llm_string == "gemma-2b-it":
        if not os.path.exists("./models/gemma-2b-it"):
            print("No model found at ./models/gemma-2b-it")
            sys.exit()
        else:
            tokenizer = AutoTokenizer.from_pretrained("./models/gemma-2b-it")
            model = AutoModelForCausalLM.from_pretrained("./models/gemma-2b-it", device_map="auto")

    elif llm_string == "gemma-7b":
        if not os.path.exists("./models/gemma-7b"):
            print("No model found at ./models/gemma-7b")
            sys.exit()
        else:
            tokenizer = AutoTokenizer.from_pretrained("./models/gemma-7b")
            model = AutoModelForCausalLM.from_pretrained("./models/gemma-7b", 
                                                         device_map="auto", low_cpu_mem_usage=True, torch_dtype=torch.float16)

    elif llm_string == "gemma-7b-it":
        if not os.path.exists("./models/gemma-7b-it"):
            print("No model found at ./models/gemma-7b-it")
            sys.exit()
        else:
            tokenizer = AutoTokenizer.from_pretrained("./models/gemma-7b-it")
            model = AutoModelForCausalLM.from_pretrained("./models/gemma-7b-it", 
                                                         device_map="auto", low_cpu_mem_usage=True, torch_dtype=torch.float16)

    elif llm_string == "vicuna-7b":
        if not os.path.exists("/data/models/huggingface/lmsys/vicuna-7b-v1.5-16k"):
            print("No model found at /data/models/huggingface/lmsys/vicuna-7b-v1.5-16k")
            sys.exit()
        else:
            tokenizer = AutoTokenizer.from_pretrained("/data/models/huggingface/lmsys/vicuna-7b-v1.5-16k")
            model = AutoModelForCausalLM.from_pretrained("/data/models/huggingface/lmsys/vicuna-7b-v1.5-16k",  
                                                            device_map="auto", torch_dtype=torch.float16, low_cpu_mem_usage=True)

    elif llm_string == "vicuna-13b":
        if not os.path.exists("/data/models/huggingface/lmsys/vicuna-13b-v1.5-16k"):
            print("No model found at /data/models/huggingface/lmsys/vicuna-13b-v1.5-16k")
            sys.exit()
        else:
            tokenizer = AutoTokenizer.from_pretrained("/data/models/huggingface/lmsys/vicuna-13b-v1.5-16k")
            model = AutoModelForCausalLM.from_pretrained("/data/models/huggingface/lmsys/vicuna-13b-v1.5-16k",  
                                                         device_map="auto", torch_dtype=torch.float16, low_cpu_mem_usage=True)
    
    elif llm_string == "falcon-7b":
        if not os.path.exists("/data/models/huggingface/tiiuae/falcon-7b"):
            print("No model found at /data/models/huggingface/tiiuae/falcon-7b")
            sys.exit()
        else:
            tokenizer = AutoTokenizer.from_pretrained("/data/models/huggingface/tiiuae/falcon-7b", force_download=False)
            model = AutoModelForCausalLM.from_pretrained("/data/models/huggingface/tiiuae/falcon-7b",  device_map="auto", trust_remote_code=True, force_download=False)

    elif llm_string == "falcon-40b":
        if not os.path.exists("/data/models/huggingface/tiiuae/falcon-40b"):
            print("No model found at /data/models/huggingface/tiiuae/falcon-40b")
            sys.exit()
        else:
            tokenizer = AutoTokenizer.from_pretrained("/data/models/huggingface/tiiuae/falcon-40b", trust_remote_code=True)
            model = AutoModelForCausalLM.from_pretrained("/data/models/huggingface/tiiuae/falcon-40b", trust_remote_code=True,
                                device_map="auto", low_cpu_mem_usage=True, torch_dtype=torch.float16)

    elif llm_string == "phi-2":
        tokenizer = AutoTokenizer.from_pretrained("./models/huggingface/microsoft/phi-2", trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained("./models/huggingface/microsoft/phi-2", device_map="auto", trust_remote_code=True)

    elif llm_string == "bloomz-560m":
        tokenizer = AutoTokenizer.from_pretrained("./models/huggingface/bigscience/bloomz-560m")
        model = AutoModelForCausalLM.from_pretrained("./models/huggingface/bigscience/bloomz-560m", device_map="auto")

    elif llm_string == "bloomz-1b1":
        # tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-1b1")
        # model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-1b1", device_map="auto")
        tokenizer = AutoTokenizer.from_pretrained("./models/huggingface/bigscience/bloomz-1b1")
        model = AutoModelForCausalLM.from_pretrained("./models/huggingface/bigscience/bloomz-1b1", device_map="auto")
    
    elif llm_string == "bloomz-1b7":
        tokenizer = AutoTokenizer.from_pretrained("./models/huggingface/bigscience/bloomz-1b7")
        model = AutoModelForCausalLM.from_pretrained("./models/huggingface/bigscience/bloomz-1b7", device_map="auto")
    
    elif llm_string == "bloomz-3b":
        tokenizer = AutoTokenizer.from_pretrained("./models/huggingface/bigscience/bloomz-3b")
        model = AutoModelForCausalLM.from_pretrained("./models/huggingface/bigscience/bloomz-3b", device_map="auto")

    elif llm_string == "bloomz-7b1":
        tokenizer = AutoTokenizer.from_pretrained("./models/huggingface/bigscience/bloomz-7b1")
        model = AutoModelForCausalLM.from_pretrained("./models/huggingface/bigscience/bloomz-7b1", device_map="auto")

    elif llm_string == "olmo-1b":
        tokenizer = AutoTokenizer.from_pretrained("./models/huggingface/allenai/olmo-1b", trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained("./models/huggingface/allenai/olmo-1b", trust_remote_code=True, device_map="auto")
    
    elif llm_string == "olmo-7b":
        tokenizer = AutoTokenizer.from_pretrained("./models/huggingface/allenai/olmo-7b", trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained("./models/huggingface/allenai/olmo-7b", trust_remote_code=True, device_map="auto")

    elif llm_string == "mistral-7b":
        # tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
        # model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", device_map="auto")
        tokenizer = AutoTokenizer.from_pretrained("./modelsMistral-7B-Instruct-v0.2")
        model = AutoModelForCausalLM.from_pretrained("./modelsMistral-7B-Instruct-v0.2", 
                    device_map="auto", low_cpu_mem_usage=True, torch_dtype=torch.float16)

    elif llm_string == "mistral-8x7b":
        tokenizer = AutoTokenizer.from_pretrained("/data/models/huggingface/mistralai/Mixtral-8x7B-Instruct-v0.1")
        model = AutoModelForCausalLM.from_pretrained("/data/models/huggingface/mistralai/Mixtral-8x7B-Instruct-v0.1", 
                    device_map="auto", low_cpu_mem_usage=True, torch_dtype=torch.float16)

    elif llm_string == "pythia-410m":
        tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m")
        model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-410m", device_map="auto")
        # tokenizer = AutoTokenizer.from_pretrained("./models/huggingface/EleutherAI/pythia-410m")
        # model = AutoModelForCausalLM.from_pretrained("./models/huggingface/EleutherAI/pythia-410m", device_map="auto")

    elif llm_string == "pythia-1b":
        tokenizer = AutoTokenizer.from_pretrained("./models/huggingface/EleutherAI/pythia-1b")
        model = AutoModelForCausalLM.from_pretrained("./models/huggingface/EleutherAI/pythia-1b", device_map="auto")

    elif llm_string == "pythia-1.4b":
        tokenizer = AutoTokenizer.from_pretrained("./models/huggingface/EleutherAI/pythia-1.4b")
        model = AutoModelForCausalLM.from_pretrained("./models/huggingface/EleutherAI/pythia-1.4b", device_map="auto")
    
    elif llm_string == "pythia-2.8b":
        tokenizer = AutoTokenizer.from_pretrained("./models/huggingface/EleutherAI/pythia-2.8b")
        model = AutoModelForCausalLM.from_pretrained("./models/huggingface/EleutherAI/pythia-2.8b", device_map="auto")

    elif llm_string == "pythia-6.9b":
        tokenizer = AutoTokenizer.from_pretrained("./models/huggingface/EleutherAI/pythia-6.9b")
        model = AutoModelForCausalLM.from_pretrained("./models/huggingface/EleutherAI/pythia-6.9b", device_map="auto")

    elif llm_string == "pythia-12b":
        tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-12b")
        model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-12b", device_map="auto")

    elif llm_string == "gpt2-medium":
        tokenizer = AutoTokenizer.from_pretrained("gpt2-medium")
        model = AutoModelForCausalLM.from_pretrained("gpt2-medium", device_map="auto")
        
    elif llm_string == "gpt2-large":
        tokenizer = AutoTokenizer.from_pretrained("gpt2-large")
        model = AutoModelForCausalLM.from_pretrained("gpt2-large", device_map="auto")
    
    elif llm_string == "gpt2-xl":
        tokenizer = AutoTokenizer.from_pretrained("gpt2-xl")
        model = AutoModelForCausalLM.from_pretrained("gpt2-xl", device_map="auto")

    elif llm_string == "qwen-7b":
        tokenizer = AutoTokenizer.from_pretrained("/data/models/huggingface/qwen/Qwen-7B-Chat", trust_remote_code=True, pad_token='<|endoftext|>')
        model = AutoModelForCausalLM.from_pretrained("/data/models/huggingface/qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True)
    
    elif llm_string == "qwen-14b":
        tokenizer = AutoTokenizer.from_pretrained("/data/models/huggingface/qwen/Qwen-14B-Chat", trust_remote_code=True, pad_token='<|endoftext|>')
        model = AutoModelForCausalLM.from_pretrained("/data/models/huggingface/qwen/Qwen-14B-Chat", device_map="auto", trust_remote_code=True)

    elif llm_string == "qwen-72b":
        tokenizer = AutoTokenizer.from_pretrained("/data/models/huggingface/qwen/Qwen-72B-Chat", trust_remote_code=True, pad_token='<|endoftext|>')
        model = AutoModelForCausalLM.from_pretrained("/data/models/huggingface/qwen/Qwen-72B-Chat", 
                device_map="auto", low_cpu_mem_usage=True, torch_dtype=torch.float16, trust_remote_code=True)

    elif llm_string == "huggyllama-7b":
        tokenizer = AutoTokenizer.from_pretrained("./models/huggingface/huggyllama/llama-7b")
        model = AutoModelForCausalLM.from_pretrained("./models/huggingface/huggyllama/llama-7b", device_map="auto")

    elif llm_string == "huggyllama-30b":
        tokenizer = AutoTokenizer.from_pretrained("./modelsllama-30b")
        model = AutoModelForCausalLM.from_pretrained("./modelsllama-30b", 
                device_map="auto", low_cpu_mem_usage=True, torch_dtype=torch.float16)
    
    elif llm_string == "huggyllama-65b":
        tokenizer = AutoTokenizer.from_pretrained("./modelsllama-65b")
        model = AutoModelForCausalLM.from_pretrained("./modelsllama-65b", 
                device_map="auto", low_cpu_mem_usage=True, torch_dtype=torch.float16)

    else:
        print("Invalid LLM string: ", llm_string)
        sys.exit()

    return model, tokenizer

if __name__ == "__main__":
    model, tokenizer = load_llm("llama")
    print(model)
    print(tokenizer)
    # print("Done")

    # try generation
    input_text = "What is the meaning of life?"
    input_ids = tokenizer.encode(input_text, return_tensors="pt")

    # generate
    output = model.generate(input_ids, max_length=50, num_return_sequences=1)
    print(tokenizer.decode(output[0], skip_special_tokens=True))
    print("Done")