#include "nanovoid_app.h"
#include "png.h"
#include "cstring"
#include "iostream"
#include <fstream>
#include <sstream>
#include <string>
#include <algorithm>
#include <iterator>
#include <cassert>
#include <iomanip>
#define PNG_BYTES_TO_CHECK 4

using namespace std;

valueType*** init_zero_mat(uint Nx, uint Ny, uint channel) {
    valueType*** mtx = new valueType**[Nx];
    for (uint i = 0; i < Nx; i++) {
        valueType** row = new valueType* [Ny];
        for (uint j = 0; j < Ny; j++) {
            valueType* channel_arr = new valueType[channel];
            for (uint c = 0; c < channel; c++) {
                channel_arr[c] = 0.0;
            }
            row[j] = channel_arr;
        }
        mtx[i] = row;
    }
    return mtx;
}

void delete_3d_array(valueType*** del_mat, uint Nx, uint Ny, uint channel) {
    for (uint i = 0; i < Nx; i++) {
        for (uint j = 0; j < Ny; j++) {
            delete del_mat[i][j];
        }
        delete del_mat[i];
    }
    delete del_mat;
}

valueType *** read_from_png(const char* filepath, int width, int height) {
    FILE *fp;
    png_structp png_ptr;
    png_infop info_ptr;
    png_bytep* row_pointers;
    char buf[PNG_BYTES_TO_CHECK];
    int w, h, x, y, temp, color_type;

    fp = fopen(filepath, "rb");
    if (fp == NULL) {
        printf("load_png_image err:fp == NULL\n");
        return 0;
    }

    png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, 0, 0, 0);
    info_ptr = png_create_info_struct(png_ptr);

    setjmp(png_jmpbuf(png_ptr));
//    temp = fread(buf, 1, PNG_BYTES_TO_CHECK, fp);
//    if (temp < PNG_BYTES_TO_CHECK) {
//        fclose(fp);
//        png_destroy_read_struct(&png_ptr, &info_ptr, 0);
//        printf("load_png_image err: read data doesn't have PNG_BYTES_TO_CHECK bytes\n");
//        return 0;
//    }
//
//    temp = png_sig_cmp((png_bytep)buf, (png_size_t)0, PNG_BYTES_TO_CHECK);
//    if (temp != 0) {
//        fclose(fp);
//        png_destroy_read_struct(&png_ptr, &info_ptr, 0);
//        printf("load_png_image err: not png signature\n");
//        return 0;
//    }
//    // reset pointer
//    rewind(fp);
    // start to read image
    png_init_io(png_ptr, fp);
    // read png image info and pixel data
//    png_read_info(png_ptr, info_ptr);
    png_read_png(png_ptr, info_ptr, PNG_TRANSFORM_STRIP_ALPHA, 0);
    // get image color type
    color_type = png_get_color_type(png_ptr, info_ptr);
    int bit_width = png_get_bit_depth(png_ptr, info_ptr);
    // get width and height
    w = png_get_image_width(png_ptr, info_ptr);
    h = png_get_image_height(png_ptr, info_ptr);
    assert(w == width);
    assert(h == height);

//    if (bit_width == 16)
//        png_set_strip_16(png_ptr);
//
//    if (color_type == PNG_COLOR_TYPE_PALETTE)
//        png_set_palette_to_rgb(png_ptr);
//
//    if (color_type == PNG_COLOR_TYPE_GRAY && bit_width < 8)
//        png_set_expand_gray_1_2_4_to_8(png_ptr);
//
//    if (png_get_valid(png_ptr, info_ptr, PNG_INFO_tRNS))
//        png_set_tRNS_to_alpha(png_ptr);
//
//    if(color_type == PNG_COLOR_TYPE_RGB ||
//       color_type == PNG_COLOR_TYPE_GRAY ||
//       color_type == PNG_COLOR_TYPE_PALETTE)
//        png_set_filler(png_ptr, 0xFF, PNG_FILLER_AFTER);
//
//    if(color_type == PNG_COLOR_TYPE_GRAY ||
//       color_type == PNG_COLOR_TYPE_GRAY_ALPHA)
//        png_set_gray_to_rgb(png_ptr);
//
//    png_read_update_info(png_ptr, info_ptr);

    // allocate memory
