import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from collections import OrderedDict
import torch
from .Normalize import Normalize

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride, activate_before_residual=False):

        super().__init__()
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.activate_before_residual = activate_before_residual

        if in_channels!=out_channels:
            self.avg_pool = nn.AvgPool2d(stride, stride)

    def forward(self, x):
        if self.activate_before_residual:
            x = F.relu(self.bn1(x), True)
            y = self.conv1(x)
        else:
            y = F.relu(self.bn1(x), True)
            y = self.conv1(y)

        y = F.relu(self.bn2(y), True)
        y = self.conv2(y)

        if self.in_channels != self.out_channels:
            x = self.avg_pool(x)
            x = F.pad(x, [0, 0, 0, 0, (self.out_channels-self.in_channels)//2, (self.out_channels-self.in_channels)//2], "constant", 0)

        return x+y


class WResnet_Rotate_VFlip(nn.Module):

    def __init__(self, mean, std):

        super().__init__()
        strides = [1, 2, 2]
        activate_before_residual = [True, False, False]
        filters = [16, 160, 320, 640]
        nlabels = 10

        self.normalize = Normalize(mean, std)

        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1, bias=False)
        self.layer1 = self.block(filters[0], filters[1], strides[0], activate_before_residual[0])
        self.layer2 = self.block(filters[1], filters[2], strides[1], activate_before_residual[1])
        self.layer3 = self.block(filters[2], filters[3], strides[2], activate_before_residual[2])
        self.last = nn.Sequential(OrderedDict([
            ('bn', nn.BatchNorm2d(filters[3])),
            ('relu', nn.ReLU(True))
        ]))

        self.fc1 = nn.Linear(filters[3], nlabels)
        self.fc2 = nn.Linear(filters[3], 4)
        self.fc3 = nn.Linear(filters[3], 2)

    def block(self, in_channels, out_channels, stride, activate_before_residual=False):
        block = []
        block.append(ResidualBlock(in_channels, out_channels, stride, activate_before_residual))

        for i in range(1, 5):
            block.append(ResidualBlock(out_channels, out_channels, 1, False))

        return nn.Sequential(*block)

    def forward(self, x, branch='label'):
        x = self.normalize(x)
        x = self.conv1(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.last(x)
        x = x.mean(dim=(2, 3))
        if branch == 'label':
            return self.fc1(x)
        elif branch == 'rotate':
            return self.fc2(x)
        elif branch == 'vflip':
            return self.fc3(x)
        else:
            return self.fc1(x), self.fc2(x), self.fc3(x)
