#ifndef __INDEXER_H__
#define __INDEXER_H__

#include <tapa.h>
#include <ap_int.h>
#include <ap_fixed.h>
#include <hls_vector.h>
#include <hls_math.h>
#include <hls_stream.h>
#include <vector>
#include <iostream>
#include <cmath>
#include <cstdint>
#include <limits>

constexpr int NUM_INDEX_HEAD = 16;
constexpr int NUM_STEP_BASE = 16;
constexpr int HEAD_DIM = 128;
constexpr int TOP_K = 16;
constexpr int MAX_BUF_CAND = 512;

void graph_traversal(
    // reading graph connection
    tapa::ostream<int>& node_req_fifo,
    tapa::istream<tapa::vec_t<int, 16>>& edge_list,
    // reading k cache
    tapa::istream<tapa::vec_t<float, 16>>& k_cache_fifo,
    tapa::ostream<tapa::vec_t<float, 16>>& k_cache_out_fifo,
    tapa::istream<tapa::vec_t<float, 16>>& q_vec_fifo,
    tapa::ostream<tapa::vec_t<float, 16>>& q_vec_out_fifo,
    tapa::istream<tapa::vec_t<float, 16>>& weight_fifo,
    tapa::ostream<tapa::vec_t<float, 16>>& weight_out_fifo,
    tapa::ostream<int>& cand_id_fifo,
    tapa::ostream<int>& cand_pos_fifo
) {
    float q_vec[NUM_INDEX_HEAD][HEAD_DIM];
    #pragma HLS array_partition variable=q_vec complete dim=1
    for (int i = 0; i < HEAD_DIM; i++) {
        #pragma HLS pipeline II=1
        tapa::vec_t<float, 16> q_vec_pack = q_vec_fifo.read();
        q_vec_out_fifo.write(q_vec_pack);
        for (int j = 0; j < 16; j++) {
            #pragma HLS unroll
            q_vec[j][i] = q_vec_pack[j];
        }
    }
    auto weight_pack = weight_fifo.read();
    weight_out_fifo.write(weight_pack);


    for (int h = 0; h < NUM_INDEX_HEAD; h++) {

        // start with index 0, greedy search
        node_req_fifo.write(0);
        const float weight = weight_pack[h];
        const int num_steps = std::round(NUM_STEP_BASE * weight * weight);

        for (int step = 0; step < num_steps; step++) {
            tapa::vec_t<int, 16> edges = edge_list.read();
            float ip_results[16];
            #pragma HLS array_partition variable=ip_results complete
            for (int i = 0; i < 16; i++) {
                #pragma HLS unroll
                ip_results[i] = 0.0f;
            }
            for(int i = 0; i < HEAD_DIM; i++){
                #pragma HLS pipeline II=1
                tapa::vec_t<float, 16> k_cache_pack = k_cache_fifo.read();
                k_cache_out_fifo.write(k_cache_pack);
                float q_val = q_vec[h][i];
                for (int j = 0; j < 16; j++) {
                    #pragma HLS unroll
                    ip_results[j] += q_val * k_cache_pack[j];
                }
            }
            int max_id = 0;
            float max_val = ip_results[0];
            for(int i = 1; i < 16; i++){
                #pragma HLS pipeline II=1
                if(ip_results[i] > max_val){
                    max_val = ip_results[i];
                    max_id = i; 
                }
            }
            const int next_node = edges[max_id];
            cand_id_fifo.write(next_node);
            cand_pos_fifo.write(max_id);
            node_req_fifo.write(next_node);
        }
    }
}

void reranker(
    tapa::istream<int>& cand_id_fifo,
    tapa::istream<int>& cand_pos_fifo,
    tapa::istream<tapa::vec_t<float, 16>>& q_vec_fifo,
    tapa::istream<tapa::vec_t<float, 16>>& k_cache_fifo,
    tapa::istream<tapa::vec_t<float, 16>>& weight_fifo,
    tapa::ostream<tapa::vec_t<int, 16>>& topk_id_fifo
) {
    float q_vec[NUM_INDEX_HEAD][HEAD_DIM];
    #pragma HLS array_partition variable=q_vec complete dim=1
    float topk_val[TOP_K];
    int topk_id[TOP_K];
    #pragma HLS array_partition variable=topk_val complete
    #pragma HLS array_partition variable=topk_id complete
    for (int i = 0; i < TOP_K; i++) {
        #pragma HLS unroll
        topk_val[i] = -1.0f;
        topk_id[i] = -1;
    }
    int min_id = 0;
    float min_val = -1.0f;

    for (int i = 0; i < HEAD_DIM; i++) {
        #pragma HLS pipeline II=1
        tapa::vec_t<float, 16> q_vec_pack = q_vec_fifo.read();
        for (int j = 0; j < 16; j++) {
            #pragma HLS unroll
            q_vec[j][i] = q_vec_pack[j];
        }
    }
    auto weight_pack = weight_fifo.read();
    for (int h = 0; h < NUM_INDEX_HEAD; h++) {
        const float weight = weight_pack[h];
        const int num_steps = std::round(NUM_STEP_BASE * weight * weight);
        for (int step = 0; step < num_steps; step++) {
            float k_cache[16][HEAD_DIM];
            #pragma HLS array_partition variable=k_cache complete dim=1
            for(int i = 0; i < HEAD_DIM; i++){
                #pragma HLS pipeline II=1
                tapa::vec_t<float, 16> k_cache_pack = k_cache_fifo.read();
                for (int j = 0; j < 16; j++) {
                    #pragma HLS unroll
                    k_cache[j][i] = k_cache_pack[j];
                }
            }
            int cand_id = cand_id_fifo.read();
            int cand_pos = cand_pos_fifo.read();
            float ip_result[16];
            for(int i = 0; i < 16; i++) {
                #pragma HLS unroll
                ip_result[i] = 0.0f;
            }
            for(int i = 0; i < HEAD_DIM; i++) {
                #pragma HLS pipeline II=1
                for (int j = 0; j < 16; j++) {
                    #pragma HLS unroll
                    ip_result[j] += q_vec[j][i] * k_cache[cand_pos][i];
                }
            }
            float ind_score = 0.0f;
            for(int i = 0; i < 16; i++) {
                #pragma HLS pipeline II=1
                ind_score += weight_pack[i] * ip_result[i];
            }

            if(ind_score > min_val){
                // scan and track topk
                topk_val[min_id] = ind_score;
                topk_id[min_id] = cand_id;
                min_val = topk_val[0];
                min_id = 0;
                for(int i = 1; i < TOP_K; i++){
                    #pragma HLS pipeline II=1
                    if(topk_val[i] < min_val){
                        min_val = topk_val[i];
                        min_id = i;
                    }
                }
            }
        }
    }

    // pack and send
    tapa::vec_t<int, 16> topk_id_pack;
    for(int i = 0; i < TOP_K; i++){
        #pragma HLS unroll
        topk_id_pack[i] = topk_id[i];
    }
    topk_id_fifo.write(topk_id_pack);
}

#endif // __INDEXER_H__