#include <iostream>
#include <cassert>
#include <vector>
#include <numeric> 
#include <algorithm>
#include <queue>

using namespace std;

struct Node {
    uint freq;
    int token_id;
    Node* left;
    Node* right;
    Node* parent;
    int current_height;
    char bit;
};


vector<int> sort_indexes(const uint* v, const int size) {
  vector<int> idx(size);
  iota(idx.begin(), idx.end(), 0);

  stable_sort(idx.begin(), idx.end(),
       [v](int i1, int i2) {
        if (v[i1] == v[i2]){
            return i1 < i2;
        }
        return v[i1] < v[i2];
    });

  return idx;
}

bool compare_elements(Node* a, Node* b, Node* c, Node* d){
    // Compare a, b with c, d 
    if (a->freq + b->freq < c->freq + d->freq){
        return true;
    }else if (a->freq + b->freq > c->freq + d->freq){
        return false;
    }else{
        uint left_smallest_freq = min(a->freq, b->freq);
        uint right_smallest_freq = min(c->freq, d->freq);
        if (left_smallest_freq < right_smallest_freq){
            return true;
        }else if (left_smallest_freq > right_smallest_freq){
            return false;
        }else{
            int left_smallest_token_id = min(a->token_id, b->token_id);
            int right_smallest_token_id = min(c->token_id, d->token_id);
            if (left_smallest_token_id < right_smallest_token_id){
                return true;
            }else if (left_smallest_token_id > right_smallest_token_id){
                return false;
            }else{
                int left_largest_token_id = max(a->token_id, b->token_id);
                int right_largest_token_id = max(c->token_id, d->token_id);

                return left_largest_token_id < right_largest_token_id;
            }
        }

    }
}


// Faster implementation based on https://webspace.science.uu.nl/~leeuw112/huffman.pdf
// Van Leeuwen, Jan. "On the Construction of Huffman Trees." ICALP. 1976.
class HuffmanCoder{
    private:
        vector<Node> index;
        vector<int> indices_mapping;
        int max_height;
        Node* root;
    
    public:
        HuffmanCoder(uint* freqs, int num_freqs){
            assert(num_freqs > 1);

            // keep track of the indices of the sorted freqs
            vector<int> sorted_indexes = sort_indexes(freqs, num_freqs);
            this->indices_mapping = vector<int>(num_freqs);
            for (int i = 0; i < num_freqs; i++){
                this->indices_mapping[sorted_indexes[i]] = i;
            }

            // rearrange freqs based on the sorted indexes
            vector<uint> rearranged_freqs = vector<uint>(num_freqs);
            for (int i = 0; i < num_freqs; i++){
                rearranged_freqs[i] = freqs[sorted_indexes[i]];
            }

            this->index = vector<Node>(2 * num_freqs - 1);
            for(int i = 0; i < num_freqs; i++){
                this->index[i] = {rearranged_freqs[i], sorted_indexes[i], nullptr, nullptr, nullptr, 1, -1};
            }
            this->max_height = 1;

            vector<Node*> queue;
            int queue_cnt = 0;
            int index_cnt = 0;
            int index_insert_cnt = num_freqs;

            while (num_freqs - index_cnt > 0 || queue.size() - queue_cnt != 1){
                Node* a = nullptr;
                Node* b = nullptr;
                int cases = -1;

                if (num_freqs - index_cnt >= 2){
                    a = &(this->index[index_cnt]);
                    b = &(this->index[index_cnt + 1]);
                    cases = 0;
                }

                if (num_freqs - index_cnt >= 1 && queue.size() - queue_cnt >= 1){
                    if (a == nullptr || compare_elements(queue[queue_cnt], &(this->index[index_cnt]), a, b)){
                        a = &(this->index[index_cnt]);
                        b = queue[queue_cnt];
                        cases = 1;
                    }
                }
                    

                if (queue.size() - queue_cnt >= 2){
                    if (a == nullptr || compare_elements(queue[queue_cnt], queue[queue_cnt + 1], a, b)){
                        a = queue[queue_cnt];
                        b = queue[queue_cnt + 1];
                        cases = 2;
                    }
                }
                assert(cases != -1);

                if (a->freq > b->freq || (a->freq == b->freq && a->token_id > b->token_id)){
                    Node* temp = a;
                    a = b;
                    b = temp;
                }

                this->index[index_insert_cnt] = {a->freq + b->freq, index_insert_cnt, a, b, nullptr, max(a->current_height, b->current_height) + 1, -1};
                Node* node = &(this->index[index_insert_cnt]);
                ++index_insert_cnt;
                a->parent = node;
                b->parent = node;
                a->bit = 0;
                b->bit = 1;

                if (max(a->current_height, b->current_height) > this->max_height){
                    this->max_height = max(a->current_height, b->current_height);
                }

                queue.push_back(node);

                if (cases == 0){
                    index_cnt += 2;
                }else if (cases == 1){
                    index_cnt += 1;
                    queue_cnt += 1;
                }else if (cases == 2){
                    queue_cnt += 2;
                }
            }

            this->root = queue[queue_cnt];
            // std::cout << "prining" << std::endl;
            // print_tree(this->root, 0);
            // std::cout << "end prining" << std::endl;
        }

