from transformers import AutoBackbone, AutoImageProcessor
import torch
import torch.nn as nn

class DinoV2Encoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        image_size = {'height': cfg.visual_branch.image_size, 'width': cfg.visual_branch.image_size}
        self.imgprocess = AutoImageProcessor.from_pretrained("LiheYoung/depth-anything-small-hf", size = image_size)

        self.patch_size = 14

        if cfg.visual_branch.backbone not in ['dinov2-small', 'dinov2-base', 'dinov2-large']:
            raise ValueError('Unsupported visual backbone "{}".'.format(cfg.visual_branch.backbone))
        
        if cfg.visual_branch.backbone == 'dinov2-small':
            self.dinov2 = AutoBackbone.from_pretrained("facebook/dinov2-small", out_features = ["stage3", "stage6", "stage9", "stage12"],
                                                    reshape_hidden_states = False)
        elif cfg.visual_branch.backbone == 'dinov2-base':
            self.dinov2 = AutoBackbone.from_pretrained("facebook/dinov2-base", out_features = ["stage3", "stage6", "stage9", "stage12"],
                                                    reshape_hidden_states = False)
        elif cfg.visual_branch.backbone == 'dinov2-large':
            self.dinov2 = AutoBackbone.from_pretrained("facebook/dinov2-large", out_features = ["stage6", "stage12", "stage18", "stage24"],
                                                    reshape_hidden_states = False)

    def forward(self, data_dict):
        ref_img = data_dict['ref_img']
        src_img = data_dict['src_img']
        with torch.no_grad():
            ref_img_norm = self.imgprocess(images = ref_img, return_tensors = "pt").pixel_values.to(self.dinov2.device)
            src_img_norm = self.imgprocess(images = src_img, return_tensors = "pt").pixel_values.to(self.dinov2.device)
            ref_dino_feats = self.dinov2(ref_img_norm).feature_maps
            src_dino_feats = self.dinov2(src_img_norm).feature_maps

        data_dict['ref_patch_size'] = (ref_img_norm.shape[2] // self.patch_size, ref_img_norm.shape[3] // self.patch_size)
        data_dict['src_patch_size'] = (src_img_norm.shape[2] // self.patch_size, src_img_norm.shape[3] // self.patch_size)
        data_dict['ref_dino'] = ref_dino_feats
        data_dict['src_dino'] = src_dino_feats

        return data_dict