from transformers import AutoTokenizer

from src.path import new_internvl_models_dir

from internvl.model.internvl_chat import InternVLChatConfig, InternVLChatModel


internvl_new_models_parameters = {
    "InternVL2-4B-LargeV": {
        "vision_path": "OpenGVLab/InternVL2-26B",
        "llm_path": "OpenGVLab/InternVL2-4B",
    },
    "InternVL2-8B-LargeV": {
        "vision_path": "OpenGVLab/InternVL2-26B",
        "llm_path": "OpenGVLab/InternVL2-8B",
    }
}


if __name__ == "__main__":
    for model_name, params in internvl_new_models_parameters.items():
        print("Building new model: ", model_name)
    
        print("Loading Vision model: ", params["vision_path"])
        internvl_for_vision_config: InternVLChatConfig = InternVLChatConfig.from_pretrained(params["vision_path"])
        internvl_for_vision = InternVLChatModel.from_pretrained(params["vision_path"])
        
        print("Loading LLM: ", params["llm_path"])
        internvl_for_llm_config: InternVLChatConfig = InternVLChatConfig.from_pretrained(params["llm_path"])
        internvl_for_llm = InternVLChatModel.from_pretrained(params["llm_path"])
        tokenizer = AutoTokenizer.from_pretrained(params["llm_path"], trust_remote_code=True)
        
        print('Building InternVLChatConfig...')
        internvl_chat_config = InternVLChatConfig(
            internvl_for_vision_config.vision_config.to_dict(), internvl_for_llm_config.llm_config.to_dict(), downsample_ratio=internvl_for_llm_config.downsample_ratio,
            pad2square=internvl_for_vision_config.pad2square,
            template=internvl_for_llm_config.template,  # this is for llm
            select_layer=internvl_for_vision_config.select_layer, dynamic_image_size=internvl_for_vision_config.dynamic_image_size,
            use_thumbnail=internvl_for_vision_config.use_thumbnail, ps_version=internvl_for_vision_config.ps_version,
            min_dynamic_patch=internvl_for_vision_config.min_dynamic_patch, max_dynamic_patch=internvl_for_vision_config.max_dynamic_patch)
        internvl_chat_config.force_image_size = internvl_for_vision_config.force_image_size
        
        print('Building InternVLChatModel...')
        model = InternVLChatModel(internvl_chat_config, internvl_for_vision.vision_model, internvl_for_llm.language_model)
        
        # save model
        new_internvl_models_dir.mkdir(parents=True, exist_ok=True)
        model.save_pretrained(new_internvl_models_dir / model_name)
        tokenizer.save_pretrained(new_internvl_models_dir / model_name)
