
#ifndef EFANNA2E_DISTANCE_H
#define EFANNA2E_DISTANCE_H

#include <immintrin.h>
#include <x86intrin.h>
#include <math.h>
#include <iostream>

namespace efanna2e {
enum Metric { L2 = 0, INNER_PRODUCT = 1, FAST_L2 = 2, PQ = 3, COSINE = 4 };
class Distance {
   public:
    virtual float compare(const float *a, const float *b, unsigned length) const = 0;
    virtual ~Distance() {}
};

class DistanceCosine : public Distance {
   public:
    float compare(const float *a, const float *b, unsigned size) const {
         __m256 msum0 = _mm256_setzero_ps();
        __m256 msum1 = _mm256_setzero_ps();
        __m256 msum2 = _mm256_setzero_ps();

        // Calculate dot product and magnitudes
        for (uint32_t i = 0; i < size; i += 8) {
            __m256 mx = _mm256_loadu_ps(a + i);
            __m256 my = _mm256_loadu_ps(b + i);
            const __m256 a_m_b1 = _mm256_mul_ps(mx, my);
            msum0 = _mm256_add_ps(msum0, a_m_b1);

            const __m256 a_m_a = _mm256_mul_ps(mx, mx);
            msum1 = _mm256_add_ps(msum1, a_m_a);

            const __m256 b_m_b = _mm256_mul_ps(my, my);
            msum2 = _mm256_add_ps(msum2, b_m_b);
        }

        // Horizontal sum of all elements in msum0, msum1, and msum2
        __m128 vlow0 = _mm256_castps256_ps128(msum0);
        __m128 vhigh0 = _mm256_extractf128_ps(msum0, 1);
        vlow0 = _mm_add_ps(vlow0, vhigh0);
        __m128 shuf0 = _mm_movehdup_ps(vlow0);
        __m128 sums0 = _mm_add_ps(vlow0, shuf0);
        shuf0 = _mm_movehl_ps(shuf0, sums0);
        sums0 = _mm_add_ss(sums0, shuf0);
        float dot_product = _mm_cvtss_f32(sums0);

        __m128 vlow1 = _mm256_castps256_ps128(msum1);
        __m128 vhigh1 = _mm256_extractf128_ps(msum1, 1);

        vlow1 = _mm_add_ps(vlow1, vhigh1);
        __m128 shuf1 = _mm_movehdup_ps(vlow1);
        __m128 sums1 = _mm_add_ps(vlow1, shuf1);
        shuf1 = _mm_movehl_ps(shuf1, sums1);
        sums1 = _mm_add_ss(sums1, shuf1);
        float magnitude_a = _mm_cvtss_f32(sums1);

        __m128 vlow2 = _mm256_castps256_ps128(msum2);
        __m128 vhigh2 = _mm256_extractf128_ps(msum2, 1);
        vlow2 = _mm_add_ps(vlow2, vhigh2);
        __m128 shuf2 = _mm_movehdup_ps(vlow2);
        __m128 sums2 = _mm_add_ps(vlow2, shuf2);
        shuf2 = _mm_movehl_ps(shuf2, sums2);
        sums2 = _mm_add_ss(sums2, shuf2);
        float magnitude_b = _mm_cvtss_f32(sums2);

        // Calculate cosine distance
        float cosine_similarity = dot_product / (std::sqrt(magnitude_a) * std::sqrt(magnitude_b));
        return -cosine_similarity;
    }
};

class DistanceL2 : public Distance {
    static inline __m128 masked_read(int d, const float *x) {
        // assert(0 <= d && d < 4);
        __attribute__((__aligned__(16))) float buf[4] = {0, 0, 0, 0};
        switch (d) {
            case 3:
                buf[2] = x[2];
            case 2:
                buf[1] = x[1];
            case 1:
                buf[0] = x[0];
        }
        return _mm_load_ps(buf);
        // cannot use AVX2 _mm_mask_set1_epi32
    }

