#include <iostream>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>

#include <cuda.h>
#include <cuda_runtime.h>

#include <vector>

#include "type_shim.h"
#include "compat.h"


__device__ __forceinline__ int lastpow2(int n)
{
  int out = 1 << (31 - __clz(n));
  if(n == out)
    out >>= 1;
  return out;
}

__host__ __forceinline__ int h_next_pow2(unsigned int n) {
    n--;
    n |= (n >>  1);
    n |= (n >>  2);
    n |= (n >>  4);
    n |= (n >>  8);
    n |= (n >> 16);
    return ++n;
}

__host__ __forceinline__ int h_last_pow2(unsigned int n) {
    n |= (n >>  1);
    n |= (n >>  2);
    n |= (n >>  4);
    n |= (n >>  8);
    n |= (n >> 16);
    return n - (n >> 1);
}


#define WARP_SIZE 32

template<typename T>
__device__ __forceinline__ T warp_reduce_sum(T val)
{
  #pragma unroll
  for(int i = WARP_SIZE/2; i > 0; i >>= 1)
    val = val + __shfl_down_sync(0xffffffff, val, i);
  return val;
}

template<typename T>
__device__ __forceinline__ T reduce_block(T *x, T val)
{
  int tid = threadIdx.y*blockDim.x + threadIdx.x;
  int blockSize = blockDim.x * blockDim.y;

  if (blockSize > 32) {
    val = warp_reduce_sum(val);
    if (tid % WARP_SIZE == 0)
      x[tid/WARP_SIZE] = val;

    __syncthreads();

    val = (tid < blockSize / WARP_SIZE? x[tid%WARP_SIZE] : T(0));
  }

  if(tid/WARP_SIZE==0) val = warp_reduce_sum(val);

  return val;
}

#define ELEMENTS_PER_ITER 4 // enables concurrency within each thread to hide latency
#define ELEMENTS_PER_THREAD 16
#define OPTIMAL_TILE_W 32
#define MAX_H_BLOCK 128
#define MAX_BLOCK_SIZE 512

__host__ int div_ru(int x, int y) {
  return h_last_pow2(1 + (x-1)/y);
}

__host__ void flexible_launch_configs(
      const int reduction,
      const int stride,
      dim3 &block,
      dim3 &grid,
      const bool coop_flag = false) {
  int block_x = std::min(h_last_pow2(stride), OPTIMAL_TILE_W);
  int block_y = std::min(h_last_pow2(div_ru(reduction , ELEMENTS_PER_THREAD)),
                         MAX_BLOCK_SIZE / block_x);
  if (block_x * block_y != MAX_BLOCK_SIZE) {
    block_x = std::min(h_last_pow2(stride), MAX_BLOCK_SIZE / block_y);
  }

  int grid_x = div_ru(stride, block_x);
  int grid_y = std::min(div_ru(reduction, block_y * ELEMENTS_PER_THREAD), MAX_H_BLOCK);
  if (coop_flag) {
    // it's not worth having a grid reduction if the reduction dimension is not big enough
    grid_y = grid_y < 8 ? 1 : grid_y;
  }

  block.x = block_x;
  block.y = block_y;
  block.z = 1;
  grid.x = grid_x;
  grid.y = grid_y;
  grid.z = 1;
}

template<typename T, typename C>
__device__ __forceinline__ void welford_merge_element(C& count,
                                                      T& mean,
                                                      T& m2n,
                                                      const C& num_new,
                                                      const T& mean_new,
                                                      const T& m2n_new) {
      T factor = T(1.0) / max(1, (count + num_new));
      T delta0 = mean - mean_new;
      mean = (mean_new * num_new + mean * count) * factor;
      m2n += m2n_new + delta0 * delta0 * num_new * count * factor;
      count += num_new;
}

template<typename T>
__device__ __forceinline__ void warp_reduce_mean_m2n(T &mean, T &m2n, int &num)
{
  #pragma unroll
  for(int i = WARP_SIZE/2; i > 0; i >>= 1) {
    auto num_new = __shfl_down_sync(0xffffffff, num, i);
    auto mean_new = __shfl_down_sync(0xffffffff, mean, i);
    auto m2n_new = __shfl_down_sync(0xffffffff, m2n, i);
    welford_merge_element(num, mean, m2n, num_new, mean_new, m2n_new);
  }
}

template <typename T>
__device__ void welford_reduce_mean_m2n(
      T* __restrict__ x,
      int* __restrict__ count,
      T &mean,
      T &m2n,
      int &num,
      int block_size,
      int thread_id)
{
  int lane = thread_id % WARP_SIZE;
  int wid = thread_id / WARP_SIZE;

  if (block_size > 32) {
    warp_reduce_mean_m2n(mean, m2n, num);
    if (lane == 0) {
      x[wid*2] = mean;
      x[wid*2+1] = m2n;
      count[wid] = num;
    }
    __syncthreads();

    if (wid == 0) {
      mean = (thread_id < block_size / WARP_SIZE)? x[lane*2] : T(0);
      m2n = (thread_id < block_size / WARP_SIZE)? x[lane*2+1] : T(0);
      num = (thread_id < block_size / WARP_SIZE)? count[lane] : int(0);
    }
  }

  if (wid==0) warp_reduce_mean_m2n(mean, m2n, num);

  return;
}

// return spatial size for NC+ Tensors
__host__ int get_tensor_spatial_size(const at::Tensor& input)
{
  auto space_size = input.size(2);
  for (int i = 3; i < input.ndimension(); i++) {
    space_size *= input.size(i);
  }
  return space_size;
}

// promote accumulation scalar type. promote half to float.
__host__ at::ScalarType promote_scalartype(const at::Tensor& input)
{
  return input.scalar_type() == at::ScalarType::Half ?
           at::ScalarType::Float : input.scalar_type();
}

// return single element size, optional accumulation type promotion.
__host__ size_t get_element_data_size(const at::Tensor& input, bool accumulation = false)
{
  auto scalar_type = accumulation ? promote_scalartype(input) : input.scalar_type();
  return at::elementSize(scalar_type);
}

template<typename T, typename C>
__device__ __forceinline__ void welford_merge_block_vertical(C& count,
                                                             T& mean,
                                                             T& m2n,
                                                             C* shmem_count,
                                                             T* shmem_mean,
                                                             T* shmem_m2n) {
  // write to shared memory
  auto address_base = threadIdx.x + threadIdx.y * blockDim.x;
  shmem_mean[address_base] = mean;
  shmem_m2n[address_base] = m2n;
  shmem_count[address_base] = count;

#pragma unroll
  for (int offset = blockDim.y/2; offset > 0; offset >>= 1) {
    __syncthreads();
    if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
      auto address = address_base + offset * blockDim.x;
      // read shared memory back to register for reduction
      auto num_new = shmem_count[address];
      auto mean_new = shmem_mean[address];
      auto m2n_new = shmem_m2n[address];

      welford_merge_element(count, mean, m2n, num_new, mean_new, m2n_new);

      // last write is not necessary
      shmem_mean[address_base] = mean;
      shmem_m2n[address_base] = m2n;
      shmem_count[address_base] = count;
    }
  }
}

