# this code is from Feature-Distillation/models/dino.py

import os
from typing import Union, List
import torch
from .clip.clip import _download
from .outputwithattnmodels.vit import VisionTransformer
import torch.distributed as dist

from .dinov2.models import vision_transformer as vits
from .utils import interpolate_pos_embed

_MODELS = {
    "DINO": "https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth",
    "DINO_T": "https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain_full_checkpoint.pth"
}

class WarpperVisionTransformer(VisionTransformer):
    def __init__(self, **kwargs):
        super(WarpperVisionTransformer, self).__init__(**kwargs)
    
    @property
    def dtype(self):
        return self.norm.weight.dtype

    def encode_image_featuremap(self, image):
        return self.forward_featuremap(image.type(self.dtype))

def load_dino(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", **kwargs):
    if name in _MODELS:
        if (dist.is_initialized() and int(dist.get_rank()) % torch.cuda.device_count() == 0) or not dist.is_available():
            model_path = _download(_MODELS[name], sha_check=False)
            dist.barrier()
        else:
            model_path = _download(_MODELS[name], sha_check=False)
    elif os.path.isfile(name):
        model_path = name
    else:
        raise RuntimeError(f"Model {name} not found; ")


    state_dict = torch.load(model_path, map_location="cpu")
    model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, num_classes=0)
    model = WarpperVisionTransformer(**model_kwargs)
    msg = model.load_state_dict(state_dict)
    print(msg)
    return model.to(device)

def load_dinov2(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", **kwargs):
    vit_kwargs = dict(
        img_size=224,
        patch_size=14,
        init_values=1.0,
        ffn_layer="mlp",
        block_chunks=0,
    )
    name = 'vit_' + name[7:] # change my name to dinov2 name

    def _make_dinov2_model_name(arch_name: str, patch_size: int) -> str:
        compact_arch_name = arch_name.replace("_", "")[:4]
        return f"dinov2_{compact_arch_name}{patch_size}"
    
    model_name = _make_dinov2_model_name(name, 14) + "_pretrain.pth"
    model = vits.__dict__[name](**vit_kwargs)
    
    current_directory = os.path.dirname(os.path.abspath(__file__))
    model_name = os.path.join(current_directory, model_name)
    state_dict = torch.load(model_name, map_location="cpu")

    if state_dict['pos_embed'].shape != model.pos_embed.shape:
        interpolate_pos_embed(model, state_dict)

    # 比较两者的keys
    model_keys = set(model.state_dict().keys())
    state_dict_keys = set(state_dict.keys())

    # 找到不同之处
    different_keys = model_keys.symmetric_difference(state_dict_keys)
    print("keys mismatch:", different_keys)

    model.load_state_dict(state_dict, strict=False)

    print('=> loading pre-trained model from {}'.format(name))
    print(model)

    return model.to(device)