#include <cstring>
#include "grain_growth.h"
#define EPS 1e-3

const int GrainGrowthOneBack::dx[] = {0, 1, 0,-1, 0};
const int GrainGrowthOneBack::dy[] = {0, 0, 1, 0,-1};
const int GrainGrowthOneBack::laplen = 5;
const valueType GrainGrowthOneBack::lapw[] = {-4, 1, 1, 1, 1};

GrainGrowthOneBack::GrainGrowthOneBack(int _Nx, int _Ny, uint _n_grains, uint _lshK, uint _lshL, valueType _h,
                       valueType _A, valueType _B, valueType _L, valueType _kappa,
                       valueType _dtime, valueType _lsh_r)
    : OneStep(_n_grains*laplen*2, _Nx*_Ny*_n_grains*2, _Nx*_Ny, _lshK, _lshL, _lsh_r),
    Nx(_Nx), Ny(_Ny), size(_Nx*_Ny), n_grains(_n_grains), lshK(_lshK), lshL(_lshL), vals_len(laplen*_n_grains*2),
    h(_h), h2(_h*_h), A(_A), B(_B), updateL(_L), kappa(_kappa), dtime(_dtime), dtimeL(_dtime*_L)                       
{}


void GrainGrowthOneBack::grab_vals(uint item, valueType* value_table, \
                                      valueType* vals)
{
    uint start_pos = 0;
    uint start_vals_pos = 0;
    uint i = 0, pg = 0;
    
    //c.from_item(item, Ny);
    int c_x = item / Ny;
    int c_y = item % Ny;
    uint pd, root_item;
    int cc_x, cc_y;

    for (; i < laplen; ++i)
    {
        //cc.from(c);
        cc_x = c_x;
        cc_y = c_y;

        cc_x += dx[i];
        cc_y += dy[i];

        cc_x = max(cc_x, 0); // smch: may be changed to mod operation
        cc_x = min(cc_x, Nx - 1);
        cc_y = max(cc_y, 0);
        cc_y = min(cc_y, Ny - 1);

        //this_item = cc.to_item(Ny);
        pd = inv.item2pd[((uint)cc_x) * Ny + ((uint)cc_y)];
        //root = inv.find_(pd);
        root_item = inv.d_item(inv.find_(pd));

        start_pos = 0;
        start_vals_pos = 0;
        for (pg = 0; pg < n_grains*2; ++pg)
        {
            vals[start_vals_pos + i] = value_table[root_item + start_pos];
            start_pos += size;
            start_vals_pos += laplen;
        }
    }
}

void GrainGrowthOneBack::assign_vals(valueType *old_v, uint c_old, valueType *new_v, uint c_new)
{
    uint start_pos = 0;
    for (uint pg = 0; pg < n_grains*2; ++pg)
    {
        new_v[start_pos + c_new] = old_v[start_pos + c_old];
        start_pos += size;
    }
}

void GrainGrowthOneBack::forward_one_step(valueType *vals, uint c, valueType *new_v)
{
    valueType back_vals[this->vals_len/2];
    forward_one_step_vals(vals, back_vals);
    new_v[c] = back_vals[0];
    new_v[c + num_items] = back_vals[laplen];

    valueType back_vals_full[this->vals_len];
    std::memcpy(back_vals_full, vals, this->vals_len * sizeof(valueType));

    back_vals_full[0] = back_vals[0];
    back_vals_full[laplen] = back_vals[laplen];
    for (uint i = 1; i < laplen; ++ i) {
        back_vals_full[i] = vals[i] + back_vals[0] - vals[0];
        back_vals_full[i + laplen] = vals[i + laplen] + back_vals[laplen] - vals[laplen];
    }

    accumulate_weight_derivative(back_vals_full, c);

    // assume only two grains
    valueType eta_1 = back_vals_full[0];
    valueType eta_2 = back_vals_full[laplen];
    valueType dloss_deta1 = back_vals_full[laplen * 2];
    valueType dloss_deta2 = back_vals_full[laplen * 3];
    valueType eta_1_sqr = eta_1*eta_1;
    valueType eta_2_sqr = eta_2*eta_2;

    // compute dloss_deta1
    //      compute dloss_deta1_deta1_deta1
    valueType dloss_deta1_deta1_deta1 = dloss_deta1 * (1 - dtimeL * (-A + B*3*eta_1_sqr + 2*eta_2_sqr));
    dloss_deta1_deta1_deta1 += dtimeL * kappa * inner_product(back_vals_full+laplen*2, lapw, laplen) / h2;
    //      compute dloss_deta2_deta2_deta1
    valueType dloss_deta2_deta2_deta1 = -dtimeL * (2 * eta_2 * 2 * eta_1);

    // compute dloss_deta2
    //      compute dloss_deta2_deta2_deta2
    valueType dloss_deta2_deta2_deta2 = dloss_deta1 * (1 - dtimeL * (-A + B*3*eta_2_sqr + 2*eta_1_sqr));
    dloss_deta2_deta2_deta2 += dtimeL * kappa * inner_product(back_vals_full+laplen*3, lapw, laplen) / h2;
    //      compute dloss_deta1_deta1_deta2
    valueType dloss_deta1_deta1_deta2 = -dtimeL * (2 * eta_2 * 2 * eta_1);

    if (eta_1 <= EPS || eta_1 >= 1.0) {
        dloss_deta1_deta1_deta1 = 0;
        dloss_deta1_deta1_deta2 = 0;
    }

    if (eta_2 <= EPS || eta_2 >= 1.0) {
        dloss_deta2_deta2_deta1 = 0;
        dloss_deta2_deta2_deta2 = 0;
    }

    new_v[c+num_items*2] = dloss_deta1_deta1_deta1 + dloss_deta2_deta2_deta1;
    new_v[c+num_items*3] = dloss_deta2_deta2_deta2 + dloss_deta1_deta1_deta2;
}

