from transformers import (
    OPTForCausalLM, GPTJForCausalLM, AutoTokenizer, LlamaForCausalLM, MistralForCausalLM, Qwen2ForCausalLM, 
    Qwen3ForCausalLM,OlmoeForCausalLM,Qwen2MoeForCausalLM,PhimoeForCausalLM,Qwen3MoeForCausalLM,Gemma3ForCausalLM,
    LlamaTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, MixtralForCausalLM,OlmoForCausalLM,Olmo2ForCausalLM,
    AutoConfig,GenerationConfig
)

import torch
import os
import os.path as osp
import rich
import pandas as pd
import itertools

mm_dict = {
    0: "36846MiB", 1: "36846MiB", 2: "36846MiB", 3: "36846MiB"
}

def get_model(device_num, model_name):

    device_list = device_num
    for i in mm_dict.keys():
        if i not in device_list:
            mm_dict[i] = "0MiB"


    main_path = "/models/"
    
    if model_name == "llama-2-7b":
        model_path = main_path + "Llama-2-7b"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = LlamaForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "llama-2-7b-chat":
        model_path = main_path + "Llama-2-7b-chat"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = LlamaForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "llama-2-13b":
        model_path = main_path + "Llama-2-13b"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = LlamaForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "llama-2-13b-chat":
        model_path = main_path + "Llama-2-13b-chat"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = LlamaForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "llama-2-70b":
        model_path = main_path + "Llama-2-70b"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = LlamaForCausalLM.from_pretrained(
            model_path, torch_dtype='auto', device_map='auto', max_memory=mm_dict)
    elif model_name == "llama-2-70b-chat":
        model_path = main_path + "Llama-2-70b-chat"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = LlamaForCausalLM.from_pretrained(
            model_path, torch_dtype='auto', device_map='auto', max_memory=mm_dict)
    elif model_name == "llama-3-8b":
        model_path = main_path + "Llama-3-8b"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = LlamaForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "llama-3-8b-instruct":
        model_path = main_path + "Llama-3-8b-instruct"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = LlamaForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "llama-3-70b":
        model_path = main_path + "Llama-3-70b"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = LlamaForCausalLM.from_pretrained(
            model_path, torch_dtype='auto', device_map='auto', max_memory=mm_dict)
    elif model_name == "llama-3-70b-instruct":
        model_path = main_path + "Llama-3-70b-instruct"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = LlamaForCausalLM.from_pretrained(
            model_path, torch_dtype='auto', device_map='auto', max_memory=mm_dict)
    elif model_name == "llama-moe-v1-3_5b-2_8-sft":
        model_path = main_path + "llama-moe-v1-3_5B-2_8-sft"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = AutoModelForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict, trust_remote_code=True)
    elif model_name == "llama-moe-v2-3_8b-2_8-sft":
        model_path = main_path + "llama-moe-v2-3_8B-2_8-sft"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = MixtralForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict, trust_remote_code=True)
    elif model_name == "mistral-7b-v0.3-instruct":
        model_path = main_path + "Mistral-7b-v0.3-instruct"
        tokenizer = LlamaTokenizer.from_pretrained(model_path)
        model = MistralForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "mistral-7b-v0.3":
        model_path = main_path + "Mistral-7b-v0.3"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = MistralForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "mixtral-8x7b-instruct":
        model_path = main_path + "Mixtral-8x7b-instruct"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = MixtralForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "mixtral-8x7b":
        model_path = main_path + "Mixtral-8x7b"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = MixtralForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "opt-1.3b":
        model_path = main_path + "opt-1.3b"
        model = OPTForCausalLM.from_pretrained(
            model_path, device_map='auto', max_memory=mm_dict)
        tokenizer = AutoTokenizer.from_pretrained(model_path)
    elif model_name == "opt-6.7b":
        model_path = main_path + "opt-6.7b"
        model = OPTForCausalLM.from_pretrained(
            model_path, device_map='auto', torch_dtype=torch.float16, max_memory=mm_dict,offload_folder='offload', offload_state_dict = True)
        tokenizer = AutoTokenizer.from_pretrained(model_path)
    elif model_name == "opt-13b":
        model_path = main_path + "opt-13b"
        model = OPTForCausalLM.from_pretrained(
            model_path, device_map='auto', torch_dtype=torch.float16, max_memory=mm_dict,offload_folder='offload', offload_state_dict = True)
        tokenizer = AutoTokenizer.from_pretrained(model_path)
    elif model_name == "internlm2-chat-1.8b":
        model_path = main_path + "internlm2-chat-1_8b"
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict, trust_remote_code=True)
    elif model_name == "internlm2-7b":
        model_path = main_path + "internlm2-7b"
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict, trust_remote_code=True)
    elif model_name == "internlm2-chat-7b":
        model_path = main_path + "internlm2-chat-7b"
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict, trust_remote_code=True)
    elif model_name == "internlm2-chat-20b":
        model_path = main_path + "internlm2-chat-20b"
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict, trust_remote_code=True)
    elif model_name == "qwen1.5-0.5b-chat":
        model_path = main_path + "Qwen1.5-0.5b-chat"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = AutoModelForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "qwen1.5-1.8b-chat":
        model_path = main_path + "Qwen1.5-1.8b-chat"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = AutoModelForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "qwen1.5-4b-chat":
        model_path = main_path + "Qwen1.5-4b-chat"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = AutoModelForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "qwen1.5-7b-chat":
        model_path = main_path + "Qwen1.5-7b-chat"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = AutoModelForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "qwen1.5-14b-chat":
        model_path = main_path + "Qwen1.5-14b-chat"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = AutoModelForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "qwen1.5-32b-chat":
        model_path = main_path + "Qwen1.5-32b-chat"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = AutoModelForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "qwen1.5-moe-a2.7b":
        model_path = main_path + "Qwen1.5-MoE-A2.7B"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = Qwen2MoeForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "qwen1.5-moe-a2.7b-chat":
        model_path = main_path + "Qwen1.5-MoE-A2.7B-chat"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = Qwen2MoeForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "qwen2-7b":
        model_path = main_path + "Qwen2-7b"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = Qwen2ForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "qwen2-0.5b-instruct":
        model_path = main_path + "Qwen2-0.5b-instruct"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = Qwen2ForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "qwen2-1.5b-instruct":
        model_path = main_path + "Qwen2-1.5b-instruct"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = Qwen2ForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "qwen2-7b-instruct":
        model_path = main_path + "Qwen2-7b-instruct"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = Qwen2ForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "qwen2-72b-instruct":
        model_path = main_path + "Qwen2-72b-instruct"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = Qwen2ForCausalLM.from_pretrained(
            model_path, torch_dtype='auto', device_map='auto', max_memory=mm_dict)
    elif model_name == "qwen2.5-0.5b-instruct":
        model_path = main_path + "Qwen2.5-0.5b-instruct"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = Qwen2ForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "qwen2.5-1.5b-instruct":
        model_path = main_path + "Qwen2.5-1.5b-instruct"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = Qwen2ForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "qwen2.5-3b-instruct":
        model_path = main_path + "Qwen2.5-3b-instruct"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = Qwen2ForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "qwen2.5-7b-instruct":
        model_path = main_path + "Qwen2.5-7b-instruct"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = Qwen2ForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "qwen2.5-14b-instruct":
        model_path = main_path + "Qwen2.5-14b-instruct"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = Qwen2ForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "qwen2.5-32b-instruct":
        model_path = main_path + "Qwen2.5-32b-instruct"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = Qwen2ForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "qwen2.5-72b-instruct":
        model_path = main_path + "Qwen2.5-72b-instruct"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = Qwen2ForCausalLM.from_pretrained(
            model_path, torch_dtype='auto', device_map='auto', max_memory=mm_dict)
    elif model_name == "qwen3-0.6b-instruct":
        model_path = main_path + "Qwen3-0.6b-instruct"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = Qwen3ForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "qwen3-1.7b-instruct":
        model_path = main_path + "Qwen3-1.7b-instruct"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = Qwen3ForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "qwen3-4b":
        model_path = main_path + "Qwen3-4b"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = Qwen3ForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "qwen3-4b-instruct":
        model_path = main_path + "Qwen3-4b-instruct"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = Qwen3ForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "qwen3-8b":
        model_path = main_path + "Qwen3-8b"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = Qwen3ForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "qwen3-8b-instruct":
        model_path = main_path + "Qwen3-8b-instruct"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = Qwen3ForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "qwen3-14b":
        model_path = main_path + "Qwen3-14b"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = Qwen3ForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "qwen3-14b-instruct":
        model_path = main_path + "Qwen3-14b-instruct"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = Qwen3ForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "qwen3-32b-instruct":
        model_path = main_path + "Qwen3-32b-instruct"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = Qwen3ForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "qwen3-30b-a3b":
        model_path = main_path + "Qwen3-30b-a3b"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = Qwen3MoeForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "qwen3-30b-a3b-instruct":
        model_path = main_path + "Qwen3-30b-a3b-instruct"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = Qwen3MoeForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "gemma-3-1b":
        model_path = main_path + "Gemma-3-1b"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = Gemma3ForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "gemma-3-4b":
        model_path = main_path + "Gemma-3-4b"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = Gemma3ForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "gemma-3-12b":
        model_path = main_path + "Gemma-3-12b"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = Gemma3ForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "gemma-3-27b":
        model_path = main_path + "Gemma-3-27b"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = Gemma3ForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "olmoe-7b":
        model_path = main_path + "OLMoE-7b"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = OlmoeForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "olmoe-7b-instruct":
        model_path = main_path + "OLMoE-7b-instruct"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = OlmoeForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "olmo-1b":
        model_path = main_path + "OLMo-1b"
        tokenizer = AutoTokenizer.from_pretrained(model_path,trust_remote_code=True)
        model = OlmoForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict,trust_remote_code=True)
    elif model_name == "olmo-7b":
        model_path = main_path + "OLMo-7b"
        tokenizer = AutoTokenizer.from_pretrained(model_path,trust_remote_code=True)
        model = OlmoForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict,trust_remote_code=True)
    elif model_name == "olmo-7b-instruct":
        model_path = main_path + "OLMo-7b-instruct"
        tokenizer = AutoTokenizer.from_pretrained(model_path,trust_remote_code=True)
        model = OlmoForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict,trust_remote_code=True)
    elif model_name == "olmo-2-1b-instruct":
        model_path = main_path + "OLMo-2-1b-instruct"
        tokenizer = AutoTokenizer.from_pretrained(model_path,trust_remote_code=True)
        model = Olmo2ForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict,trust_remote_code=True)
    elif model_name == "olmo-2-7b-instruct":
        model_path = main_path + "OLMo-2-7b-instruct"
        tokenizer = AutoTokenizer.from_pretrained(model_path,trust_remote_code=True)
        model = Olmo2ForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict,trust_remote_code=True)
    elif model_name == "olmo-2-13b-instruct":
        model_path = main_path + "OLMo-2-13b-instruct"
        tokenizer = AutoTokenizer.from_pretrained(model_path,trust_remote_code=True)
        model = Olmo2ForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict,trust_remote_code=True)
    elif model_name == "olmo-2-32b-instruct":
        model_path = main_path + "OLMo-2-32b-instruct"
        tokenizer = AutoTokenizer.from_pretrained(model_path,trust_remote_code=True)
        model = Olmo2ForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict,trust_remote_code=True)
    elif model_name == "phi-3.5-moe-instruct":
        model_path = main_path + "Phi-3.5-moe-instruct"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = PhimoeForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict)
    elif model_name == "deepseek-v2-lite":
        model_path = main_path + "DeepSeek-v2-lite"
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict,trust_remote_code=True)
        model.generation_config = GenerationConfig.from_pretrained(model_name)
        model.generation_config.pad_token_id = model.generation_config.eos_token_id
    elif model_name == "deepseek-v2-lite-chat":
        model_path = main_path + "DeepSeek-v2-lite"
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map='auto', max_memory=mm_dict,trust_remote_code=True)
        model.generation_config = GenerationConfig.from_pretrained(model_name)
        model.generation_config.pad_token_id = model.generation_config.eos_token_id
    elif model_name == "gpt-j-6b":
        model_path = main_path + "gpt-j-6b"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = GPTJForCausalLM.from_pretrained(
            model_path, revision="float16", torch_dtype=torch.float16, low_cpu_mem_usage=True,
            pad_token_id=tokenizer.eos_token_id, device_map='auto', max_memory=mm_dict)
    return model, tokenizer
