import torch
from transformers import AutoTokenizer
from .llava_llama import LlavaLlamaForCausalLM, LlavaConfig

def load_pretrained_model(
    model_name,
    cache_dir,
    device_map="auto", 
    torch_dtype=torch.float16,
    image_grid_pinpoints=None
    ):
    kwargs = {}
    kwargs["device_map"] = device_map

    tokenizer = AutoTokenizer.from_pretrained(
        model_name, cache_dir=cache_dir, use_fast=False)
    llava_cfg = LlavaConfig.from_pretrained(
        model_name, cache_dir=cache_dir)
    
    kwargs["torch_dtype"] = llava_cfg.torch_dtype if hasattr(llava_cfg, 'torch_dtype') else torch_dtype

    model = LlavaLlamaForCausalLM.from_pretrained(
        model_name, cache_dir=cache_dir, 
        low_cpu_mem_usage=True, 
        config=llava_cfg, **kwargs)
    
    if image_grid_pinpoints == 5:
        model.config.image_grid_pinpoints = [[672, 672]]

    image_processor = None
    model.resize_token_embeddings(len(tokenizer))

    vision_tower = model.get_vision_tower()
    if not vision_tower.is_loaded:
        vision_tower.load_model(device_map=device_map)
    if device_map != "auto":
        # vision_tower.to(device="cuda", dtype=torch.float16)
        vision_tower.to(device="cuda", dtype=kwargs["torch_dtype"])
    image_processor = vision_tower.image_processor

    # assert model.config.max_position_embeddings == model.config.tokenizer_model_max_length == 4096
    context_len = model.config.tokenizer_model_max_length

    return tokenizer, model, image_processor, context_len
