#include "MatrixMul.h"

// #define TEST

#ifdef TEST
#include <stdio.h>
#include <iostream>
using namespace std;
#endif

void MatrixMul(hls::stream<BundleT<9, TIN> >& A, BundleT<9 * PO, TW> B[CI/PI][CO/PO], hls::stream<TOUT>& C)
{	
	typedef BundleT<9, TIN> TIN_PI;
	typedef BundleT<9 * PO, TW> TW_PI;

	TIN_PI win_in_array[CI/PI];

	TIN_PI win_in;
#pragma HLS ARRAY_PARTITION variable=win_w.data complete dim=0
	TW_PI win_w;
#pragma HLS ARRAY_PARTITION variable=win_w.data complete dim=0

	ap_int<OW> out_buffer[PO];
#pragma HLS ARRAY_PARTITION variable=out_buffer cyclic factor=4 dim=1

	ap_int<OW> out_temp[PO];
#pragma HLS ARRAY_PARTITION variable=out_temp cyclic factor=4 dim=1

	TOUT out;

	int oc;
	int pi, po;
	int j, k, ic_block;

	ap_int<OW> temp;
	ap_int<WW> temp_w;
	ap_int<IW> temp_in;

	// GEMM Blocks in C
	for (j = 0; j < HO * WO; j++)
	{
		// Buffer A in one spatial element
loop_copy_in:
		for (ic_block = 0; ic_block < CI/PI; ic_block++)
		{
#pragma HLS UNROLL
			win_in_array[ic_block] = A.read();
		}
loop_oc:
		for (oc = 0; oc < CO/PO; oc++)
		{
#pragma HLS PIPELINE
loop_clear_out_buffer:
			for (po = 0; po < PO; po++)
			{
				out_buffer[po] = 0;
			}
			// Compute one Block in C (ic_block controls windows sliding)
loop_ic_block:
			for (ic_block = 0; ic_block < CI/PI; ic_block++)
			{
				win_in = win_in_array[ic_block];
				win_w = B[ic_block][oc];
loop_po:
				for (po = 0; po < PO; po++)
				{
#pragma HLS UNROLL
					temp = 0;
					// Compute in one kernel spatially
					for (k = 0; k < 9; k++)
					{
#pragma HLS UNROLL
						for (pi = 0; pi < PI; pi++)
						{
#pragma HLS UNROLL
							temp_w = win_w.data[po * 9 + k](WW * (pi + 1) - 1, WW * pi);
							temp_in = win_in.data[k](IW * (pi + 1) - 1, IW * pi);
							temp += temp_in * temp_w;
#ifdef TEST
							cout << "w: " << temp_w << " " << "temp_in: " << temp_in << "temp: " << temp << "\n";
#endif
						}
					}
					out_temp[po] = temp;
#ifdef TEST
					cout << "temp: " << out_temp[po] << "\n";
#endif
				} // po
				for (po = 0; po < PO; po++)
				{
#pragma HLS UNROLL
					out_buffer[po] += out_temp[po];
				}
#ifdef TEST
				for (po = 0; po < PO; po++)
				{
					cout << "out_buffer " << po << ": " << out_buffer[po] << "\n";
				}
#endif
			} // ic_block
			for (po = 0; po < PO; po++)
			{
#pragma HLS UNROLL
				out(OW * (po + 1) - 1, OW * po) = out_buffer[po];
			}
			C.write(out);
		} // oc
	} // j
}

void top(ap_int<IW> A[HO * WO][CI / PI][PI * 9], ap_int<WW> B[CI / PI][CO / PO][PI * PO * 9], ap_int<OW> C[HO * WO][CO])
{

	hls::stream<BundleT<9, TIN> > A_stream;
#pragma HLS STREAM variable=A_stream dim=1

	hls::stream<TOUT> C_stream;
#pragma HLS ARRAY_MAP variable=C_stream horizontal

	BundleT<9 * PO, TW> B_bundle[CI/PI][CO/PO];

	BundleT<9, TIN> A_stream_temp;

	int i, j, k, pi, po;

	for (j = 0; j < HO * WO; j++)
	{
		for (i = 0; i < CI/PI; i++)
		{
			for (k = 0; k < 9; k++)
			{
				for (pi = 0; pi < PI; pi++)
				{
					// Order: k, pi
					A_stream_temp.data[k](IW * (pi + 1) - 1, IW * pi) = A[j][i][PI * k + pi];
				}
			}
			A_stream.write(A_stream_temp);
		}
	}

	for (i = 0; i < CI/PI; i++)
	{
		for (j = 0; j < CO/PO; j++)
		{
			for (po = 0; po < PO; po++)
			{
				for (k = 0; k < 9; k++)
				{
					for (pi = 0; pi < PI; pi++)
					{
						// Order: po, k, pi
						B_bundle[i][j].data[po * 9 + k](WW * (pi + 1) - 1, WW * pi) = B[i][j][PI * 9 * po + PI * k + pi];
					}
				}
			}
		}
	}

	MatrixMul(A_stream, B_bundle, C_stream);

	TOUT C_buffer;
	for (i = 0; i < HO*WO; i++)
		{
			for (j = 0; j < CO/PO; j++)
			{
				C_buffer = C_stream.read();
				for (po = 0; po < PO; po++)
				{
					C[i][PO * j + po] = C_buffer(OW * (po + 1) - 1, OW * po);
#ifdef TEST
					cout << "C[i][PO * j + po]: " << C[i][PO * j + po] << "\n";
#endif
				}
			}
		}

}
