import torch
import torch.nn as nn
import torch.nn.functional as F
from .sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
from .aspp import build_aspp
from .decoder import build_decoder
from .backbone import build_backbone

class conv_block(nn.Module):
    """
    Convolution Block 
    """

    def __init__(self, in_ch, out_ch, kernel_size=3, bn=None):
        super().__init__()

        padding = kernel_size // 2

        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size,
                      stride=1, padding=padding, bias=False),
            bn(out_ch),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Conv2d(out_ch, out_ch, kernel_size=kernel_size,
                      stride=1, padding=padding, bias=False),
            bn(out_ch),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Conv2d(out_ch, out_ch, kernel_size=kernel_size,
                      stride=1, padding=padding, bias=False)
            )

    def forward(self, x):
        x = self.conv(x)
        return x


class up_conv(nn.Module):
    """
    Up Convolution Block
    """

    def __init__(self, in_ch, out_ch, bn=None):
        super().__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="bilinear",  align_corners=True),
            nn.Conv2d(in_ch, out_ch, kernel_size=3,
                      stride=1, padding=1, bias=False),
            bn(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.up(x)
        return x


class Head(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """

    def __init__(self, in_ch=3, out_ch=1, n_filters=32, bn=None):
        super().__init__()

        n1 = n_filters
        filters = [n1 * i for i in [1, 2, 4, 8, 16]]

        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = conv_block(in_ch, filters[0], bn=bn)
        self.Conv2 = conv_block(filters[0], filters[1], bn=bn)
        self.Conv3 = conv_block(filters[1], filters[2], bn=bn)
        self.Conv4 = conv_block(filters[2], filters[3], bn=bn)
        self.Conv5 = conv_block(filters[3], filters[4], bn=bn)

        self.Up5 = up_conv(filters[4], filters[3], bn=bn)
        self.Up_conv5 = conv_block(filters[4], filters[3], bn=bn)

        self.Up4 = up_conv(filters[3], filters[2], bn=bn)
        self.Up_conv4 = conv_block(filters[3], filters[2], bn=bn)

        self.Up3 = up_conv(filters[2], filters[1], bn=bn)
        self.Up_conv3 = conv_block(filters[2], filters[1], bn=bn)

        self.Up2 = up_conv(filters[1], filters[0], bn=bn)
        self.Up_conv2 = conv_block(filters[1], filters[0], bn=bn)

        self.Conv = nn.Sequential(
            nn.Conv2d(filters[0], filters[0], kernel_size=5, stride=1, padding=2),
            bn(filters[0]),
            nn.ReLU(inplace=True),
            nn.Conv2d(filters[0], filters[0], kernel_size=1),
            bn(filters[0]),
            nn.ReLU(inplace=True),
            nn.Conv2d(filters[0], out_ch, kernel_size=1),
        )

        # self.active = torch.nn.Softmax(dim=1)
        self.active = nn.Sequential(
            nn.Tanh(),
            nn.ReLU()
        )

    def forward(self, x):
        e1 = self.Conv1(x)

        e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e2)

        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)

        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)

        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5)

        d5 = self.Up5(e5)
        d5 = torch.cat((e4, d5), dim=1)

        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        d4 = torch.cat((e3, d4), dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((e2, d3), dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((e1, d2), dim=1)
        d2 = self.Up_conv2(d2)

        d1 = self.Conv(d2)

        d1 = self.active(d1)

        return d1


class DeepLab_att_begining(nn.Module):
    def __init__(self, backbone, in_ch, num_classes, output_stride=16, sync_bn=True, freeze_bn=False):
        super().__init__()
        if backbone == 'drn':
            output_stride = 8

        if sync_bn == True:
            BatchNorm = SynchronizedBatchNorm2d
        else:
            BatchNorm = nn.BatchNorm2d
        
        self.att = Head(in_ch=in_ch, out_ch=16, n_filters=16, bn=BatchNorm)  # bs*16*H*W
        self.backbone = build_backbone(in_ch, backbone, output_stride, BatchNorm)
        self.aspp = build_aspp(backbone, output_stride, BatchNorm)
        self.decoder = build_decoder(num_classes, backbone, BatchNorm, modified=True)

        self.freeze_bn = freeze_bn

    def forward(self, x):
        att = self.att(x)  # bs*16*H*W
        att_l = att.mean()

        input = (x, att)
        x_1, x_2, x_4, x_8, x_16 = self.backbone(input)
        x = self.aspp(x_16)
        x = self.decoder(x_1, x_2, x_4, x_8, x_16)
        # x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True)

        return x, att_l

    def freeze_bn(self):
        for m in self.modules():
            if isinstance(m, SynchronizedBatchNorm2d):
                m.eval()
            elif isinstance(m, nn.BatchNorm2d):
                m.eval()

    def get_1x_lr_params(self):
        modules = [self.backbone]
        for i in range(len(modules)):
            for m in modules[i].named_modules():
                if self.freeze_bn:
                    if isinstance(m[1], nn.Conv2d):
                        for p in m[1].parameters():
                            if p.requires_grad:
                                yield p
                else:
                    if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
                            or isinstance(m[1], nn.BatchNorm2d):
                        for p in m[1].parameters():
                            if p.requires_grad:
                                yield p

    def get_10x_lr_params(self):
        modules = [self.aspp, self.decoder]
        for i in range(len(modules)):
            for m in modules[i].named_modules():
                if self.freeze_bn:
                    if isinstance(m[1], nn.Conv2d):
                        for p in m[1].parameters():
                            if p.requires_grad:
                                yield p
                else:
                    if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
                            or isinstance(m[1], nn.BatchNorm2d):
                        for p in m[1].parameters():
                            if p.requires_grad:
                                yield p

if __name__ == "__main__":
    model = DeepLab_att_pix(backbone='mobilenet', output_stride=8)
    model.eval()
    input = torch.rand(1, 3, 513, 513)
    output = model(input)
    print(output.size())


