import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .sync_batchnorm.batchnorm import SynchronizedBatchNorm2d


class Decoder(nn.Module):
    def __init__(self, num_classes, backbone, BatchNorm):
        super().__init__()

        l_feat_ch = 64
        m_feat_ch = 256
        # h_feat_ch = 2048
        h_feat_ch = 256  # aspp

        self.h_conv = nn.Sequential(
            nn.Conv2d(h_feat_ch, 256, 1, bias=False),
            BatchNorm(256),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
        )

        self.m_conv = nn.Sequential(
            nn.Conv2d(m_feat_ch, 256, 1, bias=False),
            BatchNorm(256),
            nn.ReLU(inplace=True)
        )

        self.hm_conv = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1, bias=False),
            BatchNorm(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            BatchNorm(256),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
        ) 

        self.l_conv = nn.Sequential(
            nn.Conv2d(l_feat_ch, 48, 1, bias=False),
            BatchNorm(48),
            nn.ReLU(inplace=True)
        )

        self.hml_conv = nn.Sequential(
            nn.Conv2d(256+48, 256, kernel_size=3, padding=1, bias=False),
            BatchNorm(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            BatchNorm(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, kernel_size=1, bias=False),
        )

        self.decatt = nn.parameter.Parameter(torch.ones(1).float())
        self.att_gate = nn.Sequential(nn.Tanh(), nn.ReLU())


        self._init_weight()

    def forward(self, l_feat, m_feat, h_feat):

        att = self.att_gate(self.decatt)

        h_feat = self.h_conv(h_feat)
        m_feat = self.m_conv(m_feat)
        l_feat = self.l_conv(l_feat) * att

        hm_feat = self.hm_conv(torch.cat((h_feat, m_feat), dim=1))

        out = self.hml_conv(torch.cat((hm_feat, l_feat), dim=1))

        return out

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, SynchronizedBatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()


def build_decoder(num_classes, backbone, BatchNorm):
    return Decoder(num_classes, backbone, BatchNorm)
