
#include <Rmath.h>
#include <R.h>
#include "rf.h"

/*build one tree, call this multiple times to create a forest in regrf*/
void regRLTreeDV(double *x, double *y, int mdim, int nsample, int *lDaughter,
               int *rDaughter,
               double *splitCutoff, double *nodeMeans, int *nodestatus, int numNodes,
               int *treeSize, int nthsize, int mtry, int *bestSplitVarInNode, int *cat,
               int *varUsed, int *seqtreesize,int *cseqtreesize,
               int *lDaughterk,int *rDaughterk ,int *nodestatusk ,int *bestSplitVarInNodek, double *nodeMeansk, double *splitCutoffk,
               int *lensubtree,  int nTreesub, int nthsizesub) {// double *nodeImpurityDecrease
  int i, j, k, m, n,currentNode, *rowIndices, *nodestart, *numPointsInEachNode,*rowIndicesy;
  int ndstart, ndend,ndstarty,ndendy ,leftChildLastDataPoint,leftChildLastDataPointy, currentNodeCount, splitResult,splitResulty, msplit,msplity;
  double label, sumOfSquares, nodeMean, decsplit, ubest, sumOfNodeData;
  double ybest; // optimal splitting value for y
  double decsplity; // info decrease for splitting y
  double p,u; // p: probability chosing horizonal splitting. unif: bernouli random sampling
  int pointer_v=0,pointer_vk=0,pointer_vkt=0,sum;// pointer of vertical splitting node
  // pointer_v=0;
  // int pointer_vk;
  // pointer_vk=0;
  // int pointer_vk;
  // pointer_vk=0;
  // pointer_vkt
  nodestart = (int *) Calloc(numNodes, int);
  numPointsInEachNode   = (int *) Calloc(numNodes, int);
  
  
  
  int *lDaughterkk,*rDaughterkk,*bestSplitVarInNodekk,nrnodesk,*nodestatuskk,*treeSizerf;
  double *xb,*yb,*nodeMeanskk,*splitCutoffkk;
  //xindrf = (int *) Calloc(numNodes*5, int);
  
  
  /* initialize some arrays for the tree */
  zeroInt(nodestatus, numNodes);
  zeroInt(nodestart, numNodes);
  zeroInt(numPointsInEachNode, numNodes);
  zeroDouble(nodeMeans, numNodes);
  //zeroInt(xindrf, numNodes*5);
  
  rowIndices = (int *) Calloc(nsample, int);
  rowIndicesy = (int *) Calloc(nsample, int); //temp
  for (i = 1; i <= nsample; ++i){
    rowIndices[i-1] = i;
  }
  for (i = 1; i <= nsample; ++i){
    rowIndicesy[i-1] = i;
  }
  currentNode = 0;
  nodestart[0] = 0;
  numPointsInEachNode[0] = nsample;
  nodestatus[0] = NODE_TOSPLIT;
  
  /* compute mean and sum of squares for Y */
  nodeMean = 0.0;
  sumOfSquares = 0.0;
  for (i = 0; i < nsample; ++i) {
    label = y[rowIndices[i] - 1];
    sumOfSquares += i * (nodeMean - label) * (nodeMean - label) / (i + 1);
    nodeMean = (i * nodeMean + label) / (i + 1);
  }
  nodeMeans[0] = nodeMean;
  
  /* start main loop */
  for (k = 0; k < numNodes - 2; ++k) {
    if (k > currentNode || currentNode >= numNodes - 2) {
      break;
    }
    /* skip if the node is not to be split */
    if (nodestatus[k] != NODE_TOSPLIT) {
      continue;
    }
    
#ifdef RF_DEBUG
    Rprintf("regTree: k=%d, av=%f, ss=%f\n", k, av, ss);
#endif
    
    /* initialize for next call to findbestsplit */
    ndstart = nodestart[k];
    ndend = ndstart + numPointsInEachNode[k] - 1;
    currentNodeCount = numPointsInEachNode[k];
    sumOfNodeData = currentNodeCount * nodeMeans[k];
    splitResult = 0;
    splitResulty = 0;
    decsplit = 0.0;
    decsplity = 0.0; // optimal decrease of information w.r.t opt y
    /* initialize x,y if vertical splitting performed (randomforest)*/
    
#ifdef RF_DEBUG
    Rprintf("before findBestSplit: ndstart=%d, ndend=%d, jstat=%d, decsplit=%f\n",
            ndstart, ndend, jstat, decsplit);
#endif
    //rowIndicesy=rowIndices; // temp for y indices
    for (i = 1; i <= nsample; ++i){
      rowIndicesy[i-1] = rowIndices[i-1];
    }
    leftChildLastDataPointy=leftChildLastDataPoint;
    
    findBestSplitRL(x, rowIndices, y, mdim, nsample, ndstart, ndend, &msplit,
                    &decsplit, &ubest, &leftChildLastDataPoint, &splitResult, mtry, sumOfNodeData,
                    currentNodeCount, cat);
    
    findBestSplitRLy(y,rowIndicesy, y, 1, nsample, ndstart, ndend, &msplity,
                     &decsplity, &ybest, &leftChildLastDataPointy, &splitResulty, 1, sumOfNodeData,
                     currentNodeCount, 1);
    if (splitResult == NODE_TERMINAL  ||splitResulty == NODE_TERMINAL ) {
      /* Node is terminal: Mark it as such and move on to the next. */
      nodestatus[k] = NODE_TERMINAL;
      
      continue;
    }
    
    // if (splitResulty == NODE_TERMINAL) {
    //   /* Node is terminal: Mark it as such and move on to the next. */
    //   nodestatus[k] = splitResulty;
    //   continue;
    // }
    
    /* Splitting dataset by optimal feature in x or by y with probability p */
    
    // if (decsplit==0.0 && decsplity==0.0){
    //   p = 0.5;
    // }else{
    //   p= decsplity/(decsplit+decsplity);
    // }
    
    p= decsplity/(decsplit+decsplity);
    u=unif_rand() ;
    // Rprintf("p=%f\n",
    //         p);
    if(u<= p){// horizontal splitting (i.e ordinary findbestsplit)u<= p || p==0.5
      
      
      
      bestSplitVarInNode[k] = msplit;
      varUsed[msplit-1] = 1;
      splitCutoff[k] = ubest;
      // nodeImpurityDecrease[msplit] += decsplit;
      nodestatus[k] = NODE_INTERIOR;// implying this node is horizonally spliited (by feature)
      numPointsInEachNode[currentNode + 1] = leftChildLastDataPoint - ndstart + 1;
      numPointsInEachNode[currentNode + 2] = ndend - leftChildLastDataPoint;
      nodestart[currentNode + 1] = ndstart;
      nodestart[currentNode + 2] = leftChildLastDataPoint + 1;
    }else{
      
      bestSplitVarInNode[k] = 0; // implying y is used
      //varUsed[0] = 1;
      splitCutoff[k] = ybest;
      
      nodestatus[k] = NODE_INTERIOR_V;// implying this node is vertically spliited (by y)
      for (i = 1; i <= nsample; ++i){
        rowIndices[i-1] = rowIndicesy[i-1];
      }
      leftChildLastDataPoint=leftChildLastDataPointy;
      
      
      
      sum=0;
      
      yb         = (double *) Calloc(currentNodeCount, double);
      xb         = (double *) Calloc(mdim * currentNodeCount, double);
      
      nrnodesk = 2 * currentNodeCount + 1;
      //Rprintf("nrnodes=%d",nrnodes);
      lDaughterkk = (int *) Calloc(nrnodesk * nTreesub, int);
      rDaughterkk = (int *) Calloc(nrnodesk * nTreesub, int);
      nodestatuskk = (int *) Calloc(nrnodesk * nTreesub, int);
      bestSplitVarInNodekk=(int *) Calloc(nrnodesk * nTreesub, int);
      nodeMeanskk = (double *) Calloc(nrnodesk * nTreesub, double);
      splitCutoffkk =  (double *) Calloc(nrnodesk * nTreesub, double);
      
      
      zeroDouble(nodeMeanskk, nrnodesk * nTreesub);
      zeroDouble(splitCutoffkk, nrnodesk * nTreesub);
      //xdimk = (int *) Calloc(2, int);
      //x in current node
      
      
      for(n = 0; n < currentNodeCount; ++n){
        
        //in[k] += 1;
        yb[n] = y[rowIndices[n+ndstart]-1];
        
        for(m = 0; m < mdim; ++m) {
          xb[m + n * mdim] = x[m + (rowIndices[n+ndstart]-1) * mdim];
        }
      }
      treeSizerf = (int *) Calloc(nTreesub, int);
      zeroInt(treeSizerf, nTreesub);
      
      regRF(xb, yb, 0,1.0,mdim, currentNodeCount,
            nthsizesub,nrnodesk,nTreesub, mtry,
            cat, 1,treeSizerf,nodestatuskk,
            lDaughterkk, rDaughterkk, nodeMeanskk, bestSplitVarInNodekk,
            splitCutoffkk,1,1);
      
      
      
      for (i = 0; i < nTreesub; i++){
        sum = sum + treeSizerf[i];//*(treeSizerf + i);
        
        
      }
      
      // Rprintf("totaltreesize=%d\n",
      //         sum);
      // for(n=0; n<nrnodesk*5;++n){
      //
      //   Rprintf("nodeMeanskk=%f\n",
      //           nodeMeanskk[n]);
      // }
      seqtreesize[pointer_vkt]=-1-k;
      for(j = pointer_vkt+1; j <= pointer_vkt+ nTreesub ; ++j){
        
        seqtreesize[j]=treeSizerf[j-pointer_vkt-1];
        
        
      }
      //
      pointer_vkt+= 1+ nTreesub;
      
      *cseqtreesize+=1+nTreesub;
      // record nodes index in vertical sliced node
      
      // xindrf[pointer_v]=-1-k ;  // label the node
      // // Rprintf("xind=%d\n",
      // //         xindrf[pointer_v]);
      // 
      // for (j = pointer_v+1; j <= pointer_v+currentNodeCount ; ++j){
      // 
      //   xindrf[j] = rowIndices[j-pointer_v-1+ndstart];
      // 
      // }
      //  pointer_v=1+pointer_v+currentNodeCount;
      // *cxindrf+=1+currentNodeCount;
      
      // //sub rf
      lDaughterk[pointer_vk]=-1-k;
      rDaughterk[pointer_vk]=-1-k;
      nodestatusk[pointer_vk]=-1-k;
      bestSplitVarInNodek[pointer_vk]=-1-k;
      nodeMeansk[pointer_vk]=1.0*(-1-k);
      splitCutoffk[pointer_vk]=1.0*(-1-k);
      
      for(j = pointer_vk+1; j <= pointer_vk+1+sum ; ++j){
        
        lDaughterk[j]=lDaughterkk[j-pointer_vk-1];
        rDaughterk[j]=rDaughterkk[j-pointer_vk-1];
        nodestatusk[j]=nodestatuskk[j-pointer_vk-1];
        bestSplitVarInNodek[j]=bestSplitVarInNodekk[j-pointer_vk-1];
        
        nodeMeansk[j]=nodeMeanskk[j-pointer_vk-1];
        splitCutoffk[j]=splitCutoffkk[j-pointer_vk-1];
        // Rprintf("nodeMeansk=%f\n",
        //         nodeMeansk[j]);
        
      }
      pointer_vk=1+pointer_vk+sum;
      
      *lensubtree+=1+sum;//nrnodesk* 5;
      
      
      
      //
      Free(yb);
      Free(xb);
      Free(lDaughterkk);
      Free(rDaughterkk);
      Free(nodestatuskk);
      Free(bestSplitVarInNodekk);
      Free(nodeMeanskk);
      Free(splitCutoffkk);
      
      /* left node is the node after current, right node is 2 after */
      numPointsInEachNode[currentNode + 1] = leftChildLastDataPoint - ndstart + 1;
      numPointsInEachNode[currentNode + 2] = ndend - leftChildLastDataPoint;
      nodestart[currentNode + 1] = ndstart;
      nodestart[currentNode + 2] = leftChildLastDataPoint + 1;
      
      
      
      
    }
    // bestSplitVarInNode[k] = msplit;
    // varUsed[msplit-1] = 1;
    // splitCutoff[k] = ubest;
    // //nodeImpurityDecrease[msplit-1] += decsplit;
    // nodestatus[k] = NODE_INTERIOR;
    
    
    
    /* compute mean and sum of squares for the left daughter node */
    nodeMean = 0.0;
    sumOfSquares = 0.0;
    for (j = ndstart; j <= leftChildLastDataPoint; ++j) {
      label = y[rowIndices[j]-1];
      m = j - ndstart;
      sumOfSquares += m * (nodeMean - label) * (nodeMean - label) / (m + 1);
      nodeMean = (m * nodeMean + label) / (m+1);
    }
    nodeMeans[currentNode+1] = nodeMean;
    nodestatus[currentNode+1] = NODE_TOSPLIT;
    if (numPointsInEachNode[currentNode + 1] <= nthsize) {
      nodestatus[currentNode + 1] = NODE_TERMINAL;
    }
    
    /* compute mean and sum of squares for the right daughter node */
    nodeMean = 0.0;
    sumOfSquares = 0.0;
    for (j = leftChildLastDataPoint + 1; j <= ndend; ++j) {
      label = y[rowIndices[j]-1];
      m = j - (leftChildLastDataPoint + 1);
      sumOfSquares += m * (nodeMean - label) * (nodeMean - label) / (m + 1);
      nodeMean = (m * nodeMean + label) / (m + 1);
    }
    nodeMeans[currentNode + 2] = nodeMean;
    nodestatus[currentNode + 2] = NODE_TOSPLIT;
    if (numPointsInEachNode[currentNode + 2] <= nthsize) {
      nodestatus[currentNode + 2] = NODE_TERMINAL;
    }
    
    /* map the daughter nodes */
    lDaughter[k] = currentNode + 1 + 1;
    rDaughter[k] = currentNode + 2 + 1;
    
    /* Augment the tree by two nodes. */
    currentNode += 2;
#ifdef RF_DEBUG
    Rprintf(" after split: ldaughter=%d, rdaughter=%d, ncur=%d\n",
            lDaughter[k], rDaughter[k], ncur);
#endif
    
  }
  *treeSize = numNodes;
  for (k = numNodes - 1; k >= 0; --k) {
    if (nodestatus[k] == 0) (*treeSize)--;
    if (nodestatus[k] == NODE_TOSPLIT) {
      nodestatus[k] = NODE_TERMINAL;
    }
  }
  
  Free(nodestart);
  Free(rowIndices);
  Free(numPointsInEachNode);
}


