#include <utility>
#include "chunk.h"

namespace GPT {

Chunk::Chunk(int capacity, int n_head, int d_embed, torch::TensorOptions& options) {
    n_seqs = 1;
    tokens.reserve(capacity);
    n_tokens_ = 0;
    key_ = torch::zeros({ n_head, capacity, d_embed }, options);
    key_ptr_ = key_.data_ptr();
    value_ = torch::zeros({ n_head, capacity, d_embed }, options);
    value_ptr_ = value_.data_ptr();
}

Chunk::Chunk(int capacity,
             int n_head,
             int d_embed,
             torch::Tensor& k,
             torch::Tensor& v,
             int start,
             int end,
             torch::TensorOptions& options)
  : Chunk(capacity, n_head, d_embed, options) {
    key_.slice(1, 0, end - start) = k.slice(1, start, end);
    value_.slice(1, 0, end - start) = v.slice(1, start, end);
    n_tokens_ = end - start;
}

Chunk::Chunk(int capacity,
             int n_head,
             int d_embed,
             std::vector<int>& ids,
             torch::Tensor& k,
             torch::Tensor& v,
             int start,
             int end,
             torch::TensorOptions& options)
  : Chunk(capacity, n_head, d_embed, k, v, start, end, options) {
    std::copy(ids.begin() + start, ids.begin() + end, std::back_inserter(this->tokens));
}

Chunk::Chunk(Chunk const& other) {
    deep_copy(other);
}

bool Chunk::equal(Chunk& other) {
    if (this->n_tokens() != other.n_tokens()) {
        return false;
    }

    for (int i = 0; i < this->n_tokens(); i++) {
        if (this->tokens[i] != other.tokens[i]) {
            return false;
        }
    }

    return true;
}

bool Chunk::equal(std::vector<int>& ids, int start, int end) {
    int n = end - start;
    if (this->n_tokens() != n) {
        return false;
    }
    for (int i = 0; i < n; i++) {
        if (this->tokens[i] != ids[start + i]) {
            return false;
        }
    }
    return true;
}

void Chunk::append_tokens(torch::Tensor& k, torch::Tensor& v, int start, int end) {
    int n = end - start;
    if (n > capacity() - n_tokens()) {
        throw std::runtime_error("no enough space in Chunk");
    }

    key_.slice(1, n_tokens(), n_tokens() + n) = k.slice(1, start, end);
    value_.slice(1, n_tokens(), n_tokens() + n) = v.slice(1, start, end);
    n_tokens_ += n;
}

void Chunk::append_tokens(std::vector<int>& ids,
                          torch::Tensor& k,
                          torch::Tensor& v,
                          int start,
                          int end) {
    append_tokens(k, v, start, end);
    tokens.insert(tokens.end(), ids.begin() + start, ids.begin() + end);
}

void Chunk::deep_copy(Chunk const& other) {
    if (this == &other) {
        return;
    }

    if (other.children.size() > 0) {
        throw std::runtime_error("Copy a non tailing chunk is forbidden");
    }

    n_seqs = 1;
    tokens = other.tokens;
    n_tokens_ = other.n_tokens_;
    key_.copy_(other.key_);
    value_.copy_(other.value_.clone());
}

Chunk* Chunk::find_child(std::vector<int>& ids, int start, int end) {
    for (auto child : this->children) {
        if (child->equal(ids, start, end)) {
            return child;
        }
    }
    return nullptr;
}

Chunk* Chunk::add_child(Chunk* child) {
    this->children.push_back(child);
    child->parent = this;
    return child;
}

Chunk* Chunk::insert_child(int idx, Chunk* child) {
    std::vector<Chunk*>::iterator it = this->children.begin() + idx;
    this->children.insert(it, child);
    child->parent = this;
    return child;
}

std::string Chunk::to_string(bool brief) {
    std::vector<std::string> string_tokens;
    if (brief && tokens.size() > 3) {
        string_tokens.push_back(std::to_string(tokens[0]));
        string_tokens.push_back("...");
        string_tokens.push_back(std::to_string(tokens[tokens.size() - 1]));
    } else {
        std::transform(tokens.begin(),
                       tokens.end(),
                       std::back_inserter(string_tokens),
                       [](int num) { return std::to_string(num); });
    }
    std::string s = fmt_str("%d: [%s]", n_seqs, join_str(string_tokens, ",").c_str());
    return s;
}

} // namespace GPT