//
// Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
// Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
// Apoorv Vyas <avyas@idiap.ch>
//

//
// For modifications made inside namespace nvidia (authored by jdemouth):
//
// Copyright (c) 2021 NVIDIA CORPORATION. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
//

#include <torch/extension.h>
#include <assert.h>
#include <stdio.h>

#define ENABLE_NVIDIA_OPTIMIZATIONS

#ifdef ENABLE_NVIDIA_OPTIMIZATIONS
namespace nvidia {

////////////////////////////////////////////////////////////////////////////////////////////////////

constexpr int THREADS_PER_WARP = 32;

////////////////////////////////////////////////////////////////////////////////////////////////////

constexpr int LOW_OCCUPANCY_THRESHOLD = 40; // TODO: Make it HW specific (like 1/2 SMs).

////////////////////////////////////////////////////////////////////////////////////////////////////

static inline __device__ __host__ int div_up(int m, int n) {
  return (m + n-1) / n;
}

////////////////////////////////////////////////////////////////////////////////////////////////////

static inline __device__ __host__ int round_up(int m, int n) {
  return div_up(m, n) * n;
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template< typename T >
struct Lmha_params {

  // The output buffer. Dimensions [B, H, L, M].
  T *out;

  // The input Qs. Dimensions [B, H, L, E].
  const T *q;
  // The input Ks. Dimensions [B, H, L, E].
  const T *k;
  // The input Vs. Dimensions [B, H, L, M].
  const T *v;

  // The different dimensions.
  int B, L, H, E, M;

  // The strides for the different tensors.
  int q_stride_B, q_stride_H, q_stride_L;
  int k_stride_B, k_stride_H, k_stride_L;
  int v_stride_B, v_stride_H, v_stride_L;
  int o_stride_B, o_stride_H, o_stride_L;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template< int E, bool GO_BACKWARD, int WARPS, int COLS_PER_THREAD = 4 >
__global__ __launch_bounds__(WARPS * THREADS_PER_WARP)
void lmha_low_occupancy_kernel(Lmha_params<float> params) {

  // The number of threads per block.
  constexpr int THREADS_PER_BLOCK = WARPS * THREADS_PER_WARP;
  // The number of rows per thread.
  constexpr int ROWS_PER_THREAD = E / THREADS_PER_WARP;
  // The number of steps per iteration.
  constexpr int COLS_PER_ITER = WARPS * COLS_PER_THREAD;

  // Make sure E is a multiple of the warp size.
  static_assert(E % THREADS_PER_WARP == 0, "");

  // Shared memory to store V/O.
  __shared__ float smem_v[COLS_PER_ITER], smem_o[COLS_PER_ITER];
  // Shared memory buffer to performance the reductions.
  __shared__ float smem_reds[E * WARPS];

  // The sequence processed by that block.
  const int bi = blockIdx.z;
  // The head processed by that block.
  const int hi = blockIdx.y;
  // The hidden cell in the V/output buffers.
  const int vi = blockIdx.x;

  // The linear index of the thread.
  const int tidx = threadIdx.x;

  // Decompose the block in warp/lane.
  const int warp = tidx / THREADS_PER_WARP;
  const int lane = tidx % THREADS_PER_WARP;

  // The base offset loaded by the thread in Q and K.
  int offset_q = bi*params.q_stride_B + hi*params.q_stride_H + lane;
  int offset_k = bi*params.k_stride_B + hi*params.k_stride_H + lane;

  // If we walk backward, account for the extra offset.
  if( GO_BACKWARD ) {
    offset_q += (params.L-1)*params.q_stride_L;
    offset_k += (params.L-1)*params.k_stride_L;
  }

  // Position the warp at the beginning of the proper timestep.
  if( GO_BACKWARD ) {
    offset_q -= warp*COLS_PER_THREAD*params.q_stride_L;
    offset_k -= warp*COLS_PER_THREAD*params.k_stride_L;
  } else {
    offset_q += warp*COLS_PER_THREAD*params.q_stride_L;
    offset_k += warp*COLS_PER_THREAD*params.k_stride_L;
  }

  // Determine the base pointers for Q and K.
  const float *ptr_q = &params.q[offset_q];
  const float *ptr_k = &params.k[offset_k];

  // Is a given row valid?
  int valid_qk[ROWS_PER_THREAD];
  #pragma unroll
  for( int ii = 0; ii < ROWS_PER_THREAD; ++ii ) {
    valid_qk[ii] = lane + ii*THREADS_PER_WARP < params.E;
  }

  // The offset to the position loaded by the thread in V.
  int offset_v = bi*params.v_stride_B + hi*params.v_stride_H + vi;
  int offset_o = bi*params.o_stride_B + hi*params.o_stride_H + vi;

  // If we walk backward, account for the extra offset.
  if( GO_BACKWARD ) {
    offset_v += (params.L-1)*params.v_stride_L;
    offset_o += (params.L-1)*params.o_stride_L;
  }

  // We load/store a strided matrix of COLS_PER_ITER x OUTPUTS_PER_BLOCK.
  if( GO_BACKWARD ) {
    offset_v -= tidx*params.v_stride_L;
    offset_o -= tidx*params.o_stride_L;
  } else {
    offset_v += tidx*params.v_stride_L;
    offset_o += tidx*params.o_stride_L;
  }

  // Determine the base pointer for V.
  const float *ptr_v = &params.v[offset_v];
  // The output pointer.
  float *ptr_o = &params.out[offset_o];

  // The running KVs.
  float running_kv[ROWS_PER_THREAD];
  #pragma unroll
  for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
    running_kv[ri] = 0.f;
  }

  // Iterate over the timesteps. TODO: Use params.loop_count!!!
  for( int iter = 0; iter < params.L; iter += COLS_PER_ITER ) {

    // Each thread loads a matrix of elements.
    float q[ROWS_PER_THREAD][COLS_PER_THREAD], k[ROWS_PER_THREAD][COLS_PER_THREAD];

    // Trigger the memory loads for Q and K.
    #pragma unroll
    for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
      #pragma unroll
      for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {

        // For Q/K, each warp loads from various timesteps.
        int ti = iter + warp*COLS_PER_THREAD;
        if( GO_BACKWARD ) {
          ti = params.L - 1 - ti;
        }

        // Is it a valid access?
        int valid;
        if( GO_BACKWARD ) {
          valid = valid_qk[ri] && ti - ci >= 0;
        } else {
          valid = valid_qk[ri] && ti + ci < params.L;
        }

        // The extra offset to add.
        if( GO_BACKWARD ) {
          offset_q = ri*THREADS_PER_WARP - ci*params.q_stride_L;
          offset_k = ri*THREADS_PER_WARP - ci*params.k_stride_L;
        } else {
          offset_q = ri*THREADS_PER_WARP + ci*params.q_stride_L;
          offset_k = ri*THREADS_PER_WARP + ci*params.k_stride_L;
        }

        // Load Q/K if they are valid.
        q[ri][ci] = valid ? ptr_q[offset_q] : 0.f;
        k[ri][ci] = valid ? ptr_k[offset_k] : 0.f;
      }
    }

    // For the V tensor, we assign contiguous thread to different loads. So, ti is different.
    int ti = iter + tidx;
    if( GO_BACKWARD ) {
      ti = params.L - 1 - ti;
    }

    // Is it a valid access?
    int valid_vo = tidx < COLS_PER_ITER;
    if( GO_BACKWARD ) {
      valid_vo &= ti >= 0;
    } else {
      valid_vo &= ti < params.L;
    }

    // Trigger the loads for V.
    float ldg_v = valid_vo ? *ptr_v : 0.f;

    // Move the load pointers.
    if( GO_BACKWARD ) {
      ptr_q -= COLS_PER_ITER*params.q_stride_L;
      ptr_k -= COLS_PER_ITER*params.k_stride_L;
      ptr_v -= COLS_PER_ITER*params.v_stride_L;
    } else {
      ptr_q += COLS_PER_ITER*params.q_stride_L;
      ptr_k += COLS_PER_ITER*params.k_stride_L;
      ptr_v += COLS_PER_ITER*params.v_stride_L;
    }

    // Store to shared memory.
    if( tidx < COLS_PER_ITER ) {
      smem_v[tidx] = ldg_v;
    }

    // Make sure V is in shared memory.
    __syncthreads();

    // Read V from shared memory.
    float v[COLS_PER_THREAD];
    #pragma unroll
    for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
      v[ci] = smem_v[warp*COLS_PER_THREAD + ci];
    }

    // Each thread computes local K*V products.
    float kv[ROWS_PER_THREAD][COLS_PER_THREAD];
    #pragma unroll
    for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
      #pragma unroll
      for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
        kv[ri][ci] = 0.f;
      }
    }

    // Update the K*V^T product.
    #pragma unroll
    for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
      #pragma unroll
      for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
        kv[ri][ci] += k[ri][ci] * v[ci];
      }
    }

    // We must perform the prefix sums within the thread-block. Start with the thread.
    #pragma unroll
    for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
      #pragma unroll
      for( int ci = 1; ci < COLS_PER_THREAD; ++ci ) {
        kv[ri][ci] += kv[ri][ci-1];
      }
    }

    // Store the partial sums to shared memory. Unless we have no inter-warp reduction to perform.
    #pragma unroll
    for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
      smem_reds[warp*E + ri*THREADS_PER_WARP + lane] = kv[ri][COLS_PER_THREAD-1];
    }

    // Make sure the data is in shared memory.
    __syncthreads();

    // Each thread deals with one or more column(s) of the matrix.
    constexpr int SUMS_PER_THREAD = (E + THREADS_PER_BLOCK-1) / THREADS_PER_BLOCK;
    #pragma unroll
    for( int ii = 0, idx = tidx; ii < SUMS_PER_THREAD; ++ii, idx += THREADS_PER_BLOCK ) {
      if( idx < E ) {
        float sum = smem_reds[idx];
        #pragma unroll
        for( int jj = 1; jj < WARPS; ++jj ) {
          smem_reds[idx + jj*E] = sum += smem_reds[idx + jj*E];
        }
      }
    }

    // Make sure the reductions are stored in shared memory.
    __syncthreads();

    // Each thread updates his partial products.
    #pragma unroll
    for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
      float sum = running_kv[ri];
      if( warp > 0 ) {
        sum += smem_reds[(warp-1)*E + lane + ri*THREADS_PER_WARP];
      }
      #pragma unroll
      for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
        kv[ri][ci] += sum;
      }
    }

    // Compute the partial output values for that thread.
    float sum[COLS_PER_THREAD];
    #pragma unroll
    for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
      sum[ci] = q[0][ci] * kv[0][ci];
      #pragma unroll
      for( int ri = 1; ri < ROWS_PER_THREAD; ++ri ) {
        sum[ci] += q[ri][ci] * kv[ri][ci];
      }
    }

    // Run the parallel reductions inside the warp.
    #pragma unroll
    for( int mask = THREADS_PER_WARP / 2; mask >= 1; mask /= 2 ) {
      #pragma unroll
      for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
        sum[ci] += __shfl_xor_sync(uint32_t(-1), sum[ci], mask);
      }
    }

    // Store the final output to shared memory.
    if( lane == 0 ) {
      #pragma unroll
      for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
        smem_o[warp*COLS_PER_THREAD + ci] = sum[ci];
      }
    }

    // Make sure the data is in shared memory.
    __syncthreads();

    // Store the output.
    if( valid_vo ) {
      *ptr_o = smem_o[tidx];
    }

    // Each thread updates his running kv.
    #pragma unroll
    for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
      running_kv[ri] += smem_reds[(WARPS-1)*E + lane + ri*THREADS_PER_WARP];
    }

    // Move to next location.
    if( GO_BACKWARD ) {
      ptr_o -= COLS_PER_ITER*params.o_stride_L;
    } else {
      ptr_o += COLS_PER_ITER*params.o_stride_L;
    }
  }
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template< int E, bool GO_BACKWARD, int WARPS >
int lmha_low_occupancy_(const Lmha_params<float> &params) {

  // Make sure we are not going to launch an invalid grid.
  if( params.H > 65535 || params.B > 65535 ) {
    return 1;
  }

  // Prepare the grid and trigger the CUDA kernel.
  dim3 grid;
  grid.x = params.M;
  grid.y = params.H;
  grid.z = params.B;
  lmha_low_occupancy_kernel<E, GO_BACKWARD, WARPS><<<grid, WARPS*THREADS_PER_WARP>>>(params);
  return 0;
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template< int E, bool GO_BACKWARD >
int lmha_low_occupancy_(const Lmha_params<float> &params, int blocks) {
         if( params.M * blocks >= 8*LOW_OCCUPANCY_THRESHOLD ) {
    return lmha_low_occupancy_<E, GO_BACKWARD,  4>(params);
  } else if( params.M * blocks >= 4*LOW_OCCUPANCY_THRESHOLD ) {
    return lmha_low_occupancy_<E, GO_BACKWARD,  8>(params);
  } else {
    return lmha_low_occupancy_<E, GO_BACKWARD, 16>(params);
  }
  return 1;
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template< int E, typename Params >
static inline __device__ __host__ int smem_buffer_elts_(const Params &params) {
  int M = round_up(params.M, 4);
  return 2*E + 2*M;
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template< int E, int THREADS_PER_HEAD, bool GO_BACKWARD >
__global__
void lmha_kernel(Lmha_params<float> params) {

  // Make sure E is a multiple of 4.
  static_assert(E % 4 == 0, "");

  // The amount of shared memory per buffer (2 buffers for double-buffering).
  const int smem_buffer_elts = smem_buffer_elts_<E>(params);
  // The M dimension for shared memory.
  const int M = round_up(params.M, 4);

  // Shared memory to store Q, K and V. Size is 2*smem_buffer_elts.
  extern __shared__ float smem_[];

  // The various shared memory buffers.
  float *smem_q = &smem_[0*E];
  float *smem_k = &smem_[1*E];
  float *smem_v = &smem_[2*E];
  float *smem_o = &smem_[2*E + M];

  // The index of the shared memory buffer (for double-buffering).
  int smem_curr = 0;

  // The sequence processed by that block.
  const int bi = blockIdx.y;
  // The head processed by that block.
  const int hi = blockIdx.x;

  // The linear index of the thread.
  const int tidx = threadIdx.x;

  // The offset to the position loaded by the thread in Q.
  int offset_q = bi*params.q_stride_B + hi*params.q_stride_H + tidx;
  // The offset to the position loaded by the thread in K.
  int offset_k = bi*params.k_stride_B + hi*params.k_stride_H + tidx;

  // If we walk backward, account for the extra offset.
  if( GO_BACKWARD ) {
    offset_q += (params.L-1)*params.q_stride_L;
    offset_k += (params.L-1)*params.k_stride_L;
  }

  // Determine the base pointers for Q and K.
  const float *ptr_q = &params.q[offset_q];
  const float *ptr_k = &params.k[offset_k];

  // The offset to the position loaded by the thread in V and O.
  int offset_v = bi*params.v_stride_B + hi*params.v_stride_H + tidx;
  int offset_o = bi*params.o_stride_B + hi*params.o_stride_H + tidx;

  // If we walk backward, account for the extra offset.
  if( GO_BACKWARD ) {
    offset_v += (params.L-1)*params.v_stride_L;
    offset_o += (params.L-1)*params.o_stride_L;
  }

  // Determine the base pointers for V.
  const float *ptr_v = &params.v[offset_v];

  // Is it an active Q/K thread?
  const int active_qk = tidx < params.E;

  // Trigger the memory loads for Q and K.
  float ldg_q = 0.f, ldg_k = 0.f;
  if( active_qk ) {
    ldg_q = *ptr_q;
    ldg_k = *ptr_k;
  }

  // Is it an active V thread?
  const int active_v = tidx < params.M;

  // Trigger the memory loads for V.
  float ldg_v = 0.f;
  if( active_v ) {
    ldg_v = *ptr_v;
  }

  // Move the load pointers.
  if( GO_BACKWARD ) {
    ptr_q -= params.q_stride_L;
    ptr_k -= params.k_stride_L;
    ptr_v -= params.v_stride_L;
  } else {
    ptr_q += params.q_stride_L;
    ptr_k += params.k_stride_L;
    ptr_v += params.v_stride_L;
  }

  // The number of FLOAT4s per head.
  constexpr int FLOAT4s_PER_HEAD = E / 4;
  // The number of FLOAT4s per thread.
  constexpr int FLOAT4s_PER_THREAD = FLOAT4s_PER_HEAD / THREADS_PER_HEAD;

  // The storage for the K*V^T values.
  float4 kv[FLOAT4s_PER_THREAD];
  #pragma unroll
  for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
    kv[ii] = make_float4(0.f, 0.f, 0.f, 0.f);
  }

  // The output pointer.
  float *out_ptr = &params.out[offset_o];

  // Store to shared memory Q and K.
  if( tidx < E ) {
    smem_q[smem_curr*smem_buffer_elts + tidx] = ldg_q;
    smem_k[smem_curr*smem_buffer_elts + tidx] = ldg_k;
  }

  // Store to shared memory V. All threads store valid values.
  if( tidx < M ) {
    smem_v[smem_curr*smem_buffer_elts + tidx] = ldg_v;
  }

  // The position of the thread in the V dimension.
  int vo = tidx / THREADS_PER_HEAD;
  int vi = tidx % THREADS_PER_HEAD;

  // Iterate over the timesteps.
  for( int ti = 0; ti < params.L; ++ti ) {

    // Is it the last iteration?
    int is_last = ti == params.L - 1;

    // Trigger the next loads for Q and K.
    if( !is_last && active_qk ) {
      ldg_q = *ptr_q;
      ldg_k = *ptr_k;
    }

    // Trigger the next loads for V.
    if( !is_last && active_v ) {
      ldg_v = *ptr_v;
    }

    // Move the load pointers.
    if( GO_BACKWARD ) {
      ptr_q -= params.q_stride_L;
      ptr_k -= params.k_stride_L;
      ptr_v -= params.v_stride_L;
    } else {
      ptr_q += params.q_stride_L;
      ptr_k += params.k_stride_L;
      ptr_v += params.v_stride_L;
    }

    // Make sure the data is in shared memory.
    __syncthreads();

    // Each thread loads 4 values from K.
    float4 k[FLOAT4s_PER_THREAD];
    #pragma unroll
    for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
      int ki = tidx % THREADS_PER_HEAD * 4 + ii * THREADS_PER_HEAD * 4;
      k[ii] = *reinterpret_cast<const float4*>(&smem_k[smem_curr*smem_buffer_elts + ki]);
    }

    // Each thread loads a single V value.
    float v = 0.f;
    if( vo < params.M ) {
      v = *reinterpret_cast<const float *>(&smem_v[smem_curr*smem_buffer_elts + vo]);
    }

    // Update the K*V^T product.
    #pragma unroll
    for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
      kv[ii].x += k[ii].x * v;
      kv[ii].y += k[ii].y * v;
      kv[ii].z += k[ii].z * v;
      kv[ii].w += k[ii].w * v;
    }

    // Load the Q values from shared memory.
    float4 q[FLOAT4s_PER_THREAD];
    #pragma unroll
    for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
      int qi = tidx % THREADS_PER_HEAD * 4 + ii * THREADS_PER_HEAD * 4;
      q[ii] = *reinterpret_cast<const float4*>(&smem_q[smem_curr*smem_buffer_elts + qi]);
    }

    // Compute the partial output value for that thread.
    float sum = 0.f;
    #pragma unroll
    for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
      sum += q[ii].x * kv[ii].x;
      sum += q[ii].y * kv[ii].y;
      sum += q[ii].z * kv[ii].z;
      sum += q[ii].w * kv[ii].w;
    }

    // Finalize the computation of the sum (if we have more than 1 thread per head).
    if( THREADS_PER_HEAD > 1 ) {

      // Finalize the sum for each head.
      #pragma unroll
      for( int mask = THREADS_PER_HEAD / 2; mask >= 1; mask /= 2 ) {
        sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
      }

      // Store to shared memory.
      if( vo < M && vi == 0 ) {
        smem_o[smem_curr*smem_buffer_elts + vo] = sum;
      }

      // Make sure the data is in shared memory.
      __syncthreads();

      // Active threads read the data to store.
      if( active_v ) {
        sum = smem_o[smem_curr*smem_buffer_elts + tidx];
      }

    } // THREADS_PER_HEAD > 1.

    // Store the output. All the threads are active.
    if( active_v ) {
      *out_ptr = sum;
    }

    // Move to next location.
    if( GO_BACKWARD ) {
      out_ptr -= params.o_stride_L;
    } else {
      out_ptr += params.o_stride_L;
    }

    // Move the shared memory buffer.
    smem_curr = (smem_curr + 1) % 2;

    // Store to shared memory for Q and K.
    if( !is_last && tidx < E ) {
      smem_q[smem_curr*smem_buffer_elts + tidx] = ldg_q;
      smem_k[smem_curr*smem_buffer_elts + tidx] = ldg_k;
    }

    // Store to shared memory for V.
    if( !is_last && tidx < M ) {
      smem_v[smem_curr*smem_buffer_elts + tidx] = ldg_v;
    }
  }
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template< int E, int THREADS_PER_HEAD, bool GO_BACKWARD >
int lmha_(const Lmha_params<float> &params) {
  // The M dimension rounded up to 4.
  int M = round_up(params.M, 4);

  // The number of threads in the block.
  int block = round_up(max(E, M*THREADS_PER_HEAD), 32);
  if( block > 512 || params.B > 65535 ) {
    return 1;
  }

  // Prepare the kernel.
  dim3 grid(params.H, params.B);
  size_t smem = smem_buffer_elts_<E>(params)*2*sizeof(float);
  lmha_kernel<E, THREADS_PER_HEAD, GO_BACKWARD><<<grid, block, smem>>>(params);
  return 0;
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template< bool GO_BACKWARD >
int lmha(const Lmha_params<float> &params) {
  int blocks = params.B * params.H;
  int res = 1;
  if( blocks < LOW_OCCUPANCY_THRESHOLD ) {
           if( params.E <=  32 ) {
      res = lmha_low_occupancy_< 32, GO_BACKWARD>(params, blocks);
    } else if( params.E <=  64 ) {
      res = lmha_low_occupancy_< 64, GO_BACKWARD>(params, blocks);
    } else if( params.E <= 128 ) {
      res = lmha_low_occupancy_<128, GO_BACKWARD>(params, blocks);
    } else if( params.E <= 256 ) {
      res = lmha_low_occupancy_<256, GO_BACKWARD>(params, blocks);
    }
  } else {
           if( params.E <=  32 ) {
      res = lmha_< 32, 1, GO_BACKWARD>(params);
    } else if( params.E <=  48 ) {
      res = lmha_< 48, 1, GO_BACKWARD>(params);
    } else if( params.E <=  64 ) {
      res = lmha_< 64, 1, GO_BACKWARD>(params);
    } else if( params.E <= 128 ) {
      res = lmha_<128, 2, GO_BACKWARD>(params);
    } else if( params.E <= 256 ) {
      res = lmha_<256, 4, GO_BACKWARD>(params);
    }
  }
  return res;
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template< typename T >
inline void set_params(Lmha_params<T> &params,
                       const torch::Tensor q,
                       const torch::Tensor k,
                       const torch::Tensor v,
                       torch::Tensor       o) {

  // Define the pointers.
  params.out = o.data_ptr<T>();
  params.q   = q.data_ptr<T>();
  params.k   = k.data_ptr<T>();
  params.v   = v.data_ptr<T>();

  // Define the strides.
  params.q_stride_B = (int) q.stride(0);
  params.q_stride_H = (int) q.stride(1);
  params.q_stride_L = (int) q.stride(2);
  params.k_stride_B = (int) k.stride(0);
  params.k_stride_H = (int) k.stride(1);
  params.k_stride_L = (int) k.stride(2);
  params.v_stride_B = (int) v.stride(0);
  params.v_stride_H = (int) v.stride(1);
  params.v_stride_L = (int) v.stride(2);
  params.o_stride_B = (int) o.stride(0);
  params.o_stride_H = (int) o.stride(1);
  params.o_stride_L = (int) o.stride(2);

  // Extract the dimensions.
  int N = q.size(0);
  int H = q.size(1);
  int L = q.size(2);
  int E = q.size(3);
  int M = v.size(3);

  params.B = N;
  params.L = L;
  params.H  = H;
  params.E = E;
  params.M = M;
}

////////////////////////////////////////////////////////////////////////////////////////////////////

int lmha_fwd(const torch::Tensor queries,
             const torch::Tensor keys,
             const torch::Tensor values,
             torch::Tensor product) {

  // Make sure that we are using the correct GPU device
  torch::DeviceGuard _guard(queries.device());

  // Make sure the inner-most dimension of the tensors is packed.
  assert(queries.stride(3) == 1);
  assert(keys   .stride(3) == 1);
  assert(values .stride(3) == 1);
  assert(product.stride(3) == 1);

  // Extract the dimensions.
  int N = queries.size(0);
  int H = queries.size(1);
  int L = queries.size(2);
  int E = queries.size(3);
  int M = values.size (3);

  // The structure of params.
  Lmha_params<float> params;
  set_params(params, queries, keys, values, product);

  // Launch the kernel.
  return lmha<false>(params);
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template< typename T >
struct Lmha_bwd_params {

  // The output buffer for K. Dimensions [B, H, L, D].
  T *out_k;
  // The output buffer for V. Dimensions [B, H, L, D].
  T *out_v;

  // The input Qs. Dimensions [B, H, L, D].
  const T *q;
  // The input Ks. Dimensions [B, H, L, D].
  const T *k;
  // The input Vs. Dimensions [B, H, L, D].
  const T *v;
  // The input Gs. Dimensions [B, H, L, D].
  const T *g;

  // The dimensions.
  int B, L, H, M, E;

  // The strides for the input tensors.
  int q_stride_B, q_stride_L, q_stride_H;
  int k_stride_B, k_stride_L, k_stride_H;
  int v_stride_B, v_stride_L, v_stride_H;
  int g_stride_B, g_stride_L, g_stride_H;

  // The strides for the outputs.
  int out_k_stride_B, out_k_stride_L, out_k_stride_H;
  int out_v_stride_B, out_v_stride_L, out_v_stride_H;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template< int D, int THREADS_PER_HEAD >
__global__ __launch_bounds__(D*THREADS_PER_HEAD*2)
void lmha_bwd_kernel(Lmha_bwd_params<float> params) {

  // Make sure D is a multiple of 4.
  static_assert(D % 4 == 0, "");

  // The shared memory buffers.
  __shared__ struct Smem { float qg[2*D], kv[2*D], out_kv[2*D]; } smem_[2];

  // The index of the shared memory buffer (for double-buffering).
  int smem_curr = 0;

  // The sequence processed by that block.
  const int bi = blockIdx.y;
  // The head processed by that block.
  const int hi = blockIdx.x;

  // The linear index of the thread.
  const int tidx = threadIdx.x;

  // Split the threads into two slices.
  int so = tidx / (D*THREADS_PER_HEAD);
  int si = tidx % (D*THREADS_PER_HEAD);

  // The strides for B/L/H for the Q/G tensors.
  int qg_stride_B, qg_stride_L, qg_stride_H;
  if( so == 0 ) {
    qg_stride_B = params.q_stride_B;
    qg_stride_L = params.q_stride_L;
    qg_stride_H = params.q_stride_H;
  } else {
    qg_stride_B = params.g_stride_B;
    qg_stride_L = params.g_stride_L;
    qg_stride_H = params.g_stride_H;
  }

  // The strides for B/L/H for the K/V tensors.
  int kv_stride_B, kv_stride_L, kv_stride_H;
  if( so == 0 ) {
    kv_stride_B = params.k_stride_B;
    kv_stride_L = params.k_stride_L;
    kv_stride_H = params.k_stride_H;
  } else {
    kv_stride_B = params.v_stride_B;
    kv_stride_L = params.v_stride_L;
    kv_stride_H = params.v_stride_H;
  }

  // The hidden size.
  int hidden_size_per_head = 0;
  if( so == 0 ) {
    hidden_size_per_head = params.E;
  } else {
    hidden_size_per_head = params.M;
  }

  // Where to start reading from.
  int offset_qg = bi*qg_stride_B + hi*qg_stride_H + si;
  int offset_kv = bi*kv_stride_B + hi*kv_stride_H + si;

  // We walk backward, account for the extra offset.
  offset_qg += (params.L-1)*qg_stride_L;
  offset_kv += (params.L-1)*kv_stride_L;

  // Determine the base pointers for Q, K, V and G.
  const float *ptr_qg = &(so == 0 ? params.q : params.g)[offset_qg];
  const float *ptr_kv = &(so == 0 ? params.k : params.v)[offset_kv];

  // Is it an active thread?
  const int active = si < hidden_size_per_head;

  // Trigger the memory loads for Q, K, V and G.
  float ldg_qg = 0.f, ldg_kv = 0.f;
  if( active ) {
    ldg_qg = *ptr_qg;
    ldg_kv = *ptr_kv;
  }

  // Move the load pointers (backward).
  ptr_qg -= qg_stride_L;
  ptr_kv -= kv_stride_L;

  // The number of FLOAT4s per head.
  constexpr int FLOAT4s_PER_HEAD = D / 4;
  // The number of FLOAT4s per thread.
  constexpr int FLOAT4s_PER_THREAD = FLOAT4s_PER_HEAD / THREADS_PER_HEAD;

  // The storage for the G*Q^T or Q^T*G values.
  float4 gq[FLOAT4s_PER_THREAD];
  #pragma unroll
  for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
    gq[ii] = make_float4(0.f, 0.f, 0.f, 0.f);
  }

  // The strides for B/L/H for the K/V tensors.
  int out_kv_stride_B, out_kv_stride_L, out_kv_stride_H;
  if( so == 0 ) {
    out_kv_stride_B = params.out_k_stride_B;
    out_kv_stride_L = params.out_k_stride_L;
    out_kv_stride_H = params.out_k_stride_H;
  } else {
    out_kv_stride_B = params.out_v_stride_B;
    out_kv_stride_L = params.out_v_stride_L;
    out_kv_stride_H = params.out_v_stride_H;
  }

  // Where to start reading from.
  int offset_out_kv = bi*out_kv_stride_B + hi*out_kv_stride_H + si;

  // We walk backward, account for the extra offset.
  offset_out_kv += (params.L-1)*out_kv_stride_L;

  // The output pointer.
  float *ptr_out_kv = &(so == 0 ? params.out_k : params.out_v)[offset_out_kv];

  // Store to shared memory.
  if( si < D ) {
    smem_[smem_curr].qg[so*D + si] = ldg_qg;
    smem_[smem_curr].kv[so*D + si] = ldg_kv;
  }

  // The position of the thread in the output dimension.
  int oo = si / THREADS_PER_HEAD % D;
  int oi = si % THREADS_PER_HEAD * 4;

  // Iterate over the timesteps.
  for( int ti = 0; ti < params.L; ++ti ) {

    // Is it the last iteration?
    int is_last = ti == params.L - 1;

    // Trigger the next loads.
    if( !is_last && active ) {
      ldg_qg = *ptr_qg;
      ldg_kv = *ptr_kv;
    }

    // Move the load pointers.
    ptr_qg -= qg_stride_L;
    ptr_kv -= kv_stride_L;

    // Make sure the data is in shared memory.
    __syncthreads();

    // Each thread loads 4 values from G or Q.
    float4 g[FLOAT4s_PER_THREAD];
    #pragma unroll
    for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
      float *smem_ptr = &smem_[smem_curr].qg[(so^1)*D + oi];
      g[ii] = *reinterpret_cast<const float4*>(&smem_ptr[ii*THREADS_PER_HEAD*4]);
    }

    // Each thread loads a single from Q or G value.
    float q = smem_[smem_curr].qg[so*D + oo];

    // Update the G*Q^T or Q*G^T product.
    #pragma unroll
    for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
      gq[ii].x += g[ii].x * q;
      gq[ii].y += g[ii].y * q;
      gq[ii].z += g[ii].z * q;
      gq[ii].w += g[ii].w * q;
    }

    // Load the V or K values from shared memory.
    float4 v[FLOAT4s_PER_THREAD];
    #pragma unroll
    for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
      float *smem_ptr = &smem_[smem_curr].kv[(so^1)*D + oi];
      v[ii] = *reinterpret_cast<const float4*>(&smem_ptr[ii*THREADS_PER_HEAD*4]);
    }

    // Compute the partial output value for that thread.
    float sum = 0.f;
    #pragma unroll
    for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
      sum += v[ii].x * gq[ii].x;
      sum += v[ii].y * gq[ii].y;
      sum += v[ii].z * gq[ii].z;
      sum += v[ii].w * gq[ii].w;
    }

    // Finalize the computation of the sum (if we have more than 1 thread per head).
    if( THREADS_PER_HEAD > 1 ) {

      // Finalize the sum for each head.
      #pragma unroll
      for( int mask = THREADS_PER_HEAD / 2; mask >= 1; mask /= 2 ) {
        sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
      }

      // Store to shared memory.
      if( oi == 0 ) {
        smem_[smem_curr].out_kv[so*D + oo] = sum;
      }

      // Make sure the data is in shared memory.
      __syncthreads();

      // Active threads read the data to store.
      if( si < hidden_size_per_head ) {
        sum = smem_[smem_curr].out_kv[so*D + si];
      }

    } // THREADS_PER_HEAD > 1.

    // Store the output. All the threads are active.
    if( si < hidden_size_per_head ) {
      *ptr_out_kv = sum;
    }

    // Move to next location.
    ptr_out_kv -= out_kv_stride_L;

    // Move the shared memory buffer.
    smem_curr = (smem_curr + 1) % 2;

    // Store to shared memory for Q and K.
    if( !is_last && si < D ) {
      smem_[smem_curr].qg[so*D + si] = ldg_qg;
      smem_[smem_curr].kv[so*D + si] = ldg_kv;
    }
  }
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template< int D, int THREADS_PER_HEAD >
int lmha_bwd_(const Lmha_bwd_params<float> &params) {
  int block = D*THREADS_PER_HEAD*2;
  if( block >= 1024 || params.B > 65535 ) {
    return 1;
  }
  dim3 grid(params.H, params.B);
  lmha_bwd_kernel<D, THREADS_PER_HEAD><<<grid, block>>>(params);
  return 0;
}

