
#include <Rmath.h>
#include <R.h>
#include "rf.h"

/*build one tree, call this multiple times to create a forest in regrf*/
void regRLTree(double *x, double *y, int mdim, int nsample, int *lDaughter,
             int *rDaughter,
             double *splitCutoff, double *nodeMeans, int *nodestatus, int numNodes,
             int *treeSize, int nthsize, int mtry, int *bestSplitVarInNode, int *cat,int *maxcat,
           int *varUsed, int *seqtreesize,int *cseqtreesize,
           int *lDaughterk,int *rDaughterk ,int *nodestatusk ,int *bestSplitVarInNodek, double *nodeMeansk, double *splitCutoffk,
           int *lensubtree,  int nTreesub, int nthsizesubk, double Lp) {// double *nodeImpurityDecrease
  int i, j, k, m, n,currentNode, *rowIndices, *nodestart, *numPointsInEachNode,*rowIndicesy;
  int ndstart, ndend,ndstarty,ndendy ,leftChildLastDataPoint,leftChildLastDataPointy, currentNodeCount, splitResult,splitResulty, msplit,msplity;
  double label, sumOfSquares, nodeMean, decsplit, ubest, sumOfNodeData;
  double ybest; // optimal splitting value for y
  double decsplity; // info decrease for splitting y
  double p,u; // p: probability chosing horizonal splitting. unif: bernouli random sampling
  int pointer_v=0,pointer_vk=0,pointer_vkt=0,sum;// pointer of vertical splitting node
  // pointer_v=0;
  // int pointer_vk;
  // pointer_vk=0;
  // int pointer_vk;
  // pointer_vk=0;
  // pointer_vkt
  nodestart = (int *) Calloc(numNodes, int);
  numPointsInEachNode   = (int *) Calloc(numNodes, int);



  int *lDaughterkk,*rDaughterkk,*bestSplitVarInNodekk,nrnodesk,*nodestatuskk,*treeSizerf;
  double *xb,*yb,*nodeMeanskk,*splitCutoffkk,*ypred, *ypesudo;
  //xindrf = (int *) Calloc(numNodes*5, int);


  /* initialize some arrays for the tree */
  zeroInt(nodestatus, numNodes);
  zeroInt(nodestart, numNodes);
  zeroInt(numPointsInEachNode, numNodes);
  zeroDouble(nodeMeans, numNodes);
  //zeroInt(xindrf, numNodes*5);
  ypesudo = (double *) Calloc(nsample, double);
  zeroDouble(ypesudo, nsample);

  rowIndices = (int *) Calloc(nsample, int);
  rowIndicesy = (int *) Calloc(nsample, int); //temp
  for (i = 1; i <= nsample; ++i){
    rowIndices[i-1] = i;
  }
  for (i = 1; i <= nsample; ++i){
    rowIndicesy[i-1] = i;
  }
  currentNode = 0;
  nodestart[0] = 0;
  numPointsInEachNode[0] = nsample;
  nodestatus[0] = NODE_TOSPLIT;

  /* compute mean and sum of squares for Y */
  nodeMean = 0.0;
  sumOfSquares = 0.0;
  for (i = 0; i < nsample; ++i) {
    label = y[rowIndices[i] - 1];
    sumOfSquares += i * (nodeMean - label) * (nodeMean - label) / (i + 1);
    nodeMean = (i * nodeMean + label) / (i + 1);
  }
  nodeMeans[0] = nodeMean;

  /* start main loop */
  for (k = 0; k < numNodes - 2; ++k) {
    if (k > currentNode || currentNode >= numNodes - 2) {
      break;
    }
    /* skip if the node is not to be split */
    if (nodestatus[k] != NODE_TOSPLIT) {
      continue;
    }

#ifdef RF_DEBUG
    Rprintf("regTree: k=%d, av=%f, ss=%f\n", k, av, ss);
#endif

    /* initialize for next call to findbestsplit */
    ndstart = nodestart[k];
    ndend = ndstart + numPointsInEachNode[k] - 1;
    currentNodeCount = numPointsInEachNode[k];
    sumOfNodeData = currentNodeCount * nodeMeans[k];
    splitResult = 0;
    splitResulty = 0;
    decsplit = 0.0;
    decsplity = 0.0; // optimal decrease of information w.r.t opt y
    /* initialize x,y if vertical splitting performed (randomforest)*/

#ifdef RF_DEBUG
    Rprintf("before findBestSplit: ndstart=%d, ndend=%d, jstat=%d, decsplit=%f\n",
            ndstart, ndend, jstat, decsplit);
#endif


    p= Lp;//decsplity/(decsplit+decsplity);
   u=unif_rand() ;
     // Rprintf("p=%f\n",
     //         p);
  if(u<= p){// horizontal splitting (i.e ordinary findbestsplit)u<= p || p==0.5
    findBestSplitRL(x, rowIndices, y, mdim, nsample, ndstart, ndend, &msplit,
                    &decsplit, &ubest, &leftChildLastDataPoint, &splitResult, mtry, sumOfNodeData,
                    currentNodeCount, cat);
    if (splitResult == NODE_TERMINAL  ) {
      /* Node is terminal: Mark it as such and move on to the next. */
      nodestatus[k] = NODE_TERMINAL;
      
      continue;
    }
    

    bestSplitVarInNode[k] = msplit;
    varUsed[msplit-1] = 1;
    splitCutoff[k] = ubest;
   // nodeImpurityDecrease[msplit] += decsplit;
    nodestatus[k] = NODE_INTERIOR;// implying this node is horizonally spliited (by feature)
    numPointsInEachNode[currentNode + 1] = leftChildLastDataPoint - ndstart + 1;
    numPointsInEachNode[currentNode + 2] = ndend - leftChildLastDataPoint;
    nodestart[currentNode + 1] = ndstart;
    nodestart[currentNode + 2] = leftChildLastDataPoint + 1;
  }else{
    sum=0;
    
    yb         = (double *) Calloc(currentNodeCount, double);
    xb         = (double *) Calloc(mdim * currentNodeCount, double);
    ypred       = (double *) Calloc(currentNodeCount, double);
    nrnodesk = 2 * currentNodeCount + 1;
    //Rprintf("nrnodes=%d",nrnodes);
    lDaughterkk = (int *) Calloc(nrnodesk * nTreesub, int);
    rDaughterkk = (int *) Calloc(nrnodesk * nTreesub, int);
    nodestatuskk = (int *) Calloc(nrnodesk * nTreesub, int);
    bestSplitVarInNodekk=(int *) Calloc(nrnodesk * nTreesub, int);
    nodeMeanskk = (double *) Calloc(nrnodesk * nTreesub, double);
    splitCutoffkk =  (double *) Calloc(nrnodesk * nTreesub, double);
    
    
    zeroDouble(nodeMeanskk, nrnodesk * nTreesub);
    zeroDouble(splitCutoffkk, nrnodesk * nTreesub);
    zeroDouble(ypred,currentNodeCount);
    //xdimk = (int *) Calloc(2, int);
    //x in current node
    
    
    for(n = 0; n < currentNodeCount; ++n){
      
      //in[k] += 1;
      yb[n] = y[rowIndices[n+ndstart]-1];
      
      for(m = 0; m < mdim; ++m) {
        xb[m + n * mdim] = x[m + (rowIndices[n+ndstart]-1) * mdim];
      }
    }
    treeSizerf = (int *) Calloc(nTreesub, int);
    zeroInt(treeSizerf, nTreesub);
    
    regRF(xb, yb, 0,1.0,mdim, currentNodeCount,
          nthsizesubk,nrnodesk,nTreesub, mtry,
          cat, maxcat,treeSizerf,nodestatuskk,
          lDaughterkk, rDaughterkk, nodeMeanskk, bestSplitVarInNodekk,
          splitCutoffkk,1,1,ypred);



    
    for (i = 0; i < nTreesub; i++){
      sum = sum + treeSizerf[i];//*(treeSizerf + i);
      
      
    }
    
    // Rprintf("totaltreesize=%d\n",
    //         sum);
    // for(n=0; n<nrnodesk*5;++n){
    //
    //   Rprintf("nodeMeanskk=%f\n",
    //           nodeMeanskk[n]);
    // }
    seqtreesize[pointer_vkt]=-1-k;
    for(j = pointer_vkt+1; j <= pointer_vkt+ nTreesub ; ++j){
      
      seqtreesize[j]=treeSizerf[j-pointer_vkt-1];
      
      
    }
    //
    pointer_vkt+= 1+ nTreesub;
    
    *cseqtreesize+=1+nTreesub;

    
    // //sub rf
    lDaughterk[pointer_vk]=-1-k;
    rDaughterk[pointer_vk]=-1-k;
    nodestatusk[pointer_vk]=-1-k;
    bestSplitVarInNodek[pointer_vk]=-1-k;
    nodeMeansk[pointer_vk]=1.0*(-1-k);
    splitCutoffk[pointer_vk]=1.0*(-1-k);
    
    for(j = pointer_vk+1; j <= pointer_vk+1+sum ; ++j){
      
      lDaughterk[j]=lDaughterkk[j-pointer_vk-1];
      rDaughterk[j]=rDaughterkk[j-pointer_vk-1];
      nodestatusk[j]=nodestatuskk[j-pointer_vk-1];
      bestSplitVarInNodek[j]=bestSplitVarInNodekk[j-pointer_vk-1];
      
      nodeMeansk[j]=nodeMeanskk[j-pointer_vk-1];
      splitCutoffk[j]=splitCutoffkk[j-pointer_vk-1];
      // Rprintf("nodeMeansk=%f\n",
      //         nodeMeansk[j]);
      
    }
    pointer_vk=1+pointer_vk+sum;
    
    *lensubtree+=1+sum;//nrnodesk* 5;
    
    
    
    //
    Free(yb);
    Free(xb);
    Free(lDaughterkk);
    Free(rDaughterkk);
    Free(nodestatuskk);
    Free(bestSplitVarInNodekk);
    Free(nodeMeanskk);
    Free(splitCutoffkk);
    
    for(n = 0; n < currentNodeCount; ++n){
      
      //in[k] += 1;
      ypesudo[rowIndices[n+ndstart]-1]=ypred[n];

    }
    
    findBestSplitRLy(ypesudo,rowIndices, y, 1, nsample, ndstart, ndend, &msplity,
                     &decsplity, &ybest, &leftChildLastDataPoint, &splitResulty, 1, sumOfNodeData,
                     currentNodeCount, 1);
    if (splitResulty == NODE_TERMINAL ) {
      /* Node is terminal: Mark it as such and move on to the next. */
      nodestatus[k] = NODE_TERMINAL;
      
      continue;
    }
    bestSplitVarInNode[k] = 0; // implying y is used
    //varUsed[0] = 1;
    splitCutoff[k] = ybest;

    nodestatus[k] = NODE_INTERIOR_V;// implying this node is vertically spliited (by y)


    /* left node is the node after current, right node is 2 after */
    numPointsInEachNode[currentNode + 1] = leftChildLastDataPoint - ndstart + 1;
    numPointsInEachNode[currentNode + 2] = ndend - leftChildLastDataPoint;
    nodestart[currentNode + 1] = ndstart;
    nodestart[currentNode + 2] = leftChildLastDataPoint + 1;




  }
    // bestSplitVarInNode[k] = msplit;
    // varUsed[msplit-1] = 1;
    // splitCutoff[k] = ubest;
    // //nodeImpurityDecrease[msplit-1] += decsplit;
    // nodestatus[k] = NODE_INTERIOR;



    /* compute mean and sum of squares for the left daughter node */
    nodeMean = 0.0;
    sumOfSquares = 0.0;
    for (j = ndstart; j <= leftChildLastDataPoint; ++j) {
      label = y[rowIndices[j]-1];
      m = j - ndstart;
      sumOfSquares += m * (nodeMean - label) * (nodeMean - label) / (m + 1);
      nodeMean = (m * nodeMean + label) / (m+1);
    }
    nodeMeans[currentNode+1] = nodeMean;
    nodestatus[currentNode+1] = NODE_TOSPLIT;
    if (numPointsInEachNode[currentNode + 1] <= nthsize) {
      nodestatus[currentNode + 1] = NODE_TERMINAL;
    }

    /* compute mean and sum of squares for the right daughter node */
    nodeMean = 0.0;
    sumOfSquares = 0.0;
    for (j = leftChildLastDataPoint + 1; j <= ndend; ++j) {
      label = y[rowIndices[j]-1];
      m = j - (leftChildLastDataPoint + 1);
      sumOfSquares += m * (nodeMean - label) * (nodeMean - label) / (m + 1);
      nodeMean = (m * nodeMean + label) / (m + 1);
    }
    nodeMeans[currentNode + 2] = nodeMean;
    nodestatus[currentNode + 2] = NODE_TOSPLIT;
    if (numPointsInEachNode[currentNode + 2] <= nthsize) {
      nodestatus[currentNode + 2] = NODE_TERMINAL;
    }

    /* map the daughter nodes */
    lDaughter[k] = currentNode + 1 + 1;
    rDaughter[k] = currentNode + 2 + 1;

    /* Augment the tree by two nodes. */
    currentNode += 2;
#ifdef RF_DEBUG
    Rprintf(" after split: ldaughter=%d, rdaughter=%d, ncur=%d\n",
            lDaughter[k], rDaughter[k], ncur);
#endif

  }
  *treeSize = numNodes;
  for (k = numNodes - 1; k >= 0; --k) {
    if (nodestatus[k] == 0) (*treeSize)--;
    if (nodestatus[k] == NODE_TOSPLIT) {
      nodestatus[k] = NODE_TERMINAL;
    }
  }

  Free(nodestart);
  Free(rowIndices);
  Free(numPointsInEachNode);
}

