import numpy as np
import torch
import torch.nn as nn


class Downsampler(nn.Module):
    """
    http://www.realitypixels.com/turk/computergraphics/ResamplingFilters.pdf
    """

    def __init__(
        self,
        n_planes,
        factor,
        kernel_type,
        phase=0,
        kernel_width=None,
        support=None,
        sigma=None,
        preserve_size=False,
    ):
        super().__init__()

        assert phase in [0, 0.5], "phase should be 0 or 0.5"

        if kernel_type == "lanczos2":
            support = 2
            kernel_width = 4 * factor + 1
            kernel_type_ = "lanczos"

        elif kernel_type == "lanczos3":
            support = 3
            kernel_width = 6 * factor + 1
            kernel_type_ = "lanczos"

        elif kernel_type == "gauss12":
            kernel_width = 7
            sigma = 1 / 2
            kernel_type_ = "gauss"

        elif kernel_type == "gauss1sq2":
            kernel_width = 9
            sigma = 1.0 / np.sqrt(2)
            kernel_type_ = "gauss"

        elif kernel_type in ["lanczos", "gauss", "box"]:
            kernel_type_ = kernel_type

        else:
            raise AssertionError("wrong name kernel")

        # note that `kernel width` will be different to actual size for phase = 1/2
        self.kernel = get_kernel(
            factor, kernel_type_, phase, kernel_width, support=support, sigma=sigma
        )

        downsampler = nn.Conv2d(
            n_planes, n_planes, kernel_size=self.kernel.shape, stride=factor, padding=0
        )
        downsampler.weight.data[:] = 0
        downsampler.bias.data[:] = 0

        kernel_torch = torch.from_numpy(self.kernel)
        for i in range(n_planes):
            downsampler.weight.data[i, i] = kernel_torch

        self.downsampler_ = downsampler

        if preserve_size:
            if self.kernel.shape[0] % 2 == 1:
                pad = int((self.kernel.shape[0] - 1) / 2.0)
            else:
                pad = int((self.kernel.shape[0] - factor) / 2.0)

            self.padding = nn.ReplicationPad2d(pad)

        self.preserve_size = preserve_size

    def forward(self, input):
        if self.preserve_size:
            x = self.padding(input)
        else:
            x = input
        self.x = x
        return self.downsampler_(x)


def get_kernel(factor, kernel_type, phase, kernel_width, support=None, sigma=None):
    assert kernel_type in ["lanczos", "gauss", "box"]

    # factor  = float(factor)
    if phase == 0.5 and kernel_type != "box":
        kernel = np.zeros([kernel_width - 1, kernel_width - 1])
    else:
        kernel = np.zeros([kernel_width, kernel_width])

    if kernel_type == "box":
        assert phase == 0.5, "Box filter is always half-phased"
        kernel[:] = 1.0 / (kernel_width * kernel_width)

    elif kernel_type == "gauss":
        assert sigma, "sigma is not specified"
        assert phase != 0.5, "phase 1/2 for gauss not implemented"

        center = (kernel_width + 1.0) / 2.0
        print(center, kernel_width)
        sigma_sq = sigma * sigma

        for i in range(1, kernel.shape[0] + 1):
            for j in range(1, kernel.shape[1] + 1):
                di = (i - center) / 2.0
                dj = (j - center) / 2.0
                kernel[i - 1][j - 1] = np.exp(-(di * di + dj * dj) / (2 * sigma_sq))
                kernel[i - 1][j - 1] = kernel[i - 1][j - 1] / (2.0 * np.pi * sigma_sq)
    elif kernel_type == "lanczos":
        assert support, "support is not specified"
        center = (kernel_width + 1) / 2.0

        for i in range(1, kernel.shape[0] + 1):
            for j in range(1, kernel.shape[1] + 1):
                if phase == 0.5:
                    di = abs(i + 0.5 - center) / factor
                    dj = abs(j + 0.5 - center) / factor
                else:
                    di = abs(i - center) / factor
                    dj = abs(j - center) / factor

                np.pi * np.pi

                val = 1
                if di != 0:
                    val = (
                        val
                        * support
                        * np.sin(np.pi * di)
                        * np.sin(np.pi * di / support)
                    )
                    val = val / (np.pi * np.pi * di * di)

                if dj != 0:
                    val = (
                        val
                        * support
                        * np.sin(np.pi * dj)
                        * np.sin(np.pi * dj / support)
                    )
                    val = val / (np.pi * np.pi * dj * dj)

                kernel[i - 1][j - 1] = val

    else:
        raise AssertionError("wrong method name")

    kernel /= kernel.sum()

    return kernel
