from typing import Tuple

import numpy as np
import torch

import convolution_inversion


def np_to_torch(x: np.ndarray) -> torch.Tensor:
    return torch.from_numpy(x).type(torch.FloatTensor)


def _generate_data_to_test_convolution_inversion(in_channels: int,
                                                 out_channels: int,
                                                 output_h: int,
                                                 output_w: int,
                                                 kernel_width: int,
                                                 kernel_height: int) -> Tuple[np.ndarray,
                                                                            np.ndarray,
                                                                            np.ndarray,
                                                                            np.ndarray]:
    kernel_size = (kernel_height, kernel_width)
    input_h = output_h + kernel_height - 1
    input_w = output_w + kernel_width - 1

    bias = torch.rand([out_channels])
    weight = torch.rand((out_channels, in_channels, kernel_height, kernel_width))

    conv = torch.nn.Conv2d(in_channels=in_channels,
                           out_channels=out_channels,
                           kernel_size=kernel_size)

    conv.weight.data = weight
    conv.bias.data = bias

    # x_size = in_channels * input_h * input_w
    # x_torch = torch.arange(1, 1 + x_size).view(1, in_channels, input_h, input_w).float()
    x_torch = torch.rand((1, in_channels, input_h, input_w))

    convolved = conv(x_torch)
    c = convolved[0, :, :, :].detach().numpy()
    w = weight.detach().numpy()
    b = bias.detach().numpy()
    x = x_torch.detach().numpy()
    return c, w, b, x


def _output_size_after_convolution(image_dim: int,
                                  n_padding: int,
                                  kernel_size: int,
                                  stride: int) -> int:
    sz = int((image_dim - kernel_size + 2 * n_padding) / stride + 1)
    return sz


def test_convolution_mechanics1():
    # this and the next tests are adapted from:
    # https://gist.github.com/hsm207/7bfbe524bfd9b60d1a9e209759064180
    kernel_size = 2
    x = torch.arange(1, 17).view(-1, 1, 4, 4).float()
    w = torch.arange(1, 5).view(-1, 1, 2, 2).float()

    conv = torch.nn.Conv2d(in_channels=1,
                           out_channels=1,
                           kernel_size=kernel_size,
                           stride=1)

    conv.weight.data = w
    conv.bias.data = torch.zeros([1])

    n_output = _output_size_after_convolution(image_dim=4,
                                              kernel_size=kernel_size,
                                              n_padding=0,
                                              stride=1)
    convolved = conv(x)
    # im2col
    unfold = torch.nn.Unfold(kernel_size=kernel_size,
                             padding=0,
                             stride=1)
    # col2im
    fold = torch.nn.Fold(output_size=(4, 4),
                         kernel_size=kernel_size,
                         padding=0,
                         stride=1)

    x_unfolded = unfold(x)

    input_ones = torch.ones(x.shape, dtype=x.dtype)
    divisor = fold(unfold(input_ones))
    x_refolded = fold(x_unfolded) / divisor

    assert torch.equal(x, x_refolded)

    w_unfolded = unfold(w).transpose(2, 1)

    convolved2_flat = torch.matmul(w_unfolded, x_unfolded)
    convolved2 = convolved2_flat.view(-1, conv.out_channels, n_output, n_output)

    assert torch.equal(convolved2, convolved)


def test_convolution_mechanics2():
    in_channels = 3
    out_channels = 2
    output_h = 3
    output_w = 3

    conv = torch.nn.Conv2d(in_channels=in_channels,
                           out_channels=2,
                           kernel_size=2,
                           stride=1)
    w = torch.arange(1, 25).view(-1, conv.in_channels, 2, 2).float()
    x = torch.arange(1, 49).view(-1, conv.in_channels, 4, 4).float()

    conv.weight.data = w
    conv.bias.data = torch.zeros([conv.out_channels])

    result_conv2d_torch = conv(x)

    unfold = torch.nn.Unfold(kernel_size=2, padding=0, stride=1)
    x_unfold = unfold(x)
    w_unfold = unfold(w).transpose(2, 1)

    result_conv2d_matmul_flat = torch.matmul(w_unfold, x_unfold)

    result_conv2d_matmul = result_conv2d_matmul_flat.view(-1,
                                                          out_channels,
                                                          output_h,
                                                          output_w)

    np.testing.assert_allclose(result_conv2d_torch.detach().numpy(),
                               result_conv2d_matmul.detach().numpy())


def _do_random_test_for_size(num_in_channels: int,
                             num_out_channels: int,
                             output_h: int,
                             output_w: int,
                             kernel_height: int,
                             kernel_width: int) -> None:

    c, w, b, x = _generate_data_to_test_convolution_inversion(num_in_channels,
                                                              num_out_channels,
                                                              output_h,
                                                              output_w,
                                                              kernel_height,
                                                              kernel_width)
    particular_solution, \
    nullspace_basis = convolution_inversion.conv2d_inversion_kernel(c, w, b)
    nullspace_dim = nullspace_basis.shape[-1]
    random_nullspace_weights = np.random.normal(size=(1, 1, nullspace_dim))
    random_nullspace_elem = np.sum(nullspace_basis * random_nullspace_weights, axis=-1)

    conv = torch.nn.Conv2d(in_channels=num_in_channels,
                           out_channels=num_out_channels,
                           kernel_size=(kernel_height, kernel_width))

    conv.weight.data = np_to_torch(w)
    conv.bias.data = np_to_torch(b)

    conved_difference = conv(np_to_torch(x - particular_solution))
    conved_difference_np = conved_difference.detach().numpy()
    assert np.max(np.abs(conved_difference_np)) < 5e-4

    conved_nullspace_elem = conv(np_to_torch(random_nullspace_elem[None, :, :, :]))
    conved_nullspace_elem_np = conved_nullspace_elem.detach().numpy()

    for idx in range(num_out_channels):
        assert np.max(np.abs(conved_nullspace_elem_np[0, idx, :, :] - b[idx])) < 5e-4


def test_convolution_simple_random1():
    num_in_channels = 3
    num_out_channels = 2
    output_h = 3
    output_w = 3
    kernel_height = 2
    kernel_width = 2
    _do_random_test_for_size(num_in_channels,
                             num_out_channels,
                             output_h,
                             output_w,
                             kernel_height,
                             kernel_width)


def test_convolution_simple_random2():
    num_in_channels = 3
    num_out_channels = 2
    output_h = 3
    output_w = 4
    kernel_height = 2
    kernel_width = 2
    _do_random_test_for_size(num_in_channels,
                             num_out_channels,
                             output_h,
                             output_w,
                             kernel_height,
                             kernel_width)


if __name__ == "__main__":
    test_convolution_mechanics1()
    test_convolution_mechanics2()
    test_convolution_simple_random1()