/*
call on every node in a tree. randomly choose some features and
find which one has best split
*/
void findBestSplitRL(double *x, int *jdex, double *y, int mdim, int nsample,
                   int ndstart, int ndend, int *bestVarToReturn, double *decsplit,
                   double *bestSplitToReturn, int *ndendl, int *jstat, int mtry,
                   double sumnode, int nodeCount, int *cat) {
  int last, numCategoriesAllVars[MAX_CAT], icat[MAX_CAT], numCategoriesForVar, nl, nr, npopl, npopr, tieVar, tieVal;
  int i, j, kv, l, *varIndices, *ncase;
  double *xt, *ut, *v, *yl, sumcat[MAX_CAT], avcat[MAX_CAT], tavcat[MAX_CAT], valueAtBestSplit;
  double crit, bestSplitForAllVariables, bestSplitWithinVariable, suml, sumr, d, critParent;
  /*Calloc is different from calloc, r handles memory allocation instead of os*/
  ut = (double *) Calloc(nsample, double);
  xt = (double *) Calloc(nsample, double);
  v  = (double *) Calloc(nsample, double);
  yl = (double *) Calloc(nsample, double);
  varIndices  = (int *) Calloc(mdim, int);
  ncase = (int *) Calloc(nsample, int);
  zeroDouble(avcat, MAX_CAT);
  zeroDouble(tavcat, MAX_CAT);

  /* START BIG LOOP */
  *bestVarToReturn = -1;
  *decsplit = 0.0;
  bestSplitForAllVariables = 0.0;
  valueAtBestSplit = 0.0;
  for (i=0; i < mdim; ++i)
  {
    varIndices[i] = i;
  }

  last = mdim - 1;
  tieVar = 1;
  /*choose mtry number of variables, choose the one with best split*/
  for (i = 0; i < mtry; ++i)
  {
    /*sample without replacement: choose random, move to end,
    do next random choice in range 1 to len - n*/
    bestSplitWithinVariable = 0.0;
    j = (int) (unif_rand() * (last+1));
    kv = varIndices[j];
    swapInt(varIndices[j], varIndices[last]);
    last--;

    numCategoriesForVar = cat[kv];
    if (numCategoriesForVar == 1) {
      /* numeric variable */
      for (j = ndstart; j <= ndend; ++j) {
        xt[j] = x[kv + (jdex[j] - 1) * mdim]; /*indexing to represent 2d in a 1d vector */
        yl[j] = y[jdex[j] - 1];
      }
    } else {
      /* categorical variable */
      zeroInt(numCategoriesAllVars, MAX_CAT);
      zeroDouble(sumcat, MAX_CAT);
      for (j = ndstart; j <= ndend; ++j) {
        l = (int) x[kv + (jdex[j] - 1) * mdim];
        sumcat[l - 1] += y[jdex[j] - 1];
        numCategoriesAllVars[l - 1] ++;
      }
      /* Compute means of Y by category. */
      for (j = 0; j < numCategoriesForVar; ++j) {
        avcat[j] = numCategoriesAllVars[j] ? sumcat[j] / numCategoriesAllVars[j] : 0.0;
      }
      /* Make the category mean the `pseudo' X data. */
      for (j = 0; j < nsample; ++j) {
        xt[j] = avcat[(int) x[kv + (jdex[j] - 1) * mdim] - 1];
        yl[j] = y[jdex[j] - 1];
      }
    }
    /* copy the x data in this node. */
    for (j = ndstart; j <= ndend; ++j) {
      v[j] = xt[j];
    }

    for (j = 1; j <= nsample; ++j) {
      ncase[j - 1] = j;
    }

    R_qsort_I(v, ncase, ndstart + 1, ndend + 1);

    if (v[ndstart] >= v[ndend]) {
      continue;
    }
    /* ncase(n)=case number of v nth from bottom */
    /* Start from the right and search to the left. */
    critParent = sumnode * sumnode / nodeCount;
    suml = 0.0;
    sumr = sumnode;
    npopl = 0;
    npopr = nodeCount;
    crit = 0.0;
    tieVal = 1;
    /* Search through the "gaps" in the x-variable. */
    for (j = ndstart; j <= ndend - 1; ++j) {
      d = yl[ncase[j] - 1];
      suml += d;
      sumr -= d;
      npopl++;
      npopr--;
      if (v[j] < v[j+1]) {
        crit = (suml * suml / npopl) + (sumr * sumr / npopr) - critParent;
        if (crit > bestSplitWithinVariable) {
          valueAtBestSplit = (v[j] + v[j+1]) / 2.0;
          bestSplitWithinVariable = crit;
          tieVal = 1;
        }
        if (crit == bestSplitWithinVariable) {
          tieVal++;
          if (unif_rand() < 1.0 / tieVal) {
            valueAtBestSplit = (v[j] + v[j+1]) / 2.0;
            bestSplitWithinVariable = crit;
          }
        }
      }
    }
    if (bestSplitWithinVariable > bestSplitForAllVariables) {
      *bestSplitToReturn = valueAtBestSplit;
      *bestVarToReturn = kv + 1;
      bestSplitForAllVariables = bestSplitWithinVariable;
      for (j = ndstart; j <= ndend; ++j) {
        ut[j] = xt[j];
      }
      if (cat[kv] > 1) {
        for (j = 0; j < cat[kv]; ++j) tavcat[j] = avcat[j];
      }
      tieVar = 1;
    }
    if (bestSplitWithinVariable == bestSplitForAllVariables) {
      tieVar++;
      if (unif_rand() < 1.0 / tieVar) {
        *bestSplitToReturn = valueAtBestSplit;
        *bestVarToReturn = kv + 1;
        bestSplitForAllVariables = bestSplitWithinVariable;
        for (j = ndstart; j <= ndend; ++j) {
          ut[j] = xt[j];
        }
        if (cat[kv] > 1) {
          for (j = 0; j < cat[kv]; ++j) tavcat[j] = avcat[j];
        }
      }
    }

  }
  *decsplit = bestSplitForAllVariables;

  /* If best split can not be found, set to terminal node and return. */
  if (*bestVarToReturn != -1) {
    nl = ndstart;
    for (j = ndstart; j <= ndend; ++j) {
      if (ut[j] <= *bestSplitToReturn) {
        nl++;
        ncase[nl-1] = jdex[j];
      }
    }
    *ndendl = imax2(nl - 1, ndstart);
    nr = *ndendl + 1;
    for (j = ndstart; j <= ndend; ++j) {
      if (ut[j] > *bestSplitToReturn) {
        if (nr >= nsample) break;
        nr++;
        ncase[nr - 1] = jdex[j];
      }
    }
    if (*ndendl >= ndend) *ndendl = ndend - 1;
    for (j = ndstart; j <= ndend; ++j) jdex[j] = ncase[j];

    numCategoriesForVar = cat[*bestVarToReturn - 1];
    if (numCategoriesForVar > 1) {
      for (j = 0; j < numCategoriesForVar; ++j) {
        icat[j] = (tavcat[j] < *bestSplitToReturn) ? 1 : 0;
      }
      *bestSplitToReturn = pack(numCategoriesForVar, icat);
    }
  } else *jstat = NODE_TERMINAL;

  Free(ncase);
  Free(varIndices);
  Free(v);
  Free(yl);
  Free(xt);
  Free(ut);
}

