import time
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
torch.set_printoptions(precision=10)

def get_linear_ar_mask(c_out, c_in, upper=True):
    assert c_out == c_in, "Not equal!! : d - %d" % (c_out, c_in)

    mask = np.ones([c_out, c_in], dtype=np.float32)
    for i in range(c_in):
        if upper:
            mask[c_in - i - 1, :c_in-1-i] = 0
        else:
            mask[i, i+1:] = 0
    return torch.from_numpy(mask)

def get_invertible_filter_mask(kernel_size, n_channels, upper=True):

    m = (kernel_size - 1) // 2
    mask = torch.ones([n_channels, n_channels, kernel_size, kernel_size],
                      dtype=torch.long)
    if upper:
        mask[:, :, :m, :] = 0
        mask[:, :, :, :m] = 0
        mask[:, :, m, m] = get_linear_ar_mask(n_channels, n_channels, upper=upper)
    else:
        mask[:, :, -m:, :] = 0
        mask[:, :, :, -m:] = 0
        mask[:, :, m, m] = get_linear_ar_mask(n_channels, n_channels, upper=False)
    return mask, m

def convmatrix2d(kernel, n_channels, input_h, input_w):

    input_dims = n_channels*input_h*input_w
    result_dims = ((input_h - kernel.size(2))+1)*((input_w - kernel.size(3))+1)*n_channels
    m = torch.zeros(kernel.size(0), (input_h - kernel.size(2))+1,
                    (input_w - kernel.size(3))+1, n_channels,
                    input_h, input_w)
    for i in range(m.shape[1]):
        for j in range(m.shape[2]):
            m[:,i,j,:,i:i+kernel.shape[2],j:j+kernel.shape[3]] = kernel

    return m.reshape(result_dims, input_dims)

