#pragma once

#include <memory>
#include <set>
#include "chunk.h"

namespace GPT {

class ChunkAllocator {
  public:
    ChunkAllocator(int memory_mb,
                   int chunk_size,
                   int n_head,
                   int d_embed,
                   torch::TensorOptions& options);
    ChunkAllocator(ChunkAllocator const& other) = delete;
    virtual ~ChunkAllocator() = default;

    Chunk* allocate();
    Chunk* allocate(Chunk& other);
    Chunk* allocate(std::vector<int>& ids, torch::Tensor& k, torch::Tensor& v, int start, int end);
    Chunk* allocate(torch::Tensor& k, torch::Tensor& v, int start, int end);
    void free(Chunk* chunk);

    bool full() const;

  private:
    std::list<std::shared_ptr<Chunk>> chunks_;
    std::set<Chunk*> free_set_;
    std::set<Chunk*> used_set_;
    int memory_mb_;
    int chunk_size_;
    int n_head_;
    int d_embed_;
    torch::TensorOptions t_options_;
    int max_chunks_;
};

}