void findBestSplitRLy(double *x, int *jdex, double *y, int mdim, int nsample,
                     int ndstart, int ndend, int *bestVarToReturn, double *decsplit,
                     double *bestSplitToReturn, int *ndendl, int *jstat, int mtry,
                     double sumnode, int nodeCount, int cat) {
  int last, numCategoriesAllVars[MAX_CAT], icat[MAX_CAT], numCategoriesForVar, nl, nr, npopl, npopr, tieVar, tieVal;
  int i, j, kv, l, *varIndices, *ncase;
  double *xt, *ut, *v, *yl, sumcat[MAX_CAT], avcat[MAX_CAT], tavcat[MAX_CAT], valueAtBestSplit;
  double crit, bestSplitForAllVariables, bestSplitWithinVariable, suml, sumr, d, critParent;
  /*Calloc is different from calloc, r handles memory allocation instead of os*/
  ut = (double *) Calloc(nsample, double);
  xt = (double *) Calloc(nsample, double);
  v  = (double *) Calloc(nsample, double);
  yl = (double *) Calloc(nsample, double);
  varIndices  = (int *) Calloc(mdim, int);
  ncase = (int *) Calloc(nsample, int);
  zeroDouble(avcat, MAX_CAT);
  zeroDouble(tavcat, MAX_CAT);
  
  /* START BIG LOOP */
  *bestVarToReturn = -1;
  *decsplit = 0.0;
  bestSplitForAllVariables = 0.0;
  valueAtBestSplit = 0.0;
  for (i=0; i < mdim; ++i)
  {
    varIndices[i] = i;
  }
  
  last = mdim - 1;
  tieVar = 1;
  /*choose mtry number of variables, choose the one with best split*/
  for (i = 0; i < mtry; ++i)
  {
    /*sample without replacement: choose random, move to end,
     do next random choice in range 1 to len - n*/
    bestSplitWithinVariable = 0.0;
    j = (int) (unif_rand() * (last+1));
    kv = varIndices[j];
    swapInt(varIndices[j], varIndices[last]);
    last--;
    
    numCategoriesForVar = cat;
    if (numCategoriesForVar == 1) {
      /* numeric variable */
      for (j = ndstart; j <= ndend; ++j) {
        xt[j] = x[kv + (jdex[j] - 1) * mdim]; /*indexing to represent 2d in a 1d vector */
      yl[j] = y[jdex[j] - 1];
      }
    } else {
      /* categorical variable */
      zeroInt(numCategoriesAllVars, MAX_CAT);
      zeroDouble(sumcat, MAX_CAT);
      for (j = ndstart; j <= ndend; ++j) {
        l = (int) x[kv + (jdex[j] - 1) * mdim];
        sumcat[l - 1] += y[jdex[j] - 1];
        numCategoriesAllVars[l - 1] ++;
      }
      /* Compute means of Y by category. */
      for (j = 0; j < numCategoriesForVar; ++j) {
        avcat[j] = numCategoriesAllVars[j] ? sumcat[j] / numCategoriesAllVars[j] : 0.0;
      }
      /* Make the category mean the `pseudo' X data. */
      for (j = 0; j < nsample; ++j) {
        xt[j] = avcat[(int) x[kv + (jdex[j] - 1) * mdim] - 1];
        yl[j] = y[jdex[j] - 1];
      }
    }
    /* copy the x data in this node. */
    for (j = ndstart; j <= ndend; ++j) {
      v[j] = xt[j];
    }
    
    for (j = 1; j <= nsample; ++j) {
      ncase[j - 1] = j;
    }
    
    R_qsort_I(v, ncase, ndstart + 1, ndend + 1);
    
    if (v[ndstart] >= v[ndend]) {
      continue;
    }
    /* ncase(n)=case number of v nth from bottom */
    /* Start from the right and search to the left. */
    critParent = sumnode * sumnode / nodeCount;
    suml = 0.0;
    sumr = sumnode;
    npopl = 0;
    npopr = nodeCount;
    crit = 0.0;
    tieVal = 1;
    /* Search through the "gaps" in the x-variable. */
    for (j = ndstart; j <= ndend - 1; ++j) {
      d = yl[ncase[j] - 1];
      suml += d;
      sumr -= d;
      npopl++;
      npopr--;
      if (v[j] < v[j+1]) {
        crit = (suml * suml / npopl) + (sumr * sumr / npopr) - critParent;
        if (crit > bestSplitWithinVariable) {
          valueAtBestSplit = (v[j] + v[j+1]) / 2.0;
          bestSplitWithinVariable = crit;
          tieVal = 1;
        }
        if (crit == bestSplitWithinVariable) {
          tieVal++;
          if (unif_rand() < 1.0 / tieVal) {
            valueAtBestSplit = (v[j] + v[j+1]) / 2.0;
            bestSplitWithinVariable = crit;
          }
        }
      }
    }
    if (bestSplitWithinVariable > bestSplitForAllVariables) {
      *bestSplitToReturn = valueAtBestSplit;
      *bestVarToReturn = kv + 1;
      bestSplitForAllVariables = bestSplitWithinVariable;
      for (j = ndstart; j <= ndend; ++j) {
        ut[j] = xt[j];
      }
      if (cat > 1) {
        for (j = 0; j < cat; ++j) tavcat[j] = avcat[j];
      }
      tieVar = 1;
    }
    if (bestSplitWithinVariable == bestSplitForAllVariables) {
      tieVar++;
      if (unif_rand() < 1.0 / tieVar) {
        *bestSplitToReturn = valueAtBestSplit;
        *bestVarToReturn = kv + 1;
        bestSplitForAllVariables = bestSplitWithinVariable;
        for (j = ndstart; j <= ndend; ++j) {
          ut[j] = xt[j];
        }
        if (cat > 1) {
          for (j = 0; j < cat; ++j) tavcat[j] = avcat[j];
        }
      }
    }
    
  }
  *decsplit = bestSplitForAllVariables;
  
  /* If best split can not be found, set to terminal node and return. */
  if (*bestVarToReturn != -1) {
    nl = ndstart;
    for (j = ndstart; j <= ndend; ++j) {
      if (ut[j] <= *bestSplitToReturn) {
        nl++;
        ncase[nl-1] = jdex[j];
      }
    }
    *ndendl = imax2(nl - 1, ndstart);
    nr = *ndendl + 1;
    for (j = ndstart; j <= ndend; ++j) {
      if (ut[j] > *bestSplitToReturn) {
        if (nr >= nsample) break;
        nr++;
        ncase[nr - 1] = jdex[j];
      }
    }
    if (*ndendl >= ndend) *ndendl = ndend - 1;
    for (j = ndstart; j <= ndend; ++j) jdex[j] = ncase[j];
    
    numCategoriesForVar = cat;
    if (numCategoriesForVar > 1) {
      for (j = 0; j < numCategoriesForVar; ++j) {
        icat[j] = (tavcat[j] < *bestSplitToReturn) ? 1 : 0;
      }
      *bestSplitToReturn = pack(numCategoriesForVar, icat);
    }
  } else *jstat = NODE_TERMINAL;
  
  Free(ncase);
  Free(varIndices);
  Free(v);
  Free(yl);
  Free(xt);
  Free(ut);
}

