/*
 * square64.cpp
 *
 */

#include "square64.h"
#include "Tools/cpu_support.h"
#include <stdexcept>
#include <iostream>
#include <assert.h>
using namespace std;

union matrix32x8
{
    __m256i whole;
    octet rows[32];

    matrix32x8(const __m256i& x = _mm256_setzero_si256()) : whole(x) {}

    matrix32x8(square64& input, int x, int y)
    {
        for (int l = 0; l < 32; l++)
            rows[l] = input.bytes[32*x+l][y];
    }

    void transpose(square64& output, int x, int y)
    {
#if defined(__AVX2__) || !defined(__x86_64__)
        if (cpu_has_avx2())
        {
            for (int j = 0; j < 8; j++)
            {
                int row = _mm256_movemask_epi8(whole);
                whole = _mm256_slli_epi64(whole, 1);

                // _mm_movemask_epi8 uses most significant bit, hence +7-j
                output.halfrows[8*x+7-j][y] = row;
            }
        }
        else
#endif
        {
            (void) output, (void) x, (void) y;
            throw runtime_error("need AVX2 support");
        }
    }
};

#ifdef DEBUG_TRANS
ostream& operator<<(ostream& os, const __m256i& x)
{
    for (int i = 0; i < 4; i++)
        os << hex << " " << ((long*)&x)[i];
    os << dec;
    return os;
}
#endif


#define ZIP_CASE(I, LOWS, HIGHS, A, B) \
case I: \
    LOWS = _mm256_unpacklo_epi##I(A, B); \
    HIGHS = _mm256_unpackhi_epi##I(A, B); \
    break;

void zip(int chunk_size, __m256i& lows, __m256i& highs,
        const __m256i& a, const __m256i& b)
{
#if defined(__AVX2__) || !defined(__x86_64__)
    if (cpu_has_avx2())
    {
        switch (chunk_size)
        {
        ZIP_CASE(8, lows, highs, a, b);
        ZIP_CASE(16, lows, highs, a, b);
        ZIP_CASE(32, lows, highs, a, b);
        ZIP_CASE(64, lows, highs, a, b);
        case 128:
            lows = a;
            highs = b;
            swap(((__m128i*)&lows)[1], ((__m128i*)&highs)[0]);
            break;
        default:
            throw invalid_argument("not supported");
        }
    }
    else
#endif
    {
        (void) chunk_size, (void) lows, (void) highs, (void) a, (void) b;
        throw runtime_error("need AVX2 support");
    }
}

void square64::transpose(int n_rows, int n_cols)
{
#ifdef DEBUG_TRANS
    cout << "transpose" << endl;
    print();
#endif

    assert(n_rows <= 64);
    assert(n_cols <= 64);

    square64 tmp = *this;
    *this = {};

    for (int k = 0; k < DIV_CEIL(n_rows, 32); k++)
    {
        __m256i x[8], lows[4], highs[4];
        memcpy(x, &tmp.quadrows[8 * k], sizeof(x));
#ifdef DEBUG_TRANS
        for (int j = 0; j < 8; j++)
            if (not _mm256_testz_si256(x[j], x[j]))
            {
                cout << "transpose k " << k << " j " << j << ": ";
                for (int i = 0; i < 4; i++)
                    cout << hex << " " << ((long*)&x[j])[i];
                cout << dec << endl;
            }
#endif
        for (int chunk_size = 128; chunk_size >= 64; chunk_size /= 2)
        {
            for (int j = 0; j < 4; j ++)
            {
                int a, b;
                if (chunk_size > 64)
                {
                    a = j;
                    b = a + 4;
                }
                else if (chunk_size == 64)
                {
                    a = j / 2 * 2 + j;
                    b = a + 2;
                }
                else
                {
                    a = 2 * j;
                    b = a + 1;
                }
                zip(chunk_size, lows[j], highs[j], x[a], x[b]);
            }
            memcpy(x, lows, sizeof(lows));
            memcpy(&x[4], highs, sizeof(highs));
#ifdef DEBUG_TRANS
            for (int j = 0; j < 8; j++)
                if (not _mm256_testz_si256(x[j], x[j]))
                {
                    cout << "transpose k " << k << " chunk " << chunk_size
                            << " j " << j << ": ";
                    for (int i = 0; i < 4; i++)
                        cout << hex << " " << ((long*)&x[j])[i];
                    cout << dec << endl;
                }
#endif
        }
        for (int chunk_size = 8; chunk_size < 128; chunk_size *= 2)
        {
            for (int j = 0; j < 4; j ++)
            {
                int a = j / 2 * 2 + j;
                int b = a + 2;
                if (chunk_size == 8)
                {
                    a = j;
                    b = j + 4;
                }
                if (chunk_size == 64)
                {
                    a = 2 * j;
                    b = a + 1;
                }
                if (chunk_size == 32)
                {
                    a = 2 * j;
                    b = a + 1;
                }
                zip(chunk_size, lows[j], highs[j], x[a], x[b]);
            }

            memcpy(x, lows, sizeof(lows));
            memcpy(&x[4], highs, sizeof(highs));
#ifdef DEBUG_TRANS
            for (int j = 0; j < 8; j++)
                if (not _mm256_testz_si256(x[j], x[j]))
                {
                    cout << "transpose k " << k << " chunk " << chunk_size
                            << " j " << j << ": ";
                    for (int i = 0; i < 4; i++)
                        cout << hex << " " << ((long*)&x[j])[i];
                    cout << dec << endl;
                }
#endif
        }

        int perm[] = { 0, 4, 2, 6, 1, 5, 3, 7 };
        for (int i = 0; i < DIV_CEIL(n_cols, 8); i++)
        {
            matrix32x8(x[perm[i]]).transpose(*this, i, k);
        }
    }
#ifdef DEBUG_TRANS
    cout << "after transpose" << endl;
    print();
#endif
}

bool square64::operator !=(const square64& other)
{
    for (int i = 0; i < 64; i++)
        if (rows[i] != other.rows[i])
            return false;
    return true;
}


void square64::print()
{
    for (int i = 0; i < 64; i++)
    {
        for (int j = 0; j < 64; j++)
            cout << get_bit(i, j);
        cout << endl;
    }
    cout << flush;
}
