/*
 * BitMatrix.cpp
 *
 */

#include <mpirxx.h>

#include "BitMatrix.h"
#include "Tools/random.h"
#include "Tools/BitVector.h"
#include "Tools/intrinsics.h"
#include "Math/Square.h"

union matrix16x8
{
    __m128i whole;
    octet rows[16];

    matrix16x8() : whole(_mm_setzero_si128()) {}
    matrix16x8(__m128i x) { whole = x; }

    bool get_bit(int x, int y)
    { return (rows[x] >> y) & 1; }

    void input(square128& input, int x, int y);
    void transpose(square128& output, int x, int y);
};

class square16
{
public:
    // 16x16 in two halves, 128 bits each
    matrix16x8 halves[2];

    bool get_bit(int x, int y)
    { return halves[y/8].get_bit(x, y % 8); }

    void input(square128& output, int x, int y);
    void transpose(square128& output, int x, int y);

    void check_transpose(square16& dual);
    void print();
};

#ifdef __clang__
#define UNROLL_LOOPS
#else
#define UNROLL_LOOPS __attribute__((optimize("unroll-loops")))
#endif

UNROLL_LOOPS
inline void matrix16x8::input(square128& input, int x, int y)
{
    for (int l = 0; l < 16; l++)
        rows[l] = input.bytes[16*x+l][y];
}

UNROLL_LOOPS
inline void square16::input(square128& input, int x, int y)
{
    for (int i = 0; i < 2; i++)
        halves[i].input(input, x, 2 * y + i);
}

UNROLL_LOOPS
inline void matrix16x8::transpose(square128& output, int x, int y)
{
    for (int j = 0; j < 8; j++)
    {
        int row = _mm_movemask_epi8(whole);
        whole = _mm_slli_epi64(whole, 1);

        // _mm_movemask_epi8 uses most significant bit, hence +7-j
        output.doublebytes[8*x+7-j][y] = row;
    }
}

UNROLL_LOOPS
inline void square16::transpose(square128& output, int x, int y)
{
    for (int i = 0; i < 2; i++)
        halves[i].transpose(output, 2 * x + i, y);
}

#ifdef __AVX2__
union matrix32x8
{
    __m256i whole;
    octet rows[32];

    matrix32x8() : whole(_mm256_setzero_si256()) {}
    matrix32x8(__m256i x) { whole = x; }

    void input(square128& input, int x, int y);
    void transpose(square128& output, int x, int y);
};

class square32
{
public:
    matrix32x8 quarters[4];

    void input(square128& input, int x, int y);
    void transpose(square128& output, int x, int y);
};

UNROLL_LOOPS
inline void matrix32x8::input(square128& input, int x, int y)
{
    for (int l = 0; l < 32; l++)
        rows[l] = input.bytes[32*x+l][y];
}

UNROLL_LOOPS
inline void square32::input(square128& input, int x, int y)
{
    for (int i = 0; i < 4; i++)
        quarters[i].input(input, x, 4 * y + i);
}

UNROLL_LOOPS
inline void matrix32x8::transpose(square128& output, int x, int y)
{
    for (int j = 0; j < 8; j++)
    {
        int row = _mm256_movemask_epi8(whole);
        whole = _mm256_slli_epi64(whole, 1);

        // _mm_movemask_epi8 uses most significant bit, hence +7-j
        output.words[8*x+7-j][y] = row;
    }
}

UNROLL_LOOPS
inline void square32::transpose(square128& output, int x, int y)
{
    for (int i = 0; i < 4; i++)
        quarters[i].transpose(output, 4 * x + i, y);
}
#endif

#ifdef __AVX2__
typedef square32 subsquare;
#define N_SUBSQUARES 4
#else
typedef square16 subsquare;
#define N_SUBSQUARES 8
#endif

// hypercube permutation
#ifndef __AVX2__
const int perm[] = { 0, 8, 4, 0xc, 2, 0xa, 6, 0xe, 1, 9, 5, 0xd, 3, 0xb, 7, 0xf };
#else
const int perm2[] = { 0, 4, 2, 6, 1, 5, 3, 7, 8, 0xc, 0xa, 0xe, 9, 0xd, 0xb, 0xf };
#endif

