from .general_encoder import GeneralVisionTower
from transformers import AutoModel, CLIPImageProcessor, AutoConfig

class AIMv2VisionTower(GeneralVisionTower):
    # def __init__(self, vision_tower, args, delay_load=False, loader=(AutoModel, CLIPImageProcessor, AutoConfig)):
    #     super().__init__(vision_tower, args, delay_load, loader)

    def feature_select(self, image_forward_outs):
        image_features = image_forward_outs.hidden_states[self.select_layer] if self.select_layer!=0 else image_forward_outs.last_hidden_state
        if self.select_feature == 'patch':
            image_features = image_features
        else:
            raise ValueError(f'Unexpected select feature: {self.select_feature}')
        return image_features
