import numpy as np

HO = 16
WO = 16
CI = 16
CO = 16
IW = 4
WW = 4
OW = 13
PI = 4
PO = 4

# 注意数据类型，保证不溢出
A = np.round((np.random.rand(HO * WO, CO) - 0.5) * 2**IW).astype(int)
B = np.round((np.random.rand(CO, CI * 3 * 3) - 0.5) * 2**WW).astype(int)

A[A == 2**(IW-1)] = 2**(IW-1) - 1
B[B == 2**(WW-1)] = 2**(WW-1) - 1
A = A + 2 ** (IW - 1)
B = B + 2 ** (WW - 1)

CRef = np.matmul(A, B)

print(A)
print(B)
print(CRef)

# A_mem = np.round(np.zeros([HO * WO, CI / PI, PI * 9])).astype(np.int)
# B_mem = np.round(np.zeros([CI / PI, CO / PO, PI * PO * 9])).astype(np.int)

A_str = "ap_uint<IW> A["+str(HO * WO)+"]["+str(CO // PI)+"]["+str(PI)+"]={"
B_str = "ap_uint<WW> B["+str(CO // PI)+"]["+str(CI * 3 * 3 // PO)+"]["+str(PI * PO)+"]={"
CRef_str = "ap_uint<OW> RefC["+str(HO * WO)+"]["+str(CI * 3 * 3)+"]={"

for j in range(HO * WO):
    A_str = A_str + "{"
    for i in range(CO // PI):
        A_str = A_str + "{"
        for pi in range(PI):
            temp = A[j][PI * i + pi]
            if (i == CO // PI - 1) and (pi == PI - 1):
                A_str = A_str + str(temp) + "}"
            elif pi == PI - 1:
                A_str = A_str + str(temp) + "},"
            else:
                A_str = A_str + str(temp) + ","
        if (j == HO * WO - 1) and (i == CO // PI - 1):
            A_str = A_str + "}"
        elif i == CO // PI - 1:
            A_str = A_str + "},"
A_str = A_str + "};\n"

for i in range(CO // PI):
    B_str = B_str + "{"
    for j in range(CI * 3 * 3 // PO):
        B_str = B_str + "{"
        for po in range(PO):
            for pi in range(PI):
                temp = B[PI * i + pi][j*PO+po]
                if (j == CI * 3 * 3 // PO - 1) and (pi == PI - 1) and (po == PO - 1):
                    B_str = B_str + str(temp) + "}"
                elif (pi == PI - 1) and (po == PO - 1):
                    B_str = B_str + str(temp) + "},"
                else:
                    B_str = B_str + str(temp) + ","
        if (i == CO // PI - 1) and (j == CI * 3 * 3 // PO - 1):
            B_str = B_str + "}"
        elif j == CI * 3 * 3 // PO - 1:
            B_str = B_str + "},"
B_str = B_str + "};\n"

for i in range(HO * WO):
    CRef_str = CRef_str + "{"
    for j in range(CI * 3 * 3):
        if (j == CI * 3 * 3 - 1) and (i == HO * WO - 1):
            CRef_str = CRef_str + str(CRef[i][j]) + "}"
        elif j == CI * 3 * 3 - 1:
            CRef_str = CRef_str + str(CRef[i][j]) + "},"
        else:
            CRef_str = CRef_str + str(CRef[i][j]) + ","
CRef_str = CRef_str + "};\n"

print(A_str)
print(B_str)
print(CRef_str)

head_str = "#include \"MatrixMul.h\"\n"

print(head_str)

infile = open('tb_var.h', 'w')
infile.write(head_str)
infile.write(A_str)
infile.write(B_str)
infile.write(CRef_str)
infile.close()
