#include <jni.h>
#include <stdlib.h>
#include <string>
#include <vector>
#include <pthread.h>
#include <thread>
#include <android/bitmap.h>
#include <android/asset_manager.h>
#include <android/asset_manager_jni.h>
#include <opencv2/opencv.hpp>
#include "cnpy.h"
#include "threadpool.h"

using namespace std;
using namespace cv;


typedef int16_t lut_t;

const static size_t NUM_THREADS = 6, UPSCALE = 4, MID_CH = 8, LAYERS = 8;
static size_t width, height, out_width, out_height;


template <typename T>
inline T my_clamp(T var, T min, T max)
{
    return (var < min) ? min : (max < var) ? max
                                           : var;
}

void first_layer_forward(uint8_t *_in_img, int16_t *_out_img, lut_t lut[256][3][3])
{
    auto in_img = reinterpret_cast<uint8_t(*)[width][3]>(_in_img);
    auto out_img = reinterpret_cast<int16_t(*)[width][3]>(_out_img);
    memset(out_img, 0, height * width * 3 * sizeof(int16_t));

    lut_t lut_result[3][3][3];
    for (size_t row = 0; row < height; ++row)
    {
        for (size_t col = 0; col < width; ++col)
        {
            for (size_t c = 0; c < 3; ++c)
            {
                memcpy(lut_result[c], lut[in_img[row][col][c]], 9 * sizeof(lut_t));
                out_img[row][col][c] += 2032;
            }

            if (row == 0 || row == height - 1 || col == 0 || col == width - 1)
            {
                // corner
                if (row == 0 && col == 0)
                {
                    for (size_t c = 0; c < 3; ++c)
                    {
                        out_img[0][0][c] = 1;
                        out_img[row][col][c] += lut_result[c][1][1] + lut_result[c][1][2] + lut_result[c][2][1] + lut_result[c][2][2];
                        out_img[row][col + 1][c] += lut_result[c][1][2] + lut_result[c][2][2];
                        out_img[row + 1][col][c] += lut_result[c][2][1] + lut_result[c][2][2];
                        out_img[row + 1][col + 1][c] += lut_result[c][2][2];
                    }
                }
                else if (row == 0 && col == width - 1)
                {
                    for (size_t c = 0; c < 3; ++c)
                    {
                        out_img[row][col][c] += lut_result[c][1][0] + lut_result[c][1][1] + lut_result[c][2][0] + lut_result[c][2][1];
                        out_img[row][col - 1][c] += lut_result[c][1][0] + lut_result[c][2][0];
                        out_img[row + 1][col][c] += lut_result[c][2][0] + lut_result[c][2][1];
                        out_img[row + 1][col - 1][c] += lut_result[c][2][0];
                    }
                }
                else if (row == height - 1 && col == 0)
                {
                    for (size_t c = 0; c < 3; ++c)
                    {
                        out_img[row][col][c] += lut_result[c][0][1] + lut_result[c][0][2] + lut_result[c][1][1] + lut_result[c][1][2];
                        out_img[row][col + 1][c] += lut_result[c][0][2] + lut_result[c][1][2];
                        out_img[row - 1][col][c] += lut_result[c][0][1] + lut_result[c][0][2];
                        out_img[row - 1][col + 1][c] += lut_result[c][0][2];
                    }
                }
                else if (row == height - 1 && col == width - 1)
                {
                    for (size_t c = 0; c < 3; ++c)
                    {
                        out_img[row][col][c] += lut_result[c][0][0] + lut_result[c][0][1] + lut_result[c][1][0] + lut_result[c][1][1];
                        out_img[row][col - 1][c] += lut_result[c][0][0] + lut_result[c][1][0];
                        out_img[row - 1][col][c] += lut_result[c][0][0] + lut_result[c][0][1];
                        out_img[row - 1][col - 1][c] += lut_result[c][0][0];
                    }
                }
                    // edge
                else if (row == 0)
                {
                    for (size_t j = 0; j < 3; ++j)
                    {
                        for (size_t c = 0; c < 3; ++c)
                        {
                            out_img[row][col + j - 1][c] += lut_result[c][1][j] + lut_result[c][2][j];
                            out_img[row + 1][col + j - 1][c] += lut_result[c][2][j];
                        }
                    }
                }
                else if (row == height - 1)
                {
                    for (size_t j = 0; j < 3; ++j)
                    {
                        for (size_t c = 0; c < 3; ++c)
                        {
                            out_img[row][col + j - 1][c] += lut_result[c][1][j] + lut_result[c][0][j];
                            out_img[row - 1][col + j - 1][c] += lut_result[c][0][j];
                        }
                    }
                }
                else if (col == 0)
                {
                    for (size_t i = 0; i < 3; ++i)
                    {
                        for (size_t c = 0; c < 3; ++c)
                        {
                            out_img[row + i - 1][col][c] += lut_result[c][i][1] + lut_result[c][i][2];
                            out_img[row + i - 1][col + 1][c] += lut_result[c][i][2];
                        }
                    }
                }
                else if (col == width - 1)
                {
                    for (size_t i = 0; i < 3; ++i)
                    {
                        for (size_t c = 0; c < 3; ++c)
                        {
                            out_img[row + i - 1][col][c] += lut_result[c][i][1] + lut_result[c][i][0];
                            out_img[row + i - 1][col - 1][c] += lut_result[c][i][0];
                        }
                    }
                }
            }
                // center
            else
            {
                // --row;
                // --col;
                for (size_t i = 0; i < 3; ++i)
                {
                    for (size_t j = 0; j < 3; ++j)
                    {
                        for (size_t c = 0; c < 3; ++c)
                        {
                            out_img[row + i - 1][col + j - 1][c] += lut_result[c][2 - i][2 - j];
                        }
                    }
                }
            }
        }
    }
}