//    png_bytep buff = (png_bytep)malloc(h * w * 3 * sizeof(png_byte));
//    memset(buff, 0, (h * w * 3 * sizeof(png_byte)));
    valueType*** org_mat = init_zero_mat(w, h, 3);

    // get row data
    row_pointers = png_get_rows(png_ptr, info_ptr);

    // valid color type: 1 (GRAY, PALETTE)
//    if (color_type == 1) {
//        printf("GRAY PALETTE image\n");
//        printf("bit width %d\n", bit_width);
//        for (int y = 0; y < h; y++) {
//            for (int x = 0; x < w; x++) {
//                org_mat[x][y][0] = ((valueType)row_pointers[y][x]) / 256;
//                org_mat[x][y][1] = row_pointers[y][x];
//                org_mat[x][y][2] = row_pointers[y][x];
//            }
//        }
//    } else if (color_type == 2) {
//
//    }
    printf("bit width %d\n", bit_width);
    for (int y = 0; y < h; y++) {
        for (int x = 0; x < w; x++) {
            org_mat[x][y][0] = ((valueType)row_pointers[y][x * 3]) / 255;
            org_mat[x][y][1] = org_mat[x][y][0];
            org_mat[x][y][2] = org_mat[x][y][0];
        }
    }
    fclose(fp);
    png_destroy_read_struct(&png_ptr, &info_ptr, 0);
    printf("Finish reading...\n");
    return org_mat;
}

int writeImage(const char* filepath, const char* dir_name, int width, int height, char* title, valueType*** input){
    int code = 0;
    char* filename_arr[3] = {"cv", "ci", "eta"};
    for (int fv = 0; fv < 3; fv++)
    {
        FILE *fp = NULL;
        png_structp png_ptr = NULL;
        png_infop info_ptr = NULL;
        png_bytep row = NULL;

        string prefix = filepath;
        string append = filename_arr[fv];
        string dir = dir_name;
        string fullname = dir_name + append + "_" + prefix;
        fp = fopen(fullname.c_str(), "wb");
        if (fp == NULL) {
            fprintf(stderr, "Could not open file %s for writing\n", fullname.c_str());
            code = 1;
            if (fp != NULL) fclose(fp);
            if (info_ptr != NULL) png_free_data(png_ptr, info_ptr, PNG_FREE_ALL, -1);
            if (png_ptr != NULL) png_destroy_write_struct(&png_ptr, (png_infopp) NULL);
            if (row != NULL) free(row);
            return code;
        }

        png_ptr = png_create_write_struct(PNG_LIBPNG_VER_STRING, NULL, NULL, NULL);
        if (png_ptr == NULL) {
            fprintf(stderr, "Could not allocate write struct\n");
            code = 1;
            if (fp != NULL) fclose(fp);
            if (info_ptr != NULL) png_free_data(png_ptr, info_ptr, PNG_FREE_ALL, -1);
            if (png_ptr != NULL) png_destroy_write_struct(&png_ptr, (png_infopp) NULL);
            if (row != NULL) free(row);
            return code;
        }

        info_ptr = png_create_info_struct(png_ptr);
        if (info_ptr == NULL) {
            fprintf(stderr, "Could not allocate info struct\n");
            code = 1;
            if (fp != NULL) fclose(fp);
            if (info_ptr != NULL) png_free_data(png_ptr, info_ptr, PNG_FREE_ALL, -1);
            if (png_ptr != NULL) png_destroy_write_struct(&png_ptr, (png_infopp) NULL);
            if (row != NULL) free(row);
            return code;
        }

        if (setjmp(png_jmpbuf(png_ptr))) {
            fprintf(stderr, "Error during png creation\n");
            code = 1;
            if (fp != NULL) fclose(fp);
            if (info_ptr != NULL) png_free_data(png_ptr, info_ptr, PNG_FREE_ALL, -1);
            if (png_ptr != NULL) png_destroy_write_struct(&png_ptr, (png_infopp) NULL);
            if (row != NULL) free(row);
            return code;
        }

        png_init_io(png_ptr, fp);

        png_set_IHDR(png_ptr, info_ptr, width, height,
                     8, PNG_COLOR_TYPE_GRAY, PNG_INTERLACE_NONE,
                     PNG_COMPRESSION_TYPE_BASE, PNG_FILTER_TYPE_BASE);

        if (title != NULL) {
            png_text title_text;
            title_text.compression = PNG_TEXT_COMPRESSION_NONE;
            title_text.key = "Title";
            title_text.text = title;
            png_set_text(png_ptr, info_ptr, &title_text, 1);
        }

        png_write_info(png_ptr, info_ptr);

        row = (png_bytep) malloc(1 * width * sizeof(png_byte));
//    memset(row, 0, 1 * width * sizeof(png_byte));

        for (int y = 0; y < height; y++) {
            for (int x = 0; x < width; x++) {
                // should generate 3 image here, one for cv, one for ci, one for eta
                row[x] = (png_byte)(input[x][y][fv] * 255);
            }
            png_write_row(png_ptr, row);
        }
        png_write_end(png_ptr, NULL);

        if (fp != NULL) fclose(fp);
        if (info_ptr != NULL) png_free_data(png_ptr, info_ptr, PNG_FREE_ALL, -1);
        if (png_ptr != NULL) png_destroy_write_struct(&png_ptr, (png_infopp) NULL);
        if (row != NULL) free(row);
    }


    return code;
}

