#include <cstring>
#include "grain_growth.hpp"

const int GrainGrowthOneStep::dx[] = {0, 1, 0,-1, 0};
const int GrainGrowthOneStep::dy[] = {0, 0, 1, 0,-1};
const int GrainGrowthOneStep::laplen = 5;
const valueType GrainGrowthOneStep::lapw[] = {-4, 1, 1, 1, 1};


GrainGrowthOneStep::GrainGrowthOneStep(int _Nx, int _Ny, uint _n_grains,
                                       uint _lshK, uint _lshL, valueType _h, \
                                       valueType _A, valueType _B, valueType _L, \
                                       valueType _kappa, valueType _dtime, valueType _lsh_r)
  : OneStep(_n_grains*laplen, _Nx*_Ny*_n_grains, _Nx*_Ny, _lshK, _lshL, _lsh_r),
    Nx(_Nx), Ny(_Ny), size(_Nx*_Ny), n_grains(_n_grains), lshK(_lshK), lshL(_lshL),
    h(_h), h2(_h*_h),
    A(_A), B(_B), updateL(_L),
    kappa(_kappa), dtime(_dtime), dtimeL(_dtime*_L)
{}


void GrainGrowthOneStep::grab_vals(uint item, valueType* value_table, \
                                      valueType* vals) {
  //Coordinate2d c(0, 0);
  //Coordinate2d cc(c);
  
  uint start_pos = 0;
  uint start_vals_pos = 0;
  uint i = 0, pg = 0;
  
  //c.from_item(item, Ny);
  int c_x = item / Ny;
  int c_y = item % Ny;
  uint pd, root_item;
  int cc_x, cc_y;
  for ( ; i < laplen; ++ i) {
    //cc.from(c);
    cc_x = c_x; cc_y = c_y;
    
    cc_x += dx[i];
    cc_y += dy[i];

    cc_x = max(cc_x, 0);      // smch: may be changed to mod operation 
    cc_x = min(cc_x, Nx-1);
    cc_y = max(cc_y, 0);
    cc_y = min(cc_y, Ny-1);

    //this_item = cc.to_item(Ny);
    pd = inv.item2pd[((uint)cc_x)*Ny + ((uint)cc_y)];
    //root = inv.find_(pd);
    root_item = inv.d_item(inv.find_(pd));

    start_pos = 0;
    start_vals_pos = 0;
    for (pg = 0; pg < n_grains; ++ pg) {
      vals[start_vals_pos + i] = value_table[root_item + start_pos];
      start_pos += size;
      start_vals_pos += laplen;
    }
  }
}


void GrainGrowthOneStep::assign_vals(valueType* old_v, uint c_old, \
                                     valueType* new_v, uint c_new)  {
  uint start_pos = 0;
  for (uint pg = 0; pg < n_grains; ++ pg) {
    new_v[start_pos + c_new] = old_v[start_pos + c_old];
    start_pos += size;
  }
}


void GrainGrowthOneStep::forward_one_step(valueType* vals, uint c, valueType* new_v) {
  valueType eta2[n_grains];
  uint pg = 0;
  uint start_vals_pos = 0;
  valueType sum_eta2 = 0;
  for ( ; pg < n_grains; ++ pg) {
    eta2[pg] = vals[start_vals_pos]*vals[start_vals_pos];
    sum_eta2 += eta2[pg];
    start_vals_pos += laplen;
  }

  start_vals_pos = 0;
  uint start_pos = 0;
  for (pg = 0; pg < n_grains; ++ pg) {
    valueType d_energy = vals[start_vals_pos]*(-A + B*eta2[pg] + 2*(sum_eta2 - eta2[pg]));
    valueType lap_eta = inner_product(lapw, vals + start_vals_pos, laplen) / h2;
    new_v[start_pos + c] = vals[start_vals_pos] - dtimeL*(d_energy - kappa*lap_eta);
    new_v[start_pos + c] = max(new_v[start_pos + c], (valueType)0.);
    new_v[start_pos + c] = min(new_v[start_pos + c], (valueType)1.0);
    
    start_vals_pos += laplen;
    start_pos += size;
  }  
}