void mid_layers_forward(int16_t *_in_feats[MID_CH], int16_t *_out_img, lut_t lut[MID_CH][255][3][3])
{
//    int16_t (*in_feats[MID_CH])[width][3];
//    for (size_t i = 0; i < MID_CH; ++i)
//    {
//        in_feats[i] = reinterpret_cast<int16_t(*)[width][3]>(_in_feats[i]);
//    }
    auto out_img = reinterpret_cast<int16_t(*)[width][3]>(_out_img);
    memset(out_img, 0, height * width * 3 * sizeof(int16_t));

    lut_t lut_result[3][3][3];
    int16_t lut_index;
    for (size_t row = 0; row < height; ++row)
    {
        for (size_t col = 0; col < width; ++col)
        {
            // reset
            memset(lut_result, 0, sizeof(lut_result));
            for (size_t c = 0; c < 3; ++c)
            {
                out_img[row][col][c] += 2032;
            }

            // look up table
            for (size_t feat_ch = 0; feat_ch < MID_CH; ++feat_ch)
            {
                auto in_feat = reinterpret_cast<int16_t(*)[width][3]>(_in_feats[feat_ch]);
                for (size_t c = 0; c < 3; ++c)
                {
//                    lut_index = (in_feats[feat_ch][row][col][c] + 8) >> 4;
                    lut_index = (in_feat[row][col][c] + 8) >> 4;
                    lut_index = my_clamp(lut_index, (int16_t)0, (int16_t)254);
                    for (size_t i = 0; i < 3; ++i)
                    {
                        for (size_t j = 0; j < 3; ++j)
                        {
                            lut_result[c][i][j] += lut[feat_ch][lut_index][i][j];
                        }
                    }
                }
            }
            // write_down(_out_img, lut_result, row, col);

            if (row == 0 || row == height - 1 || col == 0 || col == width - 1)
            {
                // corner
                if (row == 0 && col == 0)
                {
                    for (size_t c = 0; c < 3; ++c)
                    {
                        out_img[0][0][c] = 1;
                        out_img[row][col][c] += lut_result[c][1][1] + lut_result[c][1][2] + lut_result[c][2][1] + lut_result[c][2][2];
                        out_img[row][col + 1][c] += lut_result[c][1][2] + lut_result[c][2][2];
                        out_img[row + 1][col][c] += lut_result[c][2][1] + lut_result[c][2][2];
                        out_img[row + 1][col + 1][c] += lut_result[c][2][2];
                    }
                }
                else if (row == 0 && col == width - 1)
                {
                    for (size_t c = 0; c < 3; ++c)
                    {
                        out_img[row][col][c] += lut_result[c][1][0] + lut_result[c][1][1] + lut_result[c][2][0] + lut_result[c][2][1];
                        out_img[row][col - 1][c] += lut_result[c][1][0] + lut_result[c][2][0];
                        out_img[row + 1][col][c] += lut_result[c][2][0] + lut_result[c][2][1];
                        out_img[row + 1][col - 1][c] += lut_result[c][2][0];
                    }
                }
                else if (row == height - 1 && col == 0)
                {
                    for (size_t c = 0; c < 3; ++c)
                    {
                        out_img[row][col][c] += lut_result[c][0][1] + lut_result[c][0][2] + lut_result[c][1][1] + lut_result[c][1][2];
                        out_img[row][col + 1][c] += lut_result[c][0][2] + lut_result[c][1][2];
                        out_img[row - 1][col][c] += lut_result[c][0][1] + lut_result[c][0][2];
                        out_img[row - 1][col + 1][c] += lut_result[c][0][2];
                    }
                }
                else if (row == height - 1 && col == width - 1)
                {
                    for (size_t c = 0; c < 3; ++c)
                    {
                        out_img[row][col][c] += lut_result[c][0][0] + lut_result[c][0][1] + lut_result[c][1][0] + lut_result[c][1][1];
                        out_img[row][col - 1][c] += lut_result[c][0][0] + lut_result[c][1][0];
                        out_img[row - 1][col][c] += lut_result[c][0][0] + lut_result[c][0][1];
                        out_img[row - 1][col - 1][c] += lut_result[c][0][0];
                    }
                }
                    // edge
                else if (row == 0)
                {
                    for (size_t j = 0; j < 3; ++j)
                    {
                        for (size_t c = 0; c < 3; ++c)
                        {
                            out_img[row][col + j - 1][c] += lut_result[c][1][j] + lut_result[c][2][j];
                            out_img[row + 1][col + j - 1][c] += lut_result[c][2][j];
                        }
                    }
                }
                else if (row == height - 1)
                {
                    for (size_t j = 0; j < 3; ++j)
                    {
                        for (size_t c = 0; c < 3; ++c)
                        {
                            out_img[row][col + j - 1][c] += lut_result[c][1][j] + lut_result[c][0][j];
                            out_img[row - 1][col + j - 1][c] += lut_result[c][0][j];
                        }
                    }
                }
                else if (col == 0)
                {
                    for (size_t i = 0; i < 3; ++i)
                    {
                        for (size_t c = 0; c < 3; ++c)
                        {
                            out_img[row + i - 1][col][c] += lut_result[c][i][1] + lut_result[c][i][2];
                            out_img[row + i - 1][col + 1][c] += lut_result[c][i][2];
                        }
                    }
                }
                else if (col == width - 1)
                {
                    for (size_t i = 0; i < 3; ++i)
                    {
                        for (size_t c = 0; c < 3; ++c)
                        {
                            out_img[row + i - 1][col][c] += lut_result[c][i][1] + lut_result[c][i][0];
                            out_img[row + i - 1][col - 1][c] += lut_result[c][i][0];
                        }
                    }
                }
            }
                // center
            else
            {
                // --row;
                // --col;
                for (size_t i = 0; i < 3; ++i)
                {
                    for (size_t j = 0; j < 3; ++j)
                    {
                        for (size_t c = 0; c < 3; ++c)
                        {
                            out_img[row + i - 1][col + j - 1][c] += lut_result[c][2 - i][2 - j];
                        }
                    }
                }
            }
        }
    }
}

