import numpy as np


D = 5
T = 8
N = 10
DEBUG = True


def is_close(a, b):
    return np.isclose(a, b, rtol=1e-02, atol=1e-02).all()


def inverse(m):
    return np.linalg.inv(m.astype(np.double)).astype(np.longdouble)


def get_f_column(w_bot, d, a, prev):
    lower = w_bot @ prev
    additive = d[:, D:] @ lower
    result = a - additive
    upper = inverse(d[:, :D]) @ result
    f_col = np.concatenate((upper, lower))

    if DEBUG:
        assert is_close(d[:, :D] @ upper + d[:, D:] @ lower, a)
        assert is_close(d @ f_col, a)

    return f_col


def w_from_f(base_w, f, g):
    w_lines = []
    for i in range(D):
        w_i = base_w[i]
        additive = w_i[T:] @ g[T:, :]
        w_start = (f[i] - additive) @ inverse(g[:T, :])
        w_lines.append(np.concatenate((w_start, w_i[T:])))
    return np.concatenate((np.stack(w_lines), base_w[D:, :]))


def solve(base_w, a, d, b):
    f_columns = []
    prev = b
    for i in range(T):
        f_columns.insert(0, get_f_column(base_w[D:], d, a[:, T - i - 1], prev))
        prev = f_columns[0]

    f = np.column_stack(f_columns)
    if DEBUG:
        assert is_close(d @ f, a)

    g = np.column_stack(f_columns[1:] + [b])
    w = w_from_f(base_w, f, g)

    if DEBUG:
        for i in range(T):
            assert is_close(np.linalg.matrix_power(w, T - i) @ b, f[:, i])
            assert is_close(d @ np.linalg.matrix_power(w, T - i) @ b, a[:, i])

    return w


def run_net(w, b, d, x):
    h = np.zeros(N)
    for i in range(x.shape[0]):
        h = w @ h + b * x[i]
    return d @ w @ h


def get_random(*size):
    return np.random.rand(*size).astype(np.longdouble) * 0.5


def main():
    base_w = get_random(N, N)
    a = get_random(D, T)
    b = get_random(N)
    d = get_random(D, N)
    w = solve(base_w, a, d, b)

    assert (w[D:, :] == base_w[D:, :]).all()
    assert (w[:, T:] == base_w[:, T:]).all()

    for i in range(T):
        x = np.random.randn(T).astype(np.longdouble)
        pred = run_net(w, b, d, x)
        y = a @ x
        assert is_close(y, pred)


if __name__ == '__main__':
    main()
