import math

from utils import randomize_locations, run_concurrent_solver, rand, G
import numpy as np


N = 8
B = 4
D = 6
SAMPLES = B
WEIGHT_VARS = B * D
WEIGHT_PER_LAYER = WEIGHT_VARS // 3
TRIES = 1000
INITIALIZATIONS_PER_TRY = 1000


class SubsetLearnedMatrix:
    def __init__(self, k, d1, d2, vars, max_per_row=None, max_per_col=None):
        assert vars <= k * d2
        self.k = min(k, d1)
        self.d1 = d1
        self.d2 = d2
        self._bot = rand(d1 - self.k, d2, order=d2)
        self._top_params = self._create_top_params(max_per_row, max_per_col, vars)

    def _create_top_params(self, max_per_row, max_per_col, vars):
        w_const = rand(self.d2 * self.k - vars, order=self.d2)
        var, const = randomize_locations(n=self.d2, k=self.k, r=vars, max_per_row=max_per_row,
                                         max_per_col=max_per_col)

        const_mat = np.zeros((self.k, self.d2))
        for i, loc in enumerate(const):
            const_mat[loc[0], loc[1]] = w_const[i]

        w_var_mat = np.zeros((self.k * self.d2, vars))
        for i, loc in enumerate(var):
            w_var_mat[loc[0] * self.d2 + loc[1], i] = 1

        return w_var_mat, const_mat

    def get(self, x):
        top = (self._top_params[0] @ x).reshape(self.k, self.d2) + self._top_params[1]
        return np.vstack((top, self._bot))


class DeepFF:
    def __init__(self, k):
        self._w1_T = SubsetLearnedMatrix(k, B, N, WEIGHT_PER_LAYER)
        self._w2 = SubsetLearnedMatrix(k, N, N, WEIGHT_PER_LAYER, max_per_row=B, max_per_col=D)
        self._w3 = SubsetLearnedMatrix(k, D, N, WEIGHT_PER_LAYER)
        self.x, self.y = create_labeled_samples()

    def __call__(self, x):
        w1 = self._w1_T.get(x[:WEIGHT_PER_LAYER]).T
        w2 = self._w2.get(x[WEIGHT_PER_LAYER:2*WEIGHT_PER_LAYER])
        w3 = self._w3.get(x[WEIGHT_PER_LAYER*2:WEIGHT_PER_LAYER*3])
        f = x[WEIGHT_PER_LAYER*3:].reshape(N, B)
        res1 = w2 @ w1 - f
        res2 = w3 @ f @ self.x - self.y
        return np.concatenate((res1.reshape(N * B), res2.reshape(D * SAMPLES)))


def randomize_layers():
    return rand(N, B, order=B), rand(N, N, order=N), rand(D, N, order=D)


def create_labeled_samples():
    x = rand(B, SAMPLES)
    w1, w2, w3 = randomize_layers()
    y = w3 @ w2 @ w1 @ x
    return x, y


def main():
    print(
        f'running with n={N}, d={D}, b={B}, g={G}, m={SAMPLES}, {TRIES} tries with {INITIALIZATIONS_PER_TRY} '
        f'initializations per try', flush=True)
    min_k = math.ceil(WEIGHT_PER_LAYER / B)
    max_k = min(WEIGHT_PER_LAYER, N)
    run_concurrent_solver(DeepFF, WEIGHT_VARS + N * B, None, range(min_k, max_k+1), TRIES, INITIALIZATIONS_PER_TRY)


if __name__ == '__main__':
    main()
