#ifndef C_IMPLEMENT_GRAIN_GROWTH_H
#define C_IMPLEMENT_GRAIN_GROWTH_H

#include "nanovoid.h"

class GrainGrowthOneStep : public OneStep
{
    static const int dx[]; // = {0, 1, 0,-1, 0};
    static const int dy[]; // = {0, 0, 1, 0,-1};

    static const int laplen;
    static const valueType lapw[]; // = {-4, 1, 1, 1, 1};

    int Nx, Ny, size;
    uint n_grains, lshK, lshL;
    valueType h, h2, A, B, updateL, kappa, dtime, dtimeL;

public:
    GrainGrowthOneStep(int _Nx, int _Ny, uint _n_grains, uint _lshK, uint _lshL, valueType _h,
                       valueType _A, valueType _B, valueType _L, valueType _kappa,
                       valueType _dtime, valueType _lsh_r);

    void grab_vals(uint c, valueType *value_table, valueType *vals) override;
    void forward_one_step(valueType *vals, uint c, valueType *new_v) override;
    void assign_vals(valueType *old_v, uint c_old, valueType *new_v, uint c_new) override;

    void merge_neighbor_into_n_list(uint c, PNBucket *t) override;
    void move_out_neighbor_from_n_list(uint c, PNBucket *t) override;

    void encode_from_img(valueType *img);
    valueType *decode_to_img();
};




class GrainGrowthOneBack : public OneStep
{
    static const int dx[]; // = {0, 1, 0,-1, 0};
    static const int dy[]; // = {0, 0, 1, 0,-1};

    static const int laplen;
    static const valueType lapw[]; // = {-4, 1, 1, 1, 1};

    uint vals_len; // = 5 * 2 * 2; // first 13 is for cv, second 13 for ci, third 13 for eta
                                             // forth 13 is for dloss_dcv, fifth 13 is for dloss_dci, sixth 13 is for dloss_deta
    static const uint lap_len_1st = 5;
    static const uint lap_len_2nd = 13;

    int Nx, Ny, size;
    uint n_grains, lshK, lshL;
    valueType h, h2, A, B, updateL, kappa, dtime, dtimeL;
    valueType dh, dh2, dA, dB, dupdateL, dkappa, ddtime, ddtimeL;

public:
    GrainGrowthOneBack(int _Nx, int _Ny, uint _n_grains, uint _lshK, uint _lshL, valueType _h,
                       valueType _A, valueType _B, valueType _L, valueType _kappa,
                       valueType _dtime, valueType _lsh_r);
    
    void grab_vals(uint c, valueType *value_table, valueType *vals) override;
    void forward_one_step(valueType *vals, uint c, valueType *new_v) override;
    void assign_vals(valueType *old_v, uint c_old, valueType *new_v, uint c_new) override;

    void merge_neighbor_into_n_list(uint c, PNBucket *t) override;
    void move_out_neighbor_from_n_list(uint c, PNBucket *t) override;

    void encode_from_img(valueType *img, valueType *dloss);
    valueType **decode_to_img();

    // calculation helper function
    void forward_one_step_vals(valueType *vals, valueType *new_v); // calculate vals specifically

    // gradient
    void accumulate_weight_derivative(valueType *vals, uint c);

    void print_derivative();

    valueType *decode_derivative();
};

void calculate_mse_loss(valueType* eta, valueType* eta_ref, valueType* dloss, uint size);
valueType sum_mse_loss(valueType* eta, valueType* eta_ref, uint size);
valueType sum_mtx(valueType* eta, uint size);
valueType *read_data_file(const char *filename, uint &Nx, uint &Ny, uint &n_grains, uint &n_step);
valueType *read_init_state(const char *filename, uint &Nx, uint &Ny, uint &n_grains, uint &n_step);
valueType* read_init_state_1(const char* filename, uint &Nx, uint &Ny, uint &n_grains);


struct ReturnData {
    valueType* eta1;
    valueType* eta2;
    valueType* eta1_eta2;

    ReturnData(valueType* _eta1, valueType* _eta2, valueType* _eta1_eta2) 
        : eta1(_eta1), eta2(_eta2), eta1_eta2(_eta1_eta2) {}
};

struct ReturnLabel {
    valueType* eta1_ref;
    valueType* eta2_ref;
    valueType* eta1_eta2_ref;

    ReturnLabel(valueType* _eta1_ref, valueType* _eta2_ref, valueType* _eta1_eta2_ref) 
        : eta1_ref(_eta1_ref), eta2_ref(_eta2_ref), eta1_eta2_ref(_eta1_eta2_ref) {}
};

struct ReturnItem {
    ReturnData data;
    ReturnLabel ref;
    ReturnItem(ReturnData _rd, ReturnLabel _rl) : data(_rd), ref(_rl) {}
};


class GrainGrowthDataset{
// private:
//     valueType* all_data;
//     uint Nx, Ny, n_grains, n_step;
//     int start_skip;
//     int skip_step;
//     int cnt;

public:
    valueType* all_data;
    uint Nx, Ny, n_grains, n_step;
    int start_skip;
    int skip_step;
    int cnt;

    GrainGrowthDataset(char* data_path, int _start_skip, int _skip_step);

    ReturnItem get_item(size_t index);

    inline int get_len() {
        return cnt;
    }
};





#endif // C_IMPLEMENT_GRAIN_GROWTH_H