void GrainGrowthOneBack::forward_one_step_vals(valueType *vals, valueType *back_v){
    valueType eta2[n_grains];
    uint pg = 0;
    uint start_vals_pos = 0;
    valueType sum_eta2 = 0;
    for (; pg < n_grains; ++pg)
    {
        eta2[pg] = vals[start_vals_pos] * vals[start_vals_pos];
        sum_eta2 += eta2[pg];
        start_vals_pos += laplen;
    }

    start_vals_pos = 0;
    uint start_pos = 0;
    for (pg = 0; pg < n_grains; ++pg)
    {
        valueType d_energy = vals[start_vals_pos] * (-A + B * eta2[pg] + 2 * (sum_eta2 - eta2[pg]));
        valueType lap_eta = inner_product(lapw, vals + start_vals_pos, laplen) / h2;
        back_v[start_pos] = vals[start_vals_pos] + dtimeL * (d_energy - kappa * lap_eta);
        back_v[start_pos] = std::max(back_v[start_pos], (valueType)0.);
        back_v[start_pos] = std::min(back_v[start_pos], (valueType)1.0);

        start_vals_pos += laplen;
        start_pos += laplen;
    }
}

void GrainGrowthOneBack::move_out_neighbor_from_n_list(uint item, PNBucket *t) {
    Coordinate2d c(0, 0);
    c.from_item(item, Ny);
    uint root_item = t->p_list;
    uint root_pd = inv.item2pd[root_item];

    for (uint i = 1; i < laplen; ++i)
    {
        Coordinate2d cc(c);
        cc.x += dx[i];
        cc.y += dy[i];

        cc.x = max(cc.x, 0);
        cc.x = min(cc.x, Nx - 1);
        cc.y = max(cc.y, 0);
        cc.y = min(cc.y, Ny - 1);

        if (cc.x == c.x && cc.y == c.y)
            continue;

        //set< uint >::iterator cc_it = (t->n_list).find(cc.to_item(Ny));
        //if (cc_it == (t->n_list).end())
        //  continue;
        uint cc_hash;
        uint cc_it = pnb.n_list_hash.find(t->n_list_id, cc.to_item(Ny), cc_hash);
        if (cc_it == UINT_NULL)
            continue;

        // check its neighbor
        bool clean_out = true;
        for (uint j = 0; j < laplen; ++j)
        {
            Coordinate2d c3(cc);
            c3.x += dx[i];
            c3.y += dy[i];

            c3.x = max(c3.x, 0);
            c3.x = min(c3.x, Nx - 1);
            c3.y = max(c3.y, 0);
            c3.y = min(c3.y, Ny - 1);

            if (inv.find_(inv.item2pd[c3.to_item(Ny)]) == root_pd)
            {
                clean_out = false;
                break;
            }
        }
        if (clean_out)
        {
            //(t->n_list).erase(cc_it);
            pnb.n_list_hash.delete_(cc_it);
        }
    }

    //if ((t->n_list).find(item) != (t->n_list).end())
    //  return ;
    uint item_hash;
    uint item_it = pnb.n_list_hash.find(t->n_list_id, item, item_hash);
    if (item_it != UINT_NULL)
        return;

    bool add_in = false;
    for (uint i = 1; i < laplen; ++i)
    {
        Coordinate2d cc(c);
        cc.x += dx[i];
        cc.y += dy[i];

        cc.x = max(cc.x, 0);
        cc.x = min(cc.x, Nx - 1);
        cc.y = max(cc.y, 0);
        cc.y = min(cc.y, Ny - 1);

        if (cc.x == c.x && cc.y == c.y)
            continue;

        if (inv.find_(inv.item2pd[cc.to_item(Ny)]) == root_pd)
        {
            add_in = true;
            break;
        }
    }

    if (add_in)
    {
        //(t->n_list).insert(item);
        pnb.n_list_hash.insert_(t->n_list_id, item, item_hash);
    }
}

