#ifndef C_IMPLEMENT_NANOVOID_APP_H
#define C_IMPLEMENT_NANOVOID_APP_H

#include "nanovoid.h"
#include <cstring>
#include <math.h>

#define EPS 1e-1
#define N 1

class Coordinate2d3c {
    // 2-dimensional 3 channel Coordinate of a pixel
    // channel are organized as c1_1 ... c1_n, c2_1 ... c2_n, c3_1 ... c3_n
public:
    int x, y;

    Coordinate2d3c(const Coordinate2d3c &c) :
            x(c.x), y(c.y) {}

    Coordinate2d3c(int _x, int _y) :
            x(_x), y(_y) {}

    // item could be from any channel
    // len is the range of x, while size is the size of table
    inline void from_item(uint item, uint len, uint size) {
        uint c = item % size;
        x = c / len;
        y = c % len;
    }

    inline uint to_item_c1(uint len, uint size) {
        return ((uint) x) * len + ((uint) y);
    }

    inline uint to_item_c2(uint len, uint size) {
        return ((uint) x) * len + ((uint) y) + size;
    }

    inline uint to_item_c3(uint len, uint size) {
        return (((uint) x) * len + ((uint) y) + size * 2);
    }

};

class ParameterSet {
public:
    valueType energy_v0, energy_i0, kBT0, kappa_v0, kappa_i0, kappa_eta0, r_bulk0, r_surf0, p_casc0, bias0, vg0, diff_v0, diff_i0, L0;
    ParameterSet(valueType init_value) {
        energy_v0 = init_value;
        energy_i0 = init_value;
        kBT0 = init_value;
        kappa_v0 = init_value;
        kappa_i0 = init_value;
        kappa_eta0 = init_value;
        r_bulk0 = init_value;
        r_surf0 = init_value;
        p_casc0 = init_value;
        bias0 = init_value;
        vg0 = init_value;
        diff_i0 = init_value;
        diff_v0 = init_value;
        L0 = init_value;
    }
};

class NanoVoidOneStep : public OneStep {
    static const uint vals_len = 13 * 3; // first 13 is for cv, second 13 for ci, third 13 for eta
    static const uint lap_len_1st = 5;
    static const uint lap_len_2nd = 13;
    static const uint n_channels = 3; // like n_grains

    valueType lsh_r; // = 1e-4;

    // debug switch
    static const uint debug_on = 0;

    static const int dy[]; // = {0, 1, 0,-1, 0, 1,-1, 1,-1, 2, 0,-2, 0};
    static const int dx[]; // = {0, 0, 1, 0,-1, 1, 1,-1,-1, 0, 2, 0,-2};

    static const valueType laplapw[]; // = {20,-8,-8,-8,-8,2,2,2,2,1,1,1,1};
    static const valueType lapw[]; // = {-4, 1, 1, 1, 1};

    // here Nx Nx is the measure of the image,
    // the size is Nx * Ny. i.e. Nx = Ny = 128, size = 128 * 128
    // previous size is Nx, remember to change to the newest semantics
    int Nx, Ny, size;

    uint lshK, lshL;

    // valueType energy_v0, energy_i0, kBT0, kappa_v0, kappa_i0, kappa_eta0, r_bulk0, r_surf0, p_casc0, bias0, vg0, diff_v0, diff_i0, L0;
    ParameterSet p;

public:
    NanoVoidOneStep(int _Nx, int _Ny, ParameterSet _p, valueType _lsh_r, uint _lshK, uint _lshL);

    void grab_vals(uint c, valueType *value_table, valueType *vals) override;

    void forward_one_step(valueType *vals, uint c, valueType *new_v) override;

    void merge_neighbor_into_n_list(uint c, PNBucket *t) override;

    void assign_vals(valueType *old_v, uint c_old, valueType *new_v, uint c_new) override;

    void move_out_neighbor_from_n_list(uint c, PNBucket *t) override;