////////////////////////////////////////////////////////////////////////////////////////////////////

int lmha_bwd(const Lmha_bwd_params<float> &params) {
  int blocks = params.B * params.H;
  if( blocks < LOW_OCCUPANCY_THRESHOLD ) {
    return 1;
  }

  int hidden_size_per_head = max(params.E, params.M);
  int res = 1;
  if( hidden_size_per_head <= 32 ) {
    res = lmha_bwd_< 32, 1>(params);
  } else if( hidden_size_per_head <= 64 ) {
    res = lmha_bwd_< 64, 1>(params);
  } else if( hidden_size_per_head <= 128 ) {
    res = lmha_bwd_<128, 2>(params);
  } else if( hidden_size_per_head <= 256 ) {
    res = lmha_bwd_<256, 4>(params);
  }
  return res;
}

////////////////////////////////////////////////////////////////////////////////////////////////////

int lmha_bwd(const torch::Tensor queries,
             const torch::Tensor keys,
             const torch::Tensor values,
             const torch::Tensor grad_out,
             torch::Tensor grad_queries,
             torch::Tensor grad_keys,
             torch::Tensor grad_values) {

  // Make sure that we are using the correct GPU device
  torch::DeviceGuard _guard(queries.device());

  // Make sure the inner-most dimension of the tensors is packed.
  assert(queries     .stride(3) == 1);
  assert(keys        .stride(3) == 1);
  assert(values      .stride(3) == 1);
  assert(grad_out    .stride(3) == 1);
  assert(grad_queries.stride(3) == 1);
  assert(grad_keys   .stride(3) == 1);
  assert(grad_values .stride(3) == 1);

  // Extract the dimensions.
  int N = queries.size(0);
  int H = queries.size(1);
  int L = queries.size(2);
  int E = queries.size(3);
  int M = values.size (3);

  // Gradient on Q.

  // The structure of params.
  Lmha_params<float> params;
  set_params(params, grad_out, values, keys, grad_queries);

  // Launch the kernel.
  int res = lmha<false>(params);
  if( res ) {
    return res;
  }

  // Gradient on K and V together.

  Lmha_bwd_params<float> bwd_params;
  bwd_params.out_k = grad_keys.data_ptr<float>();
  bwd_params.out_v = grad_values.data_ptr<float>();
  bwd_params.q = queries.data_ptr<float>();
  bwd_params.k = keys.data_ptr<float>();
  bwd_params.v = values.data_ptr<float>();
  bwd_params.g = grad_out.data_ptr<float>();

  bwd_params.B = N;
  bwd_params.L = L;
  bwd_params.H = H;
  bwd_params.E = E;
  bwd_params.M = M;

  bwd_params.q_stride_B = queries.stride(0);
  bwd_params.q_stride_H = queries.stride(1);
  bwd_params.q_stride_L = queries.stride(2);
  bwd_params.k_stride_B = keys.stride(0);
  bwd_params.k_stride_H = keys.stride(1);
  bwd_params.k_stride_L = keys.stride(2);
  bwd_params.v_stride_B = values.stride(0);
  bwd_params.v_stride_H = values.stride(1);
  bwd_params.v_stride_L = values.stride(2);
  bwd_params.g_stride_B = grad_out.stride(0);
  bwd_params.g_stride_H = grad_out.stride(1);
  bwd_params.g_stride_L = grad_out.stride(2);

  bwd_params.out_k_stride_B = grad_keys.stride(0);
  bwd_params.out_k_stride_H = grad_keys.stride(1);
  bwd_params.out_k_stride_L = grad_keys.stride(2);
  bwd_params.out_v_stride_B = grad_values.stride(0);
  bwd_params.out_v_stride_H = grad_values.stride(1);
  bwd_params.out_v_stride_L = grad_values.stride(2);

  // Try to run the fused kernel.
  int fallback = lmha_bwd(bwd_params);

  // If it failed, fallback on separate kernels for K and V.
  if( fallback ) {

    // Gradient on K.

    // Launch the kernel.
    set_params(params, values, grad_out, queries, grad_keys);
    res = lmha<true>(params);
    if( res ) {
      return res;
    }

    // Gradient on V.

    // Launch the kernel.
    set_params(params, keys, queries, grad_out, grad_values);
    return lmha<true>(params);
  }

  // It worked...
  return 0;
}