valueType *** read_from_data(const char* filepath, int width, int height) {
    std::ifstream infile(filepath);
    std::string line;

    int line_num = 0;
    valueType*** org_mat = init_zero_mat(width, height, 3);

    int total_ring = 0;

    while (std::getline(infile, line)) {
        std::istringstream ss(line);
        // std::string word;
        // while (ss >> word) {

        // }
        std::vector<valueType> v;
        std::copy(std::istream_iterator<valueType>(ss), 
                    std::istream_iterator<valueType>(), 
                    std::back_inserter(v));
        for (int i = 0; i < width; i++) {
            org_mat[line_num % height][i][line_num / height] = v[i];
            if (v[i] != 1.0 && v[i] != 0.0)
                total_ring++;
            // switch (line_num / height)
            // {
            // case 0:{
            //     org_mat[line_num % height][i][0] = v[i];
            //     break;
            // }
            // case 1:{
            //     break;
            // }
            // case 2:{
            //     break;
            // }
            // default:
            //     break;
            // } 
        }
        line_num ++;
    }
    printf("ring pixel: %d\n", total_ring);
    return org_mat;
}

int writeData(const char* filepath, const char* dir_name, int width, int height, char* title, valueType*** input){
    std::ofstream myfile(filepath, std::ios_base::binary);
    if (myfile.good()) {
        for (int c = 0; c < 3; c++) {
            for (int i = 0; i < height; i++) {
                for (int j = 0; j < width; j++) {
                    myfile << std::fixed << std::setprecision(8) << input[i][j][c] << " ";
                }
                myfile << std::endl;
            }
        }
    }
    else {
        std::cerr << "Unable to open file: " << filepath;
    }
}

valueType *** get_dloss(int width, int height, int channel, valueType *** img, valueType *** ground_truth) {

    // reduction = "sum"
    valueType*** dloss = init_zero_mat(width, height, 3);
    valueType sum_loss = 0.0;
    for (uint i = 0; i < width; i++) {
        for (uint j = 0; j < height; j++) {
            for (uint c = 0; c < channel; c++) {
                sum_loss += std::pow(img[i][j][c] - ground_truth[i][j][c], 2);
                // dloss[i][j][c] = std::pow(img[i][j][c] - ground_truth[i][j][c], 2);
            }
        }
    }
    printf("sum loss: %f\n", sum_loss);
    /*
    if reduction = "mean"
        sum_loss /= (width * height * channel)
    */


    for (uint i = 0; i < width; i++) {
        for (uint j = 0; j < height; j++) {
            for (uint c = 0; c < channel; c++) {
                // sum_loss += std::pow(img[i][j][c] - ground_truth[i][j][c], 2);
                dloss[i][j][c] = 2 * (img[i][j][c] - ground_truth[i][j][c]);
            }
        }
    }

    return dloss;
}