template<typename T>
__device__ __forceinline__ void merge_block_vertical(T& sum_dy,
                                                     T& sum_dy_xmu,
                                                     T* shmem_sum_dy,
                                                     T* shmem_sum_dy_xmu) {
  // write to shared memory
  auto address_base = threadIdx.x + threadIdx.y * blockDim.x;
  shmem_sum_dy[address_base] = sum_dy;
  shmem_sum_dy_xmu[address_base] = sum_dy_xmu;

#pragma unroll
  for (int offset = blockDim.y/2; offset > 0; offset >>= 1) {
    __syncthreads();
    if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
      auto address = address_base + offset * blockDim.x;

      sum_dy += shmem_sum_dy[address];
      sum_dy_xmu += shmem_sum_dy_xmu[address];

      // last write is not necessary
      shmem_sum_dy[address_base] = sum_dy;
      shmem_sum_dy_xmu[address_base] = sum_dy_xmu;
    }
  }
}


// welford kernel calculating mean/biased_variance/unbiased_variance
template <typename scalar_t, typename accscalar_t, typename outscalar_t>
__global__ void welford_kernel(
      const scalar_t* __restrict__ input,
      outscalar_t* __restrict__ out_mean,
      outscalar_t* __restrict__ out_var_biased,
      const int bs,
      const int fs,
      const int ss) {
  int block_size = blockDim.x * blockDim.y;
  int count = 0;
  accscalar_t x_mean = accscalar_t(0);
  accscalar_t m_2_n = accscalar_t(0);

  int thread_id = threadIdx.y*blockDim.x + threadIdx.x;

  for (int batch_id = threadIdx.y; batch_id < bs; batch_id += blockDim.y) {
    int input_base = blockIdx.x*ss + batch_id*ss*fs;
    // sequential welford
    for (int offset = threadIdx.x; offset < ss ; offset += blockDim.x) {
      count++;
      auto x_n = static_cast<accscalar_t>(input[offset+input_base]);
      auto d = x_n - x_mean;
      x_mean += d / count;
      m_2_n += d * (x_n - x_mean);
    }
  }

  static __shared__ int s_mem[160];
  accscalar_t* s_mem_ac = (accscalar_t*) &s_mem[32];

  welford_reduce_mean_m2n<accscalar_t>(s_mem_ac, s_mem, x_mean, m_2_n, count, block_size, thread_id);

  if (thread_id == 0) {
    out_mean[blockIdx.x] = static_cast<outscalar_t>(x_mean);
    out_var_biased[blockIdx.x] = static_cast<outscalar_t>(m_2_n/count);
  }
}

// elementwise BN kernel
template <typename scalar_t, typename accscalar_t, typename layerscalar_t>
__global__ void batchnorm_forward_kernel(
      const scalar_t* __restrict__ input,
      const accscalar_t* __restrict__ mean,
      const accscalar_t* __restrict__ inv_std,
      const layerscalar_t* __restrict__ weight,
      const layerscalar_t* __restrict__ shift,
      scalar_t* __restrict__ out,
      const int ss,
      const int bs) {
  auto m_c = mean[blockIdx.x];
  auto inv_std_c = inv_std[blockIdx.x];
  auto w_c = weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[blockIdx.x]);
  auto s_c = shift == NULL ? accscalar_t(0.0) : static_cast<accscalar_t>(shift[blockIdx.x]);

  for (int batch_offset = blockIdx.y*blockDim.y + threadIdx.y; batch_offset < bs; batch_offset += gridDim.y*blockDim.y) {
    int address_base = blockIdx.x*ss + batch_offset*gridDim.x*ss;
    for (int offset = threadIdx.x + blockIdx.z*blockDim.x; offset < ss ; offset+= gridDim.z*blockDim.x) {
      out[address_base+offset] = static_cast<scalar_t>(w_c * (static_cast<accscalar_t>(input[address_base+offset]) - m_c ) * inv_std_c + s_c);
    }
  }
}

// Backward BN kernel, calculates grad_bias, grad_weight as well as intermediate
// results to calculating grad_input.
// Breaking the grad_input to two step to support sync BN, which requires all
// reduce of the intermediate results across processes.
template <typename scalar_t, typename accscalar_t, typename layerscalar_t>
__global__ void reduce_bn_kernel(
      const scalar_t* __restrict__ input,
      const scalar_t* __restrict__ grad_output,
      const accscalar_t* __restrict__ mean,
      const accscalar_t* __restrict__ inv_std,
      accscalar_t* __restrict__ sum_dy_o,
      accscalar_t* __restrict__ sum_dy_xmu_o,
      layerscalar_t* __restrict__ grad_weight,
      layerscalar_t* __restrict__ grad_bias,
      const int bs,
      const int fs,
      const int ss) {
  static __shared__ int s_mem[64];
  //int total_item_num = bs * ss;

  int thread_id = threadIdx.y*blockDim.x + threadIdx.x;

  auto r_mean = mean[blockIdx.x];
  auto factor = inv_std[blockIdx.x];

  // Kahan sum
  accscalar_t sum_dy = 0.0;
  accscalar_t sum_dy_xmu = 0.0;
  accscalar_t sum_dy_c = 0.0;
  accscalar_t sum_dy_xmu_c = 0.0;
  for (int batch_id = threadIdx.y; batch_id < bs; batch_id += blockDim.y) {
    int input_base = blockIdx.x*ss + batch_id*ss*fs;
    for (int offset = threadIdx.x; offset < ss ; offset += blockDim.x) {
      auto e_grad = static_cast<accscalar_t>(grad_output[offset+input_base]);
      auto e_input = static_cast<accscalar_t>(input[offset+input_base]);
      // calculating sum_dy
      auto sum_dy_y = e_grad - sum_dy_c;
      auto sum_dy_t = sum_dy + sum_dy_y;
      sum_dy_c = (sum_dy_t - sum_dy) - sum_dy_y;
      sum_dy = sum_dy_t;

      // calculating sum_dy_xmu
      auto sum_dy_xmu_y = e_grad * (e_input - r_mean) - sum_dy_xmu_c;
      auto sum_dy_xmu_t = sum_dy_xmu + sum_dy_xmu_y;
      sum_dy_xmu_c = (sum_dy_xmu_t - sum_dy_xmu) - sum_dy_xmu_y;
      sum_dy_xmu = sum_dy_xmu_t;
    }
  }

  sum_dy = reduce_block((accscalar_t*)s_mem, sum_dy);
  __syncthreads();
  sum_dy_xmu = reduce_block((accscalar_t*)s_mem, sum_dy_xmu);

  if (thread_id == 0) {
    if (grad_bias != NULL) {
      grad_bias[blockIdx.x] = static_cast<layerscalar_t>(sum_dy);
    }
    if (grad_weight != NULL) {
      grad_weight[blockIdx.x] = static_cast<layerscalar_t>(sum_dy_xmu * factor);
    }
    //mean_dy[blockIdx.x] = sum_dy / total_item_num;
    //mean_dy_xmu[blockIdx.x] = sum_dy_xmu / total_item_num;
    sum_dy_o[blockIdx.x] = sum_dy;
    sum_dy_xmu_o[blockIdx.x] = sum_dy_xmu;
  }
}

