# Taken from https://github.com/nanxuanzhao/Good_transfer/blob/main/code_reconstruction_DIP/skip.py
# With some entries from https://github.com/nanxuanzhao/Good_transfer/blob/main/code_reconstruction_DIP/common.py

import numpy as np
import torch
import torch.nn as nn


def fill_noise(x, noise_type):
    """Fills tensor `x` with noise of type `noise_type`."""
    if noise_type == "u":
        x.uniform_()
    elif noise_type == "n":
        x.normal_()
    else:
        assert False


def get_noise(input_depth, spatial_size, noise_type="u", var=1.0 / 10):
    """Returns a pytorch.Tensor of size (1 x `input_depth` x `spatial_size[0]` x `spatial_size[1]`)
    initialized in a specific way.
    Args:
        input_depth: number of channels in the tensor
        spatial_size: spatial size of the tensor to initialize
        noise_type: 'u' for uniform; 'n' for normal
        var: a factor, a noise will be multiplicated by. Basically it is standard deviation scaler.
    """
    if isinstance(spatial_size, int):
        spatial_size = (spatial_size, spatial_size)

    shape = [1, input_depth, spatial_size[0], spatial_size[1]]
    net_input = torch.zeros(shape)

    fill_noise(net_input, noise_type)
    net_input *= var

    return net_input


class Downsampler(nn.Module):
    """
    http://www.realitypixels.com/turk/computergraphics/ResamplingFilters.pdf
    """

    def __init__(
        self,
        n_planes,
        factor,
        kernel_type,
        phase=0,
        kernel_width=None,
        support=None,
        sigma=None,
        preserve_size=False,
    ):
        super(Downsampler, self).__init__()

        assert phase in [0, 0.5], "phase should be 0 or 0.5"

        if kernel_type == "lanczos2":
            support = 2
            kernel_width = 4 * factor + 1
            kernel_type_ = "lanczos"

        elif kernel_type == "lanczos3":
            support = 3
            kernel_width = 6 * factor + 1
            kernel_type_ = "lanczos"

        elif kernel_type == "gauss12":
            kernel_width = 7
            sigma = 1 / 2
            kernel_type_ = "gauss"

        elif kernel_type == "gauss1sq2":
            kernel_width = 9
            sigma = 1.0 / np.sqrt(2)
            kernel_type_ = "gauss"

        elif kernel_type in ["lanczos", "gauss", "box"]:
            kernel_type_ = kernel_type

        else:
            assert False, "wrong name kernel"

        # note that `kernel width` will be different to actual size for phase = 1/2
        self.kernel = get_kernel(
            factor, kernel_type_, phase, kernel_width, support=support, sigma=sigma
        )

        downsampler = nn.Conv2d(
            n_planes, n_planes, kernel_size=self.kernel.shape, stride=factor, padding=0
        )
        downsampler.weight.data[:] = 0
        downsampler.bias.data[:] = 0

        kernel_torch = torch.from_numpy(self.kernel)
        for i in range(n_planes):
            downsampler.weight.data[i, i] = kernel_torch

        self.downsampler_ = downsampler

        if preserve_size:

            if self.kernel.shape[0] % 2 == 1:
                pad = int((self.kernel.shape[0] - 1) / 2.0)
            else:
                pad = int((self.kernel.shape[0] - factor) / 2.0)

            self.padding = nn.ReplicationPad2d(pad)

        self.preserve_size = preserve_size

    def forward(self, input):
        if self.preserve_size:
            x = self.padding(input)
        else:
            x = input
        self.x = x
        return self.downsampler_(x)