    // log_with_mask, log_with_mask_single, masked_fill are helper function used in calculation
    // of forward_one_step
    void log_with_mask(valueType *mat, valueType eps, uint len);

    valueType log_with_mask_single(valueType p, valueType eps);

    void masked_fill(valueType *mat, int *mask, valueType eps, uint len);

    // encode_from_img, decode_to_img are helper function used to do with data interface
    void encode_from_img(valueType ***img);

    void encode_from_img_torch(valueType *cv, valueType *ci, valueType *eta);

    valueType ***decode_to_img();

    valueType **decode_to_img_torch();
};



class NanoVoidNormal {
    uint value_table_size;              // the value table size (for initialization)
    uint num_items;                     // the number of items
    valueType* old_v;                   // old value table
    valueType* new_v;                   // new value table

    static const uint vals_len = 13 * 3; // first 13 is for cv, second 13 for ci, third 13 for eta
    static const uint lap_len_1st = 5;
    static const uint lap_len_2nd = 13;

    static const int dy[]; // = {0, 1, 0,-1, 0, 1,-1, 1,-1, 2, 0,-2, 0};
    static const int dx[]; // = {0, 0, 1, 0,-1, 1, 1,-1,-1, 0, 2, 0,-2};

    static const valueType laplapw[]; // = {20,-8,-8,-8,-8,2,2,2,2,1,1,1,1};
    static const valueType lapw[]; // = {-4, 1, 1, 1, 1};

    int size;  // here the size is actually the range of x, i.e. 128 for 128*128*3
    static const uint verbose = 0;

    // valueType energy_v0, energy_i0, kBT0, kappa_v0, kappa_i0, kappa_eta0, r_bulk0, r_surf0, p_casc0, bias0, vg0, diff_v0, diff_i0, L0;
    ParameterSet p;

public:
    NanoVoidNormal(uint _size, ParameterSet _p);
    ~NanoVoidNormal();

    void grab_vals(uint item, valueType *value_table, valueType *vals);

    void forward_one_step(valueType *vals, uint c, valueType *new_v);

    void log_with_mask(valueType *mat, valueType eps, uint len);

    valueType log_with_mask_single(valueType p, valueType eps);

    void masked_fill(valueType *mat, int *mask, valueType eps, uint len);

    void encode_from_img(valueType ***img);

    valueType ***decode_to_img();

    void next();
};

class NanoVoidOneBack : public OneStep {
    static const uint vals_len = 13 * 3 * 2; // first 13 is for cv, second 13 for ci, third 13 for eta
                                             // forth 13 is for dloss_dcv, fifth 13 is for dloss_dci, sixth 13 is for dloss_deta
    static const uint lap_len_1st = 5;
    static const uint lap_len_2nd = 13;
    static const uint n_channels = 6; // like n_grains. first 3 is cv ci eta, last 3 is dloss_dcv, dloss_dci, dloss_deta

    valueType lsh_r; // = 1e-4;

    // debug switch
    static const uint debug_on = 0;

    static const int dy[]; // = {0, 1, 0,-1, 0, 1,-1, 1,-1, 2, 0,-2, 0};
    static const int dx[]; // = {0, 0, 1, 0,-1, 1, 1,-1,-1, 0, 2, 0,-2};

    static const valueType laplapw[]; // = {20,-8,-8,-8,-8,2,2,2,2,1,1,1,1};
    static const valueType lapw[]; // = {-4, 1, 1, 1, 1};

    // here Nx Nx is the measure of the image,
    // the size is Nx * Ny. i.e. Nx = Ny = 128, size = 128 * 128
    // previous size is Nx, remember to change to the newest semantics
    int Nx, Ny, size;

    uint lshK, lshL;

    ParameterSet p, dp;

public:
    NanoVoidOneBack(int _Nx, int _Ny, ParameterSet _p, valueType _lsh_r, uint _lshK, uint _lshL);

    void grab_vals(uint item, valueType *value_table, valueType *vals) override;