// elementwise backward BN kernel
template <typename scalar_t, typename accscalar_t, typename layerscalar_t>
__global__ void batchnorm_backward_kernel(
      const scalar_t* __restrict__ grad_output,
      const scalar_t* __restrict__ input,
      const accscalar_t* __restrict__ mean,
      const accscalar_t* __restrict__ inv_std,
      const layerscalar_t* __restrict__ weight,
      const accscalar_t* __restrict__ sum_dy,
      const accscalar_t* __restrict__ sum_dy_xmu,
      const int* __restrict__ numel,
      scalar_t* __restrict__ grad_input,
      const int64_t world_size,
      const int ss,
      const int bs) {
  int64_t div = 0;
  for (int i = 0; i < world_size; i++) {
    div += numel[i];
  }
  auto m_c = static_cast<accscalar_t>(mean[blockIdx.x]);
  //auto m_dy_c = static_cast<accscalar_t>(mean_dy[blockIdx.x]);
  auto m_dy_c = static_cast<accscalar_t>(sum_dy[blockIdx.x]) / div;
  auto factor_1_c = inv_std[blockIdx.x];
  auto factor_2_c = (weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[blockIdx.x])) * factor_1_c;
  //factor_1_c = factor_1_c * factor_1_c * mean_dy_xmu[blockIdx.x];
  factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[blockIdx.x] / div;

  for (int batch_offset = blockIdx.y*blockDim.y+threadIdx.y; batch_offset < bs; batch_offset += gridDim.y*blockDim.y) {
    int address_base = blockIdx.x*ss + batch_offset*gridDim.x*ss;
    for (int offset = threadIdx.x + blockIdx.z*blockDim.x; offset < ss ; offset+= gridDim.z*blockDim.x) {
      grad_input[address_base+offset] = (static_cast<accscalar_t>(grad_output[address_base+offset]) - m_dy_c - (static_cast<accscalar_t>(input[address_base+offset]) - m_c) * factor_1_c) * factor_2_c;
    }
  }
}

// welford kernel for c last tensor calculating mean/biased_variance/unbiased_variance
template
   <typename scalar_t,
    typename accscalar_t,
    typename outscalar_t,
    int PARALLEL_LOADS>
__global__ void
welford_kernel_c_last(
      const scalar_t* __restrict__ input,
      outscalar_t* __restrict__ out_mean,
      outscalar_t* __restrict__ out_var_biased,
      volatile accscalar_t* staging_data,
      int* semaphores,
      const int reduction_size,
      const int stride) {
  // hide latency with concurrency
  accscalar_t x_mean[PARALLEL_LOADS];
  accscalar_t m_2_n[PARALLEL_LOADS];
  int count[PARALLEL_LOADS];

#pragma unroll
  for (int i = 0; i < PARALLEL_LOADS; i++) {
    x_mean[i] = accscalar_t(0);
    m_2_n[i] = accscalar_t(0);
    count[i] = accscalar_t(0);
  }
  // tensor dimension (m,c)

  // loop along m dimension
  int inner_loop_stride = blockDim.y * gridDim.y;

  // offset along m dimension
  int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
  int c_offset = blockIdx.x * blockDim.x + threadIdx.x;

  int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
  int address_base = m_offset * stride + c_offset;
  int address_increment = inner_loop_stride * stride;

  for (int i = 0; i < loop_count; i++) {
    accscalar_t x_math[PARALLEL_LOADS];
    accscalar_t x_count_inv[PARALLEL_LOADS];
    accscalar_t is_valid[PARALLEL_LOADS];

    // load multiple data in
#pragma unroll
    for (int j = 0; j < PARALLEL_LOADS; j++) {
      if (c_offset < stride && m_offset < reduction_size) {
        x_math[j] = input[address_base];
        count[j]++;
        x_count_inv[j] = accscalar_t(1) / count[j];
        is_valid[j] = accscalar_t(1);
      } else {
        x_math[j] = accscalar_t(0);
        x_count_inv[j] = accscalar_t(0);
        is_valid[j] = accscalar_t(0);
      }
      m_offset += inner_loop_stride;
      address_base += address_increment;
    }

    // calculate mean/m2n with welford
#pragma unroll
    for (int j = 0; j < PARALLEL_LOADS; j++) {
      accscalar_t delta0 = x_math[j] - x_mean[j];
      x_mean[j] += delta0 * x_count_inv[j];
      accscalar_t delta1 = x_math[j] - x_mean[j];
      m_2_n[j] += delta0 * delta1 * is_valid[j];
    }
  }

  // thread reduction to accumulate mean/m_2_n/count between PARALLEL_LOADS
#pragma unroll
  for (int j = 1; j < PARALLEL_LOADS; j++) {
    welford_merge_element(count[0], x_mean[0], m_2_n[0], count[j], x_mean[j], m_2_n[j]);
  }

  // release x_mean / m_2_n
  auto mean_th = x_mean[0];
  auto m2_th = m_2_n[0];
  auto count_th = count[0];

  // block-wise reduction with shared memory (since reduction cannot be done within a warp)
  static __shared__ accscalar_t shmem_mean[MAX_BLOCK_SIZE];
  static __shared__ accscalar_t shmem_m2n[MAX_BLOCK_SIZE];
  static __shared__ int shmem_count[MAX_BLOCK_SIZE];

  welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n);

  // grid reduction if needed (coop launch used at the first place)
  if (gridDim.y > 1) {
    volatile accscalar_t* staging_mean = staging_data;
    volatile accscalar_t* staging_m2n = &staging_data[stride*gridDim.y];
    volatile int* staging_count = reinterpret_cast<volatile int*>(&staging_m2n[stride*gridDim.y]);

    address_base = c_offset + blockIdx.y * stride;
    // write data to staging_data;
    if (threadIdx.y == 0 && c_offset < stride) {
      staging_mean[address_base] = mean_th;
      staging_m2n[address_base] = m2_th;
      staging_count[address_base] = count_th;
    }

    __threadfence();
    __syncthreads(); // ensuring writes to staging_ is visible to all blocks

    __shared__ bool is_last_block_done;
    // mark block done
    if (threadIdx.x == 0 && threadIdx.y == 0) {
      int old = atomicAdd(&semaphores[blockIdx.x], 1);
      is_last_block_done = (old == (gridDim.y-1));
    }

    __syncthreads();

    // check that all data is now available in global memory
    if (is_last_block_done) {
      count_th = 0;
      mean_th = accscalar_t(0.0);
      m2_th = accscalar_t(0.0);

      for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) {
        address_base = c_offset + y * stride;
        int num_new = c_offset < stride ? staging_count[address_base] : 0;
        accscalar_t mean_new = c_offset < stride ? staging_mean[address_base] : accscalar_t(0.0);
        accscalar_t m2n_new = c_offset < stride ? staging_m2n[address_base] : accscalar_t(0.0);

        welford_merge_element(count_th, mean_th, m2_th, num_new, mean_new, m2n_new);
      }

      welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n);
      if (threadIdx.y == 0 && c_offset < stride) {
        out_mean[c_offset] = static_cast<outscalar_t>(mean_th);
        out_var_biased[c_offset] = static_cast<outscalar_t>(m2_th / count_th);
      }
    }
  } else {
    if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) {
      out_mean[c_offset] = static_cast<outscalar_t>(mean_th);
      out_var_biased[c_offset] = static_cast<outscalar_t>(m2_th / count_th);
    }
  }
}

