import torch
import torchvision.transforms as T
import torch.nn as nn
import torch.hub as hub
import cv2
from ultralytics.vit import SAM


class SamVisionEncoder(nn.Module):
    def __init__(self, embedding_dim, name = "facebook/sam-vit-base", device="cpu"):
        super().__init__()
        self.device = device
        # self.image_transforms = self.define_transforms()
        self.model = SAM("sam_b.pt").model.image_encoder
        self.model.to(device)
        self.freeze()

        self.projection_layer = nn.Linear(4096, embedding_dim)
        self.cls_pool = nn.AvgPool1d(256)

    def freeze(self):
        self.model.eval()
        for param in self.model.parameters():
            param.requires_grad = False

    # def define_transforms(self, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
    #     return T.Compose([
    #         T.ToTensor(),
    #         T.Resize(1024, interpolation=T.InterpolationMode.BICUBIC),
    #         T.CenterCrop(1024),
    #         T.Normalize(mean=mean, std=std),
    #         ])

    # def preprocess(self, image):
    #     return self.image_transforms(image)

    @torch.no_grad()
    def get_features(self, images):
        print(images.shape)
        features = self.model(images)
        features = features.view((features.shape[0], features.shape[1], -1))
        return features

    def forward(self, images):
        features =  self.get_features(images)
        print(features.shape)
        features = self.projection_layer(features)
        cls = self.cls_pool(features.permute(0, 2, 1)).squeeze(-1)
        return cls, features




if __name__ == '__main__':
    from hoi.model.hoibot.modules.utils import read_image, print_info_model
    # image = read_image()
    sam = SamVisionEncoder(768)
    print_info_model(sam, "sam")
    image = torch.rand(1, 3, 1024, 1024)
    print("image:", image.shape)
    cls_token, patch_tokens = sam(image)
    print("cls_token", cls_token.shape)
    print("patch_tokens", patch_tokens.shape)

