from .general_encoder import GeneralVisionTower
from transformers import SiglipVisionModel, SiglipImageProcessor, SiglipVisionConfig

class SigLIPVisionTower(GeneralVisionTower):
    def __init__(self, vision_tower, args, delay_load=False, loader=(SiglipVisionModel, SiglipImageProcessor, SiglipVisionConfig)):
        super().__init__(vision_tower, args, delay_load, loader)

    def feature_select(self, image_forward_outs):
        image_features = image_forward_outs.hidden_states[self.select_layer]
        if self.select_feature == 'patch':
            image_features = image_features
        else:
            raise ValueError(f'Unexpected select feature: {self.select_feature}')
        return image_features