void GrainGrowthOneStep::move_out_neighbor_from_n_list(uint item, PNBucket* t) {
  Coordinate2d c(0, 0);
  c.from_item(item, Ny);

  uint root_item = t->p_list;
  uint root_pd = inv.item2pd[root_item];

  for (uint i = 1; i < laplen; ++ i) {
    Coordinate2d cc(c);
    cc.x += dx[i];
    cc.y += dy[i];

    cc.x = max(cc.x, 0);
    cc.x = min(cc.x, Nx-1);
    cc.y = max(cc.y, 0);
    cc.y = min(cc.y, Ny-1);

    if (cc.x == c.x && cc.y == c.y)
      continue;

    //set< uint >::iterator cc_it = (t->n_list).find(cc.to_item(Ny));
    //if (cc_it == (t->n_list).end())
    //  continue;
    uint cc_hash;
    uint cc_it = pnb.n_list_hash.find(t->n_list_id, cc.to_item(Ny), cc_hash);
    if (cc_it == UINT_NULL)
      continue;    

    // check its neighbor
    bool clean_out = true;
    for (uint j = 0; j < laplen; ++ j) {
      Coordinate2d c3(cc);
      c3.x += dx[i];
      c3.y += dy[i];

      c3.x = max(c3.x, 0);
      c3.x = min(c3.x, Nx-1);
      c3.y = max(c3.y, 0);
      c3.y = min(c3.y, Ny-1);

      if (inv.find_(inv.item2pd[c3.to_item(Ny)]) == root_pd) {
        clean_out = false;
        break;
      }
    }
    if (clean_out) {
      //(t->n_list).erase(cc_it); 
      pnb.n_list_hash.delete_(cc_it);
    }
  }

  //if ((t->n_list).find(item) != (t->n_list).end())
  //  return ;
  uint item_hash;
  uint item_it = pnb.n_list_hash.find(t->n_list_id, item, item_hash);
  if (item_it != UINT_NULL)
    return ;
  
  bool add_in = false;
  for (uint i = 1; i < laplen; ++ i) {
    Coordinate2d cc(c);
    cc.x += dx[i];
    cc.y += dy[i];

    cc.x = max(cc.x, 0);
    cc.x = min(cc.x, Nx-1);
    cc.y = max(cc.y, 0);
    cc.y = min(cc.y, Ny-1);

    if (cc.x == c.x && cc.y == c.y)
      continue;

    if (inv.find_(inv.item2pd[cc.to_item(Ny)]) == root_pd) {
      add_in = true;
      break;
    }    
  }  

  if (add_in){
    //(t->n_list).insert(item);
    pnb.n_list_hash.insert_(t->n_list_id, item, item_hash);
  }  
}


void GrainGrowthOneStep::merge_neighbor_into_n_list(uint item, \
                                                    PNBucket* t)  {
  Coordinate2d c(0, 0);
  c.from_item(item, Ny);

  uint root_item = t->p_list;
  uint root_pd = inv.item2pd[root_item];

  for (uint i = 1; i < laplen; ++ i) {
    Coordinate2d cc(c);
    cc.x += dx[i];
    cc.y += dy[i];

    cc.x = max(cc.x, 0);
    cc.x = min(cc.x, Nx-1);
    cc.y = max(cc.y, 0);
    cc.y = min(cc.y, Ny-1);

    if (cc.x == c.x && cc.y == c.y)
      continue;
    
    uint this_item = cc.to_item(Ny);
    //if ((t->n_list).find(this_item) != (t->n_list).end())
    //  continue;
    // this does not need to be implemented; because we automatically ensure no duplication.
    
    uint pd = inv.item2pd[this_item];

    if (inv.find_(pd) != root_pd){
      //(t->n_list).insert(this_item);
      pnb.n_list_hash.insert_no_duplicate(t->n_list_id, this_item);
    }
  }

  //set<uint>::iterator it = (t->n_list).find(item);
  //if (it != (t->n_list).end())
  //  (t->n_list).erase(it);
  uint item_hash;
  uint it = pnb.n_list_hash.find(t->n_list_id, item, item_hash);
  if (it != UINT_NULL)
    pnb.n_list_hash.delete_(it);  
}


valueType* GrainGrowthOneStep::decode_to_img() {
  valueType* img = new valueType[((uint)size)*n_grains];

  Coordinate2d c(0, 0);
  uint start_pos = 0;
  uint pg = 0;
  for (c.x = 0; c.x < Nx; ++ c.x)
    for (c.y = 0; c.y < Ny; ++ c.y) {
      uint item = c.to_item(Ny);
      uint item_pd = inv.item2pd[item];
      uint root = inv.find_(item_pd);
      uint root_item = inv.d_item(root);
      start_pos = 0;
      for (pg = 0; pg < n_grains; ++ pg){
        img[start_pos + item] = old_v[start_pos + root_item];
        start_pos += size;
      }
    }

  return img;
}


