import torch
import torch.nn as nn
import torch.nn.functional as F
from .sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
from .aspp import build_aspp,ASPP
from .decoder import build_decoder
from .backbone import build_backbone

class DeepLab_ref(nn.Module):
    def __init__(self, backbone, in_ch, out_ch, output_stride=16, sync_bn=True):
        super(DeepLab_ref, self).__init__()
        if backbone == 'drn':
            output_stride = 8

        if sync_bn == True:
            BatchNorm = SynchronizedBatchNorm2d
        else:
            BatchNorm = nn.BatchNorm2d

        self.backbone = build_backbone(in_ch, backbone, output_stride, BatchNorm)
        self.aspp = ASPP(1024,backbone, output_stride, BatchNorm)
        
        self.Conv_final = nn.Sequential(nn.Conv2d(64+128+256+512+1024+256,256,kernel_size=1,),
                                        BatchNorm(256),nn.ReLU(),
                                        nn.Conv2d(256,out_ch,kernel_size=1,),)

        self.Conv2 = nn.Sequential(nn.Conv2d(512,256,kernel_size=1),
                                    BatchNorm(256),
                                    nn.ReLU())
        self.Conv3 = nn.Sequential(nn.Conv2d(1024,256,kernel_size=1),
                                    BatchNorm(256),
                                    nn.ReLU())
        self.Conv4 = nn.Sequential(nn.Conv2d(2048,256,kernel_size=1),
                                    BatchNorm(256),
                                    nn.ReLU())

    def forward(self, input):
        x0,x1,x2,x3,x4 = self.backbone(input)
        x5 = self.aspp(x4)
        # x = self.decoder(x, low_level_feat)
        x1 = F.interpolate(x1, size=input.size()[2:], mode='bilinear', align_corners=True)
        x2 = F.interpolate(x2, size=input.size()[2:], mode='bilinear', align_corners=True)
        x3 = F.interpolate(x3, size=input.size()[2:], mode='bilinear', align_corners=True)
        x4 = F.interpolate(x4, size=input.size()[2:], mode='bilinear', align_corners=True)
        x5 = F.interpolate(x5, size=input.size()[2:], mode='bilinear', align_corners=True)

        out = self.Conv_final(torch.cat([x0,x1,x2,x3,x4,x5],dim = 1))
        return out

    def freeze_bn(self):
        for m in self.modules():
            if isinstance(m, SynchronizedBatchNorm2d):
                m.eval()
            elif isinstance(m, nn.BatchNorm2d):
                m.eval()

    def get_1x_lr_params(self):
        modules = [self.backbone]
        for i in range(len(modules)):
            for m in modules[i].named_modules():
                if self.freeze_bn:
                    if isinstance(m[1], nn.Conv2d):
                        for p in m[1].parameters():
                            if p.requires_grad:
                                yield p
                else:
                    if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
                            or isinstance(m[1], nn.BatchNorm2d):
                        for p in m[1].parameters():
                            if p.requires_grad:
                                yield p

    def get_10x_lr_params(self):
        modules = [self.aspp, self.decoder]
        for i in range(len(modules)):
            for m in modules[i].named_modules():
                if self.freeze_bn:
                    if isinstance(m[1], nn.Conv2d):
                        for p in m[1].parameters():
                            if p.requires_grad:
                                yield p
                else:
                    if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
                            or isinstance(m[1], nn.BatchNorm2d):
                        for p in m[1].parameters():
                            if p.requires_grad:
                                yield p

if __name__ == "__main__":
    model = DeepLab_ref(backbone='resnet', output_stride=16)
    model.eval()
    input = torch.rand(1, 3, 513, 513)
    output = model(input)
    print(output.size())


