import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
sys.path.append('..')
from model.bif import BiF_Conv2d, BiF_Linear
from torch.nn.modules.container import Sequential


def replace_modules(model, train_mask):
    for name, module in model._modules.items():
        if isinstance(module, nn.Conv2d):
            setattr(model, name, BiF_Conv2d(module.in_channels, module.out_channels, module.kernel_size[0], module.stride[0], module.padding, False, train_mask))
        if isinstance(module, nn.Linear):
            setattr(model, name, BiF_Linear(module.in_features, module.out_features, False, train_mask))
        if isinstance(module, Sequential) or isinstance(module, BasicBlock):
            replace_modules(module, train_mask)

def reset_module_masks(model, mask):
    for name, module in model.named_modules():
        if isinstance(module, BiF_Conv2d) or isinstance(module, BiF_Linear):
            module.reset_mask(mask)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)), inplace=True)
        out = self.bn2(self.conv2(out))
        # out = F.relu(self.conv1(x), inplace=True)
        # out = self.conv2(out)
        out += self.shortcut(x)
        out = F.relu(out, inplace=True)
        return out

class ResNet18(nn.Module):
    def __init__(self, block, layers, num_classes=10):
        super(ResNet18, self).__init__()
        channels=[64,128,256,512]
        self.in_channels = channels[0]

        self.conv1 = nn.Conv2d(3, channels[0], kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels[0])
        self.relu = nn.ReLU(inplace=True)        

        self.layer1 = self.make_layer(block, channels[0], layers[0])
        self.layer2 = self.make_layer(block, channels[1], layers[1], stride=2)
        self.layer3 = self.make_layer(block, channels[2], layers[2], stride=2)
        self.layer4 = self.make_layer(block, channels[3], layers[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(channels[3] * block.expansion, num_classes, bias=False)

        self.weight_init()

    def weight_init(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    m.bias.data.fill_(0.0)

    def make_layer(self, block, out_channels, blocks, stride=1):

        layers = [block(self.in_channels, out_channels, stride=stride)]
        self.in_channels = out_channels * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.in_channels, out_channels))

        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)

        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        out = self.fc(out)

        return out

class BiF_ResNet18(ResNet18):
    def __init__(self, block, layers, num_classes=10, train_mask=None):
        super(BiF_ResNet18, self).__init__(block, layers, num_classes)
        if isinstance(train_mask, str):
            train_mask=[int(i) for i in list(train_mask)]
        self.train_mask = train_mask
        replace_modules(self, self.train_mask)
    
    def reset_masks(self, mask):
        self.train_mask = mask
        reset_module_masks(self, self.train_mask)

def resnet18(train_mask=None):
    if train_mask==None:
        return ResNet18(BasicBlock, [2, 2, 2, 2])
    else:
        return BiF_ResNet18(BasicBlock, [2, 2, 2, 2], train_mask=train_mask)

class ResNet34(nn.Module):
    def __init__(self, block, layers, num_classes=100):
        super(ResNet34, self).__init__()
        channels=[64,128,256,512]
        self.in_channels = channels[0]

        self.conv1 = nn.Conv2d(3, channels[0], kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels[0])
        self.relu = nn.ReLU(inplace=True)        

        self.layer1 = self.make_layer(block, channels[0], layers[0])
        self.layer2 = self.make_layer(block, channels[1], layers[1], stride=2)
        self.layer3 = self.make_layer(block, channels[2], layers[2], stride=2)
        self.layer4 = self.make_layer(block, channels[3], layers[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(channels[3] * block.expansion, num_classes, bias=False)

        self.weight_init()

    def weight_init(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    m.bias.data.fill_(0.0)

    def make_layer(self, block, out_channels, blocks, stride=1):

        layers = [block(self.in_channels, out_channels, stride=stride)]
        self.in_channels = out_channels * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.in_channels, out_channels))

        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)

        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        out = self.fc(out)

        return out

class BiF_ResNet34(ResNet34):
    def __init__(self, block, layers, num_classes=100, train_mask=None):
        super(BiF_ResNet34, self).__init__(block, layers, num_classes)
        if isinstance(train_mask, str):
            train_mask=[int(i) for i in list(train_mask)]
        self.train_mask = train_mask
        replace_modules(self, self.train_mask)
    
    def reset_masks(self, mask):
        self.train_mask = mask
        reset_module_masks(self, self.train_mask)

def resnet34(train_mask=None):
    if train_mask==None:
        return ResNet34(BasicBlock, [3,4,6,3])
    else:
        return BiF_ResNet34(BasicBlock, [3,4,6,3], train_mask=train_mask)

if __name__ == "__main__":
    model=resnet18()
    print(model)
    print("done")