import numpy as np

HO = 16
WO = 16
CI = 8
CO = 8
IW = 4
WW = 4
OW = 10
PI = 4
PO = 4

# 注意数据类型，保证不溢出
A = np.round((np.random.rand(HO * WO, 3 * 3 * CI) - 0.5) * 2**IW).astype(np.int)
B = np.round((np.random.rand(3 * 3 * CI, CO) - 0.5) * 2**WW).astype(np.int)

A[A == 2**(IW-1)] = 2**(IW-1) - 1
B[B == 2**(WW-1)] = 2**(WW-1) - 1

CRef = np.matmul(A, B)

print(A)
print(B)
print(CRef)

# A_mem = np.round(np.zeros([HO * WO, CI / PI, PI * 9])).astype(np.int)
# B_mem = np.round(np.zeros([CI / PI, CO / PO, PI * PO * 9])).astype(np.int)

A_str = "ap_int<IW> A["+str(HO * WO)+"]["+str(CI // PI)+"]["+str(PI * 9)+"]={"
B_str = "ap_int<WW> B["+str(CI // PI)+"]["+str(CO // PO)+"]["+str(PI * PO * 9)+"]={"
CRef_str = "ap_int<OW> RefC["+str(HO * WO)+"]["+str(CO)+"]={"

for j in range(HO * WO):
    A_str = A_str + "{"
    for i in range(CI // PI):
        A_str = A_str + "{"
        for k in range(9):
            for pi in range(PI):
                temp = A[j][PI * 9 * i + PI * k + pi]
                if (i == CI // PI - 1) and (k == 9 - 1) and (pi == PI - 1):
                    A_str = A_str + str(temp) + "}"
                elif (k == 9 - 1) and (pi == PI - 1):
                    A_str = A_str + str(temp) + "},"
                else:
                    A_str = A_str + str(temp) + ","
        if (j == HO * WO - 1) and (i == CI // PI - 1):
            A_str = A_str + "}"
        elif i == CI // PI - 1:
            A_str = A_str + "},"
A_str = A_str + "};\n"

for i in range(CI // PI):
    B_str = B_str + "{"
    for j in range(CO // PO):
        B_str = B_str + "{"
        for po in range(PO):
            for k in range(9):
                for pi in range(PI):
                    temp = B[PI * 9 * i + PI * k + pi][j*PO+po]
                    if (j == CO // PO - 1) and (k == 9 - 1) and (pi == PI - 1) and (po == PO - 1):
                        B_str = B_str + str(temp) + "}"
                    elif (k == 9 - 1) and (pi == PI - 1) and (po == PO - 1):
                        B_str = B_str + str(temp) + "},"
                    else:
                        B_str = B_str + str(temp) + ","
        if (i == CI // PI - 1) and (j == CO // PO - 1):
            B_str = B_str + "}"
        elif j == CO // PO - 1:
            B_str = B_str + "},"
B_str = B_str + "};\n"

for i in range(HO * WO):
    CRef_str = CRef_str + "{"
    for j in range(CO):
        if (j == CO - 1) and (i == HO * WO - 1):
            CRef_str = CRef_str + str(CRef[i][j]) + "}"
        elif j == CO - 1:
            CRef_str = CRef_str + str(CRef[i][j]) + "},"
        else:
            CRef_str = CRef_str + str(CRef[i][j]) + ","
CRef_str = CRef_str + "};\n"

print(A_str)
print(B_str)
print(CRef_str)

head_str = "#include \"MatrixMul.h\"\n"

print(head_str)

infile = open('tb_var.h', 'w')
infile.write(head_str)
infile.write(A_str)
infile.write(B_str)
infile.write(CRef_str)
infile.close()
