#include <limits>
#include <stack>
#include "attention.h"
#include "spin_lock.h"
#include "logging.h"

namespace GPT {

Attention::Attention(int n_head,
                     int d_embed,
                     int chunk_size,
                     int memory_mb,
                     bool share_prefix,
                     torch::Dtype dtype,
                     torch::Device device)
  : n_head_(n_head)
  , d_embed_(d_embed)
  , chunk_size_(chunk_size)
  , share_prefix_(share_prefix)
  , t_options_(torch::TensorOptions()
                 .dtype(dtype)
                 .layout(torch::kStrided)
                 .device(device)
                 .requires_grad(false))
  , chunk_allocator_(memory_mb, chunk_size, n_head, d_embed, t_options_)
#ifdef USE_CUDA
//    , gpu_kernel_(n_head, chunk_size)
#endif
{
    root_ = std::make_shared<Chunk>(n_head, 0, d_embed, t_options_);
    root_->n_seqs = 0;

    if (t_options_.device().is_cpu()) {
        int cpu_num_threads = torch::get_num_threads();
        LOG_INFO("ChunkAttn: Runs on CPU. Set omp num of threads to {}", cpu_num_threads);
#ifdef USE_MKL
        omp_set_dynamic(0); // Explicitly disable dynamic teams
        omp_set_num_threads(cpu_num_threads);
#else
        throw std::runtime_error("MKL is not enabled");
#endif
    } else if (t_options_.device().is_cuda()) {
#ifdef USE_CUDA
        LOG_INFO("ChunkAttn: Runs on cuda:{}", t_options_.device().index());
#else
        throw std::runtime_error("CUDA is not enabled");
#endif
    } else {
        throw std::runtime_error("Unsupported device");
    }
}

// q: [n_head, n_seqs, d_embed]
torch::Tensor Attention::forward(torch::Tensor q, int partition, Trace* trace) {
    // we manipulate tensor data by raw data pointer: data_ptr<float>().
    // so we need to make sure the tensor is contiguous. 
    // q might be a transpose view and is non-contiguous. 
    // A transpose of a tensor creates a view of the original tensor 
    // which follows non-contiguous order.
    // https://www.tutorialspoint.com/pytorch-how-to-check-if-a-tensor-is-contiguous-or-not
    if (!q.is_contiguous()) {
        throw std::runtime_error("q is not contiguous");
    }

    std::vector<Task> tasks = get_chunks();
    torch::Tensor output = torch::zeros({ n_head_, q.size(1), d_embed_ }, t_options_);

#ifdef USE_MKL
    if (t_options_.device().is_cpu()) {
        attn_cpu_kernel(tasks, q, output, n_head_, chunk_size_, partition, trace);
    }
#endif
#ifdef USE_CUDA
    if (t_options_.device().is_cuda()) {
        gpu_kernel_.attention(q, tasks, output, t_options_, partition, trace);
    }
#endif

    return output;
}

Chunk* Attention::at(int seq_idx) {
    Chunk* chunk = root_.get();
    int start = 0;
    while (chunk->children.size() > 0) {
        for (auto& child : chunk->children) {
            int end = start + child->n_seqs;
            if (start <= seq_idx && seq_idx < end) {
                chunk = child;
                break;
            }
            start = end;
        }
    }
    return chunk;
}

void Attention::remove(int seq_idx) {
    Chunk* seq_tail = at(seq_idx);
    do {
        seq_tail->n_seqs -= 1;
        Chunk* chunk_parent = seq_tail->parent;
        if (seq_tail->n_seqs <= 0) {
            chunk_parent->children.erase(
              std::remove(chunk_parent->children.begin(), chunk_parent->children.end(), seq_tail),
              chunk_parent->children.end());
            chunk_allocator_.free(seq_tail);
        }
        seq_tail = chunk_parent;
    } while (seq_tail != root_.get());
}

void Attention::duplicate(int seq_idx, int copies) {
    if (copies < 1) {
        return;
    }

    Chunk* chunk = root_.get();
    int start = 0;
    while (chunk->children.size() > 0) {
        chunk->n_seqs += copies;
        for (auto& child : chunk->children) {
            int end = start + child->n_seqs;
            if (start <= seq_idx && seq_idx < end) {
                chunk = child;
                break;
            }
            start = end;
        }
    }

    auto it = std::find(chunk->parent->children.begin(), chunk->parent->children.end(), chunk);
    int distance = it - chunk->parent->children.begin() + 1;
    for (int i = 0; i < copies; i++) {
        auto chunk_copy = chunk_allocator_.allocate(*chunk);
        chunk->parent->insert_child(distance, chunk_copy);
        distance += 1;
    }
}

int Attention::add_prompt(std::vector<int>& tokens, torch::Tensor k, torch::Tensor v) {
    assert(k.size(1) == v.size(1));
    int seq_idx = 0;
    int n_tokens = tokens.size();
    Chunk* prev = root_.get();
    // if the remaining tokens are less than chunk_size_, then a new
    // chunk must be created in order not to share with other chunks

    prev->n_seqs += 1;
    int start = 0;

    // skip the common prefix
    if (share_prefix_) {
        int last_possible = n_tokens - chunk_size_;
        for (; start <= last_possible; start += chunk_size_) {
            auto next = prev->find_child(tokens, start, start + chunk_size_);
            if (next) {
                next->n_seqs += 1;
                for (auto it = prev->children.begin(); *it != next; it++) {
                    seq_idx += (*it)->n_seqs;
                }
                prev = next;
                continue;
            } else {
                break;
            }
        }
    }
    seq_idx += (prev->n_seqs - 1);

    for (; start < n_tokens; start += chunk_size_) {
        int end = std::min(start + chunk_size_, n_tokens);
        auto child = chunk_allocator_.allocate(tokens, k, v, start, end);
        prev = prev->add_child(child);
    }

    // make sure full chunks are not the tails
    if (prev->full()) {
        auto child = chunk_allocator_.allocate();
        prev = prev->add_child(child);
    }

    return seq_idx;
}

void Attention::append_completions(std::vector<int>& tokens, torch::Tensor k, torch::Tensor v) {
    int visited = 0;
    std::stack<Chunk*> stack;
    stack.push(root_.get());

    while (!stack.empty()) {
        Chunk* current = stack.top();
        stack.pop();

        // If it's a leaf node, process it
        if (current->children.size() == 0) {
            current->append_tokens(tokens, k, v, visited, visited + 1);
            // make sure full chunks are not the tails
            if (current->full()) {
                auto next = chunk_allocator_.allocate();
                current->add_child(next);
            }
            visited += 1;
            continue;
        }

        for (auto ite = current->children.rbegin(); ite != current->children.rend(); ite++) {
            stack.push(*ite);
        }
    }
}

int Attention::append_completion(int seq_idx,
                                 std::vector<int>& tokens,
                                 torch::Tensor k,
                                 torch::Tensor v) {
    // locate the seq.
    auto chunk = at(seq_idx);
    // the tail chunk can not be full for sure
    assert(!chunk->full());
    assert(chunk->children.size() == 0);
    assert(chunk->n_seqs == 1);

    if (tokens.size() > 1) {
        auto it = std::find(chunk->parent->children.begin(), chunk->parent->children.end(), chunk);
        int distance = it - chunk->parent->children.begin() + 1;
        for (int i = 1; i < tokens.size(); i++) {
            auto chunk_copy = chunk_allocator_.allocate(*chunk);
            chunk_copy->append_tokens(tokens, k, v, i, i + 1);
            chunk->parent->insert_child(distance, chunk_copy);
            distance += 1;

            // make sure full chunks are not the tails
            if (chunk_copy->full()) {
                auto next = chunk_allocator_.allocate();
                chunk_copy->add_child(next);
            }
        }
    }
    chunk->append_tokens(tokens, k, v, 0, 1);
    // make sure full chunks are not the tails
    if (chunk->full()) {
        auto next = chunk_allocator_.allocate();
        chunk->add_child(next);
    }
    return seq_idx;
}

/*
def print_tree(node, level=0):
    indent = "  " * level
    print(indent + node.data)

    for child in node.children:
        print_tree(child, level + 1)

*/
void Attention::print(Chunk* root, int level) {
    if (root == nullptr) {
        root = root_.get();
    }

    std::string indent;
    for (int i = 0; i < level; ++i) {
        indent += "    ";
    }
    std::cout << indent;

    std::cout << root->to_string(true) << std::endl;

    for (auto child : root->children) {
        print(child, level + 1);
    }
}

std::vector<Task> Attention::get_chunks() {
    auto start = std::chrono::high_resolution_clock::now();
    std::list<std::tuple<int, Chunk*>> queue;
    queue.push_back(std::make_tuple(0, root_.get()));
    std::vector<Task> tasks;
    while (!queue.empty()) {
        auto& item = queue.front();
        int seq_begin = std::get<0>(item);
        auto chunk = std::get<1>(item);
        int seq_end = seq_begin + chunk->n_seqs;
        queue.pop_front();

        int start = seq_begin;
        for (auto& child : chunk->children) {
            int end = start + child->n_seqs;
            queue.push_back(std::make_tuple(start, child));
            start = end;
        }

        if (chunk == root_.get() || chunk->n_tokens() <= 0) {
            continue;
        }

        tasks.emplace_back(chunk, seq_begin, seq_end);
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start);
    // std::cout << "get_chunks(us): " << duration.count() << std::endl;
    return tasks;
}

std::vector<std::tuple<torch::Tensor, torch::Tensor, int, int, int>> Attention::get_chunks_raw() {
    std::vector<Task> tasks = get_chunks();
    std::vector<std::tuple<torch::Tensor, torch::Tensor, int, int, int>> result;
    for (auto& t:tasks) {
        result.emplace_back(t.chunk->key(),
                            t.chunk->value(),
                            t.chunk->n_tokens(),
                            t.seq_idx_begin,
                            t.seq_idx_end);
    }
    return result;
}

} // namespace GPT