
import os
import warnings
import shutil

from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
import torch
from llava.model import *
from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="cuda", device="cuda", use_flash_attn=False, **kwargs):
    from llava.model.language_model.llava_llama import LlavaConfig, LlavaLlamaForCausalLM
    from llava.model.language_model.llava_mpt import LlavaMptForCausalLM
    from llava.model.language_model.llava_mistral import LlavaMistralForCausalLM
    kwargs = {"device_map": device_map, **kwargs}

    if device != "cuda":
        kwargs['device_map'] = {"": device}

    if load_8bit:
        kwargs['load_in_8bit'] = True
    elif load_4bit:
        kwargs['load_in_4bit'] = True
        kwargs['quantization_config'] = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type='nf4'
        )
    else:
        kwargs['torch_dtype'] = torch.float16

    if use_flash_attn:
        kwargs['attn_implementation'] = 'flash_attention_2'

    if 'llava' in model_name.lower():
        # Load LLaVA model
        if 'lora' in model_name.lower() and model_base is None:
            warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
        if 'lora' in model_name.lower() and model_base is not None:
            from llava.model.language_model.llava_llama import LlavaConfig
            lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path)
            tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)

            print('Loading LLaVA from base model...')
            model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)

            print('Loading additional LLaVA weights (e.g., mm_projector)...')
            if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
                non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu', weights_only=True)
            else:
            
                from huggingface_hub import hf_hub_download
                def load_from_hf(repo_id, filename, subfolder=None):
                    cache_file = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder)
                    return torch.load(cache_file, map_location='cpu', weights_only=True)
                non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')

          
            non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
            if any(k.startswith('model.model.') for k in non_lora_trainables):
                non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
            
            
            non_lora_trainables = {k.replace(".base_layer", ""): v for k, v in non_lora_trainables.items()}
            
            model.load_state_dict(non_lora_trainables, strict=False)
            from peft import PeftModel

            modules_to_protect = [
                'mm_projector', 'depth_feature_projector', 'depth_visual_fusion_projector',
                'spatial_norm', 'spatial_stream_processor', 'final_fusion_merger',
                'fusion_projector', 'spatial_stream_projector', 'relational_encoder'
            ]

            protected_weights = {}
            print('💾 Backing up weights for all non-LoRA multimodal modules...')
            for module_name in modules_to_protect:
                if hasattr(model, module_name):
                    module = getattr(model, module_name)
                    protected_weights[module_name] = {k: v.clone() for k, v in module.state_dict().items()}
          

            print('Loading LoRA weights and applying adapter...')
            model = PeftModel.from_pretrained(model, model_path)

            print('Merging LoRA weights...')
            model = model.merge_and_unload()

            if protected_weights:
                print('🔄 Unconditionally restoring weights for all protected modules...')
                with torch.no_grad():
                    for module_name, saved_state_dict in protected_weights.items():
                        if hasattr(model, module_name):
                            module = getattr(model, module_name)
                            module.load_state_dict(saved_state_dict)
               

            print('Model is loaded...')
        elif model_base is not None:
            # this may be mm projector only
            print('Loading LLaVA from base model...')
            if 'mpt' in model_name.lower():
                # ... (mpt 相关的代码保持不变) ...
                tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
                cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
                model = LlavaMptForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
            else:
                from llava.model.language_model.llava_llama import LlavaConfig, LlavaLlamaForCausalLM
                tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
                cfg_pretrained = LlavaConfig.from_pretrained(model_path)
                model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)

            mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
            mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
            model.load_state_dict(mm_projector_weights, strict=False)
        else:

            from llava.model.language_model.llava_llama import LlavaLlamaForCausalLM
            from llava.model.language_model.llava_mpt import LlavaMptForCausalLM
            from llava.model.language_model.llava_mistral import LlavaMistralForCausalLM

            cfg_pretrained = AutoConfig.from_pretrained(model_path)

            if getattr(cfg_pretrained, 'use_dual_stream_encoding', False):
                tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)

                model = LlavaLlamaForCausalLM.from_pretrained(
                    model_path,
                    low_cpu_mem_usage=True,
                    **kwargs
                )
            elif 'mpt' in model_name.lower():
                tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
                model = LlavaMptForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
            elif 'mistral' in model_name.lower():
                tokenizer = AutoTokenizer.from_pretrained(model_path)
                model = LlavaMistralForCausalLM.from_pretrained(
                    model_path,
                    low_cpu_mem_usage=True,
                    **kwargs
                )
            else:
                tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
                model = LlavaLlamaForCausalLM.from_pretrained(
                    model_path,
                    low_cpu_mem_usage=True,
                    **kwargs
                )
    else:
        if model_base is not None:
            # PEFT model
            from peft import PeftModel
            tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
            model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
          
            model = PeftModel.from_pretrained(model, model_path)
      
            model = model.merge_and_unload()
           
            model.to(torch.float16)
        else:
            use_fast = False
            if 'mpt' in model_name.lower():
                tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
                model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
            else:
                tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
                model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)

    image_processor = None

    if 'llava' in model_name.lower():
        mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
        mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
        if mm_use_im_patch_token:
            tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
        if mm_use_im_start_end:
            tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
        model.resize_token_embeddings(len(tokenizer))

        vision_tower = model.get_vision_tower()
        if not vision_tower.is_loaded:
            vision_tower.load_model(device_map=device_map)
        if device_map != 'auto':
            vision_tower.to(device=device_map, dtype=kwargs.get('torch_dtype', torch.float16))
        if hasattr(model, 'lm_head') and 'bitsandbytes' in str(type(model.get_output_embeddings())):

            old_lm_head = model.get_output_embeddings()
            in_features = old_lm_head.in_features
            out_features = old_lm_head.out_features
            device = old_lm_head.weight.device
            new_lm_head = torch.nn.Linear(
                in_features,
                out_features,
                bias=False,
                device=device,
                dtype=torch.float16 # Use float32 for stability
            )
            
            model.set_output_embeddings(new_lm_head)
            print("Successfully replaced lm_head for safe resizing.")
        image_processor = vision_tower.image_processor
            
    if hasattr(model.config, "max_sequence_length"):
        context_len = model.config.max_sequence_length
    else:
        context_len = 2048

    return tokenizer, model, image_processor, context_len