void GrainGrowthOneStep::encode_from_img(valueType* img) {
  // old_v
  for (uint i = 0; i < size*n_grains; ++ i)
    old_v[i] = img[i];

  // printf("after assign old_v\n");
  fflush(stdout);
  
  // inv
  uint usize = (uint)size;
  for (uint i = 0; i < usize; ++ i) {
    //printf("makeset(%u)\n", i);
    //fflush(stdout);
    inv.makeset(i);
  }

  //inv.check_from_dfslist(usize);
  
  // printf("laplen*n_grains=%u, lshK=%u\n", ((uint)laplen)*n_grains, lshK);
  fflush(stdout);
  //
  valueType vals[((uint)laplen)*n_grains];
  int item_lsh[lshK];
  uint item_k = 0, item = 0;
  // printf("before itemvec[%u], laplen*n_grains=%u, lshK=%u\n", usize, ((uint)laplen)*n_grains, lshK);
  fflush(stdout);

  uint handle_once = min(usize, (uint)1024*1024);
    
  uint item_vec[handle_once]; //[16777216]; //


  for (uint handle_start = 0; handle_start < usize; handle_start += handle_once) {
    // printf("handle_start=%u\n", handle_start);
    fflush(stdout);

    uint num_handled = min(usize - handle_start, handle_once);
    for (item_k = 0; item_k < num_handled; ++ item_k)
      item_vec[item_k] = handle_start + item_k;

    // printf("before random shuffle\n");
    fflush(stdout);
    random_shuffle(item_vec, item_vec + num_handled);
    // printf("after random shuffle\n");
  
    for (item_k = 0; item_k < num_handled; ++ item_k) {
      item = item_vec[item_k];
      //printf("encode item=%u\n", item);
      //fflush(stdout);
    
      grab_vals(item, old_v, vals);

      uint cl = 0;
      // determine if can add to other pnb;
      uint pnb_to_add = UINT_NULL;
      for (cl = 0; cl < L; ++ cl) {
        lsh.lsh(vals, cl, item_lsh);

        uint hp_it = hash_t.find(item_lsh, cl);
        if (hp_it != UINT_NULL) {
          pnb_to_add = hash_t.hp.d[hp_it].pnb;
          break;
        }
      }

      uint item_pd = inv.item2pd[item];
      PNBucket* pn_it = NULL;
      if (pnb_to_add != UINT_NULL) {
        // merge into this bucket
        pn_it = &(pnb.d[pnb_to_add]);
        uint ori_pn_plist = pn_it->p_list;
        uint pn_pd = inv.item2pd[pn_it->p_list];
        uint root = inv.union_(item_pd, pn_pd);
        pn_it->p_list = inv.d_item(root);

        if (pn_it->p_list != ori_pn_plist) {
          /*
            for (cl = 0; cl < L; ++ cl) {
            item2hp[pn_it->p_list][cl].clear();
            item2hp[pn_it->p_list][cl].insert(item2hp[ori_pn_plist][cl].begin(), \
            item2hp[ori_pn_plist][cl].end());
            item2hp[ori_pn_plist][cl].clear();
            }
          */
          for (cl = 0; cl < L; ++ cl) {
            //printf("move from %u to %u\n", item2hp_id[ori_pn_plist][cl], item2hp_id[pn_it->p_list][cl]);
            //printf("before move id=%u: ", item2hp_id[ori_pn_plist][cl]);
            //item2hp_hash.print_id_all_c(item2hp_id[ori_pn_plist][cl]);
            //printf("\n");
            //printf("before move id=%u: ", item2hp_id[pn_it->p_list][cl]);
            //item2hp_hash.print_id_all_c(item2hp_id[pn_it->p_list][cl]);
            //printf("\n");
          
            item2hp_hash.clear(item2hp_id[pn_it->p_list][cl]);
            item2hp_hash.move_from_id_to_id(item2hp_id[ori_pn_plist][cl], \
                                            item2hp_id[pn_it->p_list][cl]);
          
            //printf("after move id=%u: ", item2hp_id[ori_pn_plist][cl]);
            //item2hp_hash.print_id_all_c(item2hp_id[ori_pn_plist][cl]);
            //printf("\n");
            //printf("after move id=%u: ", item2hp_id[pn_it->p_list][cl]);
            //item2hp_hash.print_id_all_c(item2hp_id[pn_it->p_list][cl]);
            //printf("\n");
          }        
        }
        // no need to update old_v (it is their accurate value).
        merge_neighbor_into_n_list(item, pn_it);
      }else{
        pnb_to_add = pnb.new_elem();
        pn_it = &(pnb.d[pnb_to_add]);
        pn_it->p_list = item;
        //assert((pn_it->n_list).size() == 0);
        merge_neighbor_into_n_list(item, pn_it);
      }

      for (cl = 0; cl < L; ++ cl) {
        lsh.lsh(vals, cl, item_lsh);

        uint hp_it = hash_t.find(item_lsh, cl);
        if (hp_it == UINT_NULL || hash_t.hp.d[hp_it].pnb != pnb_to_add) {
          uint hp_id = hash_t.hp.new_elem();
          HashPointer* hp_it = &(hash_t.hp.d[hp_id]);
          memcpy(hp_it->lsh_hash_code, item_lsh, sizeof(int)*K);
          hp_it->hash_code = hash_t.hash_from_lsh(item_lsh);
          hp_it->pnb = pnb_to_add;
          hash_t.insert(hp_id, hp_it->hash_code, cl);
          //item2hp[pn_it->p_list][cl].insert(hp_id);

          //printf("insert id=%u c=%u\n", item2hp_id[pn_it->p_list][cl], hp_id);
          //printf("before insert: ");
          //item2hp_hash.print_id_all_c(item2hp_id[pn_it->p_list][cl]);
          //printf("\n");
          item2hp_hash.insert_no_duplicate(item2hp_id[pn_it->p_list][cl], hp_id);        
          //printf("after insert: ");
          //item2hp_hash.print_id_all_c(item2hp_id[pn_it->p_list][cl]);
          //printf("\n");
        }      
      }

      //printf("after processing %u item\n", item);
      /*
        if (item % Ny == 0) {
        printf("after processing %u items\n", item);
        //hash_t.print_hash_table(inv);
        } 
      */   
    }
  }
  //inv.check_from_dfslist(usize);
  //check_non_empty_item2hp();
}
