import numpy as np
import torch

torch.manual_seed(0)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


def _relu(x: np.ndarray) -> np.ndarray:
    return np.maximum(x, 0)


def _vec(x: np.ndarray) -> np.ndarray:
    return np.reshape(x, (-1, 1))


def _get_test_data():
    target = torch.LongTensor([1, 1, 2, 2, 0])
    output = torch.FloatTensor(
        [
            [-2.4027e-01, 3.1288e-01, -1.2617e-01],
            [-2.7673e-02, 3.1151e-01, +3.6610e-02],
            [-3.4797e-02, 3.0438e-01, +3.1281e-02],
            [+6.9885e-02, 2.2951e-01, +1.9513e-01],
            [-8.0072e-02, 3.4266e-01, -4.0489e-03],
        ]
    )
    return target, output


def _one_hot_encode(y: np.ndarray,
                    num_classes: int) -> np.ndarray:
    assert 1 == len(y.shape)
    nr = y.size
    classes = np.arange(num_classes)
    assert np.all(np.in1d(y, classes))

    rows = np.arange(nr)
    oh = np.zeros((nr, num_classes))
    oh[rows, y] = 1
    return oh


def _softmax(x: np.ndarray) -> np.ndarray:
    expx = np.exp(x)
    normalisation = np.sum(expx, axis=1, keepdims=True)
    sm = expx / normalisation
    return sm


def _cross_entropy_loss(p: np.ndarray,
                        q: np.ndarray) -> np.ndarray:
    assert p.shape == q.shape
    entropy_dist = -1 * p * np.log(q)
    assert np.all(np.isfinite(entropy_dist))
    xel = np.sum(entropy_dist, axis=1)
    return xel


def test_cross_entropy_loss1():
    target, output = _get_test_data()
    criterion = torch.nn.CrossEntropyLoss()
    loss_pytorch = criterion(output, target)

    x_np = output.numpy()
    y_np = target.numpy()
    y_np_oh = _one_hot_encode(y_np, 3)

    probs = _softmax(x_np)
    xel = _cross_entropy_loss(y_np_oh, probs)

    loss_np = np.mean(xel)
    np.testing.assert_allclose(loss_pytorch.numpy(), loss_np)


def test_multi_margin_loss():
    # https://pytorch.org/docs/stable/nn.html#torch.nn.MultiMarginLoss
    target, output = _get_test_data()

    margin = 1
    p = 1

    criterion = torch.nn.MultiMarginLoss(p=p, margin=margin)
    loss_pytorch = criterion(output, target)

    x_np = output.numpy()
    y_np = target.numpy()

    nr = x_np.shape[0]

    xy = x_np[np.arange(nr), y_np]
    difference_matrix = x_np - _vec(xy)

    pre_activ_loss_matrix = np.where(difference_matrix == 0, 0, difference_matrix + margin)
    loss_matrix = _relu(pre_activ_loss_matrix) ** p
    loss_by_obs = np.mean(loss_matrix, axis=1)
    loss_numpy = np.mean(loss_by_obs)

    np.testing.assert_allclose(loss_numpy,
                               loss_pytorch.numpy(), atol=1e-10)


def test_conv2d_0():
    in_channels = 1
    out_channels = 2
    kernel_size = 2

    x_np = np.ones((1, in_channels, 32, 32))
    x = torch.from_numpy(x_np).type(torch.FloatTensor)

    conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size)
    convolved = conv(x).detach().numpy()
    weight = conv.weight.detach().numpy()
    bias = conv.bias.detach().numpy()

    v0 = np.sum(weight[0, 0, :, :]) + bias[0]
    assert np.all(v0 == convolved[0, 0, :, :])

    v1 = np.sum(weight[1, 0, :, :]) + bias[1]
    assert np.all(v1 == convolved[0, 1, :, :])

    # if False:
    #     from collections import Counter
    #
    #     Counter(np.reshape(convolved, (1, -1)).tolist()[0])
    #
    #     Counter(np.reshape(convolved[0, 0, :, :], (1, -1)).tolist()[0])
    #     Counter(np.reshape(convolved[0, 1, :, :], (1, -1)).tolist()[0])
    #     # Counter(np.reshape(convolved[0, 2, :, :], (1, -1)).tolist()[0])
    #     import matplotlib.pyplot as plt
    #     plt.imshow(convolved[0, 0, :, :])
    #     plt.imshow(convolved[0, 1, :, :])


def test_conv2d_1():
    in_channels = 1
    out_channels = 2
    kernel_size = 2
    x_np = np.random.uniform(size=(1, in_channels, 32, 32))
    x = torch.from_numpy(x_np).type(torch.FloatTensor)

    conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size)
    convolved = conv(x).detach().numpy()
    weight = conv.weight.detach().numpy()
    bias = conv.bias.detach().numpy()

    v0 = np.sum(x_np[0, 0, :2, :2] * weight[0, 0, :, :]) + bias[0]
    assert np.abs(v0 - convolved[0, 0, 0, 0]) < 1e-7

    v1 = np.sum(x_np[0, 0, -2:, -2:] * weight[0, 0, :, :]) + bias[0]
    assert np.abs(v1 - convolved[0, 0, -1, -1]) < 1e-7

    v2 = np.sum(x_np[0, 0, -2:, -2:] * weight[1, 0, :, :]) + bias[1]
    assert np.abs(v2 - convolved[0, 1, -1, -1]) < 1e-7


def test_conv2d_2():
    in_channels = 3
    out_channels = 2
    kernel_size = 2
    x_np = np.random.uniform(size=(1, in_channels, 32, 32))
    x = torch.from_numpy(x_np).type(torch.FloatTensor)

    conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size)
    convolved = conv(x).detach().numpy()
    weight = conv.weight.detach().numpy()
    bias = conv.bias.detach().numpy()

    v0 = np.sum(x_np[0, :, :2, :2] * weight[0, :, :, :]) + bias[0]
    assert np.abs(v0 - convolved[0, 0, 0, 0]) < 1e-7


def test_conv2d_3():
    in_channels = 3
    out_channels = 5
    kernel_size = 2
    x_np = np.random.uniform(size=(1, in_channels, 32, 32))
    x = torch.from_numpy(x_np).type(torch.FloatTensor)

    conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size)
    convolved = conv(x).detach().numpy()
    weight = conv.weight.detach().numpy()
    bias = conv.bias.detach().numpy()

    obs = 0
    out_channel_ind = 1
    v0 = np.sum(x_np[obs, :, :2, :2] * weight[out_channel_ind, :, :, :]) + bias[out_channel_ind]
    assert np.abs(v0 - convolved[obs, out_channel_ind, 0, 0]) < 1e-7

    obs = 0
    out_channel_ind = 1
    v0 = np.sum(x_np[obs, :, 2:4, 2:4] * weight[out_channel_ind, :, :, :]) + bias[out_channel_ind]
    assert np.abs(v0 - convolved[obs, out_channel_ind, 2, 2]) < 1e-7

    obs = 0
    out_channel_ind = 1
    v0 = np.sum(x_np[obs, :, -4:-2, -4:-2] * weight[out_channel_ind, :, :, :]) + bias[out_channel_ind]
    assert np.abs(v0 - convolved[obs, out_channel_ind, -3, -3]) < 1e-7


def vec(x: np.ndarray) -> np.ndarray:
    return np.reshape(x, (-1, 1))


def test_conv2d_inverse():
    if False:
        np.set_printoptions(linewidth=1000)
    in_channels = 1
    out_channels = 1
    kernel_size = 2
    x_np = np.random.uniform(size=(1, in_channels, 32, 32))
    x = torch.from_numpy(x_np).type(torch.FloatTensor)

    conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size)
    convolved = conv(x).detach().numpy()
    weight = conv.weight.detach().numpy()
    bias = conv.bias.detach().numpy()
    centered_convolved = convolved - (bias[:, None, None] * np.ones((2, 31, 31)))

    another_inverse = np.zeros(x_np.shape)

    c = convolved[0, 0, :, :]
    w = weight[0, 0, :, :]
    b = bias[0]

    cflat = np.reshape(c, (-1, 1))



    # x + bias =


def _super_simple_matrix_inversion():
    # https://medium.com/@_init_/an-illustrated-explanation-of-performing-2d-convolutions-using-matrix-multiplications-1e8de8cd2544
    # https://gist.github.com/hsm207/7bfbe524bfd9b60d1a9e209759064180

    x = np.array([[1, 2, 3, 4],
                  [5, 6, 7, 8],
                  [9, 10, 11, 12],
                  [13, 14, 15, 16]])

    nr_in, nc_in = x.shape
    w = np.array([[1, 2],
                  [3, 4]])
    kernel_size = (2, 2)

    nr = int(np.prod(kernel_size))
    nc = int((nr_in + 1 - kernel_size[0]) * (nc_in + 1 - kernel_size[1]))

    patches_mat = np.empty((nr, nc))

    w_flat = np.reshape(w, (1, -1))


if __name__ == "__main__":
    test_cross_entropy_loss1()

    test_conv2d_1()
    test_conv2d_2()
    test_conv2d_3()
