#pragma once

#include "task.h"
#include "trace.h"

namespace GPT {

void attn_cpu_kernel(std::vector<Task>& tasks,
                     torch::Tensor& q,
                     torch::Tensor& output,
                     int n_head,
                     int chunk_size,
                     int partition,
                     Trace* trace);

} // namespace GPT