def get_kernel(factor, kernel_type, phase, kernel_width, support=None, sigma=None):
    assert kernel_type in ["lanczos", "gauss", "box"]

    # factor  = float(factor)
    if phase == 0.5 and kernel_type != "box":
        kernel = np.zeros([kernel_width - 1, kernel_width - 1])
    else:
        kernel = np.zeros([kernel_width, kernel_width])

    if kernel_type == "box":
        assert phase == 0.5, "Box filter is always half-phased"
        kernel[:] = 1.0 / (kernel_width * kernel_width)

    elif kernel_type == "gauss":
        assert sigma, "sigma is not specified"
        assert phase != 0.5, "phase 1/2 for gauss not implemented"

        center = (kernel_width + 1.0) / 2.0
        print(center, kernel_width)
        sigma_sq = sigma * sigma

        for i in range(1, kernel.shape[0] + 1):
            for j in range(1, kernel.shape[1] + 1):
                di = (i - center) / 2.0
                dj = (j - center) / 2.0
                kernel[i - 1][j - 1] = np.exp(-(di * di + dj * dj) / (2 * sigma_sq))
                kernel[i - 1][j - 1] = kernel[i - 1][j - 1] / (2.0 * np.pi * sigma_sq)
    elif kernel_type == "lanczos":
        assert support, "support is not specified"
        center = (kernel_width + 1) / 2.0

        for i in range(1, kernel.shape[0] + 1):
            for j in range(1, kernel.shape[1] + 1):

                if phase == 0.5:
                    di = abs(i + 0.5 - center) / factor
                    dj = abs(j + 0.5 - center) / factor
                else:
                    di = abs(i - center) / factor
                    dj = abs(j - center) / factor

                pi_sq = np.pi * np.pi

                val = 1
                if di != 0:
                    val = (
                        val
                        * support
                        * np.sin(np.pi * di)
                        * np.sin(np.pi * di / support)
                    )
                    val = val / (np.pi * np.pi * di * di)

                if dj != 0:
                    val = (
                        val
                        * support
                        * np.sin(np.pi * dj)
                        * np.sin(np.pi * dj / support)
                    )
                    val = val / (np.pi * np.pi * dj * dj)

                kernel[i - 1][j - 1] = val

    else:
        assert False, "wrong method name"

    kernel /= kernel.sum()

    return kernel


class Concat(nn.Module):
    def __init__(self, dim, *args):
        super(Concat, self).__init__()
        self.dim = dim

        for idx, module in enumerate(args):
            self.add_module(str(idx), module)

    def forward(self, input):
        inputs = []
        for module in self._modules.values():
            inputs.append(module(input))

        inputs_shapes2 = [x.shape[2] for x in inputs]
        inputs_shapes3 = [x.shape[3] for x in inputs]

        if np.all(np.array(inputs_shapes2) == min(inputs_shapes2)) and np.all(
            np.array(inputs_shapes3) == min(inputs_shapes3)
        ):
            inputs_ = inputs
        else:
            target_shape2 = min(inputs_shapes2)
            target_shape3 = min(inputs_shapes3)

            inputs_ = []
            for inp in inputs:
                diff2 = (inp.size(2) - target_shape2) // 2
                diff3 = (inp.size(3) - target_shape3) // 2
                inputs_.append(
                    inp[
                        :,
                        :,
                        diff2 : diff2 + target_shape2,
                        diff3 : diff3 + target_shape3,
                    ]
                )

        return torch.cat(inputs_, dim=self.dim)

    def __len__(self):
        return len(self._modules)


class GenNoise(nn.Module):
    def __init__(self, dim2):
        super(GenNoise, self).__init__()
        self.dim2 = dim2

    def forward(self, input):
        a = list(input.size())
        a[1] = self.dim2
        # print (input.data.type())

        b = torch.zeros(a).type_as(input.data)
        b.normal_()

        x = torch.autograd.Variable(b)

        return x


def act(act_fun="LeakyReLU"):
    """
    Either string defining an activation function or module (e.g. nn.ReLU)
    """
    if isinstance(act_fun, str):
        if act_fun == "LeakyReLU":
            return nn.LeakyReLU(0.2, inplace=True)
        elif act_fun == "ELU":
            return nn.ELU()
        elif act_fun == "none":
            return nn.Sequential()
        else:
            assert False
    else:
        return act_fun()


def bn(num_features):
    return nn.BatchNorm2d(num_features)


def conv(
    in_f, out_f, kernel_size, stride=1, bias=True, pad="zero", downsample_mode="stride"
):
    downsampler = None
    if stride != 1 and downsample_mode != "stride":

        if downsample_mode == "avg":
            downsampler = nn.AvgPool2d(stride, stride)
        elif downsample_mode == "max":
            downsampler = nn.MaxPool2d(stride, stride)
        elif downsample_mode in ["lanczos2", "lanczos3"]:
            downsampler = Downsampler(
                n_planes=out_f,
                factor=stride,
                kernel_type=downsample_mode,
                phase=0.5,
                preserve_size=True,
            )
        else:
            assert False

        stride = 1

    padder = None
    to_pad = int((kernel_size - 1) / 2)
    if pad == "reflection":
        padder = nn.ReflectionPad2d(to_pad)
        to_pad = 0

    convolver = nn.Conv2d(in_f, out_f, kernel_size, stride, padding=to_pad, bias=bias)

    layers = filter(lambda x: x is not None, [padder, convolver, downsampler])
    return nn.Sequential(*layers)