void GrainGrowthOneBack::merge_neighbor_into_n_list(uint item, PNBucket* t) {
    Coordinate2d c(0, 0);
    c.from_item(item, Ny);

    uint root_item = t->p_list;
    uint root_pd = inv.item2pd[root_item];

    for (uint i = 1; i < laplen; ++i)
    {
        Coordinate2d cc(c);
        cc.x += dx[i];
        cc.y += dy[i];

        cc.x = max(cc.x, 0);
        cc.x = min(cc.x, Nx - 1);
        cc.y = max(cc.y, 0);
        cc.y = min(cc.y, Ny - 1);

        if (cc.x == c.x && cc.y == c.y)
            continue;

        uint this_item = cc.to_item(Ny);
        //if ((t->n_list).find(this_item) != (t->n_list).end())
        //  continue;
        // this does not need to be implemented; because we automatically ensure no duplication.

        uint pd = inv.item2pd[this_item];

        if (inv.find_(pd) != root_pd)
        {
            //(t->n_list).insert(this_item);
            pnb.n_list_hash.insert_no_duplicate(t->n_list_id, this_item);
        }
    }

    //set<uint>::iterator it = (t->n_list).find(item);
    //if (it != (t->n_list).end())
    //  (t->n_list).erase(it);
    uint item_hash;
    uint it = pnb.n_list_hash.find(t->n_list_id, item, item_hash);
    if (it != UINT_NULL)
        pnb.n_list_hash.delete_(it);
}

valueType** GrainGrowthOneBack::decode_to_img() {
    valueType* img = new valueType[((uint)size)*n_grains];

    Coordinate2d c(0, 0);
    uint start_pos = 0;
    uint pg = 0;
    for (c.x = 0; c.x < Nx; ++c.x)
        for (c.y = 0; c.y < Ny; ++c.y)
        {
            uint item = c.to_item(Ny);
            uint item_pd = inv.item2pd[item];
            uint root = inv.find_(item_pd);
            uint root_item = inv.d_item(root);
            start_pos = 0;
            for (pg = 0; pg < n_grains; ++pg)
            {
                img[start_pos + item] = old_v[start_pos + root_item];
                start_pos += size;
            }
        }
    
    valueType* dev = new valueType[((uint)size)*n_grains];
    // Coordinate2d c(0, 0);
    start_pos = 0;
    pg = 0;
    for (c.x = 0; c.x < Nx; ++c.x)
        for (c.y = 0; c.y < Ny; ++c.y)
        {
            uint item = c.to_item(Ny);
            uint item_pd = inv.item2pd[item];
            uint root = inv.find_(item_pd);
            uint root_item = inv.d_item(root);
            start_pos = 0;
            for (pg = 0; pg < n_grains; ++pg)
            {
                img[start_pos + item] = old_v[start_pos + root_item + size * 2];
                start_pos += size;
            }
        }

    valueType** ret_mtx = new valueType*[2];
    ret_mtx[0] = img;
    ret_mtx[1] = dev;

    return ret_mtx;
}