////////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace nvidia
#endif // #ifdef ENABLE_NVIDIA_OPTIMIZATIONS

////////////////////////////////////////////////////////////////////////////////////////////////////

typedef torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> float_accessor;

__device__ void get_result(
    const float_accessor queries,
    const float_accessor keys,
    const float_accessor values,
    float_accessor kv,
    float_accessor result,
    const int n,
    const int h,
    const int e,
    const int m,
    const int L
) {
    for (int l=0; l<L; l++) {
        kv[n][h][e][m] += keys[n][h][l][e] * values[n][h][l][m];
        __syncthreads();
        float res = queries[n][h][l][e]*kv[n][h][e][m];
        atomicAdd(
            &result[n][h][l][m],
            res
        );
    }
}


__global__ void causal_dot_product_kernel(
    const float_accessor queries,
    const float_accessor keys,
    const float_accessor values,
    float_accessor kv,
    float_accessor result,
    const int N,
    const int H,
    const int L,
    const int E,
    const int M,
    const int E_per_block,
    const int blocks_per_sequence,
    const int T,
    const int l_offset
) {
    const int sequence_index = blockIdx.x / blocks_per_sequence;
    int n = sequence_index / H;
    int h = sequence_index % H;

    int e_local = threadIdx.x / M;
    int e_start = ((blockIdx.x % blocks_per_sequence) * E_per_block);
    int e = e_start + e_local;
    int m = threadIdx.x % M;

    // Load the shared memory for KV
    const int shared_kv_size = E_per_block * M;
    extern __shared__ float shared_mem[];
    float* shared_kv = shared_mem;
    float* shared_results = shared_mem + shared_kv_size;
    float* shared_values = shared_results + M;
    float* shared_keys = shared_values + M*T;
    float* shared_queries = shared_keys + E_per_block*T;

    if (threadIdx.x < M) {
        shared_results[threadIdx.x] = 0.0;
    }

    int t_end = (T + l_offset) <= L ? T : L - l_offset;
    for (int i = threadIdx.x; i < (t_end*M); i += blockDim.x)
    {
        int t = int(i / M) + l_offset;
        int d = i % M;
        shared_values[i] = values[n][h][t][d];
    }
    for (int i = threadIdx.x; i < (t_end*E_per_block); i += blockDim.x)
    {
        int t = int(i / E_per_block) + l_offset;
        int d = (i % E_per_block) + e_start;
        if (d < E) {
            shared_keys[i] = keys[n][h][t][d];
            shared_queries[i] = queries[n][h][t][d];
        }
    }
    __syncthreads();
    if ((n >= N) || (e >= E)) {
        return;
    }
    shared_kv[threadIdx.x] = kv[n][h][e][m];
    for (int t=0; t<t_end; t++) {
        int l = t + l_offset;
        shared_kv[e_local*M + m] += shared_keys[t*E_per_block + e_local] * shared_values[t*M + m];
        __syncthreads();
        float res = shared_queries[t*E_per_block + e_local] * shared_kv[e_local*M + m];
        atomicAdd(
            &shared_results[m],
            res
        );
        __syncthreads();
        if (threadIdx.x < M) {
            float r1 = shared_results[threadIdx.x];
            atomicAdd(
                &result[n][h][l][m],
                r1
            );
            shared_results[threadIdx.x] = 0.0;
        }
    }
    __syncthreads();
    kv[n][h][e][m] = shared_kv[e_local*M + m];
}

