import torch.nn as nn

import os
dirname = os.path.dirname(__file__)
import sys
sys.path.append(os.path.join(dirname, '../'))

from resnet import *
from resnet_masked import *


######## ResNet without any pruning 

def ResNet18():
    return ResNet(BasicBlock, [2,2,2,2], n_class=10)

def ResNet18_gn():
    return ResNet(BasicBlock, [2,2,2,2], n_class=10, norm_layer=GroupNorm)

def ResNet18_in():
    return ResNet(BasicBlock, [2,2,2,2], n_class=10, norm_layer=InstanceNorm)


######## ResNet with element-wise pruning 

def BetaResNet18():
    return BetaResNet(MaskedBasicBlock, [2,2,2,2], n_class=10)

def BetaResNet18_sbn():
    return BetaResNet(MaskedBasicBlock, [2,2,2,2], n_class=10, norm_layer=MaskedBatchNorm_no_tracking)

def BetaResNet18_gn():
    return BetaResNet(MaskedBasicBlock, [2,2,2,2], n_class=10, norm_layer=MaskedGroupNorm)

def BetaResNet18_in():
    return BetaResNet(MaskedBasicBlock, [2,2,2,2], n_class=10, norm_layer=MaskedInstanceNorm)

def BetaResNet34_sbn():
    return BetaResNet(MaskedBasicBlock, [3,4,6,3], n_class=10, norm_layer=MaskedBatchNorm_no_tracking)

def BetaResNet50_sbn():
    return BetaResNet(MaskedBottleneck, [3,4,23,3], n_class=10, norm_layer=MaskedBatchNorm_no_tracking)

def BetaResNet110_sbn():
    return BetaResNet(MaskedBasicBlock, [18,18,18,18], n_class=10, norm_layer=MaskedBatchNorm_no_tracking)
