import numpy as np
import scipy
from utils import randomize_locations, run_concurrent_solver, rand, G


HIDDEN_SIZE = 4
OUTPUT_SIZE = HIDDEN_SIZE // 4
INPUT_SIZE = HIDDEN_SIZE
F_DIM = HIDDEN_SIZE - OUTPUT_SIZE
WEIGHT_VARS = OUTPUT_SIZE * INPUT_SIZE
TRIES = 1000
INITIALIZATIONS_PER_TRY = 400
DEBUG = False


class Equations:
    def __init__(self, k):
        self.k = k
        self.b = rand(HIDDEN_SIZE)
        self.d = rand(OUTPUT_SIZE, HIDDEN_SIZE, order=HIDDEN_SIZE)
        self.a = create_a_from_teacher()
        self.w_params = self._create_w_params()
        self.w_bot = rand(HIDDEN_SIZE - k, HIDDEN_SIZE, order=HIDDEN_SIZE)
        self.f_params = self._create_f_params()

    def _create_f_linear_matrix(self):
        parts = []
        for i in range(INPUT_SIZE - 1):
            part = np.column_stack((np.zeros((HIDDEN_SIZE - self.k, HIDDEN_SIZE * (i + 1))),
                                    self.w_bot,
                                    np.zeros((HIDDEN_SIZE - self.k, HIDDEN_SIZE * (INPUT_SIZE - i - 2)))))
            for j in range(HIDDEN_SIZE - self.k):
                part[j, HIDDEN_SIZE * i + self.k + j] = -1
            parts.append(part)
        for i in range(HIDDEN_SIZE - self.k):
            part = np.zeros(HIDDEN_SIZE * INPUT_SIZE)
            part[HIDDEN_SIZE * (INPUT_SIZE - 1) + self.k + i] = -1
            parts.append(part)
        for i in range(INPUT_SIZE):
            part = np.column_stack((np.zeros((OUTPUT_SIZE, HIDDEN_SIZE * i)),
                                    self.d,
                                    np.zeros((OUTPUT_SIZE, HIDDEN_SIZE * (INPUT_SIZE - i - 1)))))
            parts.append(part)
        return np.vstack(parts)

    def _create_f_linear_results(self):
        vec = - np.zeros((HIDDEN_SIZE - self.k) * (INPUT_SIZE - 1))
        vec = np.append(vec, -self.w_bot @ self.b)
        vec = np.append(vec, self.a.reshape(INPUT_SIZE * OUTPUT_SIZE, order='F'))
        return vec

    def _create_f_params(self):
        mat = self._create_f_linear_matrix()
        res = self._create_f_linear_results()
        return scipy.linalg.null_space(mat),  np.linalg.pinv(mat) @ res

    def get_f(self, x):
        return (self.f_params[0] @ x + self.f_params[1]).reshape(HIDDEN_SIZE, INPUT_SIZE, order='F')

    def _create_w_params(self):
        w_const = rand(HIDDEN_SIZE * self.k - WEIGHT_VARS, order=HIDDEN_SIZE)
        var, const = randomize_locations(n=HIDDEN_SIZE, k=self.k, r=WEIGHT_VARS, max_per_row=INPUT_SIZE)

        const_mat = np.zeros((self.k, HIDDEN_SIZE))
        for i, loc in enumerate(const):
            const_mat[loc[0], loc[1]] = w_const[i]

        w_var_mat = np.zeros((self.k * HIDDEN_SIZE, WEIGHT_VARS))
        for i, loc in enumerate(var):
            w_var_mat[loc[0] * HIDDEN_SIZE + loc[1], i] = 1

        return w_var_mat, const_mat

    def get_w_top(self, x):
        return (self.w_params[0] @ x).reshape(self.k, HIDDEN_SIZE) + self.w_params[1]

    def get_g(self, f):
        return np.concatenate((f[:, 1:], self.b.reshape(HIDDEN_SIZE, 1)), axis=1)

    def check_f_params(self):
        x = rand((self.k - OUTPUT_SIZE) * INPUT_SIZE)
        f = self.get_f(x)
        assert np.isclose(self.d @ f, self.a).all()
        g = self.get_g(f)
        assert np.isclose(self.w_bot @ g, f[self.k:, :]).all()

    def get_w_top_f(self, x):
        x_w = x[:WEIGHT_VARS]
        x_f = x[WEIGHT_VARS:]
        w = self.get_w_top(x_w)
        f = self.get_f(x_f)
        return w, f

    def get_w_f(self, x):
        w_top, f = self.get_w_top_f(x)
        return np.vstack((w_top, self.w_bot)), f

    def __call__(self, x):
        w_top, f = self.get_w_top_f(x)
        g = self.get_g(f)
        return (w_top.dot(g) - f[:self.k, :]).reshape(self.k * INPUT_SIZE)


def create_a_from_teacher():
    b = rand(HIDDEN_SIZE)
    d = rand(OUTPUT_SIZE, HIDDEN_SIZE, order=HIDDEN_SIZE)
    w = rand(HIDDEN_SIZE, HIDDEN_SIZE, order=HIDDEN_SIZE)
    return np.column_stack([d @ np.linalg.matrix_power(w, INPUT_SIZE - i) @ b for i in range(INPUT_SIZE)])


def check(e, x):
    inp = rand(INPUT_SIZE)
    h = np.zeros(HIDDEN_SIZE)
    w, f = e.get_w_f(x)
    for inp_i in inp:
        h = w.dot(h) + e.b * inp_i
    s = e.d @ w.dot(h)
    s_real = e.a @ inp
    assert np.isclose(s, s_real, atol=1e-3).all()


def check_f(e, x):
    w, f = e.get_w_f(x)
    assert np.isclose(e.d @ f, e.a).all()
    for i in range(INPUT_SIZE):
        assert np.isclose(np.linalg.matrix_power(w, INPUT_SIZE - i) @ e.b, f[:, i], atol=1e-3).all()


def checker(e, x):
    check_f(e, x)
    check(e, x)


def main():
    print(
        f'running with n={HIDDEN_SIZE}, d={OUTPUT_SIZE}, T={INPUT_SIZE}, g={G}, {TRIES} tries with {INITIALIZATIONS_PER_TRY} '
        f'initializations per try', flush=True)
    ks = range(OUTPUT_SIZE + 1, HIDDEN_SIZE+1)
    x0_size = {k: k * INPUT_SIZE for k in ks}
    checker_if_needed = checker if DEBUG else None
    run_concurrent_solver(Equations, x0_size, checker_if_needed, ks, TRIES, INITIALIZATIONS_PER_TRY)


if __name__ == '__main__':
    main()
