#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from collections import OrderedDict
import torch
from modules import PredsegModule, PredsegSequential
from torchvision.models.segmentation import deeplabv3_resnet101

def get_resnet101_deeplab(neighbors=[[0, 1], [1, 0]]):
    model_orig = deeplabv3_resnet101(pretrained=False)
    backbone = model_orig.backbone
    # classifier = model_orig.classifier
    for m in backbone.modules():
        if hasattr(m, 'padding_mode'):
            m.padding_mode = 'reflect'
        if hasattr(m, 'weight'):
            if len(m.weight.shape) > 2:
                torch.nn.init.xavier_normal_(m.weight)
            else:
                torch.nn.init.constant_(m.weight, 1)
        if hasattr(m, 'bias') and not (m.bias is None):
            torch.nn.init.constant_(m.bias, 0)
    layer_dict = torch.nn.ModuleDict(backbone.named_children())
    layer_dict['layer1'] = PredsegModule(
        layer_dict['layer1'], neighbors, 256, normalize_output=False)
    layer_dict['layer2'] = PredsegModule(
        layer_dict['layer2'], neighbors, 512, normalize_output=False)
    layer_dict['layer3'] = PredsegModule(
        layer_dict['layer3'], neighbors, 1024, normalize_output=False)
    layer_dict['layer4'] = PredsegModule(
        layer_dict['layer4'], neighbors, 2048, normalize_output=False)
    model = PredsegSequential(OrderedDict(layer_dict))
    return model


def get_predseg1(neighbors=[[0, 1], [1, 0]], normalize_output=True):
    l1 = PredsegModule(
        torch.nn.Conv2d(3, 3, 3,
                        padding=1, padding_mode='reflect'),
        neighbors, 3, normalize_output=normalize_output)
    downsample1 = torch.nn.Conv2d(3, 3, 1, stride=3)
    l2 = PredsegModule(
        torch.nn.Conv2d(3, 64, 11,
                        padding=5, padding_mode='reflect'),
        neighbors, 64, normalize_output=normalize_output)
    downsample2 = torch.nn.Conv2d(64, 128, 1, stride=2)
    l3 = PredsegModule(
        Block(128, 5, 2, 128),
        neighbors, 128, normalize_output=normalize_output)
    downsample3 = torch.nn.Conv2d(128, 256, 1, stride=2)
    l4 = PredsegModule(
        Block(256, 5, 2, 256),
        neighbors, 256, normalize_output=normalize_output)
    model = PredsegSequential(l1, downsample1, l2, downsample2, l3, downsample3, l4)
    return model


class Block(torch.nn.Module):
    """resnet like building block for predseg1 network
    downsampling is performed outside. This asumes same number of channels
    and resolution for input and output. It implements two convolutions
    separated by a RELU nonlinearity without any normalizations.
    """

    def __init__(self, channels, kernel_size=5, padding=2,
                 bottleneck=None):
        if bottleneck is None:
            bottleneck = channels
        super(Block, self).__init__()
        self.conv1 = torch.nn.Conv2d(channels, bottleneck, kernel_size,
                                     padding=padding, padding_mode='reflect')
        self.relu = torch.nn.ReLU()
        self.conv2 = torch.nn.Conv2d(bottleneck, channels, kernel_size,
                                     padding=padding, padding_mode='reflect')

    def forward(self, x):
        y = self.conv2(self.relu(self.conv1(x)))
        return x + y