void last_layer_forward(size_t region_index, int16_t *_in_feats[MID_CH], int16_t *_out_img, lut_t (*lut)[255][UPSCALE][UPSCALE])
{
//    int16_t (*in_feats[MID_CH])[width][3];
//    for (size_t i = 0; i < MID_CH; ++i)
//    {
//        auto tmp = reinterpret_cast<int16_t(*)[width][3]>(_in_feats[i]);
//        in_feats[i] = reinterpret_cast<int16_t(*)[width][3]>(_in_feats[i]);
//    }
    auto out_img = reinterpret_cast<int16_t(*)[out_width][3]>(_out_img);

    // zoning
    size_t num_part = height / NUM_THREADS;
    size_t row_start = region_index * num_part;
    size_t row_end = row_start + num_part;
    if (region_index == NUM_THREADS - 1)
    {
        row_end = height;
    }

    lut_t lut_result[3][UPSCALE][UPSCALE];
    int16_t lut_index;
    int16_t tmp_result;
    for (size_t row = row_start; row < row_end; ++row)
    {
        for (size_t col = 0; col < width; ++col)
        {
            // reset
            memset(lut_result, 0, sizeof(lut_result));

            // look up table
            for (size_t feat_ch = 0; feat_ch < MID_CH; ++feat_ch)
            {
                auto in_feat = reinterpret_cast<int16_t(*)[width][3]>(_in_feats[feat_ch]);
                for (size_t c = 0; c < 3; ++c)
                {
                    for (size_t i = 0; i < UPSCALE; ++i)
                    {
                        for (size_t j = 0; j < UPSCALE; ++j)
                        {
//                            lut_index = (in_feats[feat_ch][row][col][c] + 8) >> 4;
                            lut_index = (in_feat[row][col][c] + 8) >> 4;
                            lut_index = my_clamp(lut_index, (int16_t)0, (int16_t)254);
                            lut_result[c][i][j] += lut[feat_ch][lut_index][i][j];
                        }
                    }
                }
            }

            // output
            for (size_t c = 0; c < 3; ++c)
            {
                for (size_t i = 0; i < UPSCALE; ++i)
                {
                    for (size_t j = 0; j < UPSCALE; ++j)
                    {
                        tmp_result = lut_result[c][i][j];
                        tmp_result = (tmp_result + 2048) >> 4;
                        tmp_result = my_clamp(tmp_result, (int16_t)0, (int16_t)255);
                        out_img[(row << 2) + i][(col << 2) + j][c] = tmp_result;
                    }
                }
            }
        }
    }
}