void GrainGrowthOneBack::encode_from_img(valueType *img, valueType *dloss){
    for (uint i = 0; i < size * n_grains; ++i)
        old_v[i] = img[i];

    for (uint i = 0; i < size * n_grains; ++i)
        old_v[i + size * 2] = dloss[i];

    // uint usize = (uint)size;
    // // printf("after assign old_v\n");
    // // fflush(stdout);

    // inv
    uint usize = (uint)size;
    for (uint i = 0; i < usize; ++i)
    {
        //printf("makeset(%u)\n", i);
        //fflush(stdout);
        inv.makeset(i);
    }

    //inv.check_from_dfslist(usize);

    // printf("laplen*n_grains=%u, lshK=%u\n", ((uint)laplen) * n_grains, lshK);
    // fflush(stdout);
    //
    valueType vals[((uint)laplen) * n_grains * 2];
    int item_lsh[lshK];
    uint item_k = 0, item = 0;
    // printf("before itemvec[%u], laplen*n_grains=%u, lshK=%u\n", usize, ((uint)laplen) * n_grains, lshK);
    // fflush(stdout);

    uint handle_once = min(usize, (uint)1024 * 1024);

    uint item_vec[handle_once]; //[16777216]; //

    for (uint handle_start = 0; handle_start < usize; handle_start += handle_once)
    {
        // printf("handle_start=%u\n", handle_start);
        // fflush(stdout);

        uint num_handled = min(usize - handle_start, handle_once);
        for (item_k = 0; item_k < num_handled; ++item_k)
            item_vec[item_k] = handle_start + item_k;

        // printf("before random shuffle\n");
        // fflush(stdout);
        random_shuffle(item_vec, item_vec + num_handled);
        // printf("after random shuffle\n");

        for (item_k = 0; item_k < num_handled; ++item_k)
        {
            item = item_vec[item_k];
            //printf("encode item=%u\n", item);
            //fflush(stdout);

            grab_vals(item, old_v, vals);

            uint cl = 0;
            // determine if can add to other pnb;
            uint pnb_to_add = UINT_NULL;
            for (cl = 0; cl < L; ++cl)
            {
                lsh.lsh(vals, cl, item_lsh);

                uint hp_it = hash_t.find(item_lsh, cl);
                if (hp_it != UINT_NULL)
                {
                    pnb_to_add = hash_t.hp.d[hp_it].pnb;
                    break;
                }
            }

            uint item_pd = inv.item2pd[item];
            PNBucket *pn_it = NULL;
            if (pnb_to_add != UINT_NULL)
            {
                // merge into this bucket
                pn_it = &(pnb.d[pnb_to_add]);
                uint ori_pn_plist = pn_it->p_list;
                uint pn_pd = inv.item2pd[pn_it->p_list];
                uint root = inv.union_(item_pd, pn_pd);
                pn_it->p_list = inv.d_item(root);

                if (pn_it->p_list != ori_pn_plist)
                {
                    /*
            for (cl = 0; cl < L; ++ cl) {
            item2hp[pn_it->p_list][cl].clear();
            item2hp[pn_it->p_list][cl].insert(item2hp[ori_pn_plist][cl].begin(), \
            item2hp[ori_pn_plist][cl].end());
            item2hp[ori_pn_plist][cl].clear();
            }
          */
                    for (cl = 0; cl < L; ++cl)
                    {
                        //printf("move from %u to %u\n", item2hp_id[ori_pn_plist][cl], item2hp_id[pn_it->p_list][cl]);
                        //printf("before move id=%u: ", item2hp_id[ori_pn_plist][cl]);
                        //item2hp_hash.print_id_all_c(item2hp_id[ori_pn_plist][cl]);
                        //printf("\n");
                        //printf("before move id=%u: ", item2hp_id[pn_it->p_list][cl]);
                        //item2hp_hash.print_id_all_c(item2hp_id[pn_it->p_list][cl]);
                        //printf("\n");

                        item2hp_hash.clear(item2hp_id[pn_it->p_list][cl]);
                        item2hp_hash.move_from_id_to_id(item2hp_id[ori_pn_plist][cl],
                                                        item2hp_id[pn_it->p_list][cl]);

                        //printf("after move id=%u: ", item2hp_id[ori_pn_plist][cl]);
                        //item2hp_hash.print_id_all_c(item2hp_id[ori_pn_plist][cl]);
                        //printf("\n");
                        //printf("after move id=%u: ", item2hp_id[pn_it->p_list][cl]);
                        //item2hp_hash.print_id_all_c(item2hp_id[pn_it->p_list][cl]);
                        //printf("\n");
                    }
                }
                // no need to update old_v (it is their accurate value).
                merge_neighbor_into_n_list(item, pn_it);
            }
            else
            {
                pnb_to_add = pnb.new_elem();
                pn_it = &(pnb.d[pnb_to_add]);
                pn_it->p_list = item;
                //assert((pn_it->n_list).size() == 0);
                merge_neighbor_into_n_list(item, pn_it);
            }

            for (cl = 0; cl < L; ++cl)
            {
                lsh.lsh(vals, cl, item_lsh);

                uint hp_it = hash_t.find(item_lsh, cl);
                if (hp_it == UINT_NULL || hash_t.hp.d[hp_it].pnb != pnb_to_add)
                {
                    uint hp_id = hash_t.hp.new_elem();
                    HashPointer *hp_it = &(hash_t.hp.d[hp_id]);
                    memcpy(hp_it->lsh_hash_code, item_lsh, sizeof(int) * K);
                    hp_it->hash_code = hash_t.hash_from_lsh(item_lsh);
                    hp_it->pnb = pnb_to_add;
                    hash_t.insert(hp_id, hp_it->hash_code, cl);
                    //item2hp[pn_it->p_list][cl].insert(hp_id);

                    //printf("insert id=%u c=%u\n", item2hp_id[pn_it->p_list][cl], hp_id);
                    //printf("before insert: ");
                    //item2hp_hash.print_id_all_c(item2hp_id[pn_it->p_list][cl]);
                    //printf("\n");
                    item2hp_hash.insert_no_duplicate(item2hp_id[pn_it->p_list][cl], hp_id);
                    //printf("after insert: ");
                    //item2hp_hash.print_id_all_c(item2hp_id[pn_it->p_list][cl]);
                    //printf("\n");
                }
            }

            //printf("after processing %u item\n", item);
            /*
        if (item % Ny == 0) {
        printf("after processing %u items\n", item);
        //hash_t.print_hash_table(inv);
        } 
      */
        }
    }
    //inv.check_from_dfslist(usize);
    //check_non_empty_item2hp();
}

