import time
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
torch.set_printoptions(precision=10)

def get_linear_ar_mask(c_out, c_in, upper=True):
    assert c_out == c_in, "Not equal!! : d - %d" % (c_out, c_in)

    mask = np.ones([c_out, c_in], dtype=np.float32)
    for i in range(c_in):
        if upper:
            mask[c_in - i - 1, :c_in-1-i] = 0
        else:
            mask[i, i+1:] = 0
    return torch.from_numpy(mask)

def get_invertible_filter_mask(kernel_size, n_channels, upper=True):

    m = (kernel_size - 1) // 2
    mask = torch.ones([n_channels, n_channels, kernel_size, kernel_size],
                      dtype=torch.long)
    if upper:
        mask[:, :, :m, :] = 0
        mask[:, :, :, :m] = 0
        mask[:, :, m, m] = get_linear_ar_mask(n_channels, n_channels, upper=upper)
    else:
        mask[:, :, -m:, :] = 0
        mask[:, :, :, -m:] = 0
        mask[:, :, m, m] = get_linear_ar_mask(n_channels, n_channels, upper=False)
    return mask, m

def convmatrix2d(kernel, n_channels, input_h, input_w):

    input_dims = n_channels*input_h*input_w
    result_dims = ((input_h - kernel.size(2))+1)*((input_w - kernel.size(3))+1)*n_channels
    m = torch.zeros(kernel.size(0), (input_h - kernel.size(2))+1,
                    (input_w - kernel.size(3))+1, n_channels,
                    input_h, input_w)
    for i in range(m.shape[1]):
        for j in range(m.shape[2]):
            m[:,i,j,:,i:i+kernel.shape[2],j:j+kernel.shape[3]] = kernel

    return m.reshape(result_dims, input_dims)