    void forward_one_step(valueType *vals, uint c, valueType *new_v) override;

    void forward_one_step_vals(valueType *vals, valueType *new_v); // calculate vals specifically

    void merge_neighbor_into_n_list(uint item, PNBucket *t) override;

    void assign_vals(valueType *old_v, uint c_old, valueType *new_v, uint c_new) override;

    void move_out_neighbor_from_n_list(uint item, PNBucket *t) override;

    void log_with_mask(valueType *mat, valueType eps, uint len);

    valueType log_with_mask_single(valueType p, valueType eps);

    void masked_fill(valueType *mat, int *mask, valueType eps, uint len);

    void encode_from_img(valueType ***img, valueType ***dloss);

    void encode_from_img_torch(valueType *cv, valueType *ci, valueType *eta, valueType *dloss_cv, valueType *dloss_ci, valueType *dloss_eta);

    valueType ***decode_to_img();

    valueType **decode_to_img_torch();

    void accumulate_weight_derivative(valueType *vals, uint c);

    void print_derivative();

    valueType *decode_derivative();
};

class NanovoidOneBackNormal {
    uint value_table_size;          // value table size
    uint num_items;                 // the number of items, actually is size here
    valueType* old_v;
    valueType* new_v;

    static const uint vals_len = 13 * 3 * 2; // first 13 is for cv, second 13 for ci, third 13 for eta
                                             // forth 13 is for dloss_dcv, fifth 13 is for dloss_dci, sixth 13 is for dloss_deta
    static const uint lap_len_1st = 5;
    static const uint lap_len_2nd = 13;
    static const uint n_channels = 6; // like n_grains. first 3 is cv ci eta, last 3 is dloss_dcv, dloss_dci, dloss_deta

    // valueType lsh_r; // = 1e-4;

    // debug switch
    static const uint debug_on = 0;

    static const int dy[]; // = {0, 1, 0,-1, 0, 1,-1, 1,-1, 2, 0,-2, 0};
    static const int dx[]; // = {0, 0, 1, 0,-1, 1, 1,-1,-1, 0, 2, 0,-2};

    static const valueType laplapw[]; // = {20,-8,-8,-8,-8,2,2,2,2,1,1,1,1};
    static const valueType lapw[]; // = {-4, 1, 1, 1, 1};

    // here Nx Nx is the measure of the image,
    // the size is Nx * Ny. i.e. Nx = Ny = 128, size = 128 * 128
    // previous size is Nx, remember to change to the newest semantics
    int Nx, Ny, size;

    // uint lshK, lshL;

    ParameterSet p, dp;

public:
    NanovoidOneBackNormal(int _Nx, int _Ny, ParameterSet _p);
    ~NanovoidOneBackNormal();

    void grab_vals(uint item, valueType *value_table, valueType *vals);

    void forward_one_step(valueType *vals, uint c, valueType *new_v);

    void forward_one_step_vals(valueType *vals, valueType *new_v); // calculate vals specifically

    void log_with_mask(valueType *mat, valueType eps, uint len);

    valueType log_with_mask_single(valueType p, valueType eps);

    void masked_fill(valueType *mat, int *mask, valueType eps, uint len);

    void encode_from_img(valueType ***img, valueType ***dloss);

    valueType ***decode_to_img();

    void accumulate_weight_derivative(valueType *vals, uint c);

    void print_derivative();

    void next();

};

valueType *** init_zero_mat(uint Nx, uint Ny, uint channel);

void delete_3d_array(valueType*** del_mat, uint Nx, uint Ny, uint channel);

valueType *** read_from_png(const char* filepath, int width, int height);
valueType *** read_from_data(const char* filepath, int width, int height);

int writeImage(const char* filepath, const char* dir_name, int width, int height, char* title, valueType*** input);

valueType *** get_dloss(int width, int height, int channel, valueType *** img, valueType *** ground_truth);

#endif //C_IMPLEMENT_NANOVOID_APP_H

