from numba import jit, prange

import time
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
torch.set_printoptions(precision=10)
from inverse_numba import inv_conv2d_tri_upper

def get_linear_ar_mask(c_out, c_in, upper=True):
    assert c_out == c_in, "Not equal!! : d - %d" % (c_out, c_in)

    mask = np.ones([c_out, c_in], dtype=np.float32)
    for i in range(c_in):
        if upper:
            mask[c_in - i - 1, :c_in-1-i] = 0
        else:
            mask[i, i+1:] = 0
    return torch.from_numpy(mask)

def get_invertible_filter_mask(kernel_size, n_channels, upper=True):

    m = (kernel_size - 1) // 2
    mask = torch.ones([n_channels, n_channels, kernel_size, kernel_size],
                      dtype=torch.long)
    if upper:
        mask[:, :, :m, :] = 0
        mask[:, :, :, :m] = 0
        mask[:, :, m, m] = get_linear_ar_mask(n_channels, n_channels, upper=upper)
    else:
        mask[:, :, -m:, :] = 0
        mask[:, :, :, -m:] = 0
        mask[:, :, m, m] = get_linear_ar_mask(n_channels, n_channels, upper=False)
    return mask, m

def convmatrix2d(kernel, n_channels, input_h, input_w):

    input_dims = n_channels*input_h*input_w
    result_dims = ((input_h - kernel.size(2))+1)*((input_w - kernel.size(3))+1)*n_channels
    m = torch.zeros(kernel.size(0), (input_h - kernel.size(2))+1,
                    (input_w - kernel.size(3))+1, n_channels,
                    input_h, input_w)
    for i in range(m.shape[1]):
        for j in range(m.shape[2]):
            m[:,i,j,:,i:i+kernel.shape[2],j:j+kernel.shape[3]] = kernel

    return m.reshape(result_dims, input_dims)

batch_size = 10
kernel_size = 3
n_channels = 3
image_height = 1280
image_width = 1280
kcenter = (kernel_size - 1 ) // 2

mask_w1, pad = get_invertible_filter_mask(kernel_size, n_channels, upper=True)
mask_w2, _ = get_invertible_filter_mask(kernel_size, n_channels, upper=False)

new_width = image_width+2*pad
new_height = image_height+2*pad

image_orig = torch.randn((batch_size, n_channels, image_height, image_width))
# padding left, right, top, bottom of a N x C x H x W tensor
pad_w1 = nn.ConstantPad2d((pad, pad, pad, pad), 0)
image = pad_w1(image_orig)
kernel = torch.zeros(n_channels, n_channels, kernel_size,
                     kernel_size).normal_() + 10
kernel*=mask_w1
kernel[:,:,1,1] *= 10
t = time.time()
#K = convmatrix2d(kernel, n_channels, new_height, new_width)
print("Make K:", time.time() - t)

# t = time.time()
# tmp = []
# remove_j_interleaves = torch.tensor([list(range(new_width-pad+(new_width*k),
#                                                 new_width+(new_width*k) + pad)) for k in
#                                      range(new_height*n_channels)]).flatten()
# remove_j_blocks = torch.tensor([list(range(new_height*new_width*c,
#                                            new_height*new_width*c + new_width*pad)) +
#                                 list(range(new_height*new_width*(c+1) -
#                                            new_width*pad, new_height*new_width*(c+1))) for
#                                 c in range(n_channels)]).flatten()
# for i, row in enumerate(K):
#     for j, element in enumerate(row):
#         # remove blocks of padded columns
#         if j not in remove_j_blocks:
#             if (not j % new_width == 0) and (j not in remove_j_interleaves):
#                 tmp.append(element)

# print("Make T:", time.time() - t)
# tmp = torch.Tensor(tmp).reshape(image_height*image_width*n_channels,
#                                 image_height*image_width*n_channels)

alternate_channel_index = torch.Tensor([i+(image_height*image_width)*c
                                        for i in range(image_height*image_width)
                                        for c in range(n_channels)]).flatten().long()

# tmp = tmp[alternate_channel_index,:]
# tmp = tmp[:,alternate_channel_index]

# b = K @ image.flatten()
# c = tmp @ image_orig.flatten()[alternate_channel_index]


a = torch.nn.functional.conv2d(image, kernel)
a = a.clone()
#a = a.reshape(batch_size, -1)[:,alternate_channel_index].contiguous()
np_a = np.array(a, dtype='float64')
np_kernel = np.array(kernel, dtype='float64')
x_recon = torch.from_numpy(inv_conv2d_tri_upper(np_a, np_kernel).astype('float32'))
t = time.time()
x_recon = torch.from_numpy(inv_conv2d_tri_upper(np_a, np_kernel).astype('float32'))
print("Seconds", time.time() -t)
print(((image_orig - x_recon)**2).sqrt().mean())
print(((image_orig - x_recon)**2).max())
#print(((image_orig.view(batch_size,-1)[:,alternate_channel_index] - x_recon.view(batch_size,-1))**2).sqrt().mean())
#print(((image_orig.view(batch_size,-1)[:,alternate_channel_index] - x_recon.view(batch_size,-1))**2).max())

# t = time.time()
# x_recon = torch.triangular_solve(a.T, tmp)
# print("Seconds tri solve", time.time() -t)
# print(((image_orig.view(batch_size,-1)[:,alternate_channel_index] - x_recon[0].view(batch_size,-1))**2).sqrt().mean())

# print("B",(a - b[alternate_channel_index]).abs().mean())
# print("C", (a - c).abs().mean())
# plt.imshow(tmp)
# plt.show()