/*====================================================================*/
void predictRegRLTree(double *x,  int nsample, int mdim,
                    int *lDaughter, int *rDaughter, int *nodestatus,
                    double *ypred, double *split, double *nodepred,
                    int *splitVar, int treeSize, int *cat, int maxcat,
                    int *nodex, int *seqtreesize,
                    int *lDaughterk,int *rDaughterk ,int *nodestatusk ,int *bestSplitVarInNodek, double *nodeMeansk, double *splitCutoffk,
                    int  nTreesub) {
  int i, j, k, m,n,t,total_treenode,node, *cbestsplit,pointer_xindrf,index_k,index_kk,*nodexk,*treeSizerf;
  double dpack,*ypredk ,*nodeMeanskk,*splitCutoffkk;

  int currentNodeCount,nrnodescurrent,xdimk;
  double *xc;
  int *lDaughterkk,*rDaughterkk , *nodestatuskk , *bestSplitVarInNodekk;
  // Rprintf("xindrf=%d\n",
  //         xindrf[0]);
  // Rprintf("nodestatusk=%d\n",
  //         nodestatusk[0]);

//
  /* decode the categorical splits */
  if (maxcat > 1) {
    cbestsplit = (int *) Calloc(maxcat * treeSize, int);
    zeroInt(cbestsplit, maxcat * treeSize);
    for (i = 0; i < treeSize; ++i) {
      if (nodestatus[i] != NODE_TERMINAL && cat[splitVar[i] - 1] > 1) {
        dpack = split[i];
        /* unpack `npack' into bits */
        /* unpack(dpack, maxcat, cbestsplit + i * maxcat); */
        for (j = 0; j < cat[splitVar[i] - 1]; ++j) {
          cbestsplit[j + i*maxcat] = ((unsigned long) dpack & 1) ? 1 : 0;
          dpack = dpack / 2.0 ;
          /* cbestsplit[j + i*maxcat] = npack & 1; */
        }
      }
    }
  }
//
  for (i = 0; i < nsample; ++i) {
    k = 0;
    pointer_xindrf=0;
    while (nodestatus[k] != NODE_TERMINAL) { /* go down the tree */
      //     m = splitVar[k] - 1;


      if(nodestatus[k]==NODE_INTERIOR){

        m = splitVar[k] - 1;
        if (cat[m] == 1) {

            k = (x[m + i*mdim] <= split[k]) ?
            lDaughter[k] - 1 : rDaughter[k] - 1;

        } else {
          /* Split by a categorical predictor */
            k = cbestsplit[(int) x[m + i * mdim] - 1 + k * maxcat] ?
            lDaughter[k] - 1 : rDaughter[k] - 1;

          // k = cbestsplit[(int) x[m + i * mdim] - 1 + k * maxcat] ?
          // lDaughter[k] - 1 : rDaughter[k] - 1;
        }

       // Rprintf("m=%d",splitVar[k]);
      }
else if (nodestatus[k]==NODE_INTERIOR_V ){



    //find index of -k-1

    index_k =0;
    t=0;
    while(seqtreesize[t]!=-k-1){
      t++;
    }
    index_k=t;
    // Rprintf("nTreesub=%d",nTreesub);

//     // find currentNodeCount
//
treeSizerf = (int *) Calloc(nTreesub, int);
zeroInt(treeSizerf, nTreesub);
    //t=index_k+1;
    currentNodeCount=0;
    for(t=index_k+1; t<index_k+1+nTreesub;++t  ){
      currentNodeCount+=seqtreesize[t];
      treeSizerf[t-index_k-1]=seqtreesize[t];

    }
//
//
//
//     //
//     Rprintf("currentNodeCount=%d",currentNodeCount);

    // //
   // yb         = (double *) Calloc(currentNodeCount, double);
   // xb         = (double *) Calloc(mdim * currentNodeCount, double);
   xc         =(double *) Calloc(mdim * 1,double);
    //nrnodescurrent = 2 * currentNodeCount + 1;
   //Rprintf("nrnodes=%d",nrnodes);
   lDaughterkk = (int *) Calloc(currentNodeCount, int);
   rDaughterkk = (int *) Calloc(currentNodeCount, int);
   nodestatuskk = (int *) Calloc(currentNodeCount, int);
   bestSplitVarInNodekk=(int *) Calloc(currentNodeCount, int);
   nodeMeanskk = (double *) Calloc(currentNodeCount, double);
   splitCutoffkk =  (double *) Calloc(currentNodeCount, double);
    ypredk         =(double *) Calloc(1, double);
   zeroDouble(ypredk, 1);
       zeroDouble(nodeMeanskk ,currentNodeCount);
       zeroDouble(splitCutoffkk, currentNodeCount);

    total_treenode=0;
    while(lDaughterk[total_treenode]!=-k-1){
      total_treenode++;
    }

    index_kk=total_treenode;

    for(node=0; node<currentNodeCount;++node){
      lDaughterkk[node] = lDaughterk[index_kk+1+node];

      rDaughterkk[node] = rDaughterk[index_kk+1+node];
      nodestatuskk[node] = nodestatusk[index_kk+1+node];
      bestSplitVarInNodekk[node] = bestSplitVarInNodek[index_kk+1+node];
      nodeMeanskk[node] = nodeMeansk[index_kk+1+node];
      splitCutoffkk[node] = splitCutoffk[index_kk+1+node];
      //Rprintf("nodeMeanskk=%f",nodeMeanskk[node]);

    }



    // copy of current x[i]
    for(m = 0; m < mdim; ++m) {
      xc[m] = x[m + i * mdim];
      //Rprintf("xc=%f",xc[m]);

    }


// // // //
regForest(xc, ypredk, mdim, 1,
          nTreesub, lDaughterkk, rDaughterkk,
          nodestatuskk, treeSizerf, splitCutoffkk,
          nodeMeanskk, bestSplitVarInNodekk,treeSizerf, cat,
          maxcat,0);


    //Rprintf("ypredk=%f",ypredk[0]);
    k = (ypredk[0]<= split[k]) ?
    lDaughter[k] - 1 : rDaughter[k] - 1;
   // Rprintf("k=%d",k);



    Free(lDaughterkk);
    Free(rDaughterkk);
    Free(bestSplitVarInNodekk);
    Free(nodeMeanskk);
     Free(treeSizerf);
     Free(splitCutoffkk);
       Free(nodestatuskk);
     Free(xc);
     Free(ypredk);

  }
  //pointer_xindrf++;//=pointer_xindrf+currentNodeCount;
    }

    /* terminal node: assign prediction and move on to next */
    ypred[i] = nodepred[k];
    //Rprintf("ypredk=%f", ypred[i]);
    nodex[i] = k + 1;
  }


  if (maxcat > 1) Free(cbestsplit);
}
