"""https://github.com/HobbitLong/RepDistiller/blob/master/models/vgg.py
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from bypass.core.activation import ActivationForBypass,ActivationForDx2
from bypass.core.detect import BypassModel
__all__ = [
    'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
    'vgg19_bn', 'vgg19',
]


model_urls = {
    'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
    'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
    'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
    'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
}

VGG_CFG = {
    'A': [[64], [128], [256, 256], [512, 512], [512, 512]],
    'B': [[64, 64], [128, 128], [256, 256], [512, 512], [512, 512]],
    'D': [[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]],
    'E': [[64, 64], [128, 128], [256, 256, 256, 256], [512, 512, 512, 512], [512, 512, 512, 512]],
    'S': [[64], [128], [256], [512], [512]],
}

class VGG(nn.Module):

    def __init__(self, cfg, batch_norm=False, num_classes=1000):
        super(VGG, self).__init__()
        self.block0 = self._make_layers(cfg[0], batch_norm, 3)
        self.block1 = self._make_layers(cfg[1], batch_norm, cfg[0][-1])
        self.block2 = self._make_layers(cfg[2], batch_norm, cfg[1][-1])
        self.block3 = self._make_layers(cfg[3], batch_norm, cfg[2][-1])
        self.block4 = self._make_layers(cfg[4], batch_norm, cfg[3][-1])

        self.pool0 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.pool4 = nn.AdaptiveAvgPool2d((1, 1))

        self.relu0 = nn.ReLU(inplace=False)
        self.relu1 = nn.ReLU(inplace=False)
        self.relu2 = nn.ReLU(inplace=False)
        self.relu3 = nn.ReLU(inplace=False)
        self.relu4 = nn.ReLU(inplace=False)
        # self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.flatten=torch.nn.Flatten()

        self.classifier = nn.Linear(512, num_classes)
        self._initialize_weights()

    def get_feat_modules(self):
        feat_m = nn.ModuleList([])
        feat_m.append(self.block0)
        feat_m.append(self.pool0)
        feat_m.append(self.block1)
        feat_m.append(self.pool1)
        feat_m.append(self.block2)
        feat_m.append(self.pool2)
        feat_m.append(self.block3)
        feat_m.append(self.pool3)
        feat_m.append(self.block4)
        feat_m.append(self.pool4)
        return feat_m

    def get_bn_before_relu(self):
        bn1 = self.block1[-1]
        bn2 = self.block2[-1]
        bn3 = self.block3[-1]
        bn4 = self.block4[-1]
        return [bn1, bn2, bn3, bn4]

    def forward(self, x, return_features=False):
        h = x.shape[2]
        x = self.relu0(self.block0(x))
        x = self.pool0(x)
        x = self.block1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.block2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        x = self.block3(x)
        x = self.relu3(x)
        if h == 64:
            x = self.pool3(x)
        x = self.block4(x)
        x = self.relu4(x)
        x = self.pool4(x)
        features = self.flatten(x)
        # features = x.view(x.size(0), -1)
        x = self.classifier(features)
        if return_features:
            return x, features
        else:
            return x

    @staticmethod
    def _make_layers(cfg, batch_norm=False, in_channels=3):
        layers = []
        for v in cfg:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
                if batch_norm:
                    layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU()]
                else:
                    layers += [conv2d, nn.ReLU()]
                in_channels = v
        layers = layers[:-1]
        return nn.Sequential(*layers)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(1)
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()
 

def vgg8(**kwargs):
    """VGG 8-layer model (configuration "S")
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = VGG(VGG_CFG['S'], **kwargs)
    return model


def vgg8_bn(**kwargs):
    """VGG 8-layer model (configuration "S")
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = VGG(VGG_CFG['S'], batch_norm=True, **kwargs)
    return model


def vgg11(**kwargs):
    """VGG 11-layer model (configuration "A")
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = VGG(VGG_CFG['A'], **kwargs)
    return model


def vgg11_bn(**kwargs):
    """VGG 11-layer model (configuration "A") with batch normalization"""
    model = VGG(VGG_CFG['A'], batch_norm=True, **kwargs)
    return model


def vgg13(**kwargs):
    """VGG 13-layer model (configuration "B")
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = VGG(VGG_CFG['B'], **kwargs)
    return model


def vgg13_bn(**kwargs):
    """VGG 13-layer model (configuration "B") with batch normalization"""
    model = VGG(VGG_CFG['B'], batch_norm=True, **kwargs)
    return model


def vgg16(**kwargs):
    """VGG 16-layer model (configuration "D")
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = VGG(VGG_CFG['D'], **kwargs)
    return model


def vgg16_bn(**kwargs):
    """VGG 16-layer model (configuration "D") with batch normalization"""
    model = VGG(VGG_CFG['D'], batch_norm=True, **kwargs)
    return model


def vgg19(**kwargs):
    """VGG 19-layer model (configuration "E")
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = VGG(VGG_CFG['E'], **kwargs)
    return model


def vgg19_bn(**kwargs):
    """VGG 19-layer model (configuration 'E') with batch normalization"""
    model = VGG(VGG_CFG['E'], batch_norm=True, **kwargs)
    return model

# class cifar100_BypassVGG19(BypassVGG):
#     def __init__(self,**kwargs):
#         super().__init__(VGG_CFG['E'],batch_norm=False,num_classes=100)
class BNRelu(nn.BatchNorm2d):
    relu = nn.ReLU()
    def forward(self,x):
        ret = super().forward(x)
        return self.relu(ret)
class BypassVGG(VGG):
    src_modules=[torch.nn.ReLU]
    bypass_wrapper=ActivationForBypass
    def __init__(self, cfg, batch_norm=False, num_classes=1000,input_shape=(3,32,32)):
        super().__init__(cfg, batch_norm, num_classes)
        BypassModel(self,input_shape,self.src_modules,self.bypass_wrapper)
        
class cifar100_BypassVGG19(BypassVGG):
    src_modules=[torch.nn.ReLU]
    bypass_wrapper=ActivationForBypass
    def __init__(self, **kwargs):
        super().__init__(VGG_CFG['E'], batch_norm=True, num_classes=100,**kwargs)
class cifar100_NaiveBypassVGG19(BypassVGG):
    src_modules=[BNRelu]
    bypass_wrapper=ActivationForBypass
    def __init__(self, **kwargs):
        super().__init__(VGG_CFG['E'], batch_norm=True, num_classes=100,**kwargs)
    @staticmethod
    def _make_layers(cfg, batch_norm=False, in_channels=3):
        layers = []
        for v in cfg:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
                if batch_norm:
                    layers += [conv2d, BNRelu(v),nn.Identity()]
                else:
                    layers += [conv2d, nn.ReLU()]
                in_channels = v
        layers = layers[:-1]
        return nn.Sequential(*layers)
if __name__ == '__main__':
    import torch
    ckpt_path = '/workspace/jaeheun_MildPruning/save_dx2/mild_pruning_W/cifar100_vgg19/baseline/save/cifar100_vgg19.pth'
    x = torch.randn(2, 3, 32, 32)
    net = vgg19_bn(num_classes=100)
    feats, logit = net(x)

    net = cifar100_BypassVGG19()
    print(1)
    # for f in feats:
    #     print(f.shape, f.min().item())
    # print(logit.shape)

    # net.load_state_dict(torch.load(ckpt_path))

    # for m in net.get_bn_before_relu():
    #     if isinstance(m, nn.BatchNorm2d):
    #         print('pass')
    #     else:
    #         print('warning')