// parallel welford kernel to further reduce mean / biased_var
// into mean / unbiased_var / inv_std across multiple processes.
template <typename scalar_t>
__global__ void welford_kernel_parallel(
      const scalar_t* __restrict__ mean,
      const scalar_t* __restrict__ var_biased,
      const int* __restrict__ numel,
      scalar_t* __restrict__ out_mean,
      scalar_t* __restrict__ out_var,
      scalar_t* __restrict__ inv_std,
      const int world_size,
      const int feature_size,
      const float eps) {

  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < feature_size; i += gridDim.x * blockDim.x) {
    // load data;
    int address = i;
    scalar_t x_mean = 0;
    scalar_t m_2_n = 0;
    int count = 0;
    for (int j = 0; j < world_size; j++) {
      welford_merge_element(count, x_mean, m_2_n, numel[j], mean[address], var_biased[address]*numel[j]);
      address += feature_size;
    }
    out_mean[i] = x_mean;
    out_var[i] = m_2_n/ (count - 1);
    inv_std[i] = scalar_t(1) / sqrt(m_2_n/count + eps);
  }
}

// elementwise BN kernel
template <
    typename scalar_t,
    typename accscalar_t,
    typename layerscalar_t,
    int PARALLEL_LOADS>
__global__ void batchnorm_forward_c_last_kernel(
      const scalar_t* __restrict__ input,
      const scalar_t* __restrict__ z,
      const accscalar_t* __restrict__ mean,
      const accscalar_t* __restrict__ inv_std,
      const layerscalar_t* __restrict__ weight,
      const layerscalar_t* __restrict__ shift,
      scalar_t* __restrict__ out,
      const int reduction_size,
      const int stride,
      const bool fuse_relu) {
  // tensor dimension (m,c)
  // loop along m dimension
  int inner_loop_stride = blockDim.y * gridDim.y;

  // offset along m dimension
  int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
  int c_offset = blockIdx.x * blockDim.x + threadIdx.x;

  auto m_c = mean[c_offset];
  auto inv_std_c = static_cast<accscalar_t>(inv_std[c_offset]);
  auto w_c = weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset]);
  auto s_c = shift == NULL ? accscalar_t(0.0) : static_cast<accscalar_t>(shift[c_offset]);

  int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
  int address_base = m_offset * stride + c_offset;
  int address_increment = inner_loop_stride * stride;

  for (int i = 0; i < loop_count; i++) {
#pragma unroll
    for (int j = 0; j < PARALLEL_LOADS; j++) {
      if (c_offset < stride && m_offset < reduction_size) {
        auto tmp = w_c * (static_cast<accscalar_t>(input[address_base]) - m_c ) * inv_std_c + s_c;
        if (z != NULL) {
          tmp += z[address_base];
        }
        out[address_base] = (fuse_relu && tmp <= accscalar_t(0.0) ? scalar_t(0.0) : static_cast<scalar_t>(tmp));
      }
      m_offset += inner_loop_stride;
      address_base += address_increment;
    }
  }
}

// elementwise BN kernel
template <
    typename scalar_t,
    typename accscalar_t,
    typename layerscalar_t,
    int PARALLEL_LOADS>
__global__ void relu_backward_c_last_kernel(
      const scalar_t* __restrict__ grad_output,
      const scalar_t* __restrict__ input,
      const scalar_t* __restrict__ z,
      const accscalar_t* __restrict__ mean,
      const accscalar_t* __restrict__ inv_std,
      const layerscalar_t* __restrict__ weight,
      const layerscalar_t* __restrict__ shift,
      scalar_t* __restrict__ out,
      const int reduction_size,
      const int stride) {
  // tensor dimension (m,c)
  // loop along m dimension
  int inner_loop_stride = blockDim.y * gridDim.y;

  // offset along m dimension
  int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
  int c_offset = blockIdx.x * blockDim.x + threadIdx.x;

  auto m_c = mean[c_offset];
  auto inv_std_c = static_cast<accscalar_t>(inv_std[c_offset]);
  auto w_c = weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset]);
  auto s_c = shift == NULL ? accscalar_t(0.0) : static_cast<accscalar_t>(shift[c_offset]);

  int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
  int address_base = m_offset * stride + c_offset;
  int address_increment = inner_loop_stride * stride;

  for (int i = 0; i < loop_count; i++) {
#pragma unroll
    for (int j = 0; j < PARALLEL_LOADS; j++) {
      if (c_offset < stride && m_offset < reduction_size) {
        auto tmp = w_c * (static_cast<accscalar_t>(input[address_base]) - m_c ) * inv_std_c + s_c;
        if (z != NULL) {
          tmp += z[address_base];
        }
        out[address_base] = (tmp <= accscalar_t(0.0) ? scalar_t(0.0) : grad_output[address_base]);
      }
      m_offset += inner_loop_stride;
      address_base += address_increment;
    }
  }
}

// batchnorm backward kernel for c last tensor
template
   <typename scalar_t,
    typename accscalar_t,
    typename layerscalar_t,
    int PARALLEL_LOADS>