   public:
    float compare(const float *x, const float *y, unsigned d) const {
        // use avx 2
        d = 128;
        __m256 msum = _mm256_setzero_ps();

        // Process chunks of 8 floats
        while (d >= 8) {
            __m256 mx = _mm256_loadu_ps(x);
            x += 8;
            __m256 my = _mm256_loadu_ps(y);
            y += 8;
            const __m256 diff = _mm256_sub_ps(mx, my);
            msum = _mm256_add_ps(msum, _mm256_mul_ps(diff, diff));
            d -= 8;
        }

        // Sum the elements of msum
        __m128 msum_low = _mm256_extractf128_ps(msum, 0);
        __m128 msum_high = _mm256_extractf128_ps(msum, 1);
        msum_low = _mm_add_ps(msum_low, msum_high);

        // Process remaining elements
        float buf[8] __attribute__((aligned(32)));
        _mm_store_ps(buf, msum_low);
        float sum = buf[0] + buf[1] + buf[2] + buf[3];

        // Handle remaining elements that are less than 8
        for (unsigned i = 0; i < d; ++i) {
            float diff = x[i] - y[i];
            sum += diff * diff;
        }

        return sum;
    }
};

class DistanceInnerProduct : public Distance {
   public:
    static inline __m128 masked_read(int d, const float *x) {
        // assert(0 <= d && d < 4);
        __attribute__((__aligned__(16))) float buf[4] = {0, 0, 0, 0};
        switch (d) {
            case 3:
                buf[2] = x[2];
            case 2:
                buf[1] = x[1];
            case 1:
                buf[0] = x[0];
        }
        return _mm_load_ps(buf);
        // cannot use AVX2 _mm_mask_set1_epi32
    }
    float compare(const float *a, const float *b, unsigned size) const {
        __m256 msum0 = _mm256_setzero_ps();
        for (uint32_t i = 0; i < size; i += 8) {
            __m256 mx = _mm256_loadu_ps(a + i);
            __m256 my = _mm256_loadu_ps(b + i);
            const __m256 a_m_b1 = _mm256_mul_ps(mx, my);
            msum0 = _mm256_add_ps(msum0, a_m_b1);
        }
        float result[8];
        _mm256_storeu_ps(result, msum0);
        float sum = 0;
        for (int i = 0; i < 8; i++) {
            sum += result[i];
        }

        return -1.0 * sum;
    }
};
class DistanceFastL2 : public DistanceInnerProduct {
   public:
    float norm(const float *a, unsigned size) const {
        float result = 0;
#ifdef __GNUC__
#ifdef __AVX__
#define AVX_L2NORM(addr, dest, tmp) \
    tmp = _mm256_loadu_ps(addr);    \
    tmp = _mm256_mul_ps(tmp, tmp);  \
    dest = _mm256_add_ps(dest, tmp);

        __m256 sum;
        __m256 l0, l1;
        unsigned D = (size + 7) & ~7U;
        unsigned DR = D % 16;
        unsigned DD = D - DR;
        const float *l = a;
        const float *e_l = l + DD;
        float unpack[8] __attribute__((aligned(32))) = {0, 0, 0, 0, 0, 0, 0, 0};

        sum = _mm256_loadu_ps(unpack);
        if (DR) {
            AVX_L2NORM(e_l, sum, l0);
        }
        for (unsigned i = 0; i < DD; i += 16, l += 16) {
            AVX_L2NORM(l, sum, l0);
            AVX_L2NORM(l + 8, sum, l1);
        }
        _mm256_storeu_ps(unpack, sum);
        result = unpack[0] + unpack[1] + unpack[2] + unpack[3] + unpack[4] + unpack[5] + unpack[6] + unpack[7];
#else
#ifdef __SSE2__
#define SSE_L2NORM(addr, dest, tmp) \
    tmp = _mm128_loadu_ps(addr);    \
    tmp = _mm128_mul_ps(tmp, tmp);  \
    dest = _mm128_add_ps(dest, tmp);

        __m128 sum;
        __m128 l0, l1, l2, l3;
        unsigned D = (size + 3) & ~3U;
        unsigned DR = D % 16;
        unsigned DD = D - DR;
        const float *l = a;
        const float *e_l = l + DD;
        float unpack[4] __attribute__((aligned(16))) = {0, 0, 0, 0};

        sum = _mm_load_ps(unpack);
        switch (DR) {
            case 12:
                SSE_L2NORM(e_l + 8, sum, l2);
            case 8:
                SSE_L2NORM(e_l + 4, sum, l1);
            case 4:
                SSE_L2NORM(e_l, sum, l0);
            default:
                break;
        }
        for (unsigned i = 0; i < DD; i += 16, l += 16) {
            SSE_L2NORM(l, sum, l0);
            SSE_L2NORM(l + 4, sum, l1);
            SSE_L2NORM(l + 8, sum, l2);
            SSE_L2NORM(l + 12, sum, l3);
        }
        _mm_storeu_ps(unpack, sum);
        result += unpack[0] + unpack[1] + unpack[2] + unpack[3];
#else
        float dot0, dot1, dot2, dot3;
        const float *last = a + size;
        const float *unroll_group = last - 3;

        /* Process 4 items with each loop for efficiency. */
        while (a < unroll_group) {
            dot0 = a[0] * a[0];
            dot1 = a[1] * a[1];
            dot2 = a[2] * a[2];
            dot3 = a[3] * a[3];
            result += dot0 + dot1 + dot2 + dot3;
            a += 4;
        }
        /* Process last 0-3 pixels.  Not needed for standard vector lengths. */
        while (a < last) {
            result += (*a) * (*a);
            a++;
        }
#endif
#endif
#endif
        return result;
    }
    using DistanceInnerProduct::compare;
    float compare(const float *a, const float *b, float norm, unsigned size) const {  // not implement
        float result = -2 * DistanceInnerProduct::compare(a, b, size);
        result += norm;
        return result;
    }
};
}  // namespace efanna2e

#endif  // EFANNA2E_DISTANCE_H
