

#include <R.h>
#include "rf.h"

// void simpleLinRegRL(int nsample, double *x, double *y, double *coef,
//                   double *mse, int *hasPred);


/*
 Train a regression randomforest model
 */
void regRLRF(double *x, double *y, int *useweights, double *weights, int *xdim, int *sampsize,
           int *nthsize, int *nrnodes, int *nTree, int *mtry, int *imp,
           int *cat, int *maxcat, int *jprint, int *doProx, int *oobprox,
           int *biasCorr, double *yptr, double *errimp, double *impmat,
           double *impSD, double *prox, int *treeSize, int *nodestatus,
           int *lDaughter, int *rDaughter, double *avnode, int *mbest,
           double *upper, double *mse, int *keepf, int *replace,
           int *testdat, double *xts, int *nts, double *yts, int *labelts,
           double *yTestPred, double *proxts, double *msets, double *coef,
           int *nout, int *inbag,int *seqtreesize,int *cseqtreesize,
            int *lDaughterk,int *rDaughterk ,int *nodestatusk ,int *bestSplitVarInNodek, double *nodeMeansk, double *splitCutoffk, int *lensubtree,
            int *nTreesub, int *nthsizesub, double *Lp, int *E ) {
  /*************************************************************************
   Input:
   mdim=number of variables in data set
   nsample=number of cases

   nthsize=number of cases in a node below which the tree will not split,
   setting nthsize=5 generally gives good results.

   nTree=number of trees in run.  200-500 gives pretty good results

   mtry=number of variables to pick to split on at each node.  mdim/3
   seems to give genrally good performance, but it can be
   altered up or down

   imp=1 turns on variable importance.  This is computed for the
   mth variable as the percent rise in the test set mean sum-of-
   squared errors when the mth variable is randomly permuted.

   *************************************************************************/

  double errts = 0.0, averrb, meanY, meanYts, varY, varYts, r, xrand,
    errb = 0.0, resid=0.0, ooberr, ooberrperm, delta, *resOOB;

  double *yb, *xtmp, *xb, *ytr, *ytree;

  int k, m, mr, n, nOOB, j, jout, idx=0,ts=0, ntest, last, ktmp, nPerm,idxrf,idk,
  nsample, mdim, keepF, keepInbag;
  int *oobpair, varImp, localImp, *varUsed;

  int *in, *nind, *nodex, *nodexts;

  nsample = xdim[0];
  mdim = xdim[1];
  ntest = *nts;
  varImp = imp[0];
  localImp = imp[1];
  nPerm = imp[2];
  keepF = keepf[0];
  keepInbag = keepf[1];

  if (*jprint == 0) *jprint = *nTree + 1;

  yb         = (double *) S_alloc(*sampsize, sizeof(double));
  xb         = (double *) S_alloc(mdim * *sampsize, sizeof(double));
  ytr        = (double *) S_alloc(nsample, sizeof(double));
  xtmp       = (double *) S_alloc(nsample, sizeof(double));
  resOOB     = (double *) S_alloc(nsample, sizeof(double));

  in        = (int *) S_alloc(nsample, sizeof(int));
  nodex      = (int *) S_alloc(nsample, sizeof(int));
  varUsed    = (int *) S_alloc(mdim, sizeof(int));
  nind = *replace ? NULL : (int *) S_alloc(nsample, sizeof(int));

  if (*testdat) {
    ytree      = (double *) S_alloc(ntest, sizeof(double));
    nodexts    = (int *) S_alloc(ntest, sizeof(int));
  }
  oobpair = (*doProx && *oobprox) ?
  (int *) S_alloc(nsample * nsample, sizeof(int)) : NULL;

  /* If variable importance is requested, tgini points to the second
   "column" of errimp, otherwise it's just the same as errimp. */


  averrb = 0.0;
  meanY = 0.0;
  varY = 0.0;

  zeroDouble(yptr, nsample);
  zeroInt(nout, nsample);
  for (n = 0; n < nsample; ++n) {
    varY += n * (y[n] - meanY)*(y[n] - meanY) / (n + 1);
    meanY = (n * meanY + y[n]) / (n + 1);
  }
  varY /= nsample;

  varYts = 0.0;
  meanYts = 0.0;
  if (*testdat) {
    for (n = 0; n < ntest; ++n) {
      varYts += n * (yts[n] - meanYts)*(yts[n] - meanYts) / (n + 1);
      meanYts = (n * meanYts + yts[n]) / (n + 1);
    }
    varYts /= ntest;
  }


  if (*labelts) zeroDouble(yTestPred, ntest);

  /* print header for running output */
  if (*jprint <= *nTree) {
    Rprintf("     |      Out-of-bag   ");
    if (*testdat) Rprintf("|       Test set    ");
    Rprintf("|\n");
    Rprintf("Tree |      MSE  %%Var(y) ");
    if (*testdat) Rprintf("|      MSE  %%Var(y) ");
    Rprintf("|\n");
  }
  GetRNGstate();
  /*************************************
   * Start the loop over trees.
   *************************************/
  int *sampledIndices = (int*)Calloc(*sampsize,int);
  zeroInt(sampledIndices, *sampsize);
  for (j = 0; j < *nTree; ++j) {
    idx = keepF ? j * *nrnodes : 0;//keepF ? (idx+ts) : 0;//keepF ? j * *nrnodes : 0;
    idxrf = keepF ? *nTreesub * j * *nrnodes : 0;
    idk = keepF ? *E * *nTreesub * j * *nrnodes : 0;

    // Rprintf("idxrf %d ",idxrf);
    // Rprintf("idk %d",idk);
    zeroInt(in, nsample);
    zeroInt(varUsed, mdim);
    /* Draw a random sample for growing a tree. */
    /*Rprintf("The useweights flag was set to %d", *useweights);*/
    sampleDataRows(nsample, *sampsize, *useweights, *replace, weights, sampledIndices);

    for(n = 0; n < *sampsize; ++n){
      k = sampledIndices[n];

      in[k] += 1;
      yb[n] = y[k];
      for(m = 0; m < mdim; ++m) {
        xb[m + n * mdim] = x[m + k * mdim];
      }
    }

    if (keepInbag) {
      for (n = 0; n < nsample; ++n) inbag[n + j * nsample] = in[n];
    }
    /* grow the regression tree */
    regRLTree(xb, yb, mdim, *sampsize, lDaughter + idx, rDaughter + idx,
            upper + idx, avnode + idx, nodestatus + idx, *nrnodes,
            treeSize + j, *nthsize, *mtry, mbest + idx, cat,maxcat,
            varUsed,seqtreesize+idxrf,cseqtreesize+j,
            lDaughterk+idk,rDaughterk+idk,nodestatusk+idk,bestSplitVarInNodek+idk,nodeMeansk+idk,splitCutoffk+idk,lensubtree+j,
            *nTreesub,  *nthsizesub, *Lp);
  //   for (n = 0; n < 100; ++n) {
  //
//
//   ts=treeSize[j];
//       Rprintf("idx=%d\n",
//               idx);
  //   }
  //
   }

}