__global__ void reduce_bn_c_last_kernel(
      const scalar_t* __restrict__ input,
      const scalar_t* __restrict__ grad_output,
      const accscalar_t* __restrict__ mean,
      const accscalar_t* __restrict__ inv_std,
      accscalar_t* __restrict__ sum_dy_o,
      accscalar_t* __restrict__ sum_dy_xmu_o,
      layerscalar_t* __restrict__ grad_weight,
      layerscalar_t* __restrict__ grad_bias,
      volatile accscalar_t* staging_data,
      int* semaphores,
      const int reduction_size,
      const int stride) {

  // hide latency with concurrency
  accscalar_t sum_dy[PARALLEL_LOADS];
  accscalar_t sum_dy_xmu[PARALLEL_LOADS];

#pragma unroll
  for (int i = 0; i < PARALLEL_LOADS; i++) {
    sum_dy[i] = accscalar_t(0);
    sum_dy_xmu[i] = accscalar_t(0);
  }
  // tensor dimension (m,c)

  // loop along m dimension
  int inner_loop_stride = blockDim.y * gridDim.y;

  // offset along m dimension
  int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
  int c_offset = blockIdx.x * blockDim.x + threadIdx.x;

  int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
  int address_base = m_offset * stride + c_offset;
  int address_increment = inner_loop_stride * stride;

  auto r_mean = mean[c_offset];
  auto factor = inv_std[c_offset];

  for (int i = 0; i < loop_count; i++) {
    accscalar_t x_input[PARALLEL_LOADS];
    accscalar_t x_grad_output[PARALLEL_LOADS];

    // load multiple data in
#pragma unroll
    for (int j = 0; j < PARALLEL_LOADS; j++) {
      if (c_offset < stride && m_offset < reduction_size) {
        x_input[j] = input[address_base];
        x_grad_output[j] = grad_output[address_base];
      } else {
        x_input[j] = accscalar_t(0);
        x_grad_output[j] = accscalar_t(0);
      }
      m_offset += inner_loop_stride;
      address_base += address_increment;
    }

    // calculate sum_dy / sum_dy_xmu
#pragma unroll
    for (int j = 0; j < PARALLEL_LOADS; j++) {
      sum_dy[j] += x_grad_output[j];
      sum_dy_xmu[j] += x_grad_output[j] * (x_input[j] - r_mean);
    }
  }

  // thread reduction to accumulate sum_dy / sum_dy_xmu between PARALLEL_LOADS
#pragma unroll
  for (int j = 1; j < PARALLEL_LOADS; j++) {
    sum_dy[0] += sum_dy[j];
    sum_dy_xmu[0] += sum_dy_xmu[j];
  }

  // release array of registers
  auto sum_dy_th = sum_dy[0];
  auto sum_dy_xmu_th = sum_dy_xmu[0];

  // block-wise reduction with shared memory (since reduction cannot be done within a warp)
  static __shared__ accscalar_t shmem_sum_dy[MAX_BLOCK_SIZE];
  static __shared__ accscalar_t shmem_sum_dy_xmu[MAX_BLOCK_SIZE];

  merge_block_vertical(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu);

  // grid reduction if needed (coop launch used at the first place)
  if (gridDim.y > 1) {
    volatile accscalar_t* staging_sum_dy = staging_data;
    volatile accscalar_t* staging_sum_dy_xmu = &staging_data[stride*gridDim.y];

    address_base = c_offset + blockIdx.y * stride;
    // write data to staging_data;
    if (threadIdx.y == 0 && c_offset < stride) {
      staging_sum_dy[address_base] = sum_dy_th;
      staging_sum_dy_xmu[address_base] = sum_dy_xmu_th;
    }

    __threadfence();
    __syncthreads(); // ensuring writes to staging_ is visible to all blocks

    __shared__ bool is_last_block_done;
    // mark block done
    if (threadIdx.x == 0 && threadIdx.y == 0) {
      int old = atomicAdd(&semaphores[blockIdx.x], 1);
      is_last_block_done = (old == (gridDim.y-1));
    }

    __syncthreads();

    // check that all data is now available in global memory
    if (is_last_block_done) {
      sum_dy_th = accscalar_t(0.0);
      sum_dy_xmu_th = accscalar_t(0.0);

      for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) {
        address_base = c_offset + y * stride;
        sum_dy_th += (c_offset < stride ? staging_sum_dy[address_base] : accscalar_t(0.0));
        sum_dy_xmu_th += (c_offset < stride ? staging_sum_dy_xmu[address_base] : accscalar_t(0.0));
      }

      merge_block_vertical(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu);
      if (threadIdx.y == 0 && c_offset < stride) {
        if (grad_bias != NULL) {
          grad_bias[c_offset] = static_cast<layerscalar_t>(sum_dy_th);
        }
        if (grad_weight != NULL) {
          grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor);
        }
        //mean_dy[c_offset] = sum_dy_th / reduction_size;
        //mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
        sum_dy_o[c_offset] = sum_dy_th;
        sum_dy_xmu_o[c_offset] = sum_dy_xmu_th;
      }
    }
  } else {
    if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) {
      if (grad_bias != NULL) {
        grad_bias[c_offset] = static_cast<layerscalar_t>(sum_dy_th);
      }
      if (grad_weight != NULL) {
        grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor);
      }
      //mean_dy[c_offset] = sum_dy_th / reduction_size;
      //mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
      sum_dy_o[c_offset] = sum_dy_th;
      sum_dy_xmu_o[c_offset] = sum_dy_xmu_th;
    }
  }
}

// elementwise BN kernel
template <
    typename scalar_t,
    typename accscalar_t,
    typename layerscalar_t,
    int PARALLEL_LOADS>
__global__ void batchnorm_backward_c_last_kernel(
      const scalar_t* __restrict__ grad_output,
      const scalar_t* __restrict__ input,
      const accscalar_t* __restrict__ mean,
      const accscalar_t* __restrict__ inv_std,
      const layerscalar_t* __restrict__ weight,
      const accscalar_t* __restrict__ sum_dy,
      const accscalar_t* __restrict__ sum_dy_xmu,
      const int* __restrict__ numel,
      scalar_t* __restrict__ grad_input,
      const int64_t world_size,
      const int reduction_size,
      const int stride) {
  int64_t div = 0;
  for (int i = 0; i < world_size; i++) {
    div += numel[i];
  }
  // tensor dimension (m,c)
  // loop along m dimension
  int inner_loop_stride = blockDim.y * gridDim.y;

  // offset along m dimension
  int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
  int c_offset = blockIdx.x * blockDim.x + threadIdx.x;

  auto m_c = mean[c_offset];
  auto m_dy_c = sum_dy[c_offset] / div;
  auto factor_1_c = inv_std[c_offset];
  auto factor_2_c = (weight == NULL? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset])) * factor_1_c;
  factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[c_offset] / div;

  int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
  int address_base = m_offset * stride + c_offset;
  int address_increment = inner_loop_stride * stride;

  for (int i = 0; i < loop_count; i++) {
#pragma unroll
    for (int j = 0; j < PARALLEL_LOADS; j++) {
      if (c_offset < stride && m_offset < reduction_size) {
        grad_input[address_base] = static_cast<scalar_t>(
            (static_cast<accscalar_t>(grad_output[address_base]) - m_dy_c -
            (static_cast<accscalar_t>(input[address_base]) - m_c) * factor_1_c)
            * factor_2_c);
      }
      m_offset += inner_loop_stride;
      address_base += address_increment;
    }
  }
}

std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) {
  const auto batch_size = input.size(0);
  const auto feature_size = input.size(1);

  auto space_size = get_tensor_spatial_size(input);
  auto scalar_type = promote_scalartype(input);

  at::Tensor out_var_biased = at::empty({feature_size}, input.options().dtype(scalar_type));
  at::Tensor out_mean = at::empty({feature_size}, input.options().dtype(scalar_type));

  int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE / 32));
  int block_x = max(1, min(MAX_BLOCK_SIZE / block_y, h_last_pow2(space_size)));
  const dim3 block(block_x, block_y);
  const dim3 grid(feature_size);

  auto stream = at::cuda::getCurrentCUDAStream();

  {
    using namespace at;
    DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "welford_mean_var_kernel",
      using accscalar_t = at::acc_type<scalar_t_0, true>;
      welford_kernel<scalar_t_0, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
          input.DATA_PTR<scalar_t_0>(),
          out_mean.DATA_PTR<accscalar_t>(),
          out_var_biased.DATA_PTR<accscalar_t>(),
          batch_size,
          feature_size,
          space_size);
    );
  }

  return {out_mean, out_var_biased};
}

