#include "MatrixMul.h"

// #define TEST

#ifdef TEST
#include <stdio.h>
#include <iostream>
using namespace std;
#endif


void VecDotDq_DSPx2(TIN A_0, TIN A_1, TW B_0, TW B_1, TG S, ap_uint<OW> temp[4])
{

	int pi, i, j;

	for (i = 0; i < 4; i++)
	{
#pragma HLS UNROLL
		temp[i] = 0;
	}

	ap_uint<IW * PI> temp_in_[2];
	temp_in_[0] = A_0;
	temp_in_[1] = A_1;

	// Compute in one kernel spatially
	for (pi = 0; pi < PI; pi++)
	{
#pragma HLS UNROLL

		ap_uint<42> temp_result[2];
		ap_uint<IW + WW> temp_result_[4];
		ap_uint<WW> temp_w_0 = B_0(WW * (pi + 1) - 1, WW * pi);
		ap_uint<WW> temp_w_1 = B_1(WW * (pi + 1) - 1, WW * pi);
		ap_uint<27> temp_w_0_expand = temp_w_0;
		ap_uint<27> temp_w_1_expand = temp_w_1;
		ap_uint<27> temp_w_add = (temp_w_0_expand << 18) + temp_w_1_expand;

		ap_uint<GW> temp_s = S(GW * (pi + 1) - 1, GW * pi);

		for (i = 0; i < 2; i++)
		{
#pragma HLS UNROLL
			ap_uint<IW> temp_in = temp_in_[i](IW * (pi + 1) - 1, IW * pi);
			ap_uint<15> temp_in_expand = temp_in;
			temp_result[i] = temp_in_expand * temp_w_add;
			for (j = 0; j < 2; j++)
			{
#pragma HLS UNROLL
				temp[2 * j + i] += temp_result[i](IW + WW - 1 + 18 * (1 - j), 18 * (1 - j)) >> temp_s;
			}
		}
	}
}


void VecDotDq_DSPx4(TIN A_0, TIN A_1, TW B_0, TW B_1, TG S, ap_uint<OW> temp[4])
{

	int pi, i, j;

	for (i = 0; i < 4; i++)
	{
#pragma HLS UNROLL
		temp[i] = 0;
	}

	// Compute in one kernel spatially
	for (pi = 0; pi < PI; pi++)
	{
#pragma HLS UNROLL

		ap_uint<42> temp_result;

		ap_uint<WW> temp_w_0 = B_0(WW * (pi + 1) - 1, WW * pi);
		ap_uint<WW> temp_w_1 = B_1(WW * (pi + 1) - 1, WW * pi);
		ap_uint<27> temp_w_0_expand = temp_w_0;
		ap_uint<27> temp_w_1_expand = temp_w_1;
		ap_uint<27> temp_w_add = (temp_w_1_expand << 20) + temp_w_0_expand;

		ap_uint<IW> temp_in_0 = A_0(IW * (pi + 1) - 1, IW * pi);
		ap_uint<IW> temp_in_1 = A_1(IW * (pi + 1) - 1, IW * pi);
		ap_uint<15> temp_in_expand_0 = temp_in_0;
		ap_uint<15> temp_in_expand_1 = temp_in_1;
		ap_uint<15> temp_in_add = (temp_in_expand_1 << 10) + temp_in_expand_0;

		ap_uint<GW> temp_s = S(GW * (pi + 1) - 1, GW * pi);

		temp_result = temp_in_add * temp_w_add;

#ifdef TEST
	cout << "W0 " <<  temp_w_0_expand << " W1 " << temp_w_1_expand << " W " << temp_w_add << "\n";
	cout << "IN0 " <<  temp_in_expand_0 << " IN1 " << temp_in_expand_1 << " IN " << temp_in_add << "\n";
	cout << "temp_result " <<  temp_result << " 0 " << temp_result(IW + WW - 1, 0) << "\n";
#endif

		// w_index in_index
		for (i = 0; i < 4; i++)
		{
#pragma HLS UNROLL
			temp[i] += temp_result(IW + WW - 1 + i * 10, i * 10) >> temp_s;
		}

	}
}

