from .general_encoder import GeneralVisionTower
from transformers import Dinov2Model, BitImageProcessor, Dinov2Config

class DINOv2VisionTower(GeneralVisionTower):
    def __init__(self, vision_tower, args, delay_load=False, loader=(Dinov2Model, BitImageProcessor, Dinov2Config)):
        super().__init__(vision_tower, args, delay_load, loader)

    def feature_select(self, image_forward_outs):
        image_features = image_forward_outs.hidden_states[self.select_layer]
        if self.select_feature == 'patch':
            image_features = image_features[:, 1:]
        elif self.select_feature == 'cls_patch':
            image_features = image_features
        else:
            raise ValueError(f'Unexpected select feature: {self.select_feature}')
        return image_features

    @property
    def num_patches_per_side(self):
        return self.image_processor.crop_size["height"] // self.config.patch_size