at::Tensor batchnorm_forward_CUDA(
    const at::Tensor input,
    const at::Tensor mean,
    const at::Tensor inv_std,
    const at::optional<at::Tensor> weight,
    const at::optional<at::Tensor> shift) {
  const auto batch_size = input.size(0);
  const auto feature_size = input.size(1);
  at::Tensor out = at::empty_like(input);

  auto space_size = get_tensor_spatial_size(input);

  int block_x = max(32, min(MAX_BLOCK_SIZE, h_last_pow2(space_size)/4));
  int block_y = max(1, min(MAX_BLOCK_SIZE/block_x, h_last_pow2(batch_size)/4));
  const dim3 block(block_x, block_y);
  int grid_z = max(1, min(65535, h_last_pow2(space_size)/4/block_x));
  int batch_group_size = max(1, min(65535, h_last_pow2(batch_size)/block_y));
  const dim3 grid(feature_size, batch_group_size, grid_z);
  auto stream = at::cuda::getCurrentCUDAStream();

  if (input.scalar_type() == at::ScalarType::Half
      && weight.has_value() &&
      weight.value().scalar_type() == at::ScalarType::Float) {
    using namespace at;
    DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
      using accscalar_t = at::acc_type<scalar_t_0, true>;
      batchnorm_forward_kernel<scalar_t_0, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
          input.DATA_PTR<scalar_t_0>(),
          mean.DATA_PTR<accscalar_t>(),
          inv_std.DATA_PTR<accscalar_t>(),
          weight.has_value() ? weight.value().DATA_PTR<accscalar_t>() : NULL,
          shift.has_value() ? shift.value().DATA_PTR<accscalar_t>() : NULL,
          out.DATA_PTR<scalar_t_0>(),
          space_size,
          batch_size);
    );
  } else {
    if (weight.has_value()) {
      TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),
          "input.scalar_type() is not supported with weight.scalar_type()");
    }
    using namespace at;
    DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
      using accscalar_t = at::acc_type<scalar_t_0, true>;
      batchnorm_forward_kernel<scalar_t_0, accscalar_t, scalar_t_0><<<grid, block, 0, stream>>>(
          input.DATA_PTR<scalar_t_0>(),
          mean.DATA_PTR<accscalar_t>(),
          inv_std.DATA_PTR<accscalar_t>(),
          weight.has_value() ? weight.value().DATA_PTR<scalar_t_0>() : NULL,
          shift.has_value() ? shift.value().DATA_PTR<scalar_t_0>() : NULL,
          out.DATA_PTR<scalar_t_0>(),
          space_size,
          batch_size);
    );
  }
  return out;
}

std::vector<at::Tensor> reduce_bn_CUDA(
    const at::Tensor grad_output,
    const at::Tensor input,
    const at::Tensor mean,
    const at::Tensor inv_std,
    const at::optional<at::Tensor> weight)
{
  const auto batch_size = input.size(0);
  const auto feature_size = input.size(1);

  auto scalar_type = promote_scalartype(input);

  at::Tensor sum_dy = at::empty({feature_size}, mean.options());
  at::Tensor sum_dy_xmu = at::empty({feature_size}, mean.options());

  at::Tensor grad_weight;
  at::Tensor grad_bias;
  if (weight.has_value()) {
    grad_weight = at::empty({feature_size}, weight.value().options());
    grad_bias = at::empty({feature_size}, weight.value().options());
  } else {
    grad_weight = at::empty({0}, mean.options());
    grad_bias = at::empty({0}, mean.options());
  }

  auto space_size = get_tensor_spatial_size(input);

  int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE/ 32));
  int block_x = max(1, min(MAX_BLOCK_SIZE/ block_y, h_last_pow2(space_size)));
  const dim3 block(block_x, block_y);
  const dim3 grid(feature_size);
  auto stream = at::cuda::getCurrentCUDAStream();

  if (input.scalar_type() == at::ScalarType::Half
      && weight.has_value() &&
      weight.value().scalar_type() == at::ScalarType::Float) {
    using namespace at;
    DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward_reduce",
      using accscalar_t = at::acc_type<scalar_t_0, true>;
      reduce_bn_kernel<scalar_t_0, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
          input.DATA_PTR<scalar_t_0>(),
          grad_output.DATA_PTR<scalar_t_0>(),
          mean.DATA_PTR<accscalar_t>(),
          inv_std.DATA_PTR<accscalar_t>(),
          sum_dy.DATA_PTR<accscalar_t>(),
          sum_dy_xmu.DATA_PTR<accscalar_t>(),
          weight.has_value() ? grad_weight.DATA_PTR<accscalar_t>() : NULL,
          weight.has_value() ? grad_bias.DATA_PTR<accscalar_t>() : NULL,
          batch_size,
          feature_size,
          space_size);
    );
  } else {
    if (weight.has_value()) {
        TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),
            "input.scalar_type() is not supported with weight.scalar_type()");
    }
    using namespace at;
    DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward_reduce",
      using accscalar_t = at::acc_type<scalar_t_0, true>;
      reduce_bn_kernel<scalar_t_0, accscalar_t, scalar_t_0><<<grid, block, 0, stream>>>(
          input.DATA_PTR<scalar_t_0>(),
          grad_output.DATA_PTR<scalar_t_0>(),
          mean.DATA_PTR<accscalar_t>(),
          inv_std.DATA_PTR<accscalar_t>(),
          sum_dy.DATA_PTR<accscalar_t>(),
          sum_dy_xmu.DATA_PTR<accscalar_t>(),
          weight.has_value() ? grad_weight.DATA_PTR<scalar_t_0>() : NULL,
          weight.has_value() ? grad_bias.DATA_PTR<scalar_t_0>() : NULL,
          batch_size,
          feature_size,
          space_size);
    );
  }

  return {sum_dy, sum_dy_xmu, grad_weight, grad_bias};
}

at::Tensor batchnorm_backward_CUDA(
    const at::Tensor grad_output,
    const at::Tensor input,
    const at::Tensor mean,
    const at::Tensor inv_std,
    const at::optional<at::Tensor> weight,
    const at::Tensor sum_dy,
    const at::Tensor sum_dy_xmu,
    const at::Tensor count) {
  const auto batch_size = input.size(0);
  const auto feature_size = input.size(1);

  at::Tensor grad_input = at::empty_like(input);

  auto space_size = get_tensor_spatial_size(input);

  int block_x = max(32, min(MAX_BLOCK_SIZE, h_last_pow2(space_size)/4));
  int block_y = max(1, min(MAX_BLOCK_SIZE/block_x, h_last_pow2(batch_size)/4));
  const dim3 block(block_x, block_y);
  int grid_z = max(1, min(65535, h_last_pow2(space_size)/4/block_x));
  int batch_group_size = max(1, min(65535, h_last_pow2(batch_size)/block_y));
  const dim3 grid(feature_size, batch_group_size, grid_z);

  auto stream = at::cuda::getCurrentCUDAStream();

  if (input.scalar_type() == at::ScalarType::Half
      && weight.has_value() &&
      weight.value().scalar_type() == at::ScalarType::Float) {
    using namespace at;
    DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward",
      using accscalar_t = at::acc_type<scalar_t_0, true>;
      batchnorm_backward_kernel<scalar_t_0, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
          grad_output.DATA_PTR<scalar_t_0>(),
          input.DATA_PTR<scalar_t_0>(),
          mean.DATA_PTR<accscalar_t>(),
          inv_std.DATA_PTR<accscalar_t>(),
          weight.has_value() ? weight.value().DATA_PTR<accscalar_t>() : NULL,
          sum_dy.DATA_PTR<accscalar_t>(),
          sum_dy_xmu.DATA_PTR<accscalar_t>(),
          count.DATA_PTR<int>(),
          grad_input.DATA_PTR<scalar_t_0>(),
          count.numel(),
          space_size,
          batch_size);
    );
  } else {
    if (weight.has_value()) {
      TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),
          "input.scalar_type() is not supported with weight.scalar_type()");
    }
    using namespace at;
    DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward",
      using accscalar_t = at::acc_type<scalar_t_0, true>;
      batchnorm_backward_kernel<scalar_t_0, accscalar_t, scalar_t_0><<<grid, block, 0, stream>>>(
          grad_output.DATA_PTR<scalar_t_0>(),
          input.DATA_PTR<scalar_t_0>(),
          mean.DATA_PTR<accscalar_t>(),
          inv_std.DATA_PTR<accscalar_t>(),
          weight.has_value() ? weight.value().DATA_PTR<scalar_t_0>() : NULL,
          sum_dy.DATA_PTR<accscalar_t>(),
          sum_dy_xmu.DATA_PTR<accscalar_t>(),
          count.DATA_PTR<int>(),
          grad_input.DATA_PTR<scalar_t_0>(),
          count.numel(),
          space_size,
          batch_size);
    );
  }

  return grad_input;
}

