import torch
import transformers
import tokenizers
import os, sys
from dataclasses import dataclass, field
import argparse
from PIL import Image

def expand2square(pil_img, background_color):
    width, height = pil_img.size
    if width == height:
        return pil_img
    elif width > height:
        result = Image.new(pil_img.mode, (width, width), background_color)
        result.paste(pil_img, (0, (width - height) // 2))
        return result
    else:
        result = Image.new(pil_img.mode, (height, height), background_color)
        result.paste(pil_img, ((height - width) // 2, 0))
        return result

def find_all_linear_names(model):
    cls = torch.nn.Linear
    lora_module_names = set()
    multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler']
    for name, module in model.named_modules():
        if any(mm_keyword in name for mm_keyword in multimodal_keywords):
            continue
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])

    if 'lm_head' in lora_module_names: 
        lora_module_names.remove('lm_head')
    return list(lora_module_names)

def add_special_tokens_and_resize_model(tokenizer, model, vq_idx_nums=8192):
    if len(tokenizer.additional_special_tokens) != 0:
        return tokenizer.additional_special_tokens
    index_tokens = [f"<idx_{i}>" for i in range(vq_idx_nums)]
    special_tokens = {
        'additional_special_tokens': ['<start_index>'] + index_tokens + ['<end_index>'] + ['<pixel_newline>']
    }
    num_new_tokens = tokenizer.add_special_tokens(special_tokens)
    model.resize_token_embeddings(len(tokenizer))
    if num_new_tokens > 0:
        input_embeddings = model.get_input_embeddings().weight.data
        output_embeddings = model.get_output_embeddings().weight.data

        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
            dim=0, keepdim=True)
        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
            dim=0, keepdim=True)

        input_embeddings[-num_new_tokens:] = input_embeddings_avg
        output_embeddings[-num_new_tokens:] = output_embeddings_avg
    
    return num_new_tokens

com_vision_args = argparse.Namespace(
    freeze_backbone=False,
    mm_patch_merge_type='flat',
    mm_projector_type='mlp2x_gelu',
    mm_use_im_patch_token=False,
    mm_use_im_start_end=False,
    mm_vision_select_feature='patch',
    mm_vision_select_layer=-2,
    model_name_or_path=None,
    pretrain_mm_mlp_adapter=None,
    tune_mm_mlp_adapter=False,
    version=None,
    vision_tower=None
)

gen_vision_args = argparse.Namespace(
    freeze_backbone=False,
    mm_patch_merge_type='flat',
    mm_projector_type='mlp2x_gelu',
    mm_use_im_patch_token=False,
    mm_use_im_start_end=False,
    mm_vision_select_feature='patch',
    mm_vision_select_layer=1,
    model_name_or_path=None,
    pretrain_mm_mlp_adapter=None,
    tune_mm_mlp_adapter=False,
    version=None,
    vision_tower=None
)

def load_weights(model, hlora_path, fusion_layer_path=None):
    hlora_weights = torch.load(hlora_path)
    hlora_unexpected_keys = model.load_state_dict(hlora_weights, strict=False)[1]

    if hlora_unexpected_keys:
        print(f"Warning: Unexpected keys in hlora checkpoint: {hlora_unexpected_keys}")
    if fusion_layer_path:
        fusion_layer_weights = torch.load(fusion_layer_path)
        fusion_layer_unexpected_keys = model.load_state_dict(fusion_layer_weights, strict=False)[1]
        if fusion_layer_unexpected_keys:
            print(f"Warning: Unexpected keys in fusion_layer checkpoint: {fusion_layer_unexpected_keys}")

    return model

import torch
import torch.nn as nn

def recursive_getattr(obj, attr_path):
    parts = attr_path.split('.')
    for p in parts[:-1]:
        obj = getattr(obj, p)
    return obj, parts[-1]

def reshape_hyper_parameter(model, key, target_shape):
    if key.endswith(".weight"):
        param_type = "weight"
    elif key.endswith(".bias"):
        param_type = "bias"
    else:
        return

    
    module_path = ".".join(key.split(".")[:-1])
    module = recursive_getattr(model, module_path)  
    print("key:",key)
    old_param = getattr(module, param_type)  

    
    new_param = torch.nn.Parameter(torch.empty(target_shape))
    torch.nn.init.normal_(new_param, mean=0.0, std=0.02)
    setattr(module, param_type, new_param)


def load_weights_hyper(model, hlora_path, fusion_layer_path=None):
    hlora_weights = torch.load(hlora_path, map_location='cpu')
    model_state_dict = model.state_dict()



    for key, ckpt_tensor in hlora_weights.items():
        if key in model_state_dict:
            model_tensor = model_state_dict[key]
            if model_tensor.shape != ckpt_tensor.shape:
                if "hyper" in key.lower():
                    reshape_hyper_parameter(model, key, ckpt_tensor.shape)
                else:
                    pass

    model.load_state_dict(hlora_weights, strict=False)

    unexpected_keys = set(hlora_weights.keys()) - set(model.state_dict().keys())
    if unexpected_keys:
        print(f"Warning: Unexpected keys in hlora checkpoint: {unexpected_keys}")

    if fusion_layer_path:
        fusion_layer_weights = torch.load(fusion_layer_path, map_location='cpu')
        fusion_unexpected_keys = model.load_state_dict(fusion_layer_weights, strict=False)[1]
        if fusion_unexpected_keys:
            print(f"Warning: Unexpected keys in fusion_layer checkpoint: {fusion_unexpected_keys}")

    return model
