import numpy as np

import utils


def encode_matrices(Ap, Bp, codebook):
    m, dim1_split, dim2 = Ap.shape
    n, dim2, dim3_split = Bp.shape
    n_groups, _ = codebook.shape

    # n_groups * m * dim1_split * dim2, n_groups * dim2 * dim3_split
    Aenc = np.zeros([n_groups, m, dim1_split, dim2])
    Benc = np.zeros([n_groups, dim2, dim3_split])

    # n_groups * m * m * 1 * 1, n_groups * n * 1 * 1
    c1 = np.expand_dims(np.stack([codebook**k for k in range(m)], -1), axis=(-2, -1))
    c2 = np.expand_dims(np.stack([utils.cheby_poly(codebook[:, 0], order=m) ** l for l in range(n)], -1), axis=(-2, -1))

    for i in range(n_groups):
        for j in range(m):
            Aenc[i, j] = (Ap * c1[i, j]).sum(0)
        Benc[i] = (Bp * c2[i]).sum(0)

    return Aenc, Benc