UNROLL_LOOPS
void square128::transpose()
{
#ifdef USE_SUBSQUARES
    for (int j = 0; j < N_SUBSQUARES; j++)
        for (int k = 0; k < j; k++)
        {
            subsquare a, b;
            a.input(*this, k, j);
            b.input(*this, j, k);
            a.transpose(*this, j, k);
            b.transpose(*this, k, j);
        }

    for (int j = 0; j < N_SUBSQUARES; j++)
    {
        subsquare a;
        a.input(*this, j, j);
        a.transpose(*this, j, j);
    }
#else
#define EIGHTTOSIXTYFOUR X(8) X(16) X(32) X(64)
#define X(I) { \
        const int J = I / 4; \
        for (int i = 0; i < 16 / J; i++) \
        { \
            for (int j = 0; j < J / 2; j++) \
            { \
                int a = base + J * i + j; \
                int b = a + J/2; \
                __m128i tmp = _mm_unpacklo_epi##I(rows[a], rows[b]); \
                rows[b] = _mm_unpackhi_epi##I(rows[a], rows[b]); \
                rows[a] = tmp; \
            } \
        } \
    }
#ifdef __AVX2__
#define Z(I) { \
        const int J = I / 8; \
        for (int i = 0; i < 16 / J; i++) \
        { \
            for (int j = 0; j < J / 2; j++) \
            { \
                int a = base + J * i + j; \
                int b = a + J/2; \
                __m256i tmp = _mm256_unpacklo_epi##I(doublerows[a], doublerows[b]); \
                doublerows[b] = _mm256_unpackhi_epi##I(doublerows[a], doublerows[b]); \
                doublerows[a] = tmp; \
            } \
        } \
    }

    square128 tmp;
    for (int k = 0; k < 4; k++)
    {
        int base = k * 16 * 2;
        X(8)
        base += 16;
        X(8)
        base = k * 16;
        Z(16) Z(32) Z(64)
        for (int i = 0; i < 8; i++)
        {
            int a = base + i;
            int b = a + 8;
            __m128i tmp = rows[2 * b];
            rows[2 * b] = rows[2 * a + 1];
            rows[2 * a + 1] = tmp;
        }
        for (int i = 0; i < 16; i++)
        {
            int j = perm2[i];
            tmp.doublerows[base + i] = doublerows[base + j];
        }
    }

    for (int i = 0; i < 16; i++)
    {
        for (int k = 0; k < 4; k++)
            matrix32x8(tmp.doublerows[k * 16 + i]).transpose(*this, i, k);
    }
#else // __AVX2__
    square128 tmp;
    for (int k = 0; k < 8; k++)
    {
        int base = k * 16;
        EIGHTTOSIXTYFOUR
        for (int i = 0; i < 16; i++)
        {
            int j = perm[i];
            tmp.rows[base + i] = rows[base + j];
        }
    }

    for (int i = 0; i < 16; i++)
    {
        for (int k = 0; k < 8; k++)
            matrix16x8(tmp.rows[k * 16 + i]).transpose(*this, i, k);
    }
#endif // __AVX2__
#endif // __USE_SUBSQUARES__
}

void square128::randomize(PRNG& G)
{
    G.get_octets((octet*)&rows, sizeof(rows));
}

void square128::randomize(int row, PRNG& G)
{
    rows[row] = G.get_doubleword();
}


void square128::conditional_add(BitVector& conditions, square128& other, int offset)
{
    for (int i = 0; i < 128; i++)
        if (conditions.get_bit(128 * offset + i))
            rows[i] ^= other.rows[i];
}

template <>
void Square<gf2n_long>::to(gf2n_long& result, false_type)
{
    int128 high, low;
    for (int i = 0; i < 128; i++)
    {
        low ^= rows[i].get() << i;
        high ^= rows[i].get() >> (128 - i);
    }
    result.reduce(high, low);
}

void square128::check_transpose(square128& dual, int i, int k)
{
    for (int j = 0; j < 16; j++)
        for (int l = 0; l < 16; l++)
            if (get_bit(16 * i + j, 16 * k + l) != dual.get_bit(16 * k + l, 16 * i + j))
            {
                cout << "Error in 16x16 square (" << i << "," << k << ")" << endl;
                print(i, k);
                cout << "dual" << endl;
                dual.print(i, k);
                exit(1);
            }
}

