#include "immintrin.h"
#include <iostream>
#include <cinttypes>
__m256 relu6_back(__m256 input, __m256 grad) {
    __m256 zero_v = _mm256_set1_ps((float)(0));
    __m256 six_v = _mm256_set1_ps((float)(6));

    __m256 mask =  _mm256_and_ps(_mm256_cmp_ps(zero_v, input, 0x11), _mm256_cmp_ps(input, six_v, 0x11));
    return _mm256_and_ps(mask, grad);
}

int main(void) {
	
	float res[128];
	float* align = (float*) (     ((uint64_t) res) / 32 * 32   + 32      );
	__m256 srce = _mm256_set1_ps((float) 8.9);
	__m256 grad = _mm256_set1_ps((float) 100.0);

	_mm256_stream_ps(align, relu6_back(srce, grad));
	std::cout << align[0] << std::endl;

	return 0;
}