import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor



class SmallNet(nn.Module):
    def __init__(self):
        super(SmallNet, self).__init__()

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(64)
        self.fc = nn.Linear(1024,10)
        self.fc1 = nn.Linear(4096, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.conv1(x)
        out = self.relu(self.bn1(out))
        out = self.relu(self.bn2(self.conv2(out)))
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        # out = self.relu(self.fc1(out))
        # out = self.relu(self.fc2(out))
        return self.fc(out)

'''Pre-activation ResNet in PyTorch.
Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
    Identity Mappings in Deep Residual Networks. arXiv:1603.05027
'''


track_running_stats=True
affine=True
normal_func = nn.BatchNorm2d

# track_running_stats=False
# affine=True
# normal_func = nn.InstanceNorm2d


if not track_running_stats:
    print('BN track False')


# class NewBN(torch.autograd.Function):

                                        

class PreActBlock(nn.Module):
    '''Pre-activation version of the BasicBlock.'''
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, activation='ReLU', softplus_beta=1):
        super(PreActBlock, self).__init__()
        self.bn1 = normal_func(in_planes, track_running_stats=track_running_stats, affine=affine)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = normal_func(planes, track_running_stats=track_running_stats, affine=affine)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        # self.means = []

        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
            )
        if activation == 'ReLU':
            self.relu = nn.ReLU(inplace=True)
            print('ReLU')
        elif activation == 'Softplus':
            self.relu = nn.Softplus(beta=softplus_beta, threshold=20)
            print('Softplus')
        elif activation == 'GELU':
            self.relu = nn.GELU()
            print('GELU')
        elif activation == 'ELU':
            self.relu = nn.ELU(alpha=1.0, inplace=True)
            print('ELU')
        elif activation == 'LeakyReLU':
            self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
            print('LeakyReLU')
        elif activation == 'SELU':
            self.relu = nn.SELU(inplace=True)
            print('SELU')
        elif activation == 'CELU':
            self.relu = nn.CELU(alpha=1.2, inplace=True)
            print('CELU')
        elif activation == 'Tanh':
            self.relu = nn.Tanh()
            print('Tanh')
        elif activation == 'SiLU':
            self.relu = nn.SiLU()
            print('SiLU')


    def forward(self, x):
        # self.means1 = x.detach().cpu().mean(dim=0).abs().max()
        # self.var1 = x.detach().cpu().var(dim=0).abs().max()
        out = self.relu(self.bn1(x))
        shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
        out = self.conv1(out)
        # self.means2 = out.detach().cpu().mean(dim=0).abs().max()
        # self.var2 = out.detach().cpu().var(dim=0).abs().max()
        out = self.conv2(self.relu(self.bn2(out)))
        out += shortcut
        return out


class PreActBottleneck(nn.Module):
    '''Pre-activation version of the original Bottleneck module.'''
    expansion = 4

    def __init__(self, in_planes, planes, stride=1, activation='ReLU', softplus_beta=1):
        super(PreActBottleneck, self).__init__()
        self.bn1 = normal_func(in_planes, track_running_stats=track_running_stats, affine=affine)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn2 = normal_func(planes, track_running_stats=track_running_stats, affine=affine)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn3 = normal_func(planes, track_running_stats=track_running_stats, affine=affine)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)

        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
            )

    def forward(self, x):
        out = F.relu(self.bn1(x))
        shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
        out = self.conv1(out)
        out = self.conv2(F.relu(self.bn2(out)))
        out = self.conv3(F.relu(self.bn3(out)))
        out += shortcut
        return out


class PreActResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10, normalize = False, normalize_only_FN = False, scale = 15, activation='ReLU', softplus_beta=1):
        super(PreActResNet, self).__init__()
        self.in_planes = 64

        self.normalize = normalize
        self.normalize_only_FN = normalize_only_FN
        self.scale = scale

        self.activation = activation
        self.softplus_beta = softplus_beta

        # self.conv0 = nn.Conv2d(3, 3, kernel_size=2, stride=2, bias=False)
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.bn = normal_func(512 * block.expansion, track_running_stats=track_running_stats, affine=affine)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        if self.normalize:
            self.linear = nn.Linear(512*block.expansion, num_classes, bias=False)
        else:
            self.linear = nn.Linear(512*block.expansion, num_classes)


        if activation == 'ReLU':
            self.relu = nn.ReLU(inplace=True)
            print('ReLU')
        elif activation == 'Softplus':
            self.relu = nn.Softplus(beta=softplus_beta, threshold=20)
            print('Softplus')
        elif activation == 'GELU':
            self.relu = nn.GELU()
            print('GELU')
        elif activation == 'ELU':
            self.relu = nn.ELU(alpha=1.0, inplace=True)
            print('ELU')
        elif activation == 'LeakyReLU':
            self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
            print('LeakyReLU')
        elif activation == 'SELU':
            self.relu = nn.SELU(inplace=True)
            print('SELU')
        elif activation == 'CELU':
            self.relu = nn.CELU(alpha=1.2, inplace=True)
            print('CELU')
        elif activation == 'Tanh':
            self.relu = nn.Tanh()
            print('Tanh')
        elif activation == 'SiLU':
            self.relu = nn.SiLU()
            print('SiLU')

        print('Use activation of ' + activation)


    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride, 
                activation=self.activation, softplus_beta=self.softplus_beta))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        # with torch.no_grad():
        # x = self.conv0(x)
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        # self.mean_buffer = out.detach().cpu().mean(dim=0).abs().max()
        # self.var_buffer = out.detach().cpu().var(dim=0).abs().max()
        # out = self.relu(self.bn(out))
        out = F.avg_pool2d(out, 4)
        # out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        if self.normalize_only_FN:
            out = F.normalize(out, p=2, dim=1)

        if self.normalize:
            out = F.normalize(out, p=2, dim=1) * self.scale
            for _, module in self.linear.named_modules():
                if isinstance(module, nn.Linear):
                    module.weight.data = F.normalize(module.weight, p=2, dim=1)
        return self.linear(out)


def Small18(num_classes=10, normalize = False, normalize_only_FN = False, scale = 15, activation='ReLU', softplus_beta=1):
    return PreActResNet(PreActBlock, [2,2,2,2], num_classes=num_classes, normalize = normalize
        , normalize_only_FN = normalize_only_FN, scale = scale, activation=activation, softplus_beta=softplus_beta)

def PreActResNet34():
    return PreActResNet(PreActBlock, [3,4,6,3])

def PreActResNet50():
    return PreActResNet(PreActBottleneck, [3,4,6,3])

def PreActResNet101():
    return PreActResNet(PreActBottleneck, [3,4,23,3])

def PreActResNet152():
    return PreActResNet(PreActBottleneck, [3,8,36,3])


def test():
    net = PreActResNet18()
    y = net((torch.randn(1,3,32,32)))
    print(y.size())

# test()