from numba import jit
import numpy as np

# consider tricks from here https://numba.pydata.org/numba-doc/dev/user/performance-tips.html

###########################################################

# Inverse functions

###########################################################

@jit(nopython=True)
def inv_conv2d_tri_upper(y, w):
    """ Inverts y from a emerging convolution

    Args:

    y (NumPy Array): array to be inverted
    w (NumPy Array): c_out, c_in, kH, kW, convolution filter used in the emerging context. Must
                     satisfy in_channel==out_channel when looking at the equivalent convolution
                     form y=Kx, K must be UPPER TRIANGULAR

    Returns:

    x = inv(y)

    """
    batch_size, channels, ih, iw = y.shape
    out_channels, in_channels, kh, kw = w.shape
    n_channels = out_channels
    x = np.zeros_like(y)
    n_jumps = 0
    d = (kw + 1) // 2
    dc = d*n_channels
    kcenter = d - 1
    init_row = kcenter + 1

    for b in range(batch_size):
        for _i in range(ih):
            i = ih - _i - 1
            for _j in range(iw):
                j = iw - _j - 1
                j_slice = min(iw-j, d)
                for _c_out in range(n_channels):
                    c_out = n_channels - 1 - _c_out

                    tmp = _j*n_channels + _c_out
                    if tmp < dc:
                        n_closest = tmp
                    else:
                        n_closest = dc - n_channels + _c_out

                    if n_jumps < kcenter:
                        n_jumps = _i
                    else:
                        n_jumps = kcenter

                    rel_out_channel = (c_out + 1) % n_channels
                    for c_in in range(n_channels):
                        reverse_cin = n_channels - 1 - c_in
                        if c_in <= c_out:
                            init_col_x = j + 1
                            init_col_w = kcenter + 1
                            j_slice_closest = j_slice - 1
                        else:
                            init_col_x = j
                            init_col_w = kcenter
                            j_slice_closest = j_slice
                        if n_closest > 0 and j_slice_closest > 0:
                            x[b, c_out, i, j] -= x[b, c_in, i, init_col_x:init_col_x+j_slice_closest].dot(w[c_out, c_in, kcenter, init_col_w:init_col_w+j_slice_closest])

                        if n_jumps > 0:
                            x[b, c_out, i, j] -= np.sum(x[b, c_in, i+1:i+1+n_jumps, j:j+j_slice]*w[c_out, c_in, init_row:init_row+n_jumps, kcenter:kcenter+j_slice])

                    x[b, c_out, i, j] += y[b, c_out, i, j]
                    x[b, c_out, i, j] /= w[c_out, c_out, kcenter, kcenter]

    return x

@jit(nopython=True)
def inv_conv2d_tri_lower(y, w):
    """ Inverts y from a emerging convolution

    Args:

    y (NumPy Array): array to be inverted
    w (NumPy Array): c_out, c_in, kH, kW, convolution filter used in the emerging context. Must
                     satisfy in_channel==out_channel when looking at the equivalent convolution
                     form y=Kx, K must be LOWER TRIANGULAR

    Returns:

    x = inv(y)

    """
    batch_size, channels, ih, iw = y.shape
    out_channels, in_channels, kh, kw = w.shape
    n_channels = out_channels
    x = np.zeros_like(y)
    n_jumps = 0
    d = (kw + 1) // 2
    dc = d*n_channels
    kcenter = d - 1
    init_row = kcenter + 1

    for b in range(batch_size):
        for i in range(ih):
            for j in range(iw):
                j_slice = min(iw-j, d)
                for c_out in range(n_channels):

                    tmp = _j*n_channels + c_out
                    if tmp < dc:
                        n_closest = tmp
                    else:
                        n_closest = dc - n_channels + c_out

                    if n_jumps < kcenter:
                        n_jumps = _i
                    else:
                        n_jumps = kcenter

                    rel_out_channel = (c_out + 1) % n_channels
                    for c_in in range(n_channels):
                        reverse_cin = n_channels - 1 - c_in
                        if c_in <= c_out:
                            init_col_x = j + 1
                            init_col_w = kcenter + 1
                            j_slice_closest = j_slice - 1
                        else:
                            init_col_x = j
                            init_col_w = kcenter
                            j_slice_closest = j_slice
                        if n_closest > 0 and j_slice_closest > 0:
                            x[b, c_out, i, j] -= x[b, c_in, i, init_col_x-j_slice_closest:init_col_x].dot(w[c_out, c_in, kcenter, init_col_w-j_slice_closest:init_col_w])

                        if n_jumps > 0:
                            x[b, c_out, i, j] -= np.sum(x[b, c_in, i-n_jumps:i, j-j_slice:j]*w[c_out, c_in, init_row-n_jumps:init_row, kcenter-j_slice:kcenter])

                    x[b, c_out, i, j] += y[b, c_out, i, j]
                    x[b, c_out, i, j] /= w[c_out, c_out, kcenter, kcenter]

    return x