////////////////////////////////////////////////////////////////////////////////////////////////////

void causal_dot_product_(const torch::Tensor queries,
                         const torch::Tensor keys,
                         const torch::Tensor values,
                         torch::Tensor product) {
    // Make sure that we are using the correct GPU device
    torch::DeviceGuard _guard(queries.device());

    int N = queries.size(0);
    int H = queries.size(1);
    int L = queries.size(2);
    int E = queries.size(3);
    int M = values.size(3);

    auto kv = torch::zeros({N, H, E, M}, queries.options());

    int threads = 1024;

    // Shared mem max size is 48KB
    int MUL_PER_BLOCK = min(threads, E*M);
    // make sure that MUL_PER_BLOCK is divisible by M;
    MUL_PER_BLOCK = int(MUL_PER_BLOCK / M) *  M;
    threads = MUL_PER_BLOCK;
    const int blocks_per_sequence = ((E*M) + threads -1) / threads;

    const int E_per_block = MUL_PER_BLOCK / M;
    int blocks  = N*H*blocks_per_sequence;
    int shared_mem_const = (E_per_block + 1)*M;
    int shared_mem_per_time = (M + 2*E_per_block);
    const int T = int(((12 * 1024) - shared_mem_const) / shared_mem_per_time);
    const int shared_mem_forward = ((T*shared_mem_per_time) + shared_mem_const) * sizeof(float);

    for (int l_offset=0; l_offset < L; l_offset += T) {
     causal_dot_product_kernel
            <<<blocks, MUL_PER_BLOCK, shared_mem_forward>>>(
            queries.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
            keys.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
            values.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
            kv.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
            product.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
            N, H, L, E, M, E_per_block, blocks_per_sequence, T, l_offset
        );
    }
}

