from simplecv.interface import CVModule
from simplecv import registry
import torch
import torch.nn as nn
import torch.nn.functional as F

# from .config import config as cfg
from .deeplab_encoder import DeeplabEncoder
from .deeplab_decoder import DeeplabDecoder


@registry.MODEL.register('deeplabv3plus')
class Deeplabv3plus(CVModule):
    def __init__(self, in_ch=3, num_classes=5):
        self.cfg = {}
        self.set_defalut_config()
        self.cfg = self.cfg['model']['params']
        super(Deeplabv3plus, self).__init__(self.cfg)
        self.encoder = DeeplabEncoder(self.cfg['encoder']['params'])
        # self.encoder = registry.MODEL[self.cfg['encoder']['type']](self.cfg['encoder']['params'])

        self.decoder = DeeplabDecoder(self.cfg['decoder']['params'])
        # self.decoder = registry.MODEL[self.cfg['decoder']['type']](self.cfg['decoder']['params'])

        self.cls_pred_conv = nn.Conv2d(self.decoder.decoder_dim, num_classes, 1)
        # self.cls_pred_conv = nn.Conv2d(self.decoder.decoder_dim, self.config['other']['num_classes'], 1)

        # self.loss_fn_dict = {}
        # for name, item in self.config['loss'].items():
        #     self.loss_fn_dict[name] = registry.LOSS[item['type']](**item['params'])

    def forward(self, x, y=None, **kwargs):
        feat_list = self.encoder(x)

        out = self.decoder(*feat_list)

        out = self.cls_pred_conv(out)
        out = F.interpolate(out, scale_factor=self.config['other']['scale_factor'], mode='bilinear', align_corners=True)
        # if self.training:
        #     ret_loss_dict = {}
        #     flat_y_true = torch.reshape(y, (-1,))
        #     flat_y_pred = torch.reshape(out, (-1,))

        #     ret_loss_dict['cls_loss'] = self.loss_fn_dict['cls_loss'](input=flat_y_pred, target=flat_y_true)
        #     return ret_loss_dict
        # else:
        #     if self.config['other']['use_softmax']:
        #         out = torch.softmax(out, dim=1)
        #     else:
        #         out = torch.sigmoid(out)

        return out

    def set_defalut_config(self):
        self.cfg.update(
            model=dict(
                type='deeplabv3plus',
                params=dict(
                    encoder=dict(
                        type='deeplab_encoder',
                        params=dict(
                            resnet_encoder=dict(
                                resnet_type='resnet50',
                                include_conv5=True,
                                batchnoram_trainable=True,
                                # batchnoram_trainable=False,
                                pretrained=True,
                                # pretrained=False,
                                freeze_at=0,
                                output_stride=16,
                            ),
                            aspp=dict(
                                in_channel=2048,
                                aspp_dim=256,
                                atrous_rates=(6, 12, 18),
                                add_image_level=True,
                                use_bias=True,
                                use_batchnorm=False,
                                norm_type='batchnorm'
                            ),
                        )
                    ),
                    decoder=dict(
                        type='deeplab_decoder',
                        params=dict(
                            low_level_feature_channel=256,
                            encoder_feature_channel=256,
                            reduction_dim=48,
                            decoder_dim=256,
                            num_3x3conv=2,
                            scale_factor=4.0,
                            use_bias=True,
                            use_batchnorm=False,
                            norm_fn='batchnorm',
                        ),
                    ),
                    loss=dict(
                        cls_loss=dict(
                            type='cross_entropy',
                            params=dict(
                                ignore_index=-100,
                                reduction='mean',
                            ),
                        )
                    ),
                    other=dict(
                        use_softmax=True,
                        num_classes=16,
                        scale_factor=4
                    )
                )
            ),
            data=dict(
                train=dict(
                    type='segdataloader',
                    params=dict(
                        image_dir='',
                        mask_dir='',
                        filenameList_path=None,
                        training=True,
                        image_format='jpg',
                        mask_format='png',
                        batch_size=2,
                        num_workers=0,
                        pin_memory=True,
                        drop_last=False,
                    ),
                ),
                test=dict(
                    type='segdataloader',
                    params=dict(
                        image_dir='',
                        mask_dir='',
                        filenameList_path=None,
                        training=False,
                        image_format='jpg',
                        mask_format='png',
                        batch_size=1,
                        num_workers=0,
                        pin_memory=True,
                        drop_last=False,
                    ),
                )
            ),
            optimizer=dict(
                type='sgd',
                params=dict(
                    lr=0.01,
                    momentum=0.9,
                    weight_decay=0.0001
                )
            ),
            learning_rate=dict(
                type='multistep',
                params=dict(
                    base_lr=0.01,
                    steps=(60000, 80000),
                    gamma=0.1,
                    warmup_step=500,
                    warmup_init_lr=0.01 / 3, ),
            ),
            train=dict(
                forward_times=1,
                num_iters=90000,
            ),
            test=dict(
            ),
        )
