import os
import warnings
import shutil

from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
import torch
from internvl.model.internvl_chat import InternVLChatConfig, InternVLChatModel
# from internvl.model.llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT
def split_model(num_layers, vit_alpha=0.5):
    device_map = {}
    world_size = torch.cuda.device_count()
    # Since the first GPU will be used for ViT, treat it as half a GPU.
    num_layers_per_gpu = math.ceil(num_layers / (world_size - vit_alpha))
    num_layers_per_gpu = [num_layers_per_gpu] * world_size
    num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * (1 - vit_alpha))
    layer_cnt = 0
    for i, num_layer in enumerate(num_layers_per_gpu):
        for j in range(num_layer):
            device_map[f'language_model.model.layers.{layer_cnt}'] = i
            layer_cnt += 1
    device_map['vision_model'] = 0
    device_map['mlp1'] = 0
    device_map['language_model.model.tok_embeddings'] = 0
    device_map['language_model.model.embed_tokens'] = 0
    device_map['language_model.output'] = 0
    device_map['language_model.model.norm'] = 0
    device_map['language_model.lm_head'] = 0
    device_map[f'language_model.model.layers.{num_layers - 1}'] = 0
    device_map['language_model.model.rotary_emb'] = 0

    return device_map


def load_pretrained_model_both(model_path, model_base, prompt_tuning_adding, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs):
    kwargs = {"device_map": device_map, **kwargs}

    if device != "cuda":
        kwargs['device_map'] = {"": device}

    if load_8bit:
        kwargs['load_in_8bit'] = True
    elif load_4bit:
        kwargs['load_in_4bit'] = True
        kwargs['quantization_config'] = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type='nf4'
        )
    else:
        kwargs['torch_dtype'] = torch.float16
    # import ipdb;ipdb.set_trace()
    if use_flash_attn:
        kwargs['attn_implementation'] = 'flash_attention_2'
        # from llava.model.language_model.llava_llama import LlavaConfig
        # lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path)
        # tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
        # print('Loading LLaVA from base model...')
        # model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
        config = InternVLChatConfig.from_pretrained(args.model_base)
        num_hidden_layers = config.llm_config.num_hidden_layers
        device_map = split_model(num_hidden_layers)
        kwargs = {'device_map': device_map} if args.auto else {}
        tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True, use_fast=False)
        model = InternVLChatModel.from_pretrained(
            args.model_base, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16,
            load_in_8bit=args.load_in_8bit, load_in_4bit=args.load_in_4bit, **kwargs).eval()
        if not args.load_in_8bit and not args.load_in_4bit and not args.auto:
            model = model.cuda()
        
        
        # token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
        # if model.lm_head.weight.shape[0] != token_num:
        #     model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
        #     model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
        lora_path = os.path.join(model_path, "llava-lora")
        # lora_path = model_path
        print('Loading additional LLaVA weights...')
        if os.path.exists(os.path.join(lora_path, 'non_lora_trainables.bin')):
            non_lora_trainables = torch.load(os.path.join(lora_path, 'non_lora_trainables.bin'), map_location='cpu')
        non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
        if any(k.startswith('model.model.') for k in non_lora_trainables):
            non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
        model.load_state_dict(non_lora_trainables, strict=False)
        from copy import deepcopy
        # withoutlora_model = deepcopy(model)
        original_parameters = {}
        for name, param in model.named_parameters():
            if param.data.is_meta:
                print(f"Warning: {name} is a meta tensor, skipping.")
                continue
            original_parameters[name] = deepcopy(param.data).to(device="cpu")
            
        from peft import PeftModel, PeftMixedModel, get_peft_model_state_dict, PromptTuningConfig
        # lora_path = os.path.join(lora_path, "lora")
        # import ipdb;ipdb.set_trace()
        print('Loading LoRA weights...')
        model = PeftModel.from_pretrained(model, lora_path)
        
        print('Merging LoRA weights...')
        
        model = model.merge_and_unload()
        print('Model is loaded...')
        print(f"prompt_tuning_adding is {prompt_tuning_adding}")
        # 加载Prompt Tuning权重
        # ...existing code...

        # 比较withoutlora_model和model的参数，若全部参数完全相同则报错
        all_equal = True
        for name2, param2 in model.named_parameters():
            if not torch.equal(original_parameters[name2].to(model.device), param2.data):
                all_equal = False
                break
        if all_equal:
            raise RuntimeError("withoutlora_model和model的所有参数完全相同，可能未正确加载LoRA或Prompt Tuning权重。")
        else:
            print("withoutlora_model和model的参数已成功加载，且不相同。")
        # ...existing code...
        if prompt_tuning_adding:
            print(f"prompt_tuning_adding is {prompt_tuning_adding}, loading prompt tuning weights...")
            prompt_tuning_path = os.path.join(model_path, "llava-prompt_tuning")
            if os.path.exists(prompt_tuning_path):
                print('Loading Prompt Tuning weights...')
                model = PeftModel.from_pretrained(model, prompt_tuning_path)
                print('Prompt Tuning weights loaded.')
            else:
                # 兼容直接在model_path下的prompt tuning
                try:
                    config_files = [f for f in os.listdir(model_path) if f.startswith("adapter_config") and "prompt" in f]
                    if config_files:
                        print('Loading Prompt Tuning weights (auto-detect)...')
                        model = PeftModel.from_pretrained(model, model_path)
                        print('Prompt Tuning weights loaded.')
                except Exception as e:
                    print(f'No prompt tuning weights found or failed to load: {e}')
    

    return tokenizer, model