////////////////////////////////////////////////////////////////////////////////////////////////////

void causal_dot_product(const torch::Tensor queries,
                        const torch::Tensor keys,
                        const torch::Tensor values,
                        torch::Tensor product) {
#ifdef ENABLE_NVIDIA_OPTIMIZATIONS
  int fallback = nvidia::lmha_fwd(queries, keys, values, product);
#else
  int fallback = 1;
#endif
  if( fallback ) {
    causal_dot_product_(queries, keys, values, product);
  }
}

////////////////////////////////////////////////////////////////////////////////////////////////////

// we need shared memory to store
// Forward direction
// keys, values, gradout
// kv, results
// Backward direction
// queries, gradout, values
// kv_backwards, results
// Shared memory usage
// Forward
// keys: E*T, (values, gradout): M_per_block*T, kv:E*M_per_block, results:E
// Backward
// queries: E*T, (values, gradout): M_per_block*T, kv:E*M_per_block, results:E
// Total memory:
__global__ void causal_dot_backward_query_key_kernel(
    const float_accessor queries,
    const float_accessor keys,
    const float_accessor values,
    const float_accessor grad_out,
    float_accessor kv,
    float_accessor kv_backwards,
    float_accessor grad_queries,
    float_accessor grad_keys,
    int N,
    int H,
    int L,
    int E,
    int M,
    const int M_per_block,
    const int blocks_per_sequence,
    const int T,
    const int l_offset
) {
    const int sequence_index = blockIdx.x / blocks_per_sequence;
    int n = sequence_index / H;
    int h = sequence_index % H;

    int m_local = threadIdx.x / E;
    int m_start = ((blockIdx.x % blocks_per_sequence)*M_per_block);
    int m = m_start + m_local;
    int e = threadIdx.x % E;

    // Load the shared memory
    // Forward memory
    // keys: E*T, (values, gradout): M_per_block*T, kv:E*M_per_block, results:E
    // Backward memory
    // queries: E*T, (values, gradout): M_per_block*T, kv:E*M_per_block, results:E
    // Load the shared memory for KV
    extern __shared__ float shared_mem[];
    const int shared_kv_size = M_per_block * E;
    float* shared_kv = shared_mem;
    float* shared_kv_bw = shared_mem + shared_kv_size;
    float* shared_results = shared_kv_bw + shared_kv_size;
    float* shared_results_bw = shared_results + E;
    float* shared_keys = shared_results_bw + E;
    float* shared_values = shared_keys + E*T;
    float* shared_gradout = shared_values + M_per_block*T;
    float* shared_queries_bw = shared_gradout + M_per_block*T;
    float* shared_values_bw = shared_queries_bw + E*T;
    float* shared_gradout_bw = shared_values_bw + M_per_block*T;

    if (threadIdx.x < E) {
        shared_results[threadIdx.x] = 0.0;
        shared_results_bw[threadIdx.x] = 0.0;
    }

    int t_end = (T + l_offset) <= L ? T : (L - l_offset);
    for (int i = threadIdx.x; i < (t_end*M_per_block); i += blockDim.x)
    {
        int t = int(i / M_per_block) + l_offset;
        int t_bw = L - t - 1;
        int d = (i % M_per_block) + m_start;
        if (d < M) {
            shared_values[i] = values[n][h][t][d];
            shared_gradout[i] = grad_out[n][h][t][d];
            shared_values_bw[i] = values[n][h][t_bw][d];
            shared_gradout_bw[i] = grad_out[n][h][t_bw][d];
        }
    }
    for (int i = threadIdx.x; i < (t_end*E); i += blockDim.x)
    {
        int t = int(i / E) + l_offset;
        int t_bw = L - t - 1;
        int d = (i % E);
        shared_keys[i] = keys[n][h][t][d];
        shared_queries_bw[i] = queries[n][h][t_bw][d];
    }
    __syncthreads();

    if ((n >= N) || (m >= M)) {
        return;
    }

    shared_kv[threadIdx.x] = kv[n][h][e][m];
    shared_kv_bw[threadIdx.x] = kv_backwards[n][h][e][m];

    for (int t=0; t<t_end; t++) {
        int l = t + l_offset;
        int l_b = L - l -1;
        shared_kv[m_local*E + e] += shared_keys[t*E + e] * shared_values[t*M_per_block + m_local];
        shared_kv_bw[m_local*E + e] += shared_queries_bw[t*E + e] * shared_gradout_bw[t*M_per_block + m_local];
        __syncthreads();
        float res = shared_gradout[t*M_per_block + m_local] * shared_kv[m_local*E + e];
        float res_bw = shared_values_bw[t*M_per_block + m_local] * shared_kv_bw[m_local*E + e];
        atomicAdd(
            &shared_results[e],
            res
        );
        atomicAdd(
            &shared_results_bw[e],
            res_bw
        );
        __syncthreads();
        if (threadIdx.x < E) {
            float rq = shared_results[threadIdx.x];
            float rk = shared_results_bw[threadIdx.x];
            atomicAdd(
                &grad_queries[n][h][l][e],
                rq
            );
            atomicAdd(
                &grad_keys[n][h][l_b][e],
                rk
            );
            shared_results[threadIdx.x] = 0.0;
            shared_results_bw[threadIdx.x] = 0.0;
        }
    }
    __syncthreads();
    kv[n][h][e][m] = shared_kv[m_local*E + e];
    kv_backwards[n][h][e][m] = shared_kv_bw[m_local*E + e];
}