void convert_inp(uint8_t *_dst, uint8_t *_src){
    auto src = reinterpret_cast<uint8_t(*)[width][4]>(_src);
    auto dst = reinterpret_cast<uint8_t(*)[width][3]>(_dst);
    for (size_t row = 0; row < height; row++){
        for (size_t col = 0; col < width; col++){
            dst[row][col][0] = src[row][col][0];
            dst[row][col][1] = src[row][col][1];
            dst[row][col][2] = src[row][col][2];
        }
    }
}

extern "C" JNIEXPORT int64_t JNICALL
Java_com_example_expandedconv_MainActivity_doSRLUT(
        JNIEnv *env,
        jobject thiz,
        jobject lr_bitmap,
        jobject sr_bitmap,
        jbyteArray first_lut_Java,
        jbyteArray mid_luts_Java,
        jbyteArray last_lut_Java) {

    // get lut
    jbyte *tmp_first_lut = env->GetByteArrayElements(first_lut_Java, JNI_FALSE);
    auto first_lut = reinterpret_cast<lut_t(*)[256][3][3]>(tmp_first_lut);

    jbyte *tmp_mid_luts = env->GetByteArrayElements(mid_luts_Java, JNI_FALSE);
    auto mid_luts = reinterpret_cast<lut_t(*)[LAYERS - 1][MID_CH][255][3][3]>(tmp_mid_luts);

    jbyte *tmp_last_lut = env->GetByteArrayElements(last_lut_Java, JNI_FALSE);
    auto last_lut = reinterpret_cast<lut_t(*)[255][UPSCALE][UPSCALE]>(tmp_last_lut);

    // init
    AndroidBitmapInfo info;
    void *lr_pixels;
    void *sr_pixels;

    AndroidBitmap_getInfo(env, lr_bitmap, &info);
    AndroidBitmap_lockPixels(env, lr_bitmap, &lr_pixels);
    AndroidBitmap_lockPixels(env, sr_bitmap, &sr_pixels);

    width = info.width;
    height = info.height;
    out_width = width * UPSCALE;
    out_height = height * UPSCALE;

    //pre-process
    Mat in_img_mat(height, width, CV_8UC3);
    auto in_img_ptr = reinterpret_cast<uint8_t*>(in_img_mat.data);
    convert_inp(in_img_ptr, (uint8_t*)lr_pixels);

    // start
    auto start = chrono::high_resolution_clock::now();

    // mid features
    int16_t *feats_in[MID_CH], *feats_out[MID_CH], *tmp_feat;
    Mat feats_in_mat[MID_CH], feats_out_mat[MID_CH];
    for (size_t ch = 0; ch < MID_CH; ++ch)
    {
        feats_in_mat[ch] = Mat::zeros(height, width, CV_16UC3);
        feats_out_mat[ch] = Mat::zeros(height, width, CV_16UC3);
        feats_in[ch] = reinterpret_cast<int16_t *>(feats_in_mat[ch].data);
        feats_out[ch] = reinterpret_cast<int16_t *>(feats_out_mat[ch].data);
    }

    threadpool th_pool{NUM_THREADS};
    future<void> f_funcs[MID_CH];

    // first layer
    for (size_t i = 0; i < MID_CH; ++i)
    {
        f_funcs[i] = th_pool.commit(first_layer_forward, in_img_ptr, feats_in[i], first_lut[i]);
    }
    for (size_t i = 0; i < MID_CH; ++i)
    {
        f_funcs[i].get();
    }

    // mid layers
    for (size_t layer = 0; layer < LAYERS - 1; ++layer)
    {
        for (size_t i = 0; i < MID_CH; ++i)
        {
            f_funcs[i] = th_pool.commit(mid_layers_forward, feats_in, feats_out[i], mid_luts[i][layer]);
        }
        for (size_t i = 0; i < MID_CH; ++i)
        {
            f_funcs[i].get();
        }

        // swap buffer
        for (size_t i = 0; i < MID_CH; ++i)
        {
            tmp_feat = feats_in[i];
            feats_in[i] = feats_out[i];
            feats_out[i] = tmp_feat;
        }
    }

    // output buffer
    Mat sr_img_mat = Mat::zeros(out_height, out_width, CV_16UC3);
    auto sr_img_ptr = reinterpret_cast<int16_t *>(sr_img_mat.data);

    // last layer
    for (size_t i = 0; i < NUM_THREADS; ++i)
    {
        f_funcs[i] = th_pool.commit(last_layer_forward, i, feats_in, sr_img_ptr, last_lut);
    }
    for (size_t i = 0; i < NUM_THREADS; ++i)
    {
        f_funcs[i].get();
    }

    // end
    auto end = chrono::high_resolution_clock::now();
    auto duration = chrono::duration_cast<chrono::milliseconds>(end - start).count();

    // convert output
    Mat tmp_out_img;
    Mat out_img;
    sr_img_mat.convertTo(tmp_out_img, CV_8UC3);
    cvtColor(tmp_out_img, out_img, COLOR_BGR2BGRA);
    auto out_ptr = reinterpret_cast<void*>(out_img.data);
    memcpy(sr_pixels, out_ptr, out_width*out_height*4);

    // release resource
    AndroidBitmap_unlockPixels(env, lr_bitmap);
    AndroidBitmap_unlockPixels(env, sr_bitmap);
    env->ReleaseByteArrayElements(first_lut_Java, (jbyte *) first_lut, 0);
    env->ReleaseByteArrayElements(mid_luts_Java, (jbyte *) mid_luts, 0);
    env->ReleaseByteArrayElements(last_lut_Java, (jbyte *) last_lut, 0);

    return duration;
}

extern "C"
JNIEXPORT jbyteArray JNICALL
Java_com_example_expandedconv_MainActivity_loadLUT(JNIEnv *env, jobject thiz, jobject assetManagerJava, jstring nameJava) {
    const char* filename = env->GetStringUTFChars(nameJava, JNI_FALSE);
    AAssetManager *assetManager = AAssetManager_fromJava(env, assetManagerJava);
    AAsset *asset = AAssetManager_open(assetManager, filename, AASSET_MODE_BUFFER);
    off_t filelength = AAsset_getLength(asset);
    void *lut_buffer = (void *) AAsset_getBuffer(asset);
    FILE *fp = fmemopen(lut_buffer, filelength, "rb");
    if (!fp) throw std::runtime_error("npy_load: Unable to open file");
    cnpy::NpyArray arr = cnpy::load_the_npy_file(fp);
    fclose(fp);
    AAsset_close(asset);

    auto *p_lut = arr.data<int8_t>();
    jbyteArray lut = env->NewByteArray(filelength);
    env->SetByteArrayRegion(lut, 0, filelength, p_lut);
    return lut;
}