void square128::check_transpose(square128& dual)
{
    for (int i = 0; i < 8; i++)
        for (int k = 0; k < 8; k++)
            check_transpose(dual, i, k);
}

void square16::print()
{
    for (int i = 0; i < 2; i++)
    {
        for (int j = 0; j < 8; j++)
        {
            for (int k = 0; k < 2; k++)
            {
                for (int l = 0; l < 8; l++)
                    cout << halves[k].get_bit(8 * i + j, l);
                cout << " ";
            }
            cout << endl;
        }
        cout << endl;
    }
}

void square128::print(int i, int k)
{
    square16 a;
    a.input(*this, i, k);
    a.print();
}

void square128::print()
{
    for (int i = 0; i < 128; i++)
    {
        for (int j = 0; j < 128; j++)
            cout << get_bit(i, j);
        cout << endl;
    }
}

void square128::print_octets()
{
    for (int i = 0; i < 128; i++)
    {
        for (int j = 0; j < 16; j++)
            cout << hex << (int)bytes[i][j] << " ";
        cout << endl;
    }
    cout << dec;
}

void square128::print_doublerows()
{
    for (int i = 0; i < 64; i++)
    {
        for (int j = 0; j < 32; j++)
        {
            cout.width(2);
            cout << hex << (int)bytes[2*i+j/16][j%16] << " ";
        }
        cout << endl;
    }
    cout << dec;
}

void square128::set_zero()
{
    for (int i = 0; i < 128; i++)
        rows[i] = _mm_setzero_si128();
}

square128& square128::operator^=(square128& other)
{
    for (int i = 0; i < 128; i++)
        rows[i] ^= other.rows[i];
    return *this;
}

square128& square128::add(square128& other)
{
    return *this ^= other;
}

square128& square128::sub(square128& other)
{
    return *this ^= other;
}

square128& square128::rsub(square128& other)
{
    return *this ^= other;
}

square128& square128::operator^=(const __m128i* other)
{
    __m128i value = _mm_loadu_si128(other);
    for (int i = 0; i < 128; i++)
        rows[i] ^= value;
    return *this;
}

square128& square128::sub(const __m128i* other)
{
    return *this ^= other;
}

square128& square128::operator^=(BitVector& other)
{
    return *this ^= (__m128i*)other.get_ptr();
}

bool square128::operator==(square128& other)
{
    for (int i = 0; i < 128; i++)
    {
        if (int128(rows[i]) != other.rows[i])
            return false;
    }
    return true;
}

void square128::pack(octetStream& o) const
{
    o.append((octet*)this->bytes, sizeof(bytes));
}

void square128::unpack(octetStream &o)
{
    o.consume((octet*)this->bytes, sizeof(bytes));
}


BitMatrix::BitMatrix(int length)
{
    resize(length);
}

void BitMatrix::resize(int length)
{
    if (length % 128 != 0)
        throw invalid_length(
                to_string(length) + " not divisible by "
                        + to_string(128));
    squares.resize(length / 128);
}

void BitMatrix::transpose()
{
    for (size_t i = 0; i < squares.size(); i++)
        squares[i].transpose();
}

void BitMatrix::check_transpose(BitMatrix& dual)
{
    for (size_t i = 0; i < squares.size(); i++)
    {
        for (int j = 0; j < 128; j++)
            for (int k = 0; k < 128; k++)
                if (squares[i].get_bit(j, k) != dual.squares[i].get_bit(k, j))
                {
                    cout << "First error in square " << i << " row " << j
                            << " column " << k << endl;
                    squares[i].print(i / 8, j / 8);
                    dual.squares[i].print(i / 8, j / 8);
                    return;
                }
    }
    cout << "No errors in transpose" << endl;
}

void BitMatrix::vertical_to(vector<BitVector>& output)
{
    int n = 128 * squares.size();
    output.resize(n);
    for (int i = 0; i < n; i++)
    {
        output[i].resize(128);
        output[i].set_int128(0, squares[i / 128].rows[i % 128]);
    }
}

template <>
void Slice<BitMatrix>::transpose()
{
    for (size_t i = start; i < end; i++)
        bm.squares[i].transpose();
}
