import torch
import torch.nn as nn
from peft import PeftModelForCausalLM
import time
import os
import transformers
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
)
from modules.init_embeddings import init_embeddings_normal, CodebookEmbedding

def merge_to_base_model(args, device_map={"": "cuda"}): # , generation_mode="text"
    checkpoint_parent_path, checkpoint_name = args.checkpoint_path.rsplit("/", maxsplit=1)
    merged_checkpoint_parent_path = os.path.join(checkpoint_parent_path, "merged")
    os.makedirs(merged_checkpoint_parent_path, exist_ok=True)
    merged_path = os.path.join(merged_checkpoint_parent_path, checkpoint_name)
    if os.path.exists(merged_path):
        print(f"model {merged_path} already exists, skipping merge")
        return merged_path
    
    config = AutoConfig.from_pretrained(args.model_name_or_path)
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, legacy=False, use_fast=not args.use_slow_tokenizer)

    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id

    tokenizer.add_special_tokens({'additional_special_tokens': ['<image>']}, replace_additional_special_tokens=False)
    tokenizer.add_special_tokens({'additional_special_tokens': ['</image>']}, replace_additional_special_tokens=False)
    # tokenizer.add_tokens([f"<image_{str(i)}>" for i in range(args.vl_vocab_size - args.image_start_token_id - 2)]) # 48386-32000-2

    model = AutoModelForCausalLM.from_pretrained(
        args.model_name_or_path,
        from_tf=bool(".ckpt" in args.model_name_or_path),
        config=config,
        low_cpu_mem_usage=True,
        device_map=device_map,
        torch_dtype=args.torch_dtype,
        use_flash_attention_2=False,
    )

    if model.config.pad_token_id is None:
        model.config.pad_token_id = tokenizer.eos_token_id

    input_embedding_size = model.get_input_embeddings().weight.shape
    # expand the language vocab to vision-language vocab
    model.resize_token_embeddings(args.vl_vocab_size)

    if args.expand_vocab == "factorized":
        visual_codebook_weights = torch.load(os.path.join(args.visual_codebook, 'visual_tokenizer', 'tokenizer_encoder.bin'))['quantize.embedding.weight']
        model.model.visual_codebook = CodebookEmbedding(num_tokens=visual_codebook_weights.shape[0], codebook_dim=visual_codebook_weights.shape[1])
        if args.factorized_linear_mlp:
            model.model.visual_factorized_linear = nn.Sequential(
                nn.Linear(visual_codebook_weights.shape[1], input_embedding_size[1], bias=False),
                nn.SiLU(),
                nn.Linear(input_embedding_size[1], input_embedding_size[1], bias=False)
            )
            
        else:
            model.model.visual_factorized_linear = nn.Linear(visual_codebook_weights.shape[1], input_embedding_size[1], bias=False)
    
    if args.use_lora:
        peft_model = PeftModelForCausalLM.from_pretrained(
            model,
            model_id=args.checkpoint_path,
            is_trainable=False,
        )
        model = peft_model.merge_and_unload(progressbar=True, safe_merge=True)
    else:
        # TODO: need to test here
        model = model.from_pretrained(args.checkpoint_path)

    if args.expand_vocab == "factorized":
        with torch.no_grad():
            visual_embeddings = model.model.visual_factorized_linear(model.model.visual_codebook.weight).to(torch.device('cuda'))
            # TODO: maybe use += here when direct+factorized
            model.model.embed_tokens.weight.data[-visual_embeddings.shape[0]:] = visual_embeddings
        del model.model.visual_codebook
        del model.model.visual_factorized_linear
    
    # if generation_mode in ["text", "image"]:
    #     if generation_mode == "text":
    #         new_lm_head = nn.Linear(model.config.hidden_size, args.image_start_token_id, bias=False)
    #         new_lm_head.weight.data = model.lm_head.weight.data[:args.image_start_token_id]
    #     elif generation_mode == "image":
    #         new_lm_head = nn.Linear(model.config.hidden_size, args.vl_vocab_size - args.image_start_token_id - 1, bias=False)
    #         new_lm_head.weight.data = model.lm_head.weight.data[args.image_start_token_id + 1:]
    #     model.lm_head = new_lm_head
    #     print(model)
    
    model.save_pretrained(merged_path)
    print(f"merged model saved to {merged_path}")
    return merged_path

def test():
    from dotmap import DotMap
    args = DotMap()
    args.model_name_or_path = 'YOUR_ROOT_PATH/model/llama2-1229/Llama-2-7b-hf'
    args.use_slow_tokenizer = True
    args.vl_vocab_size = 48386
    args.image_start_token_id = 32000
    args.torch_dtype = torch.bfloat16
    args.use_lora = True
    args.checkpoint_path = 'YOUR_ROOT_PATH/model/checkpoint/MLLM/ablation_vocab_direct/best_1427_ppl_12.997'
    args.expand_vocab = 'normal'
    args.factorized_linear_mlp = False
    
    print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
    _ = merge_to_base_model(args)
    print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))

    args.checkpoint_path = 'YOUR_ROOT_PATH/model/checkpoint/MLLM/debug_factorized_mlp_2_2e_lora_qvkomlp/best_315_ppl_35.036/'
    args.expand_vocab = 'factorized'
    args.factorized_linear_mlp = True
    args.visual_codebook = 'YOUR_ROOT_PATH/model/LaVIT-7B-v2'

    print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
    _ = merge_to_base_model(args)
    print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))