import torch
from torch import nn
from models.deeplab.deeplab_ori.backbone.resnet import ResNet101_multiscale
from models.shared import conv_block,up_conv
from models.layers import ASPP
from .sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
import torch.nn.functional as F

class DeepLabv3p_boost(nn.Module):
    def __init__(self,in_ch,num_classes):
        super().__init__()
        self.resnet = ResNet101_multiscale(in_ch,output_stride=16,pretrained=True)
        out_ch = num_classes
        
        filters = self.resnet.out_s #[3,64,256,512,2048]
        self.aspp = ASPP(filters[-1],256)
        
        
        
        self.Up5 = up_conv(filters[4]+256, filters[3])
        self.Up_conv5 = conv_block(filters[3]+filters[3], filters[3])
        
        
        self.Up4 = up_conv(filters[3], filters[2])
        self.Up_conv4 = conv_block(filters[2]+filters[2], filters[2])
        
        
        # self.Up3 = up_conv(filters[2], filters[1])
        # self.Up_conv3 = conv_block(filters[1]+filters[1], filters[1])
        
        
        # self.Up2 = up_conv(filters[1], 64)
        # self.conv_e1 = conv_block(filters[0],64)
        # self.Up_conv2 = conv_block(64+64, filters[0])
        
        
        self.Pred5 = nn.Sequential(
            conv_block(filters[4]+256, filters[0],kernel_size=1),
            nn.Conv2d(filters[0], out_ch,
                      kernel_size=1, stride=1, padding=0),
            )
        
        self.Pred4 = nn.Sequential(
            conv_block(filters[3], filters[0]),
            nn.Conv2d(filters[0], out_ch,
                      kernel_size=1, stride=1, padding=0),
            )
        self.Pred3 = nn.Sequential(
            conv_block(filters[2], filters[0]),
            nn.Conv2d(filters[0], out_ch,
                      kernel_size=1, stride=1, padding=0),
            )

        # self.active = torch.nn.Softmax(dim=1)
    def forward(self, x):
        e1, e2, e3, e4, e5 = self.resnet(x) #3,64,256,512,2048
        
        sig = 5
        aspp_e5 = self.aspp(e5)
        e5 = torch.cat([e5,aspp_e5],dim = 1)
        pred5 = self.Pred5(e5)
        
        d5 = self.Up5(e5)
        d5 = torch.cat((e4, d5), dim=1)
        d5 = self.Up_conv5(d5)
        pred4 = self.Pred4(d5)

        d4 = self.Up4(d5)
        d4 = torch.cat((e3, d4), dim=1)
        d4 = self.Up_conv4(d4)
        pred3 = self.Pred3(d4)
        
        # d3 = self.Up3(d4)
        # d3 = torch.cat((e2, d3), dim=1)
        # d3 = self.Up_conv3(d3)
        # pred2 = self.Pred2(d3)
        
        # d2 = self.Up2(d3)
        # e1 = self.conv_e1(e1)
        # d2 = torch.cat((e1, d2), dim=1)
        # d2 = self.Up_conv2(d2)
        # pred1 = self.Pred1(d2)
        # d1 = self.active(out)
        return pred5, pred4, pred3 #, pred2, pred1




class Decoder(nn.Module):
    def __init__(self, num_classes, high_level_inplanes, low_level_inplanes, BatchNorm = SynchronizedBatchNorm2d):
        super(Decoder, self).__init__()
    
        self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False)
        self.bn1 = BatchNorm(48)
        self.relu = nn.ReLU()
        self.last_conv = nn.Sequential(nn.Conv2d(high_level_inplanes+48, 256, kernel_size=3, stride=1, padding=1, bias=False),
                                       BatchNorm(256),
                                       nn.ReLU(),   
                                       )
        self.last_conv2 = nn.Sequential(nn.Dropout(0.5),
                                       nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
                                       BatchNorm(256),
                                       nn.ReLU(),
                                       nn.Dropout(0.1),
                                       nn.Conv2d(256, num_classes, kernel_size=1, stride=1))
        self._init_weight()

    def forward(self, x_high_level, low_level_feat):
        low_level_feat = self.conv1(low_level_feat)
        low_level_feat = self.bn1(low_level_feat)
        low_level_feat = self.relu(low_level_feat)

        x_high_level = F.interpolate(x_high_level, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)
        x = torch.cat((x_high_level, low_level_feat), dim=1)
        mid = self.last_conv(x)
        x = self.last_conv2(mid)
        return mid,x

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, SynchronizedBatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()


class DeepLabv3p_boost_similar(nn.Module):
    def __init__(self,in_ch,num_classes):
        super().__init__()
        self.resnet = ResNet101_multiscale(in_ch,output_stride=16,pretrained=True)
        out_ch = num_classes
        
        filters = self.resnet.out_s #[3,64,256,512,2048]
        self.aspp = ASPP(filters[-1],256)
        
        self.Pred5 = Decoder(out_ch,high_level_inplanes = 256,low_level_inplanes = filters[4])
        
        self.Pred4 = Decoder(out_ch,high_level_inplanes = 256,low_level_inplanes = filters[3])

        self.Pred3 = Decoder(out_ch,high_level_inplanes = 256,low_level_inplanes = filters[2])

        # self.active = torch.nn.Softmax(dim=1)
    def forward(self, x):
        e1, e2, e3, e4, e5 = self.resnet(x) #3,64,256,512,2048
        
        sig = 5
        aspp_e5 = self.aspp(e5)

        mid_e5,pred5 = self.Pred5(aspp_e5,e5)

        mid_e4,pred4 = self.Pred4(mid_e5,e4)

        mid_e3,pred3 = self.Pred3(mid_e4,e3)

        return pred5, pred4, pred3 #, pred2, pred1