#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <torch/extension.h>
#include <vector>
#include "kv_cache.hpp"

namespace py = pybind11;
using torch::indexing::Slice;

KVCache::KVCache(int num_hidden_layers,
            int num_key_value_heads,
            int max_cache_len,
            int head_dim,
            torch::Device device,
            torch::Dtype dtype,
            int num_seqs)
        : num_hidden_layers(num_hidden_layers),
          num_key_value_heads(num_key_value_heads),
          max_cache_len(max_cache_len),
          head_dim(head_dim),
          device(device),
          dtype(dtype),
          num_seqs(num_seqs) {
        // Initialize KV cache tensor
        k_cache = torch::empty(
            {num_hidden_layers, num_seqs, num_key_value_heads, max_cache_len, head_dim},
            torch::TensorOptions().device(device).dtype(dtype)
        );
        v_cache = torch::empty(
            {num_hidden_layers, num_seqs, num_key_value_heads, max_cache_len, head_dim},
            torch::TensorOptions().device(device).dtype(dtype)
        );
        valid_len = 0;
        // Initialize free indices
        for (int i = 0; i < max_cache_len; ++i) {
            free_indices.push_back(i);
        }
    }

int KVCache::num_allocated() const {
    return allocated_indices.size();
}

std::vector<int> KVCache::allocate(int sequence_length) {
    if (static_cast<int>(free_indices.size()) < sequence_length) {
        throw std::runtime_error("Not enough space in cache");
    }
    std::vector<int> indices(free_indices.begin(), free_indices.begin() + sequence_length);
    free_indices.erase(free_indices.begin(), free_indices.begin() + sequence_length);
    allocated_indices.insert(allocated_indices.end(), indices.begin(), indices.end());
    return indices;
}

void KVCache::free(const std::vector<int>& indices, const std::vector<int>& unfree_indices) {
    std::map<int, int> index_map;
    for (int i = 0; i < allocated_indices.size(); ++i) {
        index_map[allocated_indices[i]] = i;
    }
    for (int idx : indices) {
        free_indices.push_back(idx);
        auto it = std::find(allocated_indices.begin(), allocated_indices.end(), idx);
        if (it == allocated_indices.end()) {
            throw std::runtime_error("Index not found in allocated indices");
        }
        allocated_indices.erase(it);
    }
    valid_len = valid_len - indices.size() - unfree_indices.size();
    std::vector<int> unfree_idxs;
    for (int idx : unfree_indices) {
        unfree_idxs.push_back(index_map[idx]);
    }
    // 1) Define a real std::vector of TensorIndex
    using namespace torch::indexing;
    std::vector<TensorIndex> indxs = {
        Slice(),    // all on dim 0
        Slice(),    // all on dim 1
        Slice(),    // all on dim 2
        torch::tensor(unfree_idxs, torch::TensorOptions().dtype(torch::kLong).device(torch::kCPU)), // unfree on dim 3
        Slice()     // all on dim 4
    };

    // 2) Use the same indxs for reading...
    at::Tensor k_sel = k_cache.index(indxs);
    at::Tensor v_sel = v_cache.index(indxs);
    // 3) ... and writing
    k_cache.index_put_({Slice(), Slice(), Slice(), Slice(valid_len, valid_len + unfree_idxs.size()), Slice()}, k_sel);
    v_cache.index_put_({Slice(), Slice(), Slice(), Slice(valid_len, valid_len + unfree_idxs.size()), Slice()}, v_sel);
    valid_len = valid_len + unfree_idxs.size();
}

py::tuple KVCache::update(torch::Tensor key_states,
                    torch::Tensor value_states,
                    int layer_idx,
                    const std::vector<int>& indices,
                    const std::vector<int>& new_indices) {
    auto sizes = key_states.sizes();
    int b = sizes[0], t = sizes[1], s = sizes[2], h = sizes[3];

    k_cache.index_put_({layer_idx, Slice(), Slice(), Slice(valid_len, valid_len + s), Slice()}, key_states);
    v_cache.index_put_({layer_idx, Slice(), Slice(), Slice(valid_len, valid_len + s), Slice()}, value_states);
    auto k_cache_slice = k_cache.index({layer_idx, Slice(), Slice(), Slice(0, valid_len + s), Slice()});
    auto v_cache_slice = v_cache.index({layer_idx, Slice(), Slice(), Slice(0, valid_len + s), Slice()});
    if (num_hidden_layers == layer_idx + 1) {
        // Update valid length
        valid_len = valid_len + s;
    }
    return py::make_tuple(k_cache_slice, v_cache_slice);
}

int KVCache::get_seq_length() const {
    // Placeholder, implement as needed
    return 0;
}