__global__ void causal_dot_backward_value_kernel(
    const float_accessor queries,
    const float_accessor keys,
    const float_accessor values,
    const float_accessor grad_out,
    float_accessor kv,
    float_accessor grad_keys,
    float_accessor grad_values,
    int N,
    int H,
    int L,
    int E,
    int M,
    int E_per_block,
    int blocks_per_sequence,
    int T,
    int l_offset
) {
    const int sequence_index = blockIdx.x / blocks_per_sequence;
    int n = sequence_index / H;
    int h = sequence_index % H;

    int e_local = threadIdx.x / M;
    int e_start = ((blockIdx.x % blocks_per_sequence) * E_per_block);
    int e = e_start + e_local;
    int m = threadIdx.x % M;

    // Load the shared memory for KV
    const int shared_kv_size = E_per_block * M;
    extern __shared__ float shared_mem[];
    float* shared_kv = shared_mem;
    float* shared_results = shared_mem + shared_kv_size;
    float* shared_gradout = shared_results + M;
    float* shared_keys = shared_gradout + M*T;
    float* shared_queries = shared_keys + E_per_block*T;

    if (threadIdx.x < M) {
        shared_results[threadIdx.x] = 0.0;
    }

    int t_end = (T + l_offset) <= L ? T : L - l_offset;
    for (int i = threadIdx.x; i < (t_end*M); i += blockDim.x)
    {
        int t = int(i / M) + l_offset;
        int t_bw = L - 1 - t;
        int d = i % M;
        shared_gradout[i] = grad_out[n][h][t_bw][d];
    }
    for (int i = threadIdx.x; i < (t_end*E_per_block); i += blockDim.x)
    {
        int t = int(i / E_per_block) + l_offset;
        int t_bw = L - 1 - t;
        int d = (i % E_per_block) + e_start;
        if (d < E) {
            shared_keys[i] = keys[n][h][t_bw][d];
            shared_queries[i] = queries[n][h][t_bw][d];
        }
    }
    __syncthreads();

    if ((n >= N) || (e >= E)){
        return;
    }

    shared_kv[threadIdx.x] = kv[n][h][e][m];
    for (int t=0; t<t_end; t++) {
        int l = t + l_offset;
        int l_b = L - l -1;
        shared_kv[e_local*M + m] += shared_queries[t*E_per_block + e_local] * shared_gradout[t*M + m];
        __syncthreads();
        float res = shared_keys[t*E_per_block + e_local] * shared_kv[e_local*M + m];
        atomicAdd(
            &shared_results[m],
            res
        );
        __syncthreads();
        if (threadIdx.x < M) {
            float r1 = shared_results[threadIdx.x];
            atomicAdd(
                &grad_values[n][h][l_b][m],
                r1
            );
            shared_results[threadIdx.x] = 0.0;
        }
    }
    __syncthreads();
    kv[n][h][e][m] = shared_kv[e_local*M + m];
}

