"""
https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/modeling/backbone/resnet.py
"""
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from climategan.deeplab.mobilenet_v3 import SeparableConv2d
from climategan.utils import find_target_size


class _DeepLabHead(nn.Module):
    def __init__(
        self, nclass, c1_channels=256, c4_channels=2048, norm_layer=nn.BatchNorm2d
    ):
        super().__init__()
        last_channels = c4_channels
        # self.c1_block = _ConvBNReLU(c1_channels, 48, 1, norm_layer=norm_layer)
        # last_channels += 48
        self.block = nn.Sequential(
            SeparableConv2d(
                last_channels, 256, 3, norm_layer=norm_layer, relu_first=False
            ),
            SeparableConv2d(256, 256, 3, norm_layer=norm_layer, relu_first=False),
            nn.Conv2d(256, nclass, 1),
        )

    def forward(self, x, c1=None):
        return self.block(x)


class ConvBNReLU(nn.Module):
    """
    https://github.com/CoinCheung/DeepLab-v3-plus-cityscapes/blob/master/models/deeplabv3plus.py
    """

    def __init__(
        self, in_chan, out_chan, ks=3, stride=1, padding=1, dilation=1, *args, **kwargs
    ):
        super().__init__()
        self.conv = nn.Conv2d(
            in_chan,
            out_chan,
            kernel_size=ks,
            stride=stride,
            padding=padding,
            dilation=dilation,
            bias=True,
        )
        self.bn = nn.BatchNorm2d(out_chan)
        self.init_weight()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if ly.bias is not None:
                    nn.init.constant_(ly.bias, 0)


class ASPPv3Plus(nn.Module):
    """
    https://github.com/CoinCheung/DeepLab-v3-plus-cityscapes/blob/master/models/deeplabv3plus.py
    """

    def __init__(self, backbone, no_init):
        super().__init__()

        if backbone == "mobilenet":
            in_chan = 320
        else:
            in_chan = 2048

        self.with_gp = False
        self.conv1 = ConvBNReLU(in_chan, 256, ks=1, dilation=1, padding=0)
        self.conv2 = ConvBNReLU(in_chan, 256, ks=3, dilation=6, padding=6)
        self.conv3 = ConvBNReLU(in_chan, 256, ks=3, dilation=12, padding=12)
        self.conv4 = ConvBNReLU(in_chan, 256, ks=3, dilation=18, padding=18)
        if self.with_gp:
            self.avg = nn.AdaptiveAvgPool2d((1, 1))
            self.conv1x1 = ConvBNReLU(in_chan, 256, ks=1)
            self.conv_out = ConvBNReLU(256 * 5, 256, ks=1)
        else:
            self.conv_out = ConvBNReLU(256 * 4, 256, ks=1)

        if not no_init:
            self.init_weight()

    def forward(self, x):
        H, W = x.size()[2:]
        feat1 = self.conv1(x)
        feat2 = self.conv2(x)
        feat3 = self.conv3(x)
        feat4 = self.conv4(x)
        if self.with_gp:
            avg = self.avg(x)
            feat5 = self.conv1x1(avg)
            feat5 = F.interpolate(feat5, (H, W), mode="bilinear", align_corners=True)
            feat = torch.cat([feat1, feat2, feat3, feat4, feat5], 1)
        else:
            feat = torch.cat([feat1, feat2, feat3, feat4], 1)
        feat = self.conv_out(feat)
        return feat

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if ly.bias is not None:
                    nn.init.constant_(ly.bias, 0)


class Decoder(nn.Module):
    """
    https://github.com/CoinCheung/DeepLab-v3-plus-cityscapes/blob/master/models/deeplabv3plus.py
    """

    def __init__(self, n_classes):
        super(Decoder, self).__init__()
        self.conv_low = ConvBNReLU(256, 48, ks=1, padding=0)
        self.conv_cat = nn.Sequential(
            ConvBNReLU(304, 256, ks=3, padding=1), ConvBNReLU(256, 256, ks=3, padding=1)
        )
        self.conv_out = nn.Conv2d(256, n_classes, kernel_size=1, bias=False)

    def forward(self, feat_low, feat_aspp):
        H, W = feat_low.size()[2:]
        feat_low = self.conv_low(feat_low)
        feat_aspp_up = F.interpolate(
            feat_aspp, (H, W), mode="bilinear", align_corners=True
        )
        feat_cat = torch.cat([feat_low, feat_aspp_up], dim=1)
        feat_out = self.conv_cat(feat_cat)
        logits = self.conv_out(feat_out)
        return logits


"""
https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/modeling/deeplab.py
"""


class DeepLabV3Decoder(nn.Module):
    def __init__(self, opts, no_init=False, freeze_bn=False):
        super().__init__()

        num_classes = opts.gen.s.output_dim
        self.backbone = opts.gen.deeplabv3.backbone
        self.use_dada = ("d" in opts.tasks) and opts.gen.s.use_dada

        if self.backbone == "resnet":
            self.aspp = ASPPv3Plus(self.backbone, no_init)
            self.decoder = Decoder(num_classes)

            self.freeze_bn = freeze_bn
        else:
            self.head = _DeepLabHead(num_classes, c4_channels=320)

        self._target_size = find_target_size(opts, "s")
        print(
            "      - {}:  setting target size to {}".format(
                self.__class__.__name__, self._target_size
            )
        )

        if not no_init:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight, mode="fan_out")
                    if m.bias is not None:
                        nn.init.zeros_(m.bias)
                elif isinstance(m, nn.BatchNorm2d):
                    nn.init.ones_(m.weight)
                    nn.init.zeros_(m.bias)
                elif isinstance(m, nn.Linear):
                    nn.init.normal_(m.weight, 0, 0.01)
                    nn.init.zeros_(m.bias)
            self.load_pretrained(opts)

    def load_pretrained(self, opts):
        assert opts.gen.deeplabv3.backbone in {"resnet", "mobilenet"}
        if not Path(opts.gen.deeplabv3.pretrained_model.resnet).exists():
            return
        if opts.gen.deeplabv3.backbone == "resnet":
            std = torch.load(opts.gen.deeplabv3.pretrained_model.resnet)
            self.aspp.load_state_dict(
                {
                    k.replace("aspp.", ""): v
                    for k, v in std.items()
                    if k.startswith("aspp.")
                }
            )
            self.decoder.load_state_dict(
                {
                    k.replace("decoder.", ""): v
                    for k, v in std.items()
                    if k.startswith("decoder.")
                    and not (len(v.shape) > 0 and v.shape[0] == 19)
                },
                strict=False,
            )
            print(
                "- Loaded pre-trained DeepLabv3+ (Resnet) Decoder & ASPP as Seg Decoder"
            )
        else:
            std = torch.load(opts.gen.deeplabv3.pretrained_model.mobilenet)
            self.load_state_dict(
                {
                    k: v
                    for k, v in std.items()
                    if k.startswith("head.")
                    and not (len(v.shape) > 0 and v.shape[0] == 19)
                },
                strict=False,
            )
            print(
                "    - Loaded pre-trained DeepLabv3+ (MobileNetV2) Head as Seg Decoder"
            )

    def set_target_size(self, size):
        """
        Set final interpolation's target size

        Args:
            size (int, list, tuple): target size (h, w). If int, target will be (i, i)
        """
        if isinstance(size, (list, tuple)):
            self._target_size = size[:2]
        else:
            self._target_size = (size, size)

    def forward(self, z, z_depth=None):
        assert isinstance(z, (tuple, list))
        if self._target_size is None:
            error = "self._target_size should be set with self.set_target_size()"
            error += "to interpolate logits to the target seg map's size"
            raise ValueError(error)

        z_high, z_low = z

        if z_depth is not None and self.use_dada:
            z_high = z_high * z_depth

        if self.backbone == "resnet":
            z_high = self.aspp(z_high)
            s = self.decoder(z_high, z_low)
        else:
            s = self.head(z_high)

        s = F.interpolate(
            s, size=self._target_size, mode="bilinear", align_corners=True
        )

        return s

    def freeze_bn(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()