std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_nodes,
                                              const at::Tensor var_biased,
                                              const at::Tensor numel,
                                              const float eps) {
  const auto world_size = mean_feature_nodes.size(0);
  const auto feature_size = mean_feature_nodes.size(1);

  at::Tensor out_var = at::empty({feature_size}, var_biased.options());
  at::Tensor inv_std = at::empty_like(out_var);
  at::Tensor out_mean = at::empty_like(out_var);

  at::Tensor mean_feature_nodes_ = mean_feature_nodes.contiguous();
  at::Tensor var_biased_ = var_biased.contiguous();
  at::Tensor numel_ = numel.contiguous();

  // TODO(jie): tile this for memory coalescing!
  const int block = std::min(h_last_pow2(feature_size), MAX_BLOCK_SIZE);
  const int grid = std::max<int>(1, feature_size / block);

  auto stream = at::cuda::getCurrentCUDAStream();

  {
    using namespace at;
    DISPATCH_FLOAT_AND_HALF(mean_feature_nodes.scalar_type(), 0, "welford_parallel_kernel",
      welford_kernel_parallel<scalar_t_0><<<grid, block, 0, stream>>>(
          mean_feature_nodes_.DATA_PTR<scalar_t_0>(),
          var_biased_.DATA_PTR<scalar_t_0>(),
          numel_.DATA_PTR<int>(),
          out_mean.DATA_PTR<scalar_t_0>(),
          out_var.DATA_PTR<scalar_t_0>(),
          inv_std.DATA_PTR<scalar_t_0>(),
          world_size,
          feature_size,
          eps);
    );
  }

  return {out_mean, out_var, inv_std};
}

std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input) {
  const auto stride = input.size(input.ndimension()-1);
  const auto reduction_size = input.numel() / stride;

  auto scalar_type = promote_scalartype(input);
  auto option = input.options().dtype(scalar_type);

  at::Tensor out_var_biased = at::empty({stride}, option);
  at::Tensor out_mean = at::empty({stride}, option);

  dim3 block;
  dim3 grid;
  flexible_launch_configs(reduction_size, stride, block, grid, true);

  at::Tensor staging_data;
  at::Tensor semaphores;
  if (grid.y > 1) {
    staging_data = at::empty({4*stride*grid.y}, option);
    semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt));
  }

  auto stream = at::cuda::getCurrentCUDAStream();

  {
    using namespace at;
    DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "welford_mean_var_c_last",
      using accscalar_t = at::acc_type<scalar_t_0, true>;
      accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.DATA_PTR<accscalar_t>() : nullptr;
      int* semaphores_ptr = grid.y > 1 ? semaphores.DATA_PTR<int>() : nullptr;
      welford_kernel_c_last<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
          <<<grid, block, 0, stream>>>(
          input.DATA_PTR<scalar_t_0>(),
          out_mean.DATA_PTR<accscalar_t>(),
          out_var_biased.DATA_PTR<accscalar_t>(),
          staging_data_ptr,
          semaphores_ptr,
          reduction_size,
          stride);
    );
  }

  return {out_mean, out_var_biased};
}

at::Tensor batchnorm_forward_c_last_CUDA(
    const at::Tensor input,
    const at::optional<at::Tensor> z,
    const at::Tensor mean,
    const at::Tensor inv_std,
    const at::optional<at::Tensor> weight,
    const at::optional<at::Tensor> shift,
    const bool fuse_relu) {
  const auto stride = input.size(input.ndimension()-1);
  const auto reduction_size = input.numel() / stride;

  at::Tensor out = at::empty_like(input);

  dim3 block;
  dim3 grid;
  flexible_launch_configs(reduction_size, stride, block, grid);

  auto stream = at::cuda::getCurrentCUDAStream();

  if (input.scalar_type() == at::ScalarType::Half
      && weight.has_value() && weight.value().scalar_type() == at::ScalarType::Float) {
    using namespace at;
    DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
      using accscalar_t = at::acc_type<scalar_t_0, true>;
      batchnorm_forward_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
          <<<grid, block, 0, stream>>>(
          input.DATA_PTR<scalar_t_0>(),
          z.has_value() ? z.value().DATA_PTR<scalar_t_0>() : NULL,
          mean.DATA_PTR<accscalar_t>(),
          inv_std.DATA_PTR<accscalar_t>(),
          weight.has_value() ? weight.value().DATA_PTR<accscalar_t>() : NULL,
          shift.has_value() ? shift.value().DATA_PTR<accscalar_t>(): NULL,
          out.DATA_PTR<scalar_t_0>(),
          reduction_size,
          stride,
          fuse_relu);
    );
  } else {
    if (weight.has_value()) {
      TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),
          "input.scalar_type() is not supported with weight.scalar_type()");
    }
    using namespace at;
    DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
      using accscalar_t = at::acc_type<scalar_t_0, true>;
      batchnorm_forward_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER>
          <<<grid, block, 0, stream>>>(
          input.DATA_PTR<scalar_t_0>(),
          z.has_value() ? z.value().DATA_PTR<scalar_t_0>() : NULL,
          mean.DATA_PTR<accscalar_t>(),
          inv_std.DATA_PTR<accscalar_t>(),
          weight.has_value() ? weight.value().DATA_PTR<scalar_t_0>() : NULL,
          shift.has_value() ? shift.value().DATA_PTR<scalar_t_0>(): NULL,
          out.DATA_PTR<scalar_t_0>(),
          reduction_size,
          stride,
          fuse_relu);
    );
  }
  return out;
}

