import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .sync_batchnorm.batchnorm import SynchronizedBatchNorm2d


class Decoder(nn.Module):
    def __init__(self, num_classes, backbone, BatchNorm):
        super(Decoder, self).__init__()
        if backbone == 'resnet' or backbone == 'drn':
            low_level_inplanes = 256
        elif backbone == 'xception':
            low_level_inplanes = 128
        elif backbone == 'mobilenet':
            low_level_inplanes = 24
        elif backbone == "resnet-modified":
            low_level_inplanes = 128*4
        else:
            raise NotImplementedError

        self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False)
        self.bn1 = BatchNorm(48)
        self.relu = nn.ReLU()
        self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False),
                                       BatchNorm(256),
                                       nn.ReLU(),
                                       nn.Dropout(0.5),
                                       nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
                                       BatchNorm(256),
                                       nn.ReLU(),
                                       nn.Dropout(0.1),
                                       nn.Conv2d(256, num_classes, kernel_size=1, stride=1))
        self._init_weight()

    def forward(self, x, low_level_feat):
        low_level_feat = self.conv1(low_level_feat)
        low_level_feat = self.bn1(low_level_feat)
        low_level_feat = self.relu(low_level_feat)

        x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)
        x = torch.cat((x, low_level_feat), dim=1)
        x = self.last_conv(x)

        return x

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, SynchronizedBatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()


class Decoder_modified(nn.Module):
    def __init__(self, num_classes, backbone, BatchNorm):
        super().__init__()
        f_d = [64, 256, 512, 1024, 2048]

        self.up16 = up_conv(in_ch=f_d[-1], out_ch=f_d[-2])
        self.conv8 = conv_block(in_ch=f_d[-1], out_ch=f_d[-2], kernel_size=3)

        self.up8 = up_conv(in_ch=f_d[-2], out_ch=f_d[-3])
        self.conv4 = conv_block(in_ch=f_d[-2], out_ch=f_d[-3], kernel_size=3)

        self.up4 = up_conv(in_ch=f_d[-3], out_ch=f_d[-4])
        self.conv2 = conv_block(in_ch=f_d[-3], out_ch=f_d[-4], kernel_size=3)

        self.up2 = up_conv(in_ch=f_d[-4], out_ch=f_d[-5])
        self.conv1 = conv_block(in_ch=f_d[-5] * 2, out_ch=f_d[-5], kernel_size=3)

        self.conv = nn.Conv2d(f_d[0], num_classes, 3, stride=1, padding=1)

        self._init_weight()

    def forward(self, x_1, x_2, x_4, x_8, x_16):
        d_8 = self.conv8(torch.cat((x_8, self.up16(x_16)), dim=1))
        d_4 = self.conv4(torch.cat((x_4, self.up8(d_8)), dim=1))
        d_2 = self.conv2(torch.cat((x_2, self.up4(d_4)), dim=1))
        d_1 = self.conv1(torch.cat((x_1, self.up2(d_2)), dim=1))
        out = self.conv(d_1)

        return out

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, SynchronizedBatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()


def build_decoder(num_classes, backbone, BatchNorm, modified):
    if modified:
        return Decoder_modified(num_classes, backbone, BatchNorm)
    else:
        return Decoder(num_classes, backbone, BatchNorm)


class conv_block(nn.Module):
    """
    Convolution Block 
    """

    def __init__(self, in_ch, out_ch, kernel_size=3):
        super().__init__()

        padding = kernel_size // 2

        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size,
                      stride=1, padding=padding, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Conv2d(out_ch, out_ch, kernel_size=kernel_size,
                      stride=1, padding=padding, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Conv2d(out_ch, out_ch, kernel_size=kernel_size,
                      stride=1, padding=padding, bias=False)
            )

    def forward(self, x):
        x = self.conv(x)
        return x


class up_conv(nn.Module):
    """
    Up Convolution Block
    """

    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="bilinear",  align_corners=True),
            nn.Conv2d(in_ch, out_ch, kernel_size=3,
                      stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.up(x)
        return x