def conv2d_tri_inv_filter_lower(y, w, d, img_w, n_channels, alternate_channel_index):
    """ Inverts y from a emerging convolution

    Args:

    y (:obj: tensor): tensor to be inverted must be flatten to be shape (batch_size, -1)
    w (:obj:tensor): c_out, c_in, kH, kW, convolution filter used in the emerging context. Must
                     satisfy in_channel==out_channel when looking at the equivalent convolution
                     form y=Kx, K must be LOWER TRIANGULAR
    d (int): effective (non-zero) width of filters
    img_w (ind): non-zero padded image width
    n_channels (int): number of input and output channels
    alternative_channel_indices (:obj:tensor:long): vector with index order for x and y,
                                                    such that first element is first pixel
                                                    first channel, second element is first
                                                    pixel second channel, etc.

    Returns:

    x = inv(y)

    """
    batch_size, N = y.shape
    x = torch.zeros_like(y, dtype=torch.float)
    jumps_up = 0
    dc = d*n_channels
    imgwc = img_w*n_channels

    normalizer = max(1, 1 / w[:,:,d-1,d-1].diag().abs().min().clone().detach())
    w *= normalizer

    for n in range(N):
        #idx = alternative_channel_indices[n]
        idx = n
        c = n % n_channels
        w_out = w[c,:] # in_channel x W x H
        solved_x_true_weighted_sum = 0.
        solved_x_weighted_sum = 0.
        if n > 0:
            pix_out = n // n_channels
            # get the n closets to the value we try and solve for

            tmp = (n % imgwc)
            if tmp < dc:
                n_closest = tmp
            else:
                n_closest = dc - n_channels + c

            n_closest_indices = torch.arange(n-n_closest, n)

            # get the n trailing, which are spanning backwards in the
            # original image - they are lagging behind by an "n_channels"
            # times img_w number of indices. This occurs n_channels times
            n_trailing_base_indices = torch.arange(n-n_closest, n - c + n_channels)
            jumps_up = pix_out // img_w if jumps_up < d - 1 else d - 1
            list_block_indices = [n_trailing_base_indices - imgwc*jump for jump in range(1, jumps_up+1)[::-1]]

            # will contain <=d number of unique lists
            if n_closest > 0:
                list_block_indices.append(n_closest_indices)

            init_row = d - jumps_up - 1
            solved_x_weighted_sum = 0.
            for rel_row, index_list in enumerate(list_block_indices):
                solved_x = x[:,index_list]
                channel_list = index_list % n_channels
                n_moves = len(index_list) // n_channels
                row = init_row + rel_row
                if row == d - 1:
                    offsets = range(0, n_moves + 1)[::-1]
                else:
                    offsets = range(0, n_moves)[::-1]
                solved_x_weights = torch.tensor([w_out[c_in, row, d - offsets[i//n_channels] - 1]
                                                 for i, c_in in enumerate(channel_list)], dtype=kernel.dtype)
                solved_x_weighted_sum += torch.matmul(solved_x, solved_x_weights)

        x[:, idx] -= solved_x_weighted_sum
        x[:, idx] += y[:, idx]
        x[:, idx] /= w_out[c, d-1, d-1]

    x *= normalizer
    return x
def conv2d_tri_inv_filter_upper(y:torch.Tensor, w:torch.Tensor, d:int, img_w:int, n_channels:int,
                                channel_layout_to_pytoch_layout:torch.Tensor):
    """ Inverts y from a emerging convolution

    Args:

    y (:obj: tensor): tensor to be inverted must be flatten to be shape (batch_size, -1)
    w (:obj:tensor): c_out, c_in, kH, kW, convolution filter used in the emerging context. Must
                     satisfy in_channel==out_channel when looking at the equivalent convolution
                     form y=Kx, K must be UPPER TRIANGULAR
    d (int): effective (non-zero) width of filters
    img_w (ind): non-zero padded image width
    n_channels (int): number of input and output channels
    alternative_channel_indices (:obj:tensor:long): vector with index order for x and y,
                                                    such that first element is first pixel
                                                    first channel, second element is first
                                                    pixel second channel, etc.

    Returns:

    x = inv(y)

    """
    batch_size, N = y.shape
    out_channels, in_channels, kh, kw = w.shape
    assert out_channels == in_channels , "input channels and output channels has to be equal"
    assert kw == kh , "kernel must be dxd"
    assert (kw - 1) % 2 == 0, "kernel size (dxd) should have d uneven"
    n_channels = out_channels
    x = torch.zeros_like(y)
    jumps_down = 0
    dc = d*n_channels
    imgwc = img_w*n_channels
    kcenter = d - 1
    n_pixels = N // n_channels

    for k in range(N):
        n = N - k - 1
        c_out = n % n_channels
        reverse_c = n_channels - 1 - c_out
        w_out = w[c_out,:] # in_channel x W x H
        if n < N-1:

            tmp = (N-n-1) % imgwc
            if tmp < dc:
                n_closest = tmp
            else:
                n_closest = dc - n_channels + reverse_c

            jumps_down = (n_pixels - (n//n_channels) - 1) // img_w if jumps_down < kcenter  else kcenter
            init_row = kcenter + 1
            solved_x_weighted_sum = 0.
            n_leading = n + n_closest + 1 - (n - c_out)
            n_kernel_shifts_leading = (n_leading - 1) // n_channels
            rel_out_channel = (c_out + 1) % n_channels
            for c_in in range(n_channels):
                reverse_cin = n_channels - 1 - c_in
                if n_closest > 0 and (n_closest - reverse_cin > 0):
                    n_kernel_shifts = n_closest // n_channels
                    if c_in < rel_out_channel:
                        rel_idx = c_in + n_channels - rel_out_channel
                    else:
                        rel_idx = c_in - rel_out_channel
                    x_indices = range(n+1 + rel_idx, n+n_closest+1, n_channels)
                    lx_indices = ((n+n_closest - (n+1 + rel_idx)) // n_channels) + 1
                    rel_col = 1 if c_in <= c_out else 0
                    for i, x_idx in enumerate(x_indices):
                        col_idx_orig = kcenter + (2 - lx_indices)*n_kernel_shifts + i
                        col_idx = kcenter + rel_col + i
                        x[:,n] -= x[:, x_idx]*w_out[c_in, kcenter, col_idx]

                leading_base_indices_cin = range(n - c_out + c_in, n + n_closest + 1, n_channels)
                for rel_row, jump in enumerate(range(1, jumps_down+1)):
                    row = init_row + rel_row
                    for i, base_idx in enumerate(leading_base_indices_cin):
                        x_idx = base_idx + imgwc*jump
                        col_idx = kcenter + i
                        x[:,n] -= x[:, x_idx]*w_out[c_in, row, col_idx]

        x[:, n] += y[:, n]
        x[:, n] /= w_out[c_out, kcenter, kcenter]

    return x

batch_size = 1
kernel_size = 3
n_channels = 3
image_height = 3
image_width = 25
kcenter = (kernel_size - 1 ) // 2

mask_w1, pad = get_invertible_filter_mask(kernel_size, n_channels, upper=True)
mask_w2, _ = get_invertible_filter_mask(kernel_size, n_channels, upper=False)

new_width = image_width+2*pad
new_height = image_height+2*pad

image_orig = torch.randn((batch_size, n_channels, image_height, image_width))
# padding left, right, top, bottom of a N x C x H x W tensor
pad_w1 = nn.ConstantPad2d((pad, pad, pad, pad), 0)
image = pad_w1(image_orig)
kernel = torch.zeros(n_channels, n_channels, kernel_size,
                     kernel_size).normal_()
kernel*=mask_w1
torch.nn.init.eye_(kernel[:,:,kcenter,kcenter])
#kernel[:,:,1,1] *= 10

t = time.time()
K = convmatrix2d(kernel, n_channels, new_height, new_width)
print("Make K:", time.time() - t)

# t = time.time()
# tmp = []
# remove_j_interleaves = torch.tensor([list(range(new_width-pad+(new_width*k),
#                                                 new_width+(new_width*k) + pad)) for k in
#                                      range(new_height*n_channels)]).flatten()
# remove_j_blocks = torch.tensor([list(range(new_height*new_width*c,
#                                            new_height*new_width*c + new_width*pad)) +
#                                 list(range(new_height*new_width*(c+1) -
#                                            new_width*pad, new_height*new_width*(c+1))) for
#                                 c in range(n_channels)]).flatten()
# for i, row in enumerate(K):
#     for j, element in enumerate(row):
#         # remove blocks of padded columns
#         if j not in remove_j_blocks:
#             if (not j % new_width == 0) and (j not in remove_j_interleaves):
#                 tmp.append(element)

# print("Make T:", time.time() - t)
# tmp = torch.Tensor(tmp).reshape(image_height*image_width*n_channels,
#                                 image_height*image_width*n_channels)

alternate_channel_index = torch.Tensor([i+(image_height*image_width)*c
                                        for i in range(image_height*image_width)
                                        for c in range(n_channels)]).flatten().long()

# tmp = tmp[alternate_channel_index,:]
# tmp = tmp[:,alternate_channel_index]

# b = K @ image.flatten()
# c = tmp @ image_orig.flatten()[alternate_channel_index]


with torch.no_grad():
    a = torch.nn.functional.conv2d(image, kernel)
    a = a.clone()
    a = a.reshape(batch_size, -1)[:,alternate_channel_index].contiguous()
    t = time.time()
    x_recon = conv2d_tri_inv_filter_upper(a, kernel, int((kernel_size +1)/2),
                                          image_width, n_channels, alternate_channel_index)
print("Seconds", time.time() -t)
print(((image_orig.view(batch_size,-1)[:,alternate_channel_index] - x_recon.view(batch_size,-1))**2).sqrt().mean())

# t = time.time()
# x_recon = torch.triangular_solve(a.T, tmp)
# print("Seconds tri solve", time.time() -t)
# print(((image_orig.view(batch_size,-1)[:,alternate_channel_index] - x_recon[0].view(batch_size,-1))**2).sqrt().mean())

# print("B",(a - b[alternate_channel_index]).abs().mean())
# print("C", (a - c).abs().mean())
# plt.imshow(tmp)
# plt.show()
