import torch
from transformers import AutoTokenizer
from .llava_llama import LlavaLlamaForCausalLM, LlavaConfig
# from .llava_mistral import LlavaMistralForCausalLM, LlavaMistralConfig
from .constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN

def load_pretrained_model(
    model_name,
    cache_dir,
    device_map="auto", 
    torch_dtype='float16',
    overwrite_config=None
    ):
    kwargs = {}
    kwargs["device_map"] = device_map

    tokenizer = AutoTokenizer.from_pretrained(
        model_name, cache_dir=cache_dir, use_fast=False)

    if torch_dtype == "float16":
        kwargs["torch_dtype"] = torch.float16
    elif torch_dtype == "bfloat16":
        kwargs["torch_dtype"] = torch.bfloat16

    llava_cfg = LlavaConfig.from_pretrained(
        model_name, cache_dir=cache_dir)

    if overwrite_config is not None:
        for k, v in overwrite_config.items():
            setattr(llava_cfg, k, v)

    model = LlavaLlamaForCausalLM.from_pretrained(
        model_name, cache_dir=cache_dir, 
        low_cpu_mem_usage=True, 
        config=llava_cfg, **kwargs)

    image_processor = None

    mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
    mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
    if mm_use_im_patch_token:
        tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
    if mm_use_im_start_end:
        tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
    model.resize_token_embeddings(len(tokenizer))

    vision_tower = model.get_vision_tower()
    if not vision_tower.is_loaded:
        vision_tower.load_model(device_map=device_map)
    if device_map != "auto":
        vision_tower.to(device="cuda", dtype=kwargs["torch_dtype"])
    image_processor = vision_tower.image_processor

    # assert model.config.max_position_embeddings == model.config.tokenizer_model_max_length == 4096
    # context_len = model.config.max_position_embeddings
    context_len = model.config.tokenizer_model_max_length

    if 'mistral' in model_name.lower():
        tokenizer.pad_token = tokenizer.eos_token
        model.config.pad_token_id = model.config.eos_token_id
        

    return tokenizer, model, image_processor, context_len
