#pragma once

#include <cstdint>
#include <cuda_runtime.h>


namespace group_norm_v2 {

struct Meta {
    int64_t red_buffer_size;
    int64_t barrier_size;
    int BLOCK_DIM_X;
    int C_PER_BLOCK;
    int ROWS_PER_BLOCK;
    int VEC_ELEMS;
    bool LOAD_TWICE;
    int BLOCKS_PER_SM;
    bool HARDWARE_CLUSTER;
    int wgrad_sync_method;
};

template<typename T>
void gn_cuda(T *out, T *x, T *w, T *b, float eps, bool silu, int64_t n, int64_t hw, int num_groups, int channels_per_group, float *mean_var_out, float *red_buffer, unsigned *barrier, int sm_margin, cudaStream_t stream, int device_id, Meta *meta_ptr, bool meta_only);

template<typename T>
void gn_bwd_cuda(T *grad_input, T *grad_weight, T *grad_bias, T *grad_output, T *x, T *w, T *b, float *mean_var, float eps, bool silu, int64_t n, int64_t hw, int num_groups, int channels_per_group, float *red_buffer, unsigned *barrier, int sm_margin, cudaStream_t stream, int device_id, Meta *meta_ptr, bool meta_only);

}  // namespace group_norm_v2