void GrainGrowthOneBack::accumulate_weight_derivative(valueType *vals, uint c){
    uint root_pd = inv.item2pd[c];
    uint bucket_size = inv.d_size(root_pd);

    valueType eta_1 = vals[0];
    valueType eta_2 = vals[laplen];
    valueType dloss_deta1 = vals[laplen * 2];
    valueType dloss_deta2 = vals[laplen * 3];
    valueType eta_1_sqr = eta_1*eta_1;
    valueType eta_2_sqr = eta_2*eta_2;

    valueType d_energy_eta1 = eta_1*(-A + B*eta_1_sqr + 2*(eta_2_sqr));
    valueType d_energy_eta2 = eta_2*(-A + B*eta_2_sqr + 2*(eta_1_sqr));

    valueType lapeta1 = inner_product(vals, lapw, laplen) / h2;
    valueType lapeta2 = inner_product(vals + laplen, lapw, laplen) / h2;

    // deta_dL
    valueType deta1_dL = -dtime * (d_energy_eta1 - kappa*lapeta1);
    valueType deta2_dL = -dtime * (d_energy_eta2 - kappa*lapeta2);

    // deta_dA
    valueType deta1_dA = -dtimeL * (-eta_1);
    valueType deta2_dA = -dtimeL * (-eta_2);

    // deta_dB
    valueType deta1_dB = -dtimeL * (-eta_1_sqr*eta_1);
    valueType deta2_dB = -dtimeL * (-eta_2_sqr*eta_2);

    // deta_dkappa
    valueType deta1_dkappa = -dtimeL * (-lapeta1);
    valueType deta2_dkappa = -dtimeL * (-lapeta2);

     if (!(eta_1 <= EPS || eta_1 >= 1.0)) {
         dupdateL += bucket_size * dloss_deta1 * deta1_dL;
         dA += bucket_size * dloss_deta1 * deta1_dA;
         dB += bucket_size * dloss_deta1 * deta1_dB;
         dkappa += bucket_size * dloss_deta1 * deta1_dkappa;
     }

     if (!(eta_2 <= EPS || eta_2 >= 1.0)) {
         dupdateL += bucket_size * dloss_deta2 * deta2_dL;
         dA += bucket_size * dloss_deta2 * deta2_dA;
         dB += bucket_size * dloss_deta2 * deta2_dB;
         dkappa += bucket_size * dloss_deta2 * deta2_dkappa;
     }
}

void GrainGrowthOneBack::print_derivative(){
    fflush(stdout);
    cout << "L: " << dupdateL << endl;
    cout << "A: " << dA << endl;
    cout << "B: " << dB << endl;
    cout << "kappa: " << dkappa << endl;
}

valueType* GrainGrowthOneBack::decode_derivative(){
    valueType* dloss_dp = new valueType[4];
    dloss_dp[0] = dupdateL;
    dloss_dp[1] = dA;
    dloss_dp[2] = dB;
    dloss_dp[3] = dkappa;

    return dloss_dp;
}



