# from torchvision.models import wide_resnet101_2
from torchvision.models.resnet import Bottleneck, ResNet

class WideResnet101(ResNet):
    def __init__(self,
                 block=Bottleneck,
                 layers=[3, 4, 23, 3],
                 num_classes=1000):
        
        super(WideResnet101, self).__init__(block=block,
                                       layers=layers,
                                       num_classes=num_classes, width_per_group=128)
        self.feature_size = 2048


    def forward(self, x, return_feature=False, return_feature_list=False):

        feature1 = self.relu(self.bn1(self.conv1(x)))
        feature1 = self.maxpool(feature1)
        feature2 = self.layer1(feature1)
        feature3 = self.layer2(feature2)
        feature4 = self.layer3(feature3)
        feature5 = self.layer4(feature4)
        feature5 = self.avgpool(feature5)
        feature = feature5.view(feature5.size(0), -1)
        logits_cls = self.fc(feature)

        feature_list = [feature1, feature2, feature3, feature4, feature5]
        if return_feature:
            return logits_cls, feature
        elif return_feature_list:
            return logits_cls, feature_list
        else:
            return logits_cls

    # def forward_threshold(self, x, threshold):
    #     feature1 = self.relu(self.bn1(self.conv1(x)))
    #     feature1 = self.maxpool(feature1)
    #     feature2 = self.layer1(feature1)
    #     feature3 = self.layer2(feature2)
    #     feature4 = self.layer3(feature3)
    #     feature5 = self.layer4(feature4)
    #     feature5 = self.avgpool(feature5)
    #     feature = feature5.clip(max=threshold)
    #     feature = feature.view(feature.size(0), -1)
    #     logits_cls = self.fc(feature)

    #     return logits_cls, feature

    # def get_fc(self):
    #     fc = self.fc
    #     return fc.weight.cpu().detach().numpy(), fc.bias.cpu().detach().numpy()
