/*
 *
 * Copyright (c) 2014, Laurens van der Maaten (Delft University of Technology)
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 * 3. All advertising materials mentioning features or use of this software
 *    must display the following acknowledgement:
 *    This product includes software developed by the Delft University of Technology.
 * 4. Neither the name of the Delft University of Technology nor the names of 
 *    its contributors may be used to endorse or promote products derived from 
 *    this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY LAURENS VAN DER MAATEN ''AS IS'' AND ANY EXPRESS
 * OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES 
 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO 
 * EVENT SHALL LAURENS VAN DER MAATEN BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 
 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 
 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR 
 * BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING 
 * IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY 
 * OF SUCH DAMAGE.
 *
 */


/* This code was adopted with minor modifications from Steve Hanov's great tutorial at http://stevehanov.ca/blog/index.php?id=130 */

#include <stdlib.h>
#include <algorithm>
#include <vector>
#include <stdio.h>
#include <queue>
#include <limits>
#include <cmath>


#ifndef VPTREE_H
#define VPTREE_H

class DataPoint {
    int _ind;

public:
    double *_x;
    int _D;

    DataPoint() {
        _D = 1;
        _ind = -1;
        _x = NULL;
    }

    DataPoint(int D, int ind, double *x) {
        _D = D;
        _ind = ind;
        _x = (double *) malloc(_D * sizeof(double));
        for (int d = 0; d < _D; d++) _x[d] = x[d];
    }

    DataPoint(const DataPoint &other) {                     // this makes a deep copy -- should not free anything
        if (this != &other) {
            _D = other.dimensionality();
            _ind = other.index();
            _x = (double *) malloc(_D * sizeof(double));
            for (int d = 0; d < _D; d++) _x[d] = other.x(d);
        }
    }

    ~DataPoint() { if (_x != NULL) free(_x); }

    DataPoint &operator=(const DataPoint &other) {         // asignment should free old object
        if (this != &other) {
            if (_x != NULL) free(_x);
            _D = other.dimensionality();
            _ind = other.index();
            _x = (double *) malloc(_D * sizeof(double));
            for (int d = 0; d < _D; d++) _x[d] = other.x(d);
        }
        return *this;
    }

    int index() const { return _ind; }

    int dimensionality() const { return _D; }

    double x(int d) const { return _x[d]; }
};

double euclidean_distance(const DataPoint &t1, const DataPoint &t2) {
    double dd = .0;
    double *x1 = t1._x;
    double *x2 = t2._x;
    double diff;
    for (int d = 0; d < t1._D; d++) {
        diff = (x1[d] - x2[d]);
        dd += diff * diff;
    }
    return sqrt(dd);
}


template<typename T, double (*distance)(const T &, const T &)>
class VpTree {
public:

    // Default constructor
    VpTree() : _root(0) {}

    // Destructor
    ~VpTree() {
        delete _root;
    }

    // Function to create a new VpTree from data
    void create(const std::vector <T> &items) {
        delete _root;
        _items = items;
        _root = buildFromPoints(0, items.size());
    }

    // Function that uses the tree to find the k nearest neighbors of target
    void search(const T &target, int k, std::vector <T> *results, std::vector<double> *distances) {

        // Use a priority queue to store intermediate results on
        std::priority_queue <HeapItem> heap;

        // Variable that tracks the distance to the farthest point in our results

        // Perform the search
        double _tau = DBL_MAX;
        search(_root, target, k, heap, _tau);

        // Gather final results
        results->clear();
        distances->clear();
        while (!heap.empty()) {
            results->push_back(_items[heap.top().index]);
            distances->push_back(heap.top().dist);
            heap.pop();
        }

        // Results are in reverse order
        std::reverse(results->begin(), results->end());
        std::reverse(distances->begin(), distances->end());
    }

private:
    std::vector <T> _items;

    // Single node of a VP tree (has a point and radius; left children are closer to point than the radius)
    struct Node {
        int index;              // index of point in node
        double threshold;       // radius(?)
        Node *left;             // points closer by than threshold
        Node *right;            // points farther away than threshold

        Node() :
                index(0), threshold(0.), left(0), right(0) {}

        ~Node() {               // destructor
            delete left;
            delete right;
        }
    } *_root;


    // An item on the intermediate result queue
    struct HeapItem {
        HeapItem(int index, double dist) :
                index(index), dist(dist) {}

        int index;
        double dist;

        bool operator<(const HeapItem &o) const {
            return dist < o.dist;
        }
    };

    // Distance comparator for use in std::nth_element
    struct DistanceComparator {
        const T &item;

        DistanceComparator(const T &item) : item(item) {}

        bool operator()(const T &a, const T &b) {
            return distance(item, a) < distance(item, b);
        }
    };

    // Function that (recursively) fills the tree
    Node *buildFromPoints(int lower, int upper) {
        if (upper == lower) {     // indicates that we're done here!
            return NULL;
        }

        // Lower index is center of current node
        Node *node = new Node();
        node->index = lower;

        if (upper - lower > 1) {      // if we did not arrive at leaf yet

            // Choose an arbitrary point and move it to the start
            int i = (int) ((double) rand() / RAND_MAX * (upper - lower - 1)) + lower;
            std::swap(_items[lower], _items[i]);

            // Partition around the median distance
            int median = (upper + lower) / 2;
            std::nth_element(_items.begin() + lower + 1,
                             _items.begin() + median,
                             _items.begin() + upper,
                             DistanceComparator(_items[lower]));

            // Threshold of the new node will be the distance to the median
            node->threshold = distance(_items[lower], _items[median]);

            // Recursively build tree
            node->index = lower;
            node->left = buildFromPoints(lower + 1, median);
            node->right = buildFromPoints(median, upper);
        }

        // Return result
        return node;
    }

    // Helper function that searches the tree    
    void search(Node *node, const T &target, int k, std::priority_queue <HeapItem> &heap, double &_tau) {
        if (node == NULL) return;     // indicates that we're done here

        // Compute distance between target and current node
        double dist = distance(_items[node->index], target);

        // If current node within radius tau
        if (dist < _tau) {
            if (heap.size() == k)
                heap.pop();                 // remove furthest node from result list (if we already have k results)
            heap.push(HeapItem(node->index, dist));           // add current node to result list
            if (heap.size() == k) _tau = heap.top().dist;     // update value of tau (farthest point in result list)
        }

        // Return if we arrived at a leaf
        if (node->left == NULL && node->right == NULL) {
            return;
        }

        // If the target lies within the radius of ball
        if (dist < node->threshold) {
            if (dist - _tau <=
                node->threshold) {         // if there can still be neighbors inside the ball, recursively search left child first
                search(node->left, target, k, heap, _tau);
            }

            if (dist + _tau >=
                node->threshold) {         // if there can still be neighbors outside the ball, recursively search right child
                search(node->right, target, k, heap, _tau);
            }

            // If the target lies outsize the radius of the ball
        } else {
            if (dist + _tau >=
                node->threshold) {         // if there can still be neighbors outside the ball, recursively search right child first
                search(node->right, target, k, heap, _tau);
            }

            if (dist - _tau <=
                node->threshold) {         // if there can still be neighbors inside the ball, recursively search left child
                search(node->left, target, k, heap, _tau);
            }
        }
    }
};

#endif
