from typing import Tuple

import numpy as np
import scipy.linalg
import torch
import cvxpy
import matplotlib.pyplot as plt

import tools
import convolution_inversion


def np_to_torch(x: np.ndarray) -> torch.Tensor:
    return torch.from_numpy(x).type(torch.FloatTensor)


def _do_random_test_for_size(num_in_channels: int,
                             num_out_channels: int,
                             output_h: int,
                             output_w: int,
                             kernel_height: int,
                             kernel_width: int) -> None:

    pass


def safe_int(n: float) -> int:
    intn = int(n)
    assert intn == n
    return intn


def g_matrix(kernel_size,
             num_channels,
             x_h,
             x_w,
             y_h,
             y_w) -> np.ndarray:

    stride_height = 2
    stride_width = 2

    stride = (stride_height, stride_width)

    ind_matrix = convolution_inversion.build_ind_matrix(num_channels,
                                                        x_h,
                                                        x_w,
                                                        kernel_size,
                                                        stride)

    w = convolution_inversion.build_avgpool_weights(kernel_size, num_channels) * np.prod(kernel_size)
    lhs_size = num_channels * y_h * y_w

    x_shape = (num_channels, x_h, x_w)
    x_size = int(np.prod(x_shape))
    weight_flat = np.reshape(w, (num_channels, -1))

    implied_weight_matrix = np.zeros((lhs_size, x_size))

    offset = y_h * y_w

    for ir in range(ind_matrix.shape[1]):
        # ir = 0
        corresp_cols = ind_matrix[:, ir]
        for ic in range(num_channels):
            row = (offset * ic) + ir
            implied_weight_matrix[row, corresp_cols] = weight_flat[ic, :]

    g = implied_weight_matrix
    return g


def _get_an_inverse(y: np.ndarray,
                    kernel_size) -> np.ndarray:
    y_channels, y_h, y_w = y.shape

    x_h = safe_int(y_h * kernel_size[0])
    x_w = safe_int(y_w * kernel_size[1])
    x_channels = y_channels

    num_channels = x_channels

    x_size = x_channels * x_h * x_w
    y_size = y_channels * y_h * y_w
    g = g_matrix(kernel_size, num_channels, x_h, x_w, y_h, y_w)

    y_vec = tools.vec(y)

    h_ineq = np.hstack((y_vec, -1 * g))
    h_lin = np.empty((0, h_ineq.shape[1]))

    objective_sense = cvxpy.Minimize
    fall_back_to_vacuous_criterion = False

    prototype1 = tools.build_prototype_from_h_form(h_ineq,
                                                   h_lin,
                                                   objective_sense,
                                                   1,
                                                   fall_back_to_vacuous_criterion)
    if False:
        plt.plot(prototype1)

    x_vec = np.empty((x_size, ))

    an_inverse = np.reshape(x_vec, (x_channels, x_h, x_w))
    return an_inverse


def test_single_inverse():
    kernel_size = (2, 2)
    input_h = 26
    input_w = 28
    in_channels = 2

    x_torch = torch.rand((1, in_channels, input_h, input_w))

    maxpool = torch.nn.MaxPool2d(kernel_size=kernel_size)
    y_torch = maxpool(x_torch)
    y_np = y_torch.detach().numpy()[0, :, :, :]
    the_inverse = _get_an_inverse(y_np, kernel_size)

    x_np = x_torch.detach().numpy()


#
# def test_maxpool_mechanics1():
#     kernel_size = (2, 2)
#     input_h = 26
#     input_w = 28
#     in_channels = 2
#
#     # x_size = in_channels * input_h * input_w
#     # x_torch = torch.arange(1, 1 + x_size).view(1, in_channels, input_h, input_w).float()
#     x_torch = torch.rand((1, in_channels, input_h, input_w))
#     x_np = x_torch.detach().numpy()
#     x_np_vec = tools.vec(x_np)
#
#     maxpool = torch.nn.MaxPool2d(kernel_size=kernel_size)
#     x_maxpooled_torch = maxpool(x_torch)
#     x_maxpooled_np = x_maxpooled_torch.detach().numpy()
#
#     x_maxpooled_np_vec = tools.vec(x_maxpooled_np)
#
#     output_h = safe_int(input_h / kernel_size[0])
#     output_w = safe_int(input_w / kernel_size[1])
#
#     input_size = in_channels * input_h * input_w
#     output_size = in_channels * output_h * output_w
#
#     k = np.prod(kernel_size)
#     g = np.kron(np.eye(output_size), np.ones((k, 1)))
#     y = x_maxpooled_np
#     x = x_np
#
#     n = input_size
#
#     # num_constr = 10
#     unit_cube_h = tools.unit_cube_h_repr(n)
#     a = -1 * unit_cube_h[:, 1:]
#     b = tools.vec(unit_cube_h[:, 0])
#     nc = a.shape[0]
#     # upper = g @ x_maxpooled_np_vec
#
#     to_vstack = (np.hstack((a, np.zeros((nc, output_size)))),
#                  np.hstack((np.eye(input_size), -1 * g)))
#     stacked_weight = np.vstack(to_vstack)  # (nc + 96) x (96 + 24)
#     stacked_limit = np.vstack((b, np.zeros((input_size, 1))))
#
#     stacked_h = np.hstack((stacked_limit, -1 * stacked_weight))
#     stacked_h_lin = np.empty((0, stacked_h.shape[1]))
#     # stacked_v = tools.h_to_v(stacked_h, stacked_h_lin)
#
#     c = 0
#     assert x_maxpooled_torch[0, c, 0, 0] == torch.max(x_torch[0, c, :2, :2])
#
#     c = 1
#     assert x_maxpooled_torch[0, c, 0, 0] == torch.max(x_torch[0, c, :2, :2])

"""
Basic idea, want to find all x such that there exists a y such that 
y = maxpool(x, (2, 2)) 
and 
y in P [iff Ay <= b]

e.g. y is 24 x 1, 
     x is 96 x 1
     A is nc X 24
     b is nc x 1
          
Now,

y = maxpool(x, (2, 2)) iff
vec(x) <= G @ vec(y)
 
where G = [1; 1; 1; 1] kron I

So, we have A @ vec(y) <= b
            vec(x) <= G @ vec(y)
iff

[[A, 0]; [I, -G]] @ [y; x] <= [b; 0]

Note: We do not actually care about y, but we have to have a system in it:
"""

# def _build_maxpool_indices() -> np.ndarray:
#     pass


if __name__ == "__main__":
    # test_maxpool_mechanics1()
    test_single_inverse()