def conv2d_tri_inv_filter_lower(y, w, d, img_w, n_channels):
    """ Inverts y from a emerging convolution

    Args:

    y (:obj: tensor): tensor to be inverted must be flatten to be shape (batch_size, -1)
    w (:obj:tensor): c_out, c_in, kH, kW, convolution filter used in the emerging context. Must
                     satisfy in_channel==out_channel when looking at the equivalent convolution
                     form y=Kx, K must be LOWER TRIANGULAR
    d (int): effective (non-zero) width of filters
    img_w (ind): non-zero padded image width
    n_channels (int): number of input and output channels
    alternative_channel_indices (:obj:tensor:long): vector with index order for x and y,
                                                    such that first element is first pixel
                                                    first channel, second element is first
                                                    pixel second channel, etc.

    Returns:

    x = inv(y)

    """
    batch_size, N = y.shape
    x = torch.zeros_like(y, dtype=torch.float)
    jumps_up = 0
    dc = d*n_channels
    imgwc = img_w*n_channels

    normalizer = max(1, 1 / w[:,:,d-1,d-1].diag().abs().min().clone().detach())
    w *= normalizer

    for n in range(N):
        #idx = alternative_channel_indices[n]
        idx = n
        c = n % n_channels
        w_out = w[c,:] # in_channel x W x H
        solved_x_true_weighted_sum = 0.
        solved_x_weighted_sum = 0.
        if n > 0:
            pix_out = n // n_channels
            # get the n closets to the value we try and solve for

            tmp = (n % imgwc)
            if tmp < dc:
                n_closest = tmp
            else:
                n_closest = dc - n_channels + c

            n_closest_indices = torch.arange(n-n_closest, n)

            # get the n trailing, which are spanning backwards in the
            # original image - they are lagging behind by an "n_channels"
            # times img_w number of indices. This occurs n_channels times
            n_trailing_base_indices = torch.arange(n-n_closest, n - c + n_channels)
            jumps_up = pix_out // img_w if jumps_up < d - 1 else d - 1
            list_block_indices = [n_trailing_base_indices - imgwc*jump for jump in range(1, jumps_up+1)[::-1]]

            # will contain <=d number of unique lists
            if n_closest > 0:
                list_block_indices.append(n_closest_indices)

            init_row = d - jumps_up - 1
            solved_x_weighted_sum = 0.
            for rel_row, index_list in enumerate(list_block_indices):
                solved_x = x[:,index_list]
                channel_list = index_list % n_channels
                n_moves = len(index_list) // n_channels
                row = init_row + rel_row
                if row == d - 1:
                    offsets = range(0, n_moves + 1)[::-1]
                else:
                    offsets = range(0, n_moves)[::-1]
                solved_x_weights = torch.tensor([w_out[c_in, row, d - offsets[i//n_channels] - 1]
                                                 for i, c_in in enumerate(channel_list)], dtype=kernel.dtype)
                solved_x_weighted_sum += torch.matmul(solved_x, solved_x_weights)

        x[:, idx] -= solved_x_weighted_sum
        x[:, idx] += y[:, idx]
        x[:, idx] /= w_out[c, d-1, d-1]

    x *= normalizer
    return x

def conv2d_tri_inv_filter_upper(z, w):
    zs = z.shape
    batchsize, n_channels, height, width = zs
    ksize = w.shape[0]
    kcenter = 1

    # Subtract bias term.

    diagonal = np.diag(w[kcenter, kcenter, :, :])
    # print(diagonal[np.argsort(diagonal)])
    factor = 1./np.min(diagonal)
    factor = max(1, factor)
    factor = 1.
    is_upper = True

    print('factor is', factor)
    # print('w is', w.transpose(3, 2, 0, 1))

    x_np = np.zeros(zs)
    z_np = np.array(z, dtype='float64')
    w_np = np.array(w, dtype='float64')

    w_np *= factor

    def filter2image(j, i, m, k):
        m_ = (m - kcenter)
        k_ = (k - kcenter)
        return j+k_, i+m_

    def in_bound(idx, lower, upper):
        return (idx >= lower) and (idx < upper)

    def reverse_range(n, reverse):
        if reverse:
            return range(n)
        else:
            return reversed(range(n))

    counter = 0
    for b in range(batchsize):
        for j in reverse_range(height, is_upper):
            for i in reverse_range(width, is_upper):
                for c_out in reverse_range(n_channels, not is_upper):
                    for c_in in range(n_channels):
                        for k in range(ksize):
                            for m in range(ksize):
                                if k == kcenter and m == kcenter and \
                                        c_in == c_out:
                                    continue

                                j_, i_ = filter2image(j, i, m, k)

                                if not in_bound(j_, 0, height):
                                    continue

                                if not in_bound(i_, 0, width):
                                    continue

                                x_np[b, c_out, j, i] -= \
                                    w_np[c_out, c_in, k, m] \
                                    * x_np[b, c_in, j_, i_]

                    # Compute value for x
                    counter += 1
                    x_np[b, c_out, j, i] += z_np[b, c_out, j, i]
                    x_np[b, c_out, j, i] /= \
                        w_np[ c_out, c_out, kcenter, kcenter]

    x_np = x_np * factor
    print(counter)

    return x_np.astype('float32')

kernel_size = 3
assert (kernel_size - 1) % 2 == 0, "kernel size (dxd) should have d uneven"
n_channels = 3
image_height = 128
image_width = 128

mask_w1, pad = get_invertible_filter_mask(kernel_size, n_channels, upper=True)
mask_w2, _ = get_invertible_filter_mask(kernel_size, n_channels, upper=False)

new_width = image_width+2*pad
new_height = image_height+2*pad

image_orig = torch.randn((100, n_channels, image_height, image_width))
batch_size = image_orig.size(0)
# padding left, right, top, bottom of a N x C x H x W tensor
pad_w1 = nn.ConstantPad2d((pad, pad, pad, pad), 0)
image = pad_w1(image_orig)
kernel = torch.zeros(n_channels, n_channels, kernel_size,
                     kernel_size).normal_() + 10
kernel*=mask_w1
kernel[:,:,1,1] *= 10
# K = convmatrix2d(kernel, n_channels, new_height, new_width)

# tmp = []
# remove_j_interleaves = torch.tensor([list(range(new_width-pad+(new_width*k),
#                                                 new_width+(new_width*k) + pad)) for k in
#                                      range(new_height*n_channels)]).flatten()
# remove_j_blocks = torch.tensor([list(range(new_height*new_width*c,
#                                            new_height*new_width*c + new_width*pad)) +
#                                 list(range(new_height*new_width*(c+1) -
#                                            new_width*pad, new_height*new_width*(c+1))) for
#                                 c in range(n_channels)]).flatten()
# for i, row in enumerate(K):
#     for j, element in enumerate(row):
#         # remove blocks of padded columns
#         if j not in remove_j_blocks:
#             if (not j % new_width == 0) and (j not in remove_j_interleaves):
#                 tmp.append(element)

# tmp = torch.Tensor(tmp).reshape(image_height*image_width*n_channels,
#                                 image_height*image_width*n_channels)

# tmp = tmp[alternate_channel_index,:]
# tmp = tmp[:,alternate_channel_index]
# plt.imshow(tmp)
# plt.show()
#b = K @ image.flatten()
# c = tmp @ image_orig.flatten()[alternate_channel_index]

# print("B",(a - b.view(a.shape)).abs().sum())
# print("C", (a.view(c.shape)[alternate_channel_index] - c).abs().sum())

alternate_channel_index = torch.Tensor([i+(image_height*image_width)*c
                                        for i in range(image_height*image_width)
                                        for c in range(n_channels)]).flatten().long()
a = torch.nn.functional.conv2d(image, kernel)
a = a.clone()
a = a.contiguous()
t = time.time()
x_recon = conv2d_tri_inv_filter_upper(a, kernel)

print("Seconds", time.time() -t)
