import torch
import torchvision.transforms as T
import torch.nn as nn
import torch.hub as hub
import cv2

class Dinov2(nn.Module):
    def __init__(self, name = "dinov2_vits14", train_feature_extractor=False, device="cuda:0"):
        super().__init__()
        self.device = device
        self.image_transforms = self.define_transforms()
        self.model = hub.load('facebookresearch/dinov2', name)
        self.model.to(device)
        self.train_feature_extractor = train_feature_extractor
        self.embedding_dim = 768 if name =="dinov2_vitb14" else 384
        if not train_feature_extractor:
            self.freeze()

    def info_model(self):
        total_params = sum(p.numel() for p in self.parameters())
        print("[{}] - {:.2f}M".format("Feature Extractor: DINOv2", total_params / 10 ** 6))
    def freeze(self):
        self.model.eval()
        for param in self.model.parameters():
            param.requires_grad = False

    def define_transforms(self, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
        return T.Compose([
            T.ToTensor(),
            T.Resize(256, interpolation=T.InterpolationMode.BICUBIC),
            T.CenterCrop(224),
            T.Normalize(mean=mean, std=std),
            ])

    def preprocess(self, image):
        return self.image_transforms(image)

    def get_features(self, images):
        if self.train_feature_extractor:
            features = self.model.forward_features(images)
        else:
            with torch.no_grad():
                features = self.model.forward_features(images)
        return features["x_norm_clstoken"], features["x_norm_patchtokens"]

    def forward(self, images):
        return self.get_features(images)




if __name__ == '__main__':
    from hoi.model.hoibot.modules.utils import read_image, print_info_model
    dinov2 = Dinov2()
    print_info_model(dinov2)
    print(dinov2)

    image = torch.randn(3, 224,224).to(dinov2.device)
    print("image:", image.shape)
    #img_t = dinov2.preprocess(image)
    cls_token, patch_tokens = dinov2(image.unsqueeze(0))
    print("cls_token", cls_token.shape)
    print("patch_tokens", patch_tokens.shape)
