# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
import torch.nn as nn

from functools import partial
from mmcv_custom import load_checkpoint
from mmdet.utils import get_root_logger
from mmdet.models.builder import BACKBONES
import models


@BACKBONES.register_module()
class HiRep(models.HiRep):
    def __init__(self,
                 img_size=224,
                 patch_size=16,
                 in_chans=3,
                 embed_dim=512, 
                 mlp_depth=3, 
                 depth=24, 
                 num_heads=8, 
                 bridge_mlp_ratio=3., 
                 mlp_ratio=4., 
                 num_outs=5,
                 out_embed_dim = 256,
                 drop_path_rate=0.1,
                 norm_layer=partial(nn.LayerNorm, eps=1e-6), 
                 ape=True, rpe=True, 
                 patch_norm=True,
                 use_checkpoint=False,
                 init_cfg=None,
                 **kwargs):
        super(HiRep, self).__init__(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim, 
            mlp_depth=mlp_depth, 
            depth=depth, 
            num_heads=num_heads, 
            bridge_mlp_ratio=bridge_mlp_ratio, 
            mlp_ratio=mlp_ratio,
            num_outs=num_outs,
            drop_path_rate=drop_path_rate,
            norm_layer=norm_layer,
            ape=ape, rpe=rpe,
            patch_norm=patch_norm,
            use_checkpoint=use_checkpoint,
            **kwargs)
        self.init_cfg = init_cfg

        del self.num_classes

        def get_dim(i):
            if i == 0:
                return embed_dim // 4
            elif i == 1:
                return embed_dim // 2
            else:
                return embed_dim
        
        self.out_embed = nn.ModuleList(
            nn.Sequential(
                norm_layer(get_dim(i)),
                nn.Linear(get_dim(i), out_embed_dim),
            )
            for i in range(num_outs)
        )

    def init_weights(self):
        if self.init_cfg is None:
            raise ValueError
        else:
            assert 'checkpoint' in self.init_cfg, f'Only support ' \
                                                  f'specify `Pretrained` in ' \
                                                  f'`init_cfg` in ' \
                                                  f'{self.__class__.__name__} '
            self.apply(self._init_weights)
            pretrained = self.init_cfg['checkpoint']
            logger = get_root_logger()
            if os.path.isfile(pretrained):
                load_checkpoint(self, pretrained, strict=False, logger=logger)
            else:
                raise ValueError(f"checkpoint path {pretrained} is invalid")
    
    def forward(self, x):
        features = self.forward_features(x)
        features = list(reversed(features))
        features = [
            embed(feat.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) 
            for feat, embed in zip(features, self.out_embed)
        ]
        return tuple(features)
