
#include <R.h>
#include "rf.h"


void zeroInt(int *x, int length) {
    memset(x, 0, length * sizeof(int));
}

void zeroDouble(double *x, int length) {
    memset(x, 0, length * sizeof(double));
}

void createClass(double *x, int realN, int totalN, int mdim) {
/* Create the second class by bootstrapping each variable independently. */
    int i, j, k;
    for (i = realN; i < totalN; ++i) {
        for (j = 0; j < mdim; ++j) {
            k = (int) (unif_rand() * realN);
            x[j + i * mdim] = x[j + k * mdim];
        }
    }
}

void normClassWt(int *cl, const int nsample, const int nclass,
                 const int useWt, double *classwt, int *classFreq) {
    int i;
    double sumwt = 0.0;

    if (useWt) {
        /* Normalize user-supplied weights so they sum to one. */
        for (i = 0; i < nclass; ++i) sumwt += classwt[i];
        for (i = 0; i < nclass; ++i) classwt[i] /= sumwt;
    } else {
        for (i = 0; i < nclass; ++i) {
            classwt[i] = ((double) classFreq[i]) / nsample;
        }
    }
    for (i = 0; i < nclass; ++i) {
        classwt[i] = classFreq[i] ? classwt[i] * nsample / classFreq[i] : 0.0;
    }
}

void makeA(double *x, const int mdim, const int nsample, int *cat, int *a,
           int *b) {
    /* makeA() constructs the mdim by nsample integer array a.  For each
       numerical variable with values x(m, n), n=1, ...,nsample, the x-values
       are sorted from lowest to highest.  Denote these by xs(m, n).  Then
       a(m,n) is the case number in which xs(m, n) occurs. The b matrix is
       also contructed here.  If the mth variable is categorical, then
       a(m, n) is the category of the nth case number. */
    int i, j, n1, n2, *index;
    double *v;

    v     = (double *) Calloc(nsample, double);
    index = (int *) Calloc(nsample, int);

    for (i = 0; i < mdim; ++i) {
        if (cat[i] == 1) { /* numerical predictor */
            for (j = 0; j < nsample; ++j) {
                v[j] = x[i + j * mdim];
                index[j] = j + 1;
            }
            R_qsort_I(v, index, 1, nsample);

            /*  this sorts the v(n) in ascending order. index(n) is the case
                number of that v(n) nth from the lowest (assume the original
                case numbers are 1,2,...).  */
            for (j = 0; j < nsample-1; ++j) {
                n1 = index[j];
                n2 = index[j + 1];
                a[i + j * mdim] = n1;
                if (j == 0) b[i + (n1-1) * mdim] = 1;
                b[i + (n2-1) * mdim] =  (v[j] < v[j + 1]) ?
                    b[i + (n1-1) * mdim] + 1 : b[i + (n1-1) * mdim];
            }
            a[i + (nsample-1) * mdim] = index[nsample-1];
        } else { /* categorical predictor */
            for (j = 0; j < nsample; ++j)
                a[i + j*mdim] = (int) x[i + j * mdim];
        }
    }
    Free(index);
    Free(v);
}


void modA(int *a, int *nuse, const int nsample, const int mdim,
	  int *cat, const int maxcat, int *ncase, int *jin) {
    int i, j, k, m, nt;

    *nuse = 0;
    for (i = 0; i < nsample; ++i) if (jin[i]) (*nuse)++;

    for (i = 0; i < mdim; ++i) {
      k = 0;
      nt = 0;
      if (cat[i] == 1) {
          for (j = 0; j < nsample; ++j) {
              if (jin[a[i + k * mdim] - 1]) {
                  a[i + nt * mdim] = a[i + k * mdim];
                  k++;
              } else {
                  for (m = 0; m < nsample - k; ++m) {
                      if (jin[a[i + (k + m) * mdim] - 1]) {
                          a[i + nt * mdim] = a[i + (k + m) * mdim];
                          k += m + 1;
                          break;
                      }
                  }
              }
              nt++;
              if (nt >= *nuse) break;
          }
      }
    }
    if (maxcat > 1) {
        k = 0;
        nt = 0;
        for (i = 0; i < nsample; ++i) {
            if (jin[k]) {
                k++;
                ncase[nt] = k;
            } else {
                for (j = 0; j < nsample - k; ++j) {
                    if (jin[k + j]) {
                        ncase[nt] = k + j + 1;
                        k += j + 1;
                        break;
                    }
                }
            }
            nt++;
            if (nt >= *nuse) break;
        }
    }
}