def dip_ae(
    num_input_channels=2,
    num_output_channels=3,
    num_channels_down=[16, 32, 64, 128, 128],
    num_channels_up=[16, 32, 64, 128, 128],
    num_channels_skip=[4, 4, 4, 4, 4],
    filter_size_down=3,
    filter_size_up=3,
    filter_skip_size=1,
    need_sigmoid=True,
    need_bias=True,
    pad="zero",
    upsample_mode="nearest",
    downsample_mode="stride",
    act_fun="LeakyReLU",
    need1x1_up=True,
):
    """Assembles encoder-decoder with skip connections.

    Arguments:
        act_fun: Either string 'LeakyReLU|Swish|ELU|none' or module (e.g. nn.ReLU)
        pad (string): zero|reflection (default: 'zero')
        upsample_mode (string): 'nearest|bilinear' (default: 'nearest')
        downsample_mode (string): 'stride|avg|max|lanczos2' (default: 'stride')

    """
    assert len(num_channels_down) == len(num_channels_up) == len(num_channels_skip)

    n_scales = len(num_channels_down)

    if not (isinstance(upsample_mode, list) or isinstance(upsample_mode, tuple)):
        upsample_mode = [upsample_mode] * n_scales

    if not (isinstance(downsample_mode, list) or isinstance(downsample_mode, tuple)):
        downsample_mode = [downsample_mode] * n_scales

    if not (isinstance(filter_size_down, list) or isinstance(filter_size_down, tuple)):
        filter_size_down = [filter_size_down] * n_scales

    if not (isinstance(filter_size_up, list) or isinstance(filter_size_up, tuple)):
        filter_size_up = [filter_size_up] * n_scales

    last_scale = n_scales - 1

    cur_depth = None

    model = nn.Sequential()
    model_tmp = model

    input_depth = num_input_channels
    for i in range(len(num_channels_down)):

        deeper = nn.Sequential()
        skip = nn.Sequential()

        if num_channels_skip[i] != 0:
            model_tmp.append(Concat(1, skip, deeper))
        else:
            model_tmp.append(deeper)

        model_tmp.append(
            bn(
                num_channels_skip[i]
                + (num_channels_up[i + 1] if i < last_scale else num_channels_down[i])
            )
        )

        if num_channels_skip[i] != 0:
            skip.append(
                conv(
                    input_depth,
                    num_channels_skip[i],
                    filter_skip_size,
                    bias=need_bias,
                    pad=pad,
                )
            )
            skip.append(bn(num_channels_skip[i]))
            skip.append(act(act_fun))

        # skip.add(Concat(2, GenNoise(nums_noise[i]), skip_part))

        deeper.append(
            conv(
                input_depth,
                num_channels_down[i],
                filter_size_down[i],
                2,
                bias=need_bias,
                pad=pad,
                downsample_mode=downsample_mode[i],
            )
        )
        deeper.append(bn(num_channels_down[i]))
        deeper.append(act(act_fun))

        deeper.append(
            conv(
                num_channels_down[i],
                num_channels_down[i],
                filter_size_down[i],
                bias=need_bias,
                pad=pad,
            )
        )
        deeper.append(bn(num_channels_down[i]))
        deeper.append(act(act_fun))

        deeper_main = nn.Sequential()

        if i == len(num_channels_down) - 1:
            # The deepest
            k = num_channels_down[i]
        else:
            deeper.append(deeper_main)
            k = num_channels_up[i + 1]

        deeper.append(nn.Upsample(scale_factor=2, mode=upsample_mode[i]))

        model_tmp.append(
            conv(
                num_channels_skip[i] + k,
                num_channels_up[i],
                filter_size_up[i],
                1,
                bias=need_bias,
                pad=pad,
            )
        )
        model_tmp.append(bn(num_channels_up[i]))
        model_tmp.append(act(act_fun))

        if need1x1_up:
            model_tmp.append(
                conv(num_channels_up[i], num_channels_up[i], 1, bias=need_bias, pad=pad)
            )
            model_tmp.append(bn(num_channels_up[i]))
            model_tmp.append(act(act_fun))

        input_depth = num_channels_down[i]
        model_tmp = deeper_main

    model.append(conv(num_channels_up[0], num_output_channels, 1, bias=need_bias, pad=pad))
    if need_sigmoid:
        model.append(nn.Sigmoid())

    return model