void MatrixMulx4(hls::stream<TIN2 >& A, TW B_TW[PO][CO/PI][CI * 3 * 3/PO], TG S_TW[CO/PI], hls::stream<TOUT2>& C)
{
#pragma HLS ARRAY_PARTITION variable=B_TW complete dim=0

	TIN2 win_in_array[CI / PI];
#pragma HLS ARRAY_PARTITION variable=win_in_array complete dim=0

	ap_uint<OW> out_buffer[PO * 2];
#pragma HLS ARRAY_PARTITION variable=out_buffer cyclic factor=4 dim=0

	TOUT2 out;

	int oc;
	int pi, po;
	int j, k, ic_block;

	ap_uint<OW> temp[4];

	// GEMM Blocks in C
	for (j = 0; j < HO * WO / 2; j++)
	{
		// Buffer A in one spatial element
loop_copy_in:
		for (ic_block = 0; ic_block < CO/PI; ic_block++)
		{
#pragma HLS UNROLL

			win_in_array[ic_block] = A.read();
		}
loop_oc:
		for (oc = 0; oc < CI * 3 * 3/PO; oc++)
		{
loop_clear_out_buffer:
			for (po = 0; po < PO * 2; po++)
			{
				out_buffer[po] = 0;
			}
			// Compute one Block in C (ic_block controls windows sliding)
loop_ic_block:
			for (ic_block = 0; ic_block < CO/PI; ic_block++)
			{
#pragma HLS PIPELINE
loop_po:
				for (po = 0; po < PO/2; po++)
				{
#pragma HLS UNROLL
loop_k:
					TW win_w_ele_0 = B_TW[2 * po][ic_block][oc];
					TW win_w_ele_1 = B_TW[2 * po + 1][ic_block][oc];
					TIN2 win_in_ele = win_in_array[ic_block];
					TIN win_in_ele_0 = win_in_ele(PI * IW - 1, 0);
					TIN win_in_ele_1 = win_in_ele(2 * PI * IW - 1, PI * IW);
					TG win_s_ele = S_TW[ic_block];
					VecDotDq_DSPx2(win_in_ele_0, win_in_ele_1, win_w_ele_0, win_w_ele_1, win_s_ele, temp);

					out_buffer[2 * po] += temp[0];
					out_buffer[2 * po + 1] += temp[2];

					out_buffer[PO + 2 * po] += temp[1];
					out_buffer[PO + 2 * po + 1] += temp[3];
				}

//#ifdef TEST
//				for (po = 0; po < PO; po++)
//				{
//					cout << "out_buffer " << po << ": " << out_buffer[po] << "\n";
//				}
//#endif
			} // ic_block
loop_out_pack:
			for (po = 0; po < PO; po++)
			{
#pragma HLS UNROLL
				out(OW * (po + 1) - 1, OW * po) = out_buffer[po];
				out(OW * (PO + po + 1) - 1, OW * (PO + po)) = out_buffer[PO + po];
			}
			C.write(out);
		} // oc
	} // j
}


void top(ap_uint<IW> A[HO * WO][CO / PI][PI], ap_uint<WW> B[CO / PI][CI * 3 * 3 / PO][PI * PO], ap_uint<GW> S[CO/PI][PI], ap_uint<OW> C[HO * WO][CI * 3 * 3])
{

	hls::stream<TIN2 > A_stream;
#pragma HLS STREAM variable=A_stream dim=1

	hls::stream<TOUT2 > C_stream;
#pragma HLS ARRAY_MAP variable=C_stream horizontal

	TW B_TW[PO][CO/PI][CI * 3 * 3/PO];
#pragma HLS ARRAY_PARTITION variable=B_TW complete dim=0

	TG S_TW[CO/PI];

	TIN2 A_stream_temp;

	int i, j, k, pi, po;

	for (i = 0; i < CO/PI; i++)
	{
		for (pi = 0; pi < PI; pi++)
		{
			S_TW[i](GW * (pi + 1) - 1, GW * pi) = S[i][pi];
		}
	}

	for (j = 0; j < HO * WO / 2; j++)
	{
		for (i = 0; i < CO/PI; i++)
		{
			for (pi = 0; pi < PI; pi++)
			{
				// Order: k, pi
				A_stream_temp(IW * (pi + 1) - 1, IW * pi) = A[2 * j][i][pi];
				A_stream_temp(IW * (PI + pi + 1) - 1, IW * (PI + pi)) = A[2 * j + 1][i][pi];
			}
			A_stream.write(A_stream_temp);
		}
	}

	for (i = 0; i < CO/PI; i++)
	{
		for (j = 0; j < CI * 3 * 3/PO; j++)
		{
			for (po = 0; po < PO; po++)
			{
				for (pi = 0; pi < PI; pi++)
				{
					// Order: po, k, pi
					B_TW[po][i][j](WW * (pi + 1) - 1, WW * pi) = B[i][j][PI * po + pi];
				}
			}
		}
	}

	MatrixMulx4(A_stream, B_TW, S_TW, C_stream);

	TOUT2 C_buffer;
	for (i = 0; i < HO * WO/2; i++)
	{
		for (j = 0; j < CI * 3 * 3/PO; j++)
		{
			C_buffer = C_stream.read();
			for (po = 0; po < PO; po++)
			{
				C[2 * i][PO * j + po] = C_buffer(OW * (po + 1) - 1, OW * po);
				C[2 * i + 1][PO * j + po] = C_buffer(OW * (PO + po + 1) - 1, OW * (PO + po));
//#ifdef TEST
//				cout << "C[i][PO * j + po]: " << C[i][PO * j + po] << "\n";
//#endif
			}
		}
	}

}
