#include "chunk_allocator.h"
#include <utility>
#include "str_utils.h"
#include "logging.h"

namespace GPT {

ChunkAllocator::ChunkAllocator(int memory_mb,
                               int chunk_size,
                               int n_head,
                               int d_embed,
                               torch::TensorOptions& options)
  : memory_mb_(memory_mb)
  , chunk_size_(chunk_size)
  , n_head_(n_head)
  , d_embed_(d_embed)
  , t_options_(options) {
    float size_of_data = 4;
    if (options.dtype() == torch::kFloat64) {
        size_of_data = 8;
    } else if (options.dtype() == torch::kFloat32 || options.dtype() == torch::kFloat) {
        size_of_data = 4;
    } else if (options.dtype() == torch::kFloat16) {
        size_of_data = 2;
    } else if (options.dtype() == torch::kInt8) {
        size_of_data = 1;
    } else {
        throw std::runtime_error(fmt_str("unsupported dtype {}", options.dtype()));
    }

    float size_of_chunk =
      chunk_size_ * n_head_ * d_embed_ * size_of_data * 2; // 2 for key and value
    max_chunks_ = std::floor(float(memory_mb_) * 1024 * 1024 / size_of_chunk);
    LOG_INFO(
      "ChunkAllocator: {} MB memory in total, each chunk requires {} KB memory, max possible chunks {}",
      memory_mb,
      size_of_chunk / 1024,
      max_chunks_);
}

Chunk* ChunkAllocator::allocate() {
    if (free_set_.size() > 0) {
        auto ite = free_set_.begin();
        used_set_.insert(*ite);
        free_set_.erase(ite);
        return *ite;
    }

    if (full()) {
        throw std::runtime_error("ChunkAllocator is full");
    } else {
        auto chunk = std::make_shared<Chunk>(chunk_size_, n_head_, d_embed_, t_options_);
        chunks_.push_back(chunk);
        used_set_.insert(chunk.get());
        return chunk.get();
    }
}

Chunk* ChunkAllocator::allocate(Chunk& other) {
    Chunk* chunk = allocate();
    chunk->deep_copy(other);
    return chunk;
}

Chunk* ChunkAllocator::allocate(std::vector<int>& ids,
                                torch::Tensor& k,
                                torch::Tensor& v,
                                int start,
                                int end) {
    Chunk* chunk = allocate();
    chunk->append_tokens(ids, k, v, start, end);
    return chunk;
}

Chunk* ChunkAllocator::allocate(torch::Tensor& k, torch::Tensor& v, int start, int end) {
    Chunk* chunk = allocate();
    chunk->append_tokens(k, v, start, end);
    return chunk;
}

void ChunkAllocator::free(Chunk* chunk) {
    auto ite = used_set_.find(chunk);
    if (ite == used_set_.end()) {
        throw std::runtime_error("free is called, but chunk not found in used set");
    }
    used_set_.erase(ite);
    (*ite)->clear();
    free_set_.insert(*ite);
}

bool ChunkAllocator::full() const {
    return free_set_.size() + used_set_.size() >= max_chunks_;
}

} // namespace GPT
