
import torch
import torch.nn as nn
from torch import Tensor
# from torchvision.models.resnet import resnet50
from .resnet_wide import Bottleneck, BasicBlock, resnet18, resnet50x1, resnet50x2, resnet50x4

RESNET_DICT = {
    'resnet18':resnet18,
    'resnet50x1':resnet50x1,
    'resnet50x2':resnet50x2,
    'resnet50x4':resnet50x4,
}

class ResnetBackbone(nn.Module):
    """Take the feature embedding as input, output the feature for classification
    """
    def __init__(self, model_type='resnet50x1'):
        super().__init__()
        assert model_type in RESNET_DICT, f"Unsupported resnet type {model_type}"
        model_func = RESNET_DICT[model_type]
        _resnet50 = model_func()
        self.layer1 = _resnet50.layer1
        self.layer2 = _resnet50.layer2
        self.layer3 = _resnet50.layer3
        self.layer4 = _resnet50.layer4
        self.avgpool = _resnet50.avgpool
        self.init_weights(zero_init_residual=True)

    def load_from_resnet(self, resnet_model):
        self.layer1 = resnet_model.layer1
        self.layer2 = resnet_model.layer2
        self.layer3 = resnet_model.layer3
        self.layer4 = resnet_model.layer4
        self.avgpool = resnet_model.avgpool

    def init_weights(self, zero_init_residual=False):
        # Actually do not need
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)
        
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)

        return x

class FCNResnetBackbone(ResnetBackbone):
    """Take the feature embedding as input, output the feature for classification
    """
    
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        # x = self.avgpool(x)
        # x = torch.flatten(x, 1)
        return x


class VideoResnetBackbone(ResnetBackbone):
    """Take the feature embedding as input, output the feature for classification
    """
    
    def forward(self, x):
        """Input:
            x:[B,F,C,H,W]
        """
        B, F, C, H, W = x.size()
        x = x.reshape(B * F, C, H, W)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        _, C, H, W = x.size()
        # x = self.avgpool(x)
        # x = torch.flatten(x, 1)
        return x.reshape(B, F, C, H, W)
