import numpy as np

HO = 16
WO = 16
CI = 16
CO = 16
IW = 4
WW = 4
OW = 13
PI = 4
PO = 4
GW = 2
G = 4

# 注意数据类型，保证不溢出
A = np.round((np.random.rand(HO * WO, CO) - 0.5) * 2**IW).astype(int)
B = np.round((np.random.rand(CO, CI * 3 * 3) - 0.5) * 2**WW).astype(int)
S = np.round((np.random.rand(CO) - 0.5) * 2**GW).astype(int)

A[A == 2**(IW-1)] = 2**(IW-1) - 1
B[B == 2**(WW-1)] = 2**(WW-1) - 1
S[S == 2**(GW-1)] = 2**(GW-1) - 1

A = A + 2 ** (IW - 1)
B = B + 2 ** (WW - 1)
S = S + 2 ** (GW - 1)

CRef = np.zeros([HO * WO, CI * 3 * 3])

for i in range(HO * WO):
    for j in range(CI * 3 * 3):
        temp_vec = np.zeros([G])
        for k in range(CO):
            temp = A[i, k] * B[k, j]
            temp_vec[S[k]] = temp_vec[S[k]] + temp
        for g in range(G):
            CRef[i, j] = CRef[i, j] + np.floor(temp_vec[g] * 2 ** (-g))

print(A)
print(B)
print(CRef)

CRef = CRef.astype(int)

A_str = "ap_uint<IW> A["+str(HO * WO)+"]["+str(CO // PI)+"]["+str(PI)+"]={"
B_str = "ap_uint<WW> B["+str(CO // PI)+"]["+str(CI * 3 * 3 // PO)+"]["+str(PI * PO)+"]={"
CRef_str = "ap_uint<OW> RefC["+str(HO * WO)+"]["+str(CI * 3 * 3)+"]={"
S_str = "ap_uint<GW> S["+str(CO // PI)+"]["+str(PI)+"]={"

for i in range(CO // PI):
    S_str = S_str + "{"
    for pi in range(PI):
        temp = S[PI * i + pi]
        if (i == CI // PI - 1) and (pi == PI - 1):
                S_str = S_str + str(temp) + "}"
        elif pi == PI - 1:
            S_str = S_str + str(temp) + "},"
        else:
            S_str = S_str + str(temp) + ","
S_str = S_str + "};\n"

for j in range(HO * WO):
    A_str = A_str + "{"
    for i in range(CO // PI):
        A_str = A_str + "{"
        for pi in range(PI):
            temp = A[j][PI * i + pi]
            if (i == CO // PI - 1) and (pi == PI - 1):
                A_str = A_str + str(temp) + "}"
            elif pi == PI - 1:
                A_str = A_str + str(temp) + "},"
            else:
                A_str = A_str + str(temp) + ","
        if (j == HO * WO - 1) and (i == CO // PI - 1):
            A_str = A_str + "}"
        elif i == CO // PI - 1:
            A_str = A_str + "},"
A_str = A_str + "};\n"

for i in range(CO // PI):
    B_str = B_str + "{"
    for j in range(CI * 3 * 3 // PO):
        B_str = B_str + "{"
        for po in range(PO):
            for pi in range(PI):
                temp = B[PI * i + pi][j*PO+po]
                if (j == CI * 3 * 3 // PO - 1) and (pi == PI - 1) and (po == PO - 1):
                    B_str = B_str + str(temp) + "}"
                elif (pi == PI - 1) and (po == PO - 1):
                    B_str = B_str + str(temp) + "},"
                else:
                    B_str = B_str + str(temp) + ","
        if (i == CO // PI - 1) and (j == CI * 3 * 3 // PO - 1):
            B_str = B_str + "}"
        elif j == CI * 3 * 3 // PO - 1:
            B_str = B_str + "},"
B_str = B_str + "};\n"

for i in range(HO * WO):
    CRef_str = CRef_str + "{"
    for j in range(CI * 3 * 3):
        if (j == CI * 3 * 3 - 1) and (i == HO * WO - 1):
            CRef_str = CRef_str + str(CRef[i][j]) + "}"
        elif j == CI * 3 * 3 - 1:
            CRef_str = CRef_str + str(CRef[i][j]) + "},"
        else:
            CRef_str = CRef_str + str(CRef[i][j]) + ","
CRef_str = CRef_str + "};\n"

print(A_str)
print(B_str)
print(CRef_str)
print(S_str)

head_str = "#include \"MatrixMul.h\"\n"

print(head_str)

infile = open('tb_var.h', 'w')
infile.write(head_str)
infile.write(A_str)
infile.write(B_str)
infile.write(CRef_str)
infile.write(S_str)
infile.close()
