import sys
import torch
import torch.nn as nn
dinov2_path = "path_to_dinov2"
sys.path.append(dinov2_path)
from dinov2.hub.backbones import _make_dinov2_model

torch.hub._validate_not_a_forked_repo=lambda a,b,c: True

class DinoV2Encoder(nn.Module):
    def __init__(self, name, feature_key, use_pretrained_encoder=True):
        super().__init__()
        self.name = name
        if use_pretrained_encoder:
            self.base_model = torch.hub.load("facebookresearch/dinov2", name)
        else:
            self.base_model = _make_dinov2_model(arch_name="vit_small", pretrained=False, patch_size=14)
            # self.base_model.blocks = nn.ModuleList(self.base_model.blocks[:6])
            # block5 = self.base_model.blocks[5]
            # block5.add_module("final_norm", self.base_model.norm)
            # block5.add_module("final_head", self.base_model.head)
            # del self.base_model.norm
            # del self.base_model.head

        self.feature_key = feature_key
        self.emb_dim = self.base_model.num_features
        if feature_key == "x_norm_patchtokens":
            self.latent_ndim = 2
        elif feature_key == "x_norm_clstoken":
            self.latent_ndim = 1
        else:
            raise ValueError(f"Invalid feature key: {feature_key}")

        self.patch_size = self.base_model.patch_size

    def forward(self, x):
        # print('x shape:', x.shape)
        emb = self.base_model.forward_features(x)[self.feature_key]
        if self.latent_ndim == 1:
            emb = emb.unsqueeze(1) # dummy patch dim
        # print('emb shape:', emb.shape)
        return emb
    

# if __name__ == '__main__':
#     encoder = DinoV2Encoder(name="dinov2_vits14", feature_key="x_norm_patchtokens", use_pretrained_encoder=True)
#     x = torch.randn(1, 3, 140, 140)
#     emb = encoder(x)
#     print(emb.shape)