        void print_tree(Node* node, int height){
            if (node == nullptr){
                return;
            }
            // pring spaces based on height
            std::cout << std::string(height, ' ');
            std::cout << node->freq << " " << node->token_id << " " << node->current_height << " " << node->bit << std::endl;
            print_tree(node->left, height + 1);
            print_tree(node->right, height + 1);
        }

        int encode_symbol(int symbol, bool* bits){
            struct Node* node = &(this->index[this->indices_mapping[symbol]]);
            assert(node != nullptr);

            vector<int> reversed_results;
            
            while(node != nullptr && node->bit >= 0){
                // bits[offset] = node->bit;
                reversed_results.push_back(node->bit);
                node = node->parent;
            }

            for(int i = reversed_results.size() - 1; i >= 0; i--){
                bits[reversed_results.size() - 1 - i] = reversed_results[i];
            }

            return reversed_results.size();
        }

        int decode_symbol(bool* bits, int length_of_bits, int* offset){
            struct Node* node = this->root;
            
            while(node != nullptr && (*offset) < length_of_bits){
                if(bits[*offset]){
                    node = node->right;
                }else{
                    node = node->left;
                }
                ++(*offset);
                if (node->current_height == 1){
                    break;
                }
            }
            if (node->current_height == 1){
                return node->token_id;
            }
            return -1;
        }

        int get_max_height(){
            return this->max_height;
        }
    
        void get_code_lengths_impl(uint* output, Node* node, int depth){
            if (node == nullptr){
                return;
            }
            if (node->current_height == 1){
                output[node->token_id] = depth;
                return;
            }
            get_code_lengths_impl(output, node->left, depth + 1);
            get_code_lengths_impl(output, node->right, depth + 1);
        }
    
        void get_code_lengths(uint* output){
            get_code_lengths_impl(output, this->root, 0);
        }

        void delete_node(Node* node){
            if(node == nullptr){
                return;
            }
            delete_node(node->parent);
            if (node->left != nullptr){
                node->left->parent = nullptr;
            }
            if (node->right != nullptr){
                node->right->parent = nullptr;
            }
            delete node;
        }

        ~HuffmanCoder(){
            // delete each node
            // for(int i = 0; i < this->index.size(); i++){
            //     delete_node(this->index[i]);
            // }

            this->indices_mapping.clear();
            this->index.clear();
        }
};

extern "C" {
    HuffmanCoder* HuffmanCoder_new(uint* freqs, int num_freqs){ return new HuffmanCoder(freqs, num_freqs); }
    void HuffmanCoder_delete(HuffmanCoder* huffman_coder){ delete huffman_coder; }
    int HuffmanCoder_encode_symbol(HuffmanCoder* huffman_coder, int symbol, bool* bits){ return huffman_coder->encode_symbol(symbol, bits);}
    int HuffmanCoder_decode_symbol(HuffmanCoder* huffman_coder, bool* bits, int length_of_bits, int* offset){ return huffman_coder->decode_symbol(bits, length_of_bits, offset); }
    void HuffmanCoder_get_code_lengths(HuffmanCoder* huffman_coder, uint* output){ return huffman_coder->get_code_lengths(output); }
    int HuffmanCoder_get_max_height(HuffmanCoder* huffman_coder){ return huffman_coder->get_max_height(); }
}