from pathlib import Path

import torch
import torch.nn as nn
from climategan.deeplab.deeplab_v2 import DeepLabV2Decoder
from climategan.deeplab.deeplab_v3 import DeepLabV3Decoder
from climategan.deeplab.mobilenet_v3 import MobileNetV2
from climategan.deeplab.resnet101_v3 import ResNet101
from climategan.deeplab.resnetmulti_v2 import ResNetMulti


def create_encoder(opts, no_init=False, verbose=0):
    if opts.gen.encoder.architecture == "deeplabv2":
        if verbose > 0:
            print("  - Add Deeplabv2 Encoder")
        return DeeplabV2Encoder(opts, no_init, verbose)
    elif opts.gen.encoder.architecture == "deeplabv3":
        if verbose > 0:
            backone = opts.gen.deeplabv3.backbone
            print("  - Add Deeplabv3 ({}) Encoder".format(backone))
        return build_v3_backbone(opts, no_init)
    else:
        raise NotImplementedError(
            "Unknown encoder: {}".format(opts.gen.encoder.architecture)
        )


def create_segmentation_decoder(opts, no_init=False, verbose=0):
    if opts.gen.s.architecture == "deeplabv2":
        if verbose > 0:
            print("  - Add DeepLabV2Decoder")
        return DeepLabV2Decoder(opts)
    elif opts.gen.s.architecture == "deeplabv3":
        if verbose > 0:
            print("  - Add DeepLabV3Decoder")
        return DeepLabV3Decoder(opts, no_init)
    else:
        raise NotImplementedError(
            "Unknown Segmentation architecture: {}".format(opts.gen.s.architecture)
        )


def build_v3_backbone(opts, no_init, verbose=0):
    backbone = opts.gen.deeplabv3.backbone
    output_stride = opts.gen.deeplabv3.output_stride
    if backbone == "resnet":
        resnet = ResNet101(
            output_stride=output_stride,
            BatchNorm=nn.BatchNorm2d,
            verbose=verbose,
            no_init=no_init,
        )
        if not no_init:
            if opts.gen.deeplabv3.backbone == "resnet":
                if not Path(opts.gen.deeplabv3.pretrained_model.resnet).exists():
                    return resnet

                std = torch.load(opts.gen.deeplabv3.pretrained_model.resnet)
                resnet.load_state_dict(
                    {
                        k.replace("backbone.", ""): v
                        for k, v in std.items()
                        if k.startswith("backbone.")
                    }
                )
                print(
                    "    - Loaded pre-trained DeepLabv3+ Resnet101 Backbone as Encoder"
                )
        return resnet

    elif opts.gen.deeplabv3.backbone == "mobilenet":
        assert Path(opts.gen.deeplabv3.pretrained_model.mobilenet).exists()
        mobilenet = MobileNetV2(
            no_init=no_init,
            pretrained_path=opts.gen.deeplabv3.pretrained_model.mobilenet,
        )
        print("    - Loaded pre-trained DeepLabv3+ MobileNetV2 Backbone as Encoder")
        return mobilenet

    else:
        raise NotImplementedError("Unknown backbone in " + str(opts.gen.deeplabv3))


class DeeplabV2Encoder(nn.Module):
    def __init__(self, opts, no_init=False, verbose=0):
        """Deeplab architecture encoder
        """
        super().__init__()

        self.model = ResNetMulti(opts.gen.deeplabv2.nblocks, opts.gen.encoder.n_res)
        if opts.gen.deeplabv2.use_pretrained and not no_init:
            saved_state_dict = torch.load(opts.gen.deeplabv2.pretrained_model)
            new_params = self.model.state_dict().copy()
            for i in saved_state_dict:
                i_parts = i.split(".")
                if not i_parts[1] in ["layer5", "resblock"]:
                    new_params[".".join(i_parts[1:])] = saved_state_dict[i]
            self.model.load_state_dict(new_params)
            if verbose > 0:
                print("    - Loaded pretrained weights")

    def forward(self, x):
        return self.model(x)
