import os
import re
import collections
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.deeplabv3.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
from model.deeplabv3.aspp import build_aspp
from model.deeplabv3.decoder import build_decoder
from model.deeplabv3.backbone import build_backbone

class DeepLab(nn.Module):
    def __init__(self, backbone='resnet', output_stride=16, num_classes=21,
                 sync_bn=False, freeze_bn=False, freeze_backbone = False, pretrained_backbone=True):
        super(DeepLab, self).__init__()
        if backbone == 'drn':
            output_stride = 8

        if sync_bn == True:
            BatchNorm = SynchronizedBatchNorm2d
        else:
            BatchNorm = nn.BatchNorm2d

        self.backbone = build_backbone(backbone, output_stride, BatchNorm, pretrained=pretrained_backbone)
        
        # freeze backbone parameters
        if freeze_backbone:
            for p in self.backbone.parameters():
                p.requires_grad = False
        
        self.aspp = build_aspp(backbone, output_stride, BatchNorm)
        self.decoder = build_decoder(num_classes, backbone, BatchNorm)
        
        print("Total number of paramerters in backbone is {}  ".format(sum(x.numel() for x in self.backbone.parameters())))
        print("Total number of paramerters in aspp is {}  ".format(sum(x.numel() for x in self.aspp.parameters())))
        print("Total number of paramerters in decoder is {}  ".format(sum(x.numel() for x in self.decoder.parameters())))
        
        self.n_channels = 3
        self.n_classes = num_classes
        
        self.freeze_bn = freeze_bn

    def forward(self, input):
        x, low_level_feat = self.backbone(input)
        x = self.aspp(x)
        x = self.decoder(x, low_level_feat)
        x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True)

        return x

    def freeze_bn(self):
        for m in self.modules():
            if isinstance(m, SynchronizedBatchNorm2d):
                m.eval()
            elif isinstance(m, nn.BatchNorm2d):
                m.eval()

    def get_1x_lr_params(self):
        modules = [self.backbone]
        for i in range(len(modules)):
            for m in modules[i].named_modules():
                if self.freeze_bn:
                    if isinstance(m[1], nn.Conv2d):
                        for p in m[1].parameters():
                            if p.requires_grad:
                                yield p
                else:
                    if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
                            or isinstance(m[1], nn.BatchNorm2d):
                        for p in m[1].parameters():
                            if p.requires_grad:
                                yield p

    def get_10x_lr_params(self):
        modules = [self.aspp, self.decoder]
        for i in range(len(modules)):
            for m in modules[i].named_modules():
                if self.freeze_bn:
                    if isinstance(m[1], nn.Conv2d):
                        for p in m[1].parameters():
                            if p.requires_grad:
                                yield p
                else:
                    if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
                            or isinstance(m[1], nn.BatchNorm2d):
                        for p in m[1].parameters():
                            if p.requires_grad:
                                yield p
                                
    def init_backbone(self, backbone_path, freeze = True):
        if backbone_path is not None:
            assert(os.path.isfile(backbone_path) and backbone_path.endswith(".pth"))
            state_dict = torch.load(backbone_path)
            backbone_dict = collections.OrderedDict()
            for param in state_dict:
                if param.startswith("backbone."):
                    new_param = re.match(r'backbone\.(.+)', param).group(1)
                    backbone_dict[new_param] = state_dict[param]
                elif param.startswith("module."):
                    new_param = re.match(r'module\.(.+)', param).group(1)
                    backbone_dict[new_param] = state_dict[param]
            self.backbone.load_state_dict(backbone_dict)
        # freeze backbone
        if freeze:
            for param in self.backbone.parameters():
                param.requires_grad = False

if __name__ == "__main__":
    model = DeepLab(backbone='mobilenet', output_stride=16)
    model.eval()
    input = torch.rand(1, 3, 513, 513)
    output = model(input)
    print(output.size())