std::vector<at::Tensor> reduce_bn_c_last_CUDA(
    const at::Tensor grad_output,
    const at::Tensor input,
    const at::Tensor mean,
    const at::Tensor inv_std,
    const at::optional<at::Tensor> weight) {
  const auto stride = input.size(input.ndimension()-1);
  const auto reduction_size = input.numel() / stride;

  at::Tensor sumn_dy = at::empty({stride}, mean.options());
  at::Tensor sum_dy_xmu = at::empty({stride}, mean.options());

  at::Tensor grad_weight;
  at::Tensor grad_bias;
  if (weight.has_value()) {
    grad_weight = at::empty({stride}, weight.value().options());
    grad_bias = at::empty({stride}, weight.value().options());
  } else {
    // because I cannot return an uninitialized at::Tensor
    grad_weight = at::empty({0}, mean.options());
    grad_bias = at::empty({0}, mean.options());
  }

  dim3 block;
  dim3 grid;
  flexible_launch_configs(reduction_size, stride, block, grid, true);

  at::Tensor staging_data;
  at::Tensor semaphores;
  if (grid.y > 1) {
    staging_data = at::empty({2*stride*grid.y}, mean.options());
    semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt));
  }
  auto stream = at::cuda::getCurrentCUDAStream();

  if (input.scalar_type() == at::ScalarType::Half
      && weight.has_value()
      && weight.value().scalar_type() == at::ScalarType::Float) {
    using namespace at;
    DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward_reduce",
      using accscalar_t = at::acc_type<scalar_t_0, true>;
      accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.DATA_PTR<accscalar_t>() : nullptr;
      int* semaphores_ptr = grid.y > 1 ? semaphores.DATA_PTR<int>() : nullptr;
      reduce_bn_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
          <<<grid, block, 0, stream>>>(
          input.DATA_PTR<scalar_t_0>(),
          grad_output.DATA_PTR<scalar_t_0>(),
          mean.DATA_PTR<accscalar_t>(),
          inv_std.DATA_PTR<accscalar_t>(),
          sumn_dy.DATA_PTR<accscalar_t>(),
          sum_dy_xmu.DATA_PTR<accscalar_t>(),
          weight.has_value() ? grad_weight.DATA_PTR<accscalar_t>() : NULL,
          weight.has_value() ?grad_bias.DATA_PTR<accscalar_t>() : NULL,
          staging_data_ptr,
          semaphores_ptr,
          reduction_size,
          stride);
    );
  } else {
    if (weight.has_value()) {
      TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),
          "input.scalar_type() is not supported with weight.scalar_type()");
    }
    using namespace at;
    DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward_reduce",
      using accscalar_t = at::acc_type<scalar_t_0, true>;
      accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.DATA_PTR<accscalar_t>() : nullptr;
      int* semaphores_ptr = grid.y > 1 ? semaphores.DATA_PTR<int>() : nullptr;
      reduce_bn_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER>
          <<<grid, block, 0, stream>>>(
          input.DATA_PTR<scalar_t_0>(),
          grad_output.DATA_PTR<scalar_t_0>(),
          mean.DATA_PTR<accscalar_t>(),
          inv_std.DATA_PTR<accscalar_t>(),
          sumn_dy.DATA_PTR<accscalar_t>(),
          sum_dy_xmu.DATA_PTR<accscalar_t>(),
          weight.has_value() ? grad_weight.DATA_PTR<scalar_t_0>() : NULL,
          weight.has_value() ?grad_bias.DATA_PTR<scalar_t_0>() : NULL,
          staging_data_ptr,
          semaphores_ptr,
          reduction_size,
          stride);
    );
  }

  return {sumn_dy, sum_dy_xmu, grad_weight, grad_bias};
}

at::Tensor batchnorm_backward_c_last_CUDA(
    const at::Tensor grad_output,
    const at::Tensor input,
    const at::Tensor mean,
    const at::Tensor inv_std,
    const at::optional<at::Tensor> weight,
    const at::Tensor sum_dy,
    const at::Tensor sum_dy_xmu,
    const at::Tensor count) {
  const auto stride = input.size(input.ndimension()-1);
  const auto reduction_size = input.numel() / stride;

  at::Tensor grad_input = at::empty_like(input);

  dim3 block;
  dim3 grid;
  flexible_launch_configs(reduction_size, stride, block, grid);

  auto stream = at::cuda::getCurrentCUDAStream();

  if (input.scalar_type() == at::ScalarType::Half
      && weight.has_value() && weight.value().scalar_type() == at::ScalarType::Float) {
    using namespace at;
    DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
      using accscalar_t = at::acc_type<scalar_t_0, true>;
      batchnorm_backward_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
          <<<grid, block, 0, stream>>>(
          grad_output.DATA_PTR<scalar_t_0>(),
          input.DATA_PTR<scalar_t_0>(),
          mean.DATA_PTR<accscalar_t>(),
          inv_std.DATA_PTR<accscalar_t>(),
          weight.has_value() ? weight.value().DATA_PTR<accscalar_t>() : NULL,
          sum_dy.DATA_PTR<accscalar_t>(),
          sum_dy_xmu.DATA_PTR<accscalar_t>(),
          count.DATA_PTR<int>(),
          grad_input.DATA_PTR<scalar_t_0>(),
          count.numel(),
          reduction_size,
          stride);
    );
  } else {
    if (weight.has_value()) {
      TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),
          "input.scalar_type() is not supported with weight.scalar_type()");
    }
    using namespace at;
    DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
      using accscalar_t = at::acc_type<scalar_t_0, true>;
      batchnorm_backward_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER>
          <<<grid, block, 0, stream>>>(
          grad_output.DATA_PTR<scalar_t_0>(),
          input.DATA_PTR<scalar_t_0>(),
          mean.DATA_PTR<accscalar_t>(),
          inv_std.DATA_PTR<accscalar_t>(),
          weight.has_value() ? weight.value().DATA_PTR<scalar_t_0>() : NULL,
          sum_dy.DATA_PTR<accscalar_t>(),
          sum_dy_xmu.DATA_PTR<accscalar_t>(),
          count.DATA_PTR<int>(),
          grad_input.DATA_PTR<scalar_t_0>(),
          count.numel(),
          reduction_size,
          stride);
    );
  }
 
  return grad_input;
}

at::Tensor relu_backward_c_last_CUDA(
    const at::Tensor grad_output,
    const at::Tensor input,
    const at::optional<at::Tensor> z,
    const at::Tensor mean,
    const at::Tensor inv_std,
    const at::optional<at::Tensor> weight,
    const at::optional<at::Tensor> shift) {

  const auto stride = input.size(input.ndimension()-1);
  const auto reduction_size = input.numel() / stride;

  at::Tensor out = at::empty_like(input);

  dim3 block;
  dim3 grid;
  flexible_launch_configs(reduction_size, stride, block, grid);

  auto stream = at::cuda::getCurrentCUDAStream();

  if (input.scalar_type() == at::ScalarType::Half
      && weight.has_value() && weight.value().scalar_type() == at::ScalarType::Float) {
    using namespace at;
    DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
      using accscalar_t = at::acc_type<scalar_t_0, true>;
      relu_backward_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
          <<<grid, block, 0, stream>>>(
          grad_output.DATA_PTR<scalar_t_0>(),
          input.DATA_PTR<scalar_t_0>(),
          z.has_value() ? z.value().DATA_PTR<scalar_t_0>() : NULL,
          mean.DATA_PTR<accscalar_t>(),
          inv_std.DATA_PTR<accscalar_t>(),
          weight.has_value() ? weight.value().DATA_PTR<accscalar_t>() : NULL,
          shift.has_value() ? shift.value().DATA_PTR<accscalar_t>(): NULL,
          out.DATA_PTR<scalar_t_0>(),
          reduction_size,
          stride);
    );
  } else {
    if (weight.has_value()) {
      TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),
          "input.scalar_type() is not supported with weight.scalar_type()");
    }
    using namespace at;
    DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
      using accscalar_t = at::acc_type<scalar_t_0, true>;
      relu_backward_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER>
          <<<grid, block, 0, stream>>>(
          grad_output.DATA_PTR<scalar_t_0>(),
          input.DATA_PTR<scalar_t_0>(),
          z.has_value() ? z.value().DATA_PTR<scalar_t_0>() : NULL,
          mean.DATA_PTR<accscalar_t>(),
          inv_std.DATA_PTR<accscalar_t>(),
          weight.has_value() ? weight.value().DATA_PTR<scalar_t_0>() : NULL,
          shift.has_value() ? shift.value().DATA_PTR<scalar_t_0>(): NULL,
          out.DATA_PTR<scalar_t_0>(),
          reduction_size,
          stride);
    );
  }
  return out;
}
