import torch
import torchvision
import torch.nn as nn



class DPTHead(nn.Module):
    def __init__(
        self, 
        in_channels, 
        inner_channels, 
        use_clstoken=False,
        out_channel=1024,
    ):
        super(DPTHead, self).__init__()
        
        self.use_clstoken = use_clstoken
        self.projects = nn.ModuleList([
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channel,
                kernel_size=1,
                stride=1,
                padding=0,
            ) for out_channel in inner_channels
        ])
        
        if use_clstoken:
            self.readout_projects = nn.ModuleList()
            for _ in range(len(self.projects)):
                self.readout_projects.append(
                    nn.Sequential(
                        nn.Linear(2 * in_channels, in_channels),
                        nn.GELU()))
        
        self.output_conv = nn.Conv2d(sum(inner_channels) , out_channel, kernel_size=1, stride=1, padding=0)
        
    
    def forward(self, out_features, patch_h, patch_w):
        out = []
        for i, x in enumerate(out_features):
            if self.use_clstoken:
                x, cls_token = x[0], x[1]
                readout = cls_token.unsqueeze(1).expand_as(x)
                x = self.readout_projects[i](torch.cat((x, readout), -1))
            else:
                x = x[0]
            
            x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
            
            x = self.projects[i](x)
            
            out.append(x)
        
        fusion_feats = torch.cat(out, dim=1)        

        fusion_feats = self.output_conv(fusion_feats)
        
        return fusion_feats

class DINOBase(nn.Module):
    def __init__(self, output_dim=128, only_global=False):
        super().__init__()
        self.only_global = only_global
        self.dino_model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14', pretrained=True)
        self.dino_normlize = torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )
        if not only_global:
            in_dim = self.dino_model.blocks[0].attn.qkv.in_features # 768
            hidden_dims=256
            out_dims=[256, 512, 1024, 1024]
            # modules
            self.projects = nn.ModuleList([
                nn.Conv2d(
                    in_dim, out_dim, kernel_size=1, stride=1, padding=0,
                ) for out_dim in out_dims
            ])
            self.resize_layers = nn.ModuleList([
                nn.ConvTranspose2d(
                    out_dims[0], out_dims[0], kernel_size=4, stride=4, padding=0
                ),
                nn.ConvTranspose2d(
                    out_dims[1], out_dims[1], kernel_size=2, stride=2, padding=0
                ),
                nn.Identity(),
                nn.Conv2d(
                    out_dims[3], out_dims[3], kernel_size=3, stride=2, padding=1
                )
            ])
            self.layer_rn = nn.ModuleList([
                nn.Conv2d(out_dims[0]+3, hidden_dims, kernel_size=3, stride=1, padding=1, bias=False),
                nn.Conv2d(out_dims[1]+3, hidden_dims, kernel_size=3, stride=1, padding=1, bias=False),
                nn.Conv2d(out_dims[2]+3, hidden_dims, kernel_size=3, stride=1, padding=1, bias=False),
                nn.Conv2d(out_dims[3]+3, hidden_dims, kernel_size=3, stride=1, padding=1, bias=False),
            ])

            self.refinenet = nn.ModuleList([
                FeatureFusionBlock(hidden_dims, nn.ReLU(False)),
                FeatureFusionBlock(hidden_dims, nn.ReLU(False)),
                FeatureFusionBlock(hidden_dims, nn.ReLU(False)),
                FeatureFusionBlock(hidden_dims, nn.ReLU(False)),
            ])
            self.output_conv = nn.Conv2d(hidden_dims, output_dim, kernel_size=3, stride=1, padding=1)
    
    def forward(self, images, output_size=None):
        images = self.dino_normlize(images)
   

        patch_h, patch_w = images.shape[-2]//14, images.shape[-1]//14
        if self.only_global:
            image_features = self.dino_model.get_intermediate_layers(images, 1)
            out_global = image_features[0][:, 0]
            return out_global
        else:
            image_features = self.dino_model.get_intermediate_layers(images, 4) # four tensors of size  (bs,1369,768)
            out_features = []
            for i, feature in enumerate(image_features):
                feature = feature.permute(0, 2, 1).reshape(
                    (feature.shape[0], feature.shape[-1], patch_h, patch_w)
                ) # (bs,768,37,37)
                feature = self.projects[i](feature)
                feature = self.resize_layers[i](feature)
                feature = torch.cat([
                        torchvision.transforms.functional.resize(images, (feature.shape[-2], feature.shape[-1]), antialias=True).detach(),
                        feature
                    ], dim=1
                )
                out_features.append(feature)
            layer_rns = []
            for i, feature in enumerate(out_features):
                layer_rns.append(self.layer_rn[i](feature))

            path_4 = self.refinenet[0](layer_rns[3], size=layer_rns[2].shape[2:])
            path_3 = self.refinenet[1](path_4, layer_rns[2], size=layer_rns[1].shape[2:])
            path_2 = self.refinenet[2](path_3, layer_rns[1], size=layer_rns[0].shape[2:])
            path_1 = self.refinenet[3](path_2, layer_rns[0])
            out = self.output_conv(path_1)
            if output_size is not None:
                out = nn.functional.interpolate(out, output_size, mode="bilinear", align_corners=True) # (bs,256,296,296)
            out_global = image_features[-1][:, 0] # （bs, 768)
            out_multiscale = image_features[-1]  # （bs, 1369， 768)
            return out, out_global, out_multiscale


class ResidualConvUnit(nn.Module):
    """Residual convolution module.
    """

    def __init__(self, features, activation, bn):
        """Init.

        Args:
            features (int): number of features
        """
        super().__init__()

        self.bn = bn

        self.groups=1

        self.conv1 = nn.Conv2d(
            features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
        )
        
        self.conv2 = nn.Conv2d(
            features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
        )

        if self.bn==True:
            self.bn1 = nn.BatchNorm2d(features)
            self.bn2 = nn.BatchNorm2d(features)

        self.activation = activation

        self.skip_add = nn.quantized.FloatFunctional()

    def forward(self, x):
        """Forward pass.

        Args:
            x (tensor): input

        Returns:
            tensor: output
        """
        
        out = self.activation(x)
        out = self.conv1(out)
        if self.bn==True:
            out = self.bn1(out)
       
        out = self.activation(out)
        out = self.conv2(out)
        if self.bn==True:
            out = self.bn2(out)

        if self.groups > 1:
            out = self.conv_merge(out)

        return self.skip_add.add(out, x)


class FeatureFusionBlock(nn.Module):
    """Feature fusion block.
    """

    def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None):
        """Init.

        Args:
            features (int): number of features
        """
        super(FeatureFusionBlock, self).__init__()

        self.deconv = deconv
        self.align_corners = align_corners

        self.groups=1

        self.expand = expand
        out_features = features
        if self.expand==True:
            out_features = features//2
        
        self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)

        self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
        self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
        
        self.skip_add = nn.quantized.FloatFunctional()

        self.size=size

    def forward(self, *xs, size=None):
        """Forward pass.

        Returns:
            tensor: output
        """
        output = xs[0]

        if len(xs) == 2:
            res = self.resConfUnit1(xs[1])
            output = self.skip_add.add(output, res)

        output = self.resConfUnit2(output)

        if (size is None) and (self.size is None):
            modifier = {"scale_factor": 2}
        elif size is None:
            modifier = {"size": self.size}
        else:
            modifier = {"size": size}
        output = nn.functional.interpolate(
            output, **modifier, mode="bilinear", align_corners=self.align_corners
        )
        output = self.out_conv(output)
        return output