/*--
 Function for predicting from a trained randomforest model
 --------------------------------------------------------------------*/
void regRLForest(double *x,double *ypred, int *mdim, int *n,
                 int *ntree, int *lDaughter, int *rDaughter,
                 int *nodestatus, int *nrnodes, double *xsplit,
                 double *avnodes, int *mbest, int *treeSize, int *cat,
                 int *maxcat, int *nodes, int *nodex, int *seqtreesize,
                 int *lDaughterk,int *rDaughterk ,int *nodestatusk ,int *bestSplitVarInNodek, double *nodeMeansk, double *splitCutoffk,
                 int *nrnodesk, int *cseqtreemax, int *nTreesub) {
  int i, j, idx1, idx2, idx3,idx4,idx5,*junk;
  double *ytree;



  ytree = (double *) S_alloc(*n, sizeof(double));
  if (*nodes) {
    zeroInt(nodex, *n * *ntree);
  } else {
    zeroInt(nodex, *n);
  }
  idx1 = 0;
  idx2 = 0;
  idx3 =0;
  idx4 =0;
  idx5 =0;
  for (i = 0; i < *ntree; ++i) {
     zeroDouble(ytree, *n);

    predictRegRLTree(x,*n, *mdim, lDaughter + idx1, rDaughter + idx1,
                   nodestatus + idx1, ytree, xsplit + idx1,
                   avnodes + idx1, mbest + idx1, treeSize[i], cat, *maxcat,
                   nodex + idx2  , seqtreesize+idx5,
                    lDaughterk+idx4,rDaughterk+idx4 ,nodestatusk+idx4 ,bestSplitVarInNodek+idx4, nodeMeansk+idx4, splitCutoffk+idx4,
                    *nTreesub);

    for (j = 0; j < *n; ++j) ypred[j] += ytree[j];

    idx1 += *nrnodes; /* increment the offset */
    // idx3 += *nrnodesx;
    idx4 += *nrnodesk;
    idx5 += *cseqtreemax;
//
    // Rprintf("nTreesub=%d\n",
    //         nTreesub[0]);
    //
    // Rprintf("idx1=%d\n",
    //         idx1);
   // idx5 +=  *lensubtree;
    if (*nodes) idx2 += *n;
  }
  for (i = 0; i < *n; ++i) ypred[i] /= *ntree;

}