////////////////////////////////////////////////////////////////////////////////////////////////////

void causal_dot_backward_(const torch::Tensor queries,
                          const torch::Tensor keys,
                          const torch::Tensor values,
                          const torch::Tensor grad_out,
                          torch::Tensor grad_queries,
                          torch::Tensor grad_keys,
                          torch::Tensor grad_values) {

    // Make sure that we are using the correct GPU device
    torch::DeviceGuard _guard(queries.device());

    int N = queries.size(0);
    int H = queries.size(1);
    int L = queries.size(2);
    int E = queries.size(3);
    int M = values.size(3);

    auto kv = torch::zeros({N, H, E, M}, queries.options());
    auto kv_backward = torch::zeros({N, H, E, M}, queries.options());

    const int threads = 1024;
    int MUL_PER_BLOCK = min(threads, E*M);
    // make sure that MUL_PER_BLOCK is divisible by M;
    MUL_PER_BLOCK = int(MUL_PER_BLOCK / E) *  E;
    const int blocks_per_sequence = ((E*M) + MUL_PER_BLOCK -1) / MUL_PER_BLOCK;
    const int M_per_block = MUL_PER_BLOCK / E;
    int blocks  = N*H*blocks_per_sequence;

    // Forward memory
    // keys: E*T, (values, gradout): M_per_block*T, kv:E*M_per_block, results:E
    // Backward memory
    // queries: E*T, (values, gradout): M_per_block*T, kv:E*M_per_block, results:E
    // Total memory
    // 2*((E + 2*M_per_block)*T + (E+1)*M_per_block)
    int shared_mem_const = 2*E*(1+M_per_block);
    int shared_mem_per_time = 2*(E + 2*M_per_block);
    int T = int(((12 * 1024) - shared_mem_const) / shared_mem_per_time);
    const int shared_mem_qk_backward = ((T*shared_mem_per_time) + shared_mem_const) * sizeof(float);
    for (int l_offset=0; l_offset < L; l_offset += T) {
        causal_dot_backward_query_key_kernel
            <<<blocks, MUL_PER_BLOCK, shared_mem_qk_backward>>>(
            queries.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
            keys.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
            values.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
            grad_out.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
            kv.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
            kv_backward.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
            grad_queries.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
            grad_keys.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
            N, H, L, E, M, M_per_block, blocks_per_sequence, T, l_offset
        );
    }

    int MPB = min(threads, E*M);
    // make sure that MUL_PER_BLOCK is divisible by M;
    MPB = int(MPB / M) *  M;
    const int blocks_per_sequence_value = ((E*M) + MPB - 1)/ MPB;
    const int E_per_block = MPB / M;
    const int blocks_value  = N*H*blocks_per_sequence_value;

    shared_mem_const = (E_per_block + 1)*M;
    shared_mem_per_time = (M + 2*E_per_block);
    T = int(((12 * 1024) - shared_mem_const) / shared_mem_per_time);
    const int shared_mem_v_backward = ((T*shared_mem_per_time) + shared_mem_const) * sizeof(float);
    kv.zero_();
    for (int l_offset=0; l_offset < L; l_offset += T) {
        causal_dot_backward_value_kernel
            <<<blocks_value, MPB, shared_mem_v_backward>>>(
            queries.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
            keys.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
            values.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
            grad_out.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
            kv.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
            grad_keys.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
            grad_values.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
            N, H, L, E, M, E_per_block, blocks_per_sequence_value, T, l_offset
        );
    }
}

////////////////////////////////////////////////////////////////////////////////////////////////////

void causal_dot_backward(const torch::Tensor queries,
                         const torch::Tensor keys,
                         const torch::Tensor values,
                         const torch::Tensor grad_out,
                         torch::Tensor grad_queries,
                         torch::Tensor grad_keys,
                         torch::Tensor grad_values) {
#ifdef ENABLE_NVIDIA_OPTIMIZATIONS
  int fallback = nvidia::lmha_bwd(queries,
                                  keys,
                                  values,
                                  grad_out,
                                  grad_queries,
                                  grad_keys,
                                  grad_values);
#else
  int fallback = 1;
#endif
  if( fallback ) {
    causal_dot_backward_(queries, keys, values, grad_out, grad_queries, grad_keys, grad_values);
  }
}

////////////////////////////////////////////////////////////////////////////////////////////////////

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def(
        "causal_dot_product",
        &causal_dot_product,
        "Compute the weighted sum of values but attending only to previous "
        "values."
    );
    m.def(
        "causal_dot_backward",
        &causal_dot_backward,
        "Compute the gradients for the causal dot product."
    );
}
