import torch

def print_layer_lst_brief(layer_lst):
    layers_lst_combined = []
    for layer_name in layer_lst:
        if layer_name.startswith("model.layers."):
            prefix_len = 3
        elif layer_name.startswith("model.vision_tower.vision_tower"):
            prefix_len = 5
        else:
            prefix_len = 3

        layer_name_prefix = layer_name.split('.')[:prefix_len]
        layer_name_prefix = '.'.join(layer_name_prefix)

        if layer_name_prefix not in layers_lst_combined:
            layers_lst_combined.append(layer_name_prefix)
    print(f"layers total num:{len(layer_lst)}")
    print(f"{layers_lst_combined[:min(50, len(layers_lst_combined))]}")
    return layers_lst_combined
 
def compare_state_dict(state_dict1, state_dict2):
    common_layers = set(state_dict1.keys()) & set(state_dict2.keys())

    print(f"layers num:{len(set(state_dict1.keys()))}, {len(set(state_dict2.keys()))}")
    same_layers_lst = []
    diff_layers_lst = []
    for layer_name in common_layers:
        params1 = state_dict1[layer_name].to("cuda")
        params2 = state_dict2[layer_name].to("cuda")
        # maybe need to check whether the difference is small enough or just due to different dtype
        # check dtype
        if params1.dtype != params2.dtype:
            # change dtype
            if params1.dtype == torch.float16 or params1.dtype == torch.bfloat16 or params1.dtype == torch.float32:
                params1 = params1.float()
            if params2.dtype == torch.float16 or params2.dtype == torch.bfloat16 or params2.dtype == torch.float32:
                params2 = params2.float()
            if params1.dtype != params2.dtype:
                print(f"diff dtype: {params1.dtype}, {params2.dtype}")
        if torch.equal(params1, params2):
            same_layers_lst.append(layer_name)
        else:
            # check shape
            if params1.shape != params2.shape:
                diff_layers_lst.append(layer_name)
            else:
                # check value
                print("check value")
                diffed_params = torch.abs(params1 - params2)
                diffed_params = diffed_params / torch.max(diffed_params)
                diffed_params = diffed_params[diffed_params > 0.01]
                print(f"diffed_params num: {len(diffed_params)}")
                # if len(diffed_params) > 50:
                diff_layers_lst.append(layer_name)
                # else:
                #     same_layers_lst.append(layer_name)


    # combine the layers with the same prefix
    
    print("same layers:")
    same_layers_lst_combined = print_layer_lst_brief(same_layers_lst)
    print("diff layers:")
    diff_layers_lst_combined = print_layer_lst_brief(diff_layers_lst)


def compare_all_layers(checkpoint_path1, checkpoint_path2):
    print(f"compare {checkpoint_path1.split('/')[-2:]} and {checkpoint_path2.split('/')[-2:]}")
    state_dict1 = torch.load(checkpoint_path1)
    state_dict2 = torch.load(checkpoint_path2)

    compare_state_dict(state_dict1, state_dict2)

def compare_models_weight(model1, model2):
    # compare the weight of the two models
    # check whether the weight is the same
    # check whether the weight is close enough
    state_dict1 = model1.state_dict()
    state_dict2 = model2.state_dict()
    compare_state_dict(state_dict1, state_dict2)

if __name__ == "__main__":
    # ckpt0 = "./llava/backbone/checkpoints_pretrain/llava-llama-2-7b-chat-DETR-v2-pretrain-1/mm_projector.bin"
    # ckpt1 = './checkpoints/llava-pretrain-llama-2-7b-chat/mm_projector.bin'
    # ckpt2 = "./llava/backbone/checkpoints_tune/llava-llama-2-7b-chat-DETR-v2-pretrain-1-tune-1/mm_projector.bin"
    # compare_all_layers(ckpt0, ckpt2)

    # org_1 = "./checkpoints/llava-llama-2-7b-chat-lightning-preview/pytorch_model-00001-of-00002.bin"
    org_2 = "./checkpoints/llava-llama-2-7b-chat-lightning-preview/pytorch_model-00002-of-00002.bin"
    # ckpt4 = "./llava/backbone/checkpoints_tune/llava-llama-2-7b-chat-DETR-v2-pretrain-1-tune-1/checkpoint-400/pytorch_model-00001-of-00002.bin"
    # ckpt7 = "./llava/backbone/checkpoints_pretrain/llava-llama-2-7b-chat-DETR-pretrain-1/checkpoint-4000/pytorch_model-00001-of-00002.bin"
    # ckpt8 = "./llava/backbone/checkpoints_pretrain/llava-llama-2-7b-chat-DETR-pretrain-1/checkpoint-4000/pytorch_model-00002-of-00002.bin"
    mm1 = "./checkpoints/llava-llama-2-7b-chat-lightning-preview/mm_projector_7b_chat.bin"
    mm2 = "./LLaVA/checkpoints/llava-pretrain-llama-2-7b-chat/mm_projector.bin"
    compare_all_layers(mm1, org_2)
    compare_all_layers(mm1, mm2)
    exit()
    # compare_all_layers(ckpt3, ckpt7)
    
    # ckpt5 = "./checkpoints/llava-llama-2-7b-chat-lightning-preview/pytorch_model-00002-of-00002.bin"
    ckpt_tune_2 = "./llava/backbone/checkpoints_tune/llava-llama-2-7b-chat-DETR-v2-pretrain-1-tune-1/checkpoint-400/pytorch_model-00002-of-00002.bin"
    ckpt_detr_2_all = "./llava/backbone/checkpoints_vision_branch/detr-v2-model.pt"
    # # compare_all_layers(ckpt5, ckpt6)
    # compare_all_layers(ckpt5, ckpt8)

    ckpt_clip_branch = "./llava/backbone/checkpoints_vision_branch/aligned_vision_model.pt"
    # compare_all_layers(ckpt9, ckpt4) # vision tower and mm_projector is same
    # compare_all_layers(ckpt10, ckpt6)
    # compare_all_layers(ckpt10, ckpt3)
    # compare_all_layers(ckpt10, ckpt4)
    # compare_all_layers(ckpt10, ckpt9) # right: model.mm_projector diff
    ckpt_pretrain_1 = "./llava/backbone/checkpoints_pretrain/llava-llama-2-7b-chat-DETR-v2-pad-pretrain-1/checkpoint-2000/pytorch_model-00001-of-00002.bin"
    ckpt_pretrain_2 = "./llava/backbone/checkpoints_pretrain/llava-llama-2-7b-chat-DETR-v2-pad-pretrain-1/checkpoint-2000/pytorch_model-00002-of-00002.bin"
    ckpt13 = "./llava/backbone/checkpoints_pretrain/llava-llama-2-7b-chat-DETR-pretrain-1/checkpoint-4000/pytorch_model-00001-of-00002.bin"
    ckpt14 = "./llava/backbone/checkpoints_pretrain/llava-llama-2-7b-chat-DETR-pretrain-1/checkpoint-4000/pytorch_model-00002-of-00002.bin"
    # compare_all_layers(ckpt11, ckpt13)
    compare_all_layers(ckpt_pretrain_1, org_1)
    compare_all_layers(ckpt_pretrain_2, org_2)
    compare_all_layers(ckpt_pretrain_2, ckpt_tune_2) # diff: mm, 
    # compare_all_layers(ckpt14, ckpt4)
    # only need to check why the LM is different during training

