import torch
from torch import nn
from models.deeplab.deeplab_ori.backbone.resnet import ResNet101_multiscale,ResNet50_multiscale,Bottleneck
from models.shared import conv_block,up_conv
from models.layers import ASPP,PPM,Aux_Module
from torch.nn import BatchNorm2d
import torch.nn.functional as F
from .shared import conv_block




class FPN_Decoder(nn.Module):
    def __init__(self, num_classes, high_level_inplanes, low_level_inplanes,low_dim=48,out_dim=256,rm_BN=False,dropout=0.5, BatchNorm = nn.BatchNorm2d):
        super().__init__()
    
        self.conv1 = nn.Conv2d(low_level_inplanes, low_dim, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = BatchNorm(low_dim)
        self.relu = nn.ReLU()
                                    
        self.last_conv = nn.Sequential(nn.Conv2d(high_level_inplanes+low_dim, out_dim, kernel_size=3, stride=1, padding=1, bias=False),
                                       BatchNorm(out_dim),
                                       nn.ReLU(),
                                       )

        self._init_weight()

    def forward(self, x_high_level, low_level_feat):
        low_level_feat = self.conv1(low_level_feat)
        low_level_feat = self.bn1(low_level_feat)
        low_level_feat = self.relu(low_level_feat)

        x_high_level = F.interpolate(x_high_level, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)
        x = torch.cat((x_high_level, low_level_feat), dim=1)
        mid = self.last_conv(x)
        return mid

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            # elif isinstance(m, nn.BatchNorm2d):
            #     m.weight.data.fill_(1)
            #     m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()


class Up_sample_k(nn.Module):
    def __init__(self,in_ch,out_ch,k,**kwargs):
        super().__init__()

        self.upsamlek = [] 
        if k ==0:
            self.upsamlek.extend([
                    nn.Conv2d(in_ch, out_ch, kernel_size=1),
                    nn.BatchNorm2d(out_ch),  #nn.GroupNorm(num_groups, num_channels)
                    nn.ReLU()])
        else:
            for i in range(k):
                cur = in_ch if i ==0 else out_ch
                self.upsamlek.extend([
                    nn.Conv2d(cur, out_ch, kernel_size=1),
                    nn.BatchNorm2d(out_ch),  #nn.GroupNorm(num_groups, num_channels)
                    nn.ReLU(),
                    nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)])
                
        self.upsamlek = nn.Sequential(*self.upsamlek)
    
    
    def forward(self,x):
        return self.upsamlek(x)


class Semantic_FPN_no(nn.Module):
    def __init__(self,in_ch,num_classes,**kwargs):
        super().__init__()
        pre = False
        if in_ch==3:
            pre = True
        self.resnet = ResNet50_multiscale(in_ch,output_stride=16,pretrained=pre)
    
        out_ch = num_classes
        self.ts = nn.Parameter(torch.zeros(3),requires_grad=False)
        self.ts[:2] = 0.5

        filters = self.resnet.out_s #[3,64,256,512,2048]
        self.aspp = ASPP(filters[-1],256,BatchNorm = nn.BatchNorm2d)
        
        self.Pred5 = FPN_Decoder(out_ch,high_level_inplanes = 256,low_level_inplanes = filters[4],low_dim=256,out_dim=256)
        
        self.Pred4 = FPN_Decoder(out_ch,high_level_inplanes = 256,low_level_inplanes = filters[3],low_dim=512,out_dim=128)

        self.Pred3 = FPN_Decoder(out_ch,high_level_inplanes = 128,low_level_inplanes = filters[2],low_dim=256,out_dim=128)
        
        self.up5 = Up_sample_k(256, 128, 2)

        self.up4 = Up_sample_k(256, 128, 1)

        self.up3 = Up_sample_k(256, 128, 0)

        self.final = nn.Sequential(nn.Conv2d(128,num_classes,1))
        # self.active = torch.nn.Softmax(dim=1)
    def forward(self, x):
        e1, e2, e3, e4, e5 = self.resnet(x) #3,64,256,512,2048
        
        sig = 5
        aspp_e5 = self.aspp(e5)

        mid_e5 = self.Pred5(aspp_e5,e5)

        mid_e4 = self.Pred4(mid_e5,e4)

        mid_e3 = self.Pred3(mid_e4,e3)

        o5 = F.interpolate(self.up5(mid_e5),size=mid_e3.shape[2:],mode = 'bilinear',align_corners=True)

        o4 = F.interpolate(self.up4(mid_e4),size=mid_e3.shape[2:],mode = 'bilinear',align_corners=True)

        o3 = self.up3(mid_e3)

        o = o5+o4+o3

        o = self.final(o)
        return o #, pred2, pred1




class Semantic_FPN_boost(nn.Module):
    def __init__(self,in_ch,num_classes,**kwargs):
        super().__init__()
        pre = False
        if in_ch==3:
            pre = True
        self.resnet = ResNet50_multiscale(in_ch,output_stride=16,pretrained=pre)
    
        out_ch = num_classes
        self.ts = nn.Parameter(torch.zeros(3),requires_grad=False)
        self.ts[:2] = 0.5

        filters = self.resnet.out_s #[3,64,256,512,2048]
        self.aspp = ASPP(filters[-1],256,BatchNorm = nn.BatchNorm2d)
        
        self.decoder5 = FPN_Decoder(out_ch,high_level_inplanes = 256,low_level_inplanes = filters[4],low_dim=256,out_dim=256)
        
        self.decoder4 = FPN_Decoder(out_ch,high_level_inplanes = 256,low_level_inplanes = filters[3],low_dim=512,out_dim=128)

        self.decoder3 = FPN_Decoder(out_ch,high_level_inplanes = 128,low_level_inplanes = filters[2],low_dim=256,out_dim=128)

        self.up5 = Up_sample_k(256, 128, 2)

        self.up4 = Up_sample_k(128, 128, 1)

        self.up3 = Up_sample_k(128, 128, 0)

        self.pred5 = nn.Sequential(
                                    nn.Dropout(0.1),
                                    nn.Conv2d(128, num_classes, kernel_size=1, stride=1)
                                    )

        self.pred4 = nn.Sequential(
                                    nn.Dropout(0.1),
                                    nn.Conv2d(128, num_classes, kernel_size=1, stride=1)
                                    )

        self.pred3 = nn.Sequential(
                                    nn.Dropout(0.1),
                                    nn.Conv2d(128, num_classes, kernel_size=1, stride=1)
                                    )

        # self.active = torch.nn.Softmax(dim=1)
    def forward(self, x):
        e1, e2, e3, e4, e5 = self.resnet(x) #3,64,256,512,2048
        
        sig = 5
        aspp_e5 = self.aspp(e5)

        mid_e5 = self.decoder5(aspp_e5,e5)

        mid_e4 = self.decoder4(mid_e5,e4)

        mid_e3 = self.decoder3(mid_e4,e3)

        o5 = F.interpolate(self.up5(mid_e5),size=mid_e3.shape[2:],mode = 'bilinear',align_corners=True)
        o5 = self.pred5(o5)

        o4 = F.interpolate(self.up4(mid_e4),size=mid_e3.shape[2:],mode = 'bilinear',align_corners=True)
        o4 = self.pred4(o4)

        o3 = self.up3(mid_e3)
        o3 = self.pred3(o3)


        return o5,o4,o3