void Xtranslate(double *x, int mdim, int nrnodes, int nsample,
		int *bestvar, double *bestsplit, double *bestsplitnext,
		double *xbestsplit, int *nodestatus, int *cat, int treeSize) {
/*
 this subroutine takes the splits on numerical variables and translates them
 back into x-values.  It also unpacks each categorical split into a
 32-dimensional vector with components of zero or one--a one indicates that
 the corresponding category goes left in the split.
*/

    int i, m;

    for (i = 0; i < treeSize; ++i) {
	if (nodestatus[i] == 1) {
	    m = bestvar[i] - 1;
	    if (cat[m] == 1) {
		    xbestsplit[i] = 0.5 * (x[m + ((int) bestsplit[i] - 1) * mdim] +
				       x[m + ((int) bestsplitnext[i] - 1) * mdim]);
	    } else {
		    xbestsplit[i] = bestsplit[i];
	    }
	}
    }
}

void permuteOOB(int m, double *x, int *in, int nsample, int mdim) {
/* Permute the OOB part of a variable in x.
 * Argument:
 *   m: the variable to be permuted
 *   x: the data matrix (variables in rows)
 *   in: vector indicating which case is OOB
 *   nsample: number of cases in the data
 *   mdim: number of variables in the data
 */
    double *tp, tmp;
    int i, last, k, nOOB = 0;

    tp = (double *) Calloc(nsample, double);

    for (i = 0; i < nsample; ++i) {
		/* make a copy of the OOB part of the data into tp (for permuting) */
		if (in[i] == 0) {
            tp[nOOB] = x[m + i*mdim];
            nOOB++;
        }
    }
    /* Permute tp */
    last = nOOB;
    for (i = 0; i < nOOB; ++i) {
		k = (int) last * unif_rand();
		tmp = tp[last - 1];
		tp[last - 1] = tp[k];
		tp[k] = tmp;
		last--;
    }

    /* Copy the permuted OOB data back into x. */
    nOOB = 0;
    for (i = 0; i < nsample; ++i) {
		if (in[i] == 0) {
            x[m + i*mdim] = tp[nOOB];
            nOOB++;
		}
    }
    Free(tp);
}

/* Compute proximity. */
void computeProximity(double *prox, int oobprox, int *node, int *inbag,
                      int *oobpair, int n) {
/* Accumulate the number of times a pair of points fall in the same node.
   prox:    n x n proximity matrix
   oobprox: should the accumulation only count OOB cases? (0=no, 1=yes)
   node:    vector of terminal node labels
   inbag:   indicator of whether a case is in-bag
   oobpair: matrix to accumulate the number of times a pair is OOB together
   n:       total number of cases
*/
    int i, j;
    for (i = 0; i < n; ++i) {
        for (j = i+1; j < n; ++j) {
            if (oobprox) {
                if (! (inbag[i] > 0 || inbag[j] > 0) ) {
                    oobpair[j*n + i] ++;
                    oobpair[i*n + j] ++;
                    if (node[i] == node[j]) {
                        prox[j*n + i] += 1.0;
                        prox[i*n + j] += 1.0;
                    }
                }
            } else {
                if (node[i] == node[j]) {
                    prox[j*n + i] += 1.0;
                    prox[i*n + j] += 1.0;
                }
            }
        }
    }
}

double pack(const int nBits, const int *bits) {
    int i = nBits - 1;
	double pack = bits[i];
    for (i = nBits - 1; i > 0; --i) pack = 2.0 * pack + bits[i - 1];
    return(pack);
}

void unpack(const double pack, const int nBits, int *bits) {
    int i;
    double x = pack;
    for (i = 0; i <= nBits; ++i) {
    	bits[i] = ((unsigned long) x & 1) ? 1 : 0;
    	x = x / 2;
    }
}

/*
unsigned int pack(int nBits, int *bits) {
    int i = nBits;
	unsigned int pack = 0;
    while (--i >= 0) pack += bits[i] << i;
    return(pack);
}

void unpack(int nBits, unsigned int pack, int *bits) {
    int i;
    for (i = 0; i < nBits; pack >>= 1, ++i) bits[i] = pack & 1;
}
*/

void F77_NAME(unpack)(double *pack, int *nBits, int *bits) {
	unpack(*pack, *nBits, bits);
}

void sampleDataRows(int populationSize, int sampleSize, int useWeights,
                    int withReplacement, double *weights, int *sampledIndices) {
    /*At this point shouldn't need to check for sanity (sample > population)*/
    /*Rprintf("Flags: replace %d useweights %d ", withReplacement, useWeights);*/
    if(withReplacement) {

        if(useWeights) {
            sampleWithReplacementWithWeights(sampleSize, populationSize, weights, sampledIndices);
        }
        else{
            /*Rprintf("We will be sampling with replacement and no weights");*/
            sampleWithReplacement(sampleSize, populationSize, sampledIndices);
        }

    }
    else{
        if(useWeights) {
            sampleWithoutReplacementWithWeights(sampleSize, populationSize, weights, sampledIndices);
        }
        else{
            sampleWithoutReplacement(sampleSize, populationSize, sampledIndices);
        }
    }
}

void sampleWithReplacementWithWeights(int sampleSize, int populationSize, double *weights, int *sampledIndices) {
    int mflag = 0;
    int numBoundaries = populationSize + 1;
    double *boundaries = (double*)Calloc(numBoundaries, double);
    calculateBoundaries(weights, boundaries, populationSize, numBoundaries);

    double sample;
    for(int i = 0; i < sampleSize; ++i) {
        sample = unif_rand();
        sampledIndices[i] = findInterval(boundaries, numBoundaries, sample, 1, 1,0, &mflag) - 1;
    }
}

void sampleWithoutReplacementWithWeights(int sampleSize, int populationSize, double *weights, int *sampledIndices) {
    /*A snippet to run once when rf is run to verify the output of find_interval, which is what we use to sample with weights*/
    /*if(tree == 0){
        int testflag = 0;
        Rprintf("Testing the find interval function in r\n");
        double testBoundaries[] = {0, .2, .6, .6, .6, .6, .8, .8, .8, 1};
        int testInterval = findInterval(testBoundaries, 11, .7, 1, 1,0, &testflag);
        Rprintf("The resultant test interval found was %d\n", testInterval);
    }*/
    int mflag = 0;
    int numBoundaries = populationSize + 1;
    double *weightsForTree = (double*)Calloc(populationSize, double);

    /*copy over weights to array so we don't mess with original weights when removing them*/
    for(int i = 0; i < populationSize; i++){
        weightsForTree[i] = weights[i];
    }
    double *boundaries = (double*)Calloc(numBoundaries, double);
    int *populationTaken = (int*)Calloc(populationSize, int);
    zeroInt(populationTaken, populationSize);
    calculateBoundaries(weightsForTree, boundaries, populationSize, numBoundaries);
    int first;
    int sampledIndex;
    int sampledIndexCopy;
    double sample;

    for(int i = 0; i < sampleSize; ++i) {
        sample = unif_rand();
        sampledIndex = findInterval(boundaries, numBoundaries, sample, 1, 1,0, &mflag) - 1;
        first = sampledIndex;
        sampledIndexCopy = sampledIndex;
        sampledIndices[i] = sampledIndex;
        populationTaken[sampledIndex] = -1;
        removeWeightAndNormalize(weightsForTree, sampledIndex, populationSize);
        calculateBoundaries(weightsForTree, boundaries, populationSize, numBoundaries);
    }
}

void sampleWithReplacement(int sampleSize, int populationSize, int *sampledIndices) {
    for(int i = 0; i < sampleSize; ++i) {
        sampledIndices[i] = unif_rand() * populationSize;
    }
}

void sampleWithoutReplacement(int sampleSize, int populationSize, int *sampledIndices) {
    int *indices = (int*)Calloc(populationSize, int);
    for(int i = 0; i < populationSize; ++i) {
        indices[i] = i;
    }
    int indexOfLast = populationSize - 1;
    int k;
    for(int i = 0; i < sampleSize; i++) {
        int currentSample = unif_rand() * (indexOfLast + 1);
        k = indices[currentSample];
        swapInt(indices[currentSample], indices[indexOfLast]);
        indexOfLast--;
        sampledIndices[i] = k;
    }
}

void removeWeightAndNormalize(double *weights, int indexToRemove, int populationSize) {
    double weightToRemove = weights[indexToRemove];
    double sumOfWeights = 1 - weightToRemove;
    weights[indexToRemove] = 0;

    for(int i = 0; i < populationSize; i++) {
        weights[i] /= sumOfWeights;
    }
}

void normalizeWeights(double *weights, int numWeights){
    double weightSum = 0;
    for(int i = 0; i < numWeights; i++){
        weightSum += weights[i];
    }
    for(int i = 0; i < numWeights; i++){
        weights[i] /= weightSum;
    }
}

void calculateBoundaries(double *weights, double *boundaries, int populationSize, int numBoundaries) {
    zeroDouble(boundaries, numBoundaries);
    double currentSum = 0;
    boundaries[0] = currentSum;
    for(int i = 1; i <= populationSize; i++) {
        currentSum += weights[i - 1];
        boundaries[i] = currentSum;
    }
}



