#include <torch/extension.h>

#include <vector>
#include <stdio.h>
#include <mma.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cub/block/block_load.cuh>
#include <cub/block/block_store.cuh>

#include "monarch_cuda_shared_bf16_load_frags.h"
#include "monarch_cuda_shared_bf16_complex_mul.h"
#include "monarch_cuda_shared_bf16_matmuls.h"
using namespace nvcuda;

using complex_bfloat16_t = typename c10::complex<at::BFloat16>;

#define WMMA_M 16
#define WMMA_N 16
#define WMMA_K 16
// #define TILE_SIZE 4
// #define SHMEM_SIZE 256 * TILE_SIZE
// #define SEQUENCE_SIZE 256
#define WARP_SIZE 32


#ifndef MONARCH_CUDA_BF16_
#define MONARCH_CUDA_BF16_

template <typename ALayout, typename BLayout, bool a_trans, bool out_trans, int MATMUL_WARP_WIDTH, bool a_frag_from_acc, bool output_to_shmem>
__device__ __forceinline__ void complex_matmul(
  __nv_bfloat16 *a_real,
  __nv_bfloat16 *a_imag,
  int sqrt_N,
  int N,
  wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_K, WMMA_N, __nv_bfloat16, BLayout> b_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, float> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_half[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, __nv_bfloat16, ALayout> k_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::layout_t out_layout = wmma::mem_row_major)
{

  wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, __nv_bfloat16, ALayout> a_frag [MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2];
  load_a_frag<ALayout, a_trans, MATMUL_WARP_WIDTH, a_frag_from_acc>(a_real, a_imag, sqrt_N, N, acc_frag_half, a_frag);

  // __syncthreads();
  // multiply a_frag by k_frag
  for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
    for (int k = 0; k < MATMUL_WARP_WIDTH; k++) {
      for (int i = 0; i < acc_frag_1[j_a][k][0].num_elements / 2; i++) {
        complex_mul_bfloat162(
          __nv_bfloat162(a_frag[j_a][k][0].x[2 * i], a_frag[j_a][k][0].x[2 * i + 1]),
          __nv_bfloat162(a_frag[j_a][k][1].x[2 * i], a_frag[j_a][k][1].x[2 * i + 1]),
          __nv_bfloat162(k_frag[j_a][k][0].x[2 * i], k_frag[j_a][k][0].x[2 * i + 1]),
          __nv_bfloat162(k_frag[j_a][k][1].x[2 * i], k_frag[j_a][k][1].x[2 * i + 1]),
          &a_frag[j_a][k][0].x[2 * i], 
          &a_frag[j_a][k][1].x[2 * i],
          &a_frag[j_a][k][0].x[2 * i + 1],
          &a_frag[j_a][k][1].x[2 * i + 1]
        );
      }
    }
  }

  _complex_matmul<ALayout, BLayout, out_trans, MATMUL_WARP_WIDTH, output_to_shmem>(a_real, a_imag, sqrt_N, N, a_frag, b_frag, acc_frag_1, acc_frag_half, out_layout);
}

template <typename ALayout, typename BLayout, bool a_trans, bool out_trans, int MATMUL_WARP_WIDTH, bool a_frag_from_acc, bool output_to_shmem>
__device__ __forceinline__ void complex_matmul(
  __nv_bfloat16 *a_real,
  __nv_bfloat16 *a_imag,
  int sqrt_N,
  int N,
  wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_K, WMMA_N, __nv_bfloat16, BLayout> b_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, float> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_half[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::layout_t out_layout = wmma::mem_row_major)
{
  wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, __nv_bfloat16, ALayout> a_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2];
  load_a_frag<ALayout, a_trans, MATMUL_WARP_WIDTH, a_frag_from_acc>(a_real, a_imag, sqrt_N, N, acc_frag_half, a_frag);

  // __syncthreads();
  _complex_matmul<ALayout, BLayout, out_trans, MATMUL_WARP_WIDTH, output_to_shmem>(a_real, a_imag, sqrt_N, N, a_frag, b_frag, acc_frag_1, acc_frag_half, out_layout);
}

template <typename ALayout, typename BLayout, bool b_trans, bool out_trans, int MATMUL_WARP_WIDTH, bool b_frag_from_acc, bool output_to_shmem>
__device__ __forceinline__ void complex_matmul_load_b(
   __nv_bfloat16* b_real,
   __nv_bfloat16* b_imag,
   int sqrt_N,
   int N,
   wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, __nv_bfloat16, ALayout> a_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
   wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, float> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
   wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_half[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
   wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_K, WMMA_N, __nv_bfloat16, BLayout> k_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
   wmma::layout_t out_layout = wmma::mem_row_major)
{
  wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_K, WMMA_N, __nv_bfloat16, BLayout> b_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2];
  load_b_frag<BLayout, b_trans, MATMUL_WARP_WIDTH, b_frag_from_acc>(b_real, b_imag, sqrt_N, N, b_frag);

  // __syncthreads();
  // multiply b_frag by k_frag
  for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
    for (int k = 0; k < MATMUL_WARP_WIDTH; k++) {
      for (int i = 0; i < acc_frag_1[j_a][k][0].num_elements / 2; i++) {
        complex_mul_bfloat162(
          __nv_bfloat162(b_frag[j_a][k][0].x[2 * i], b_frag[j_a][k][0].x[2 * i + 1]),
          __nv_bfloat162(b_frag[j_a][k][1].x[2 * i], b_frag[j_a][k][1].x[2 * i + 1]),
          __nv_bfloat162(k_frag[j_a][k][0].x[2 * i], k_frag[j_a][k][0].x[2 * i + 1]),
          __nv_bfloat162(k_frag[j_a][k][1].x[2 * i], k_frag[j_a][k][1].x[2 * i + 1]),
          &b_frag[j_a][k][0].x[2 * i], 
          &b_frag[j_a][k][1].x[2 * i],
          &b_frag[j_a][k][0].x[2 * i + 1],
          &b_frag[j_a][k][1].x[2 * i + 1]
        );
      }
    }
  }

  // __syncthreads();
  _complex_matmul<ALayout, BLayout, out_trans, MATMUL_WARP_WIDTH, output_to_shmem>(b_real, b_imag, sqrt_N, N, a_frag, b_frag, acc_frag_1, acc_frag_half, out_layout);
}

template <typename ALayout, typename BLayout, bool b_trans, bool out_trans, int MATMUL_WARP_WIDTH, bool b_frag_from_acc, bool output_to_shmem>
__device__ __forceinline__ void complex_matmul_r2c_load_b(
  __nv_bfloat16 *b_real_input,
  __nv_bfloat16* a_real,
  __nv_bfloat16* a_imag,
  int sqrt_N,
  int N,
  wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, __nv_bfloat16, ALayout> a_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, float> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_half[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::layout_t out_layout = wmma::mem_row_major)
{
  wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_K, WMMA_N, __nv_bfloat16, BLayout> b_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2];
  load_b_frag_r2c<BLayout, b_trans, MATMUL_WARP_WIDTH, b_frag_from_acc>(b_real_input, sqrt_N, N, b_frag);

  _complex_matmul_r2c_load_b<ALayout, BLayout, out_trans, MATMUL_WARP_WIDTH, output_to_shmem>(a_real, a_imag, sqrt_N, N, a_frag, b_frag, acc_frag_1, acc_frag_half, out_layout);
}

template <typename ALayout, typename BLayout, bool a_trans, bool out_trans, int MATMUL_WARP_WIDTH, bool a_frag_from_acc, bool output_to_shmem>
__device__ __forceinline__ void complex_matmul_r2c_256(
  const __nv_bfloat16 *a_real_input,
  __nv_bfloat16 *a_real,
  __nv_bfloat16 *a_imag,
  int sqrt_N,
  int N,
  wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_K, WMMA_N, __nv_bfloat16, BLayout> b_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, float> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_half[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::layout_t out_layout = wmma::mem_row_major)
{
  wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, __nv_bfloat16, ALayout> a_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2];
  load_a_frag_r2c_256<ALayout, a_trans, MATMUL_WARP_WIDTH, a_frag_from_acc>(a_real_input, sqrt_N, N, acc_frag_half, a_frag);

  // __syncthreads();

  _complex_matmul_r2c_256<ALayout, BLayout, out_trans, MATMUL_WARP_WIDTH, output_to_shmem>(a_real, a_imag, sqrt_N, N, a_frag, b_frag, acc_frag_1, acc_frag_half, out_layout);
}

template <typename ALayout, typename BLayout, bool a_trans, bool out_trans, int MATMUL_WARP_WIDTH, bool a_frag_from_acc, bool output_to_shmem>
__device__ __forceinline__ void complex_matmul_r2c_1024(
  const __nv_bfloat16 *a_real_input,
  __nv_bfloat16 *a_real,
  __nv_bfloat16 *a_imag,
  int sqrt_N,
  int N,
  wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_K, WMMA_N, __nv_bfloat16, BLayout> b_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, float> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_half[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::layout_t out_layout = wmma::mem_row_major)
{
  wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, __nv_bfloat16, ALayout> a_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2];
  load_a_frag_r2c_1024<ALayout, a_trans, MATMUL_WARP_WIDTH, a_frag_from_acc>(a_real_input, sqrt_N, N, acc_frag_half, a_frag);

  // __syncthreads();

  _complex_matmul_r2c_1024<ALayout, BLayout, out_trans, MATMUL_WARP_WIDTH, output_to_shmem>(a_real, a_imag, sqrt_N, N, a_frag, b_frag, acc_frag_1, acc_frag_half, out_layout);
}

template <typename ALayout, typename BLayout, bool a_trans, bool out_trans, int MATMUL_WARP_WIDTH, bool a_frag_from_acc, bool output_to_shmem>
__device__ __forceinline__ void complex_matmul_c2c_1024(
  __nv_bfloat16 *a_real,
  __nv_bfloat16 *a_imag,
  int sqrt_N,
  int N,
  wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_K, WMMA_N, __nv_bfloat16, BLayout> b_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, float> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_half[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::layout_t out_layout = wmma::mem_row_major)
{
  wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, __nv_bfloat16, ALayout> a_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2];
  load_a_frag_1024<ALayout, a_trans, MATMUL_WARP_WIDTH, a_frag_from_acc>(a_real, a_imag, sqrt_N, N, acc_frag_half, a_frag);

  // __syncthreads();

  _complex_matmul_1024<ALayout, BLayout, out_trans, MATMUL_WARP_WIDTH, output_to_shmem>(a_real, a_imag, sqrt_N, N, a_frag, b_frag, acc_frag_1, acc_frag_half, out_layout);
}

template <typename ALayout, typename BLayout, bool a_trans, bool out_trans, int MATMUL_WARP_WIDTH, bool a_frag_from_acc, bool output_to_shmem>
__device__ __forceinline__ void complex_matmul_c2c_1024(
  const __nv_bfloat16 *a_real_inp,
  const __nv_bfloat16 *a_imag_inp,
  __nv_bfloat16 *a_real_out,
  __nv_bfloat16 *a_imag_out,
  int sqrt_N,
  int N,
  wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_K, WMMA_N, __nv_bfloat16, BLayout> b_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, float> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_half[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::layout_t out_layout = wmma::mem_row_major)
{
  wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, __nv_bfloat16, ALayout> a_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2];
  load_a_frag_1024<ALayout, a_trans, MATMUL_WARP_WIDTH, a_frag_from_acc>(a_real_inp, a_imag_inp, sqrt_N, N, acc_frag_half, a_frag);

  // __syncthreads();

  _complex_matmul_1024<ALayout, BLayout, out_trans, MATMUL_WARP_WIDTH, output_to_shmem>(a_real_out, a_imag_out, sqrt_N, N, a_frag, b_frag, acc_frag_1, acc_frag_half, out_layout);
}

template <typename ALayout, typename BLayout, bool a_trans, bool out_trans, int MATMUL_WARP_WIDTH, bool a_frag_from_acc, bool output_to_shmem>
__device__ __forceinline__ void complex_matmul_c2c_1024(
  __nv_bfloat16 *a_real_inp,
  __nv_bfloat16 *a_imag_inp,
  __nv_bfloat16 *a_real_out,
  __nv_bfloat16 *a_imag_out,
  int sqrt_N,
  int N,
  wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_K, WMMA_N, __nv_bfloat16, BLayout> b_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, float> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_half[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, __nv_bfloat16, ALayout> k_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::layout_t out_layout = wmma::mem_row_major)
{
  wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, __nv_bfloat16, ALayout> a_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2];
  load_a_frag_1024<ALayout, a_trans, MATMUL_WARP_WIDTH, a_frag_from_acc>(a_real_inp, a_imag_inp, sqrt_N, N, acc_frag_half, a_frag);

  // multiply a_frag by k_frag
  for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
    for (int k = 0; k < MATMUL_WARP_WIDTH; k++) {
      for (int i = 0; i < acc_frag_1[j_a][k][0].num_elements / 2; i++) {
        complex_mul_bfloat162(
          __nv_bfloat162(a_frag[j_a][k][0].x[2 * i], a_frag[j_a][k][0].x[2 * i + 1]),
          __nv_bfloat162(a_frag[j_a][k][1].x[2 * i], a_frag[j_a][k][1].x[2 * i + 1]),
          __nv_bfloat162(k_frag[j_a][k][0].x[2 * i], k_frag[j_a][k][0].x[2 * i + 1]),
          __nv_bfloat162(k_frag[j_a][k][1].x[2 * i], k_frag[j_a][k][1].x[2 * i + 1]),
          &a_frag[j_a][k][0].x[2 * i], 
          &a_frag[j_a][k][1].x[2 * i],
          &a_frag[j_a][k][0].x[2 * i + 1],
          &a_frag[j_a][k][1].x[2 * i + 1]
        );
      }
    }
  }

  _complex_matmul_1024<ALayout, BLayout, out_trans, MATMUL_WARP_WIDTH, output_to_shmem>(a_real_out, a_imag_out, sqrt_N, N, a_frag, b_frag, acc_frag_1, acc_frag_half, out_layout);
}

template <typename ALayout, typename BLayout, bool a_trans, bool out_trans, int MATMUL_WARP_WIDTH, bool a_frag_from_acc, bool output_to_shmem>
__device__ __forceinline__ void complex_matmul_c2r_256(
  __nv_bfloat16 *a_real,
  __nv_bfloat16 *a_imag,
  __nv_bfloat16 *a_real_out,
  int sqrt_N,
  int N,
  wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_K, WMMA_N, __nv_bfloat16, BLayout> b_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, float> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_half[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::layout_t out_layout = wmma::mem_row_major)
{
  wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, __nv_bfloat16, ALayout> a_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2];
  load_a_frag_256<ALayout, a_trans, MATMUL_WARP_WIDTH, a_frag_from_acc>(a_real, a_imag, sqrt_N, N, acc_frag_half, a_frag);
  // __syncthreads();

  _complex_matmul_c2r_256<ALayout, BLayout, out_trans, MATMUL_WARP_WIDTH, output_to_shmem>(a_real_out, sqrt_N, N, a_frag, b_frag, acc_frag_1, acc_frag_half, out_layout);
}

template <typename ALayout, typename BLayout, bool a_trans, bool out_trans, int MATMUL_WARP_WIDTH, bool a_frag_from_acc, bool output_to_shmem>
__device__ __forceinline__ void complex_matmul_c2r_256(
  __nv_bfloat16 *a_real,
  __nv_bfloat16 *a_imag,
  __nv_bfloat16 *a_real_out,
  int sqrt_N,
  int N,
  wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_K, WMMA_N, __nv_bfloat16, BLayout> b_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, float> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_half[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, __nv_bfloat16, ALayout> k_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::layout_t out_layout = wmma::mem_row_major)
{
  wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, __nv_bfloat16, ALayout> a_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2];
  load_a_frag_256<ALayout, a_trans, MATMUL_WARP_WIDTH, a_frag_from_acc>(a_real, a_imag, sqrt_N, N, acc_frag_half, a_frag);
  // __syncthreads();

  // multiply a_frag by k_frag
  for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
    for (int k = 0; k < MATMUL_WARP_WIDTH; k++) {
      for (int i = 0; i < acc_frag_1[j_a][k][0].num_elements / 2; i++) {
        complex_mul_bfloat162(
          __nv_bfloat162(a_frag[j_a][k][0].x[2 * i], a_frag[j_a][k][0].x[2 * i + 1]),
          __nv_bfloat162(a_frag[j_a][k][1].x[2 * i], a_frag[j_a][k][1].x[2 * i + 1]),
          __nv_bfloat162(k_frag[j_a][k][0].x[2 * i], k_frag[j_a][k][0].x[2 * i + 1]),
          __nv_bfloat162(k_frag[j_a][k][1].x[2 * i], k_frag[j_a][k][1].x[2 * i + 1]),
          &a_frag[j_a][k][0].x[2 * i], 
          &a_frag[j_a][k][1].x[2 * i],
          &a_frag[j_a][k][0].x[2 * i + 1],
          &a_frag[j_a][k][1].x[2 * i + 1]
        );
      }
    }
  }

  _complex_matmul_c2r_256<ALayout, BLayout, out_trans, MATMUL_WARP_WIDTH, output_to_shmem>(a_real_out, sqrt_N, N, a_frag, b_frag, acc_frag_1, acc_frag_half, out_layout);
}

template <typename ALayout, typename BLayout, bool a_trans, bool out_trans, int MATMUL_WARP_WIDTH, bool a_frag_from_acc, bool output_to_shmem>
__device__ __forceinline__ void complex_matmul_c2r_1024(
  __nv_bfloat16 *a_real,
  __nv_bfloat16 *a_imag,
  __nv_bfloat16 *a_real_out,
  int sqrt_N,
  int N,
  wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_K, WMMA_N, __nv_bfloat16, BLayout> b_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, float> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_half[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, __nv_bfloat16, ALayout> k_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::layout_t out_layout = wmma::mem_row_major)
{
  wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, __nv_bfloat16, ALayout> a_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2];
  load_a_frag_1024<ALayout, a_trans, MATMUL_WARP_WIDTH, a_frag_from_acc>(a_real, a_imag, sqrt_N, N, acc_frag_half, a_frag);
  // __syncthreads();

  // multiply a_frag by k_frag
  for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
    for (int k = 0; k < MATMUL_WARP_WIDTH; k++) {
      for (int i = 0; i < acc_frag_1[j_a][k][0].num_elements / 2; i++) {
        complex_mul_bfloat162(
          __nv_bfloat162(a_frag[j_a][k][0].x[2 * i], a_frag[j_a][k][0].x[2 * i + 1]),
          __nv_bfloat162(a_frag[j_a][k][1].x[2 * i], a_frag[j_a][k][1].x[2 * i + 1]),
          __nv_bfloat162(k_frag[j_a][k][0].x[2 * i], k_frag[j_a][k][0].x[2 * i + 1]),
          __nv_bfloat162(k_frag[j_a][k][1].x[2 * i], k_frag[j_a][k][1].x[2 * i + 1]),
          &a_frag[j_a][k][0].x[2 * i], 
          &a_frag[j_a][k][1].x[2 * i],
          &a_frag[j_a][k][0].x[2 * i + 1],
          &a_frag[j_a][k][1].x[2 * i + 1]
        );
      }
    }
  }

  _complex_matmul_c2r_1024<ALayout, BLayout, out_trans, MATMUL_WARP_WIDTH, output_to_shmem>(a_real_out, sqrt_N, N, a_frag, b_frag, acc_frag_1, acc_frag_half, out_layout);
}

template <typename ALayout, typename BLayout, bool a_trans, bool out_trans, int MATMUL_WARP_WIDTH, bool a_frag_from_acc, bool output_to_shmem>
__device__ __forceinline__ void complex_matmul_c2r(
  __nv_bfloat16 *a_real,
  __nv_bfloat16 *a_imag,
  __nv_bfloat16 *a_real_out,
  int sqrt_N,
  int N,
  wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_K, WMMA_N, __nv_bfloat16, BLayout> b_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, float> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_half[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, __nv_bfloat16, ALayout> k_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
  wmma::layout_t out_layout = wmma::mem_row_major)
{
  wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, __nv_bfloat16, ALayout> a_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2];
  load_a_frag<ALayout, a_trans, MATMUL_WARP_WIDTH, a_frag_from_acc>(a_real, a_imag, sqrt_N, N, acc_frag_half, a_frag);
  // __syncthreads();

  //multiply a_frag by k_frag
  for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
    for (int k = 0; k < MATMUL_WARP_WIDTH; k++) {
      for (int i = 0; i < acc_frag_1[j_a][k][0].num_elements / 2; i++) {
        complex_mul_bfloat162(
          __nv_bfloat162(a_frag[j_a][k][0].x[2 * i], a_frag[j_a][k][0].x[2 * i + 1]),
          __nv_bfloat162(a_frag[j_a][k][1].x[2 * i], a_frag[j_a][k][1].x[2 * i + 1]),
          __nv_bfloat162(k_frag[j_a][k][0].x[2 * i], k_frag[j_a][k][0].x[2 * i + 1]),
          __nv_bfloat162(k_frag[j_a][k][1].x[2 * i], k_frag[j_a][k][1].x[2 * i + 1]),
          &a_frag[j_a][k][0].x[2 * i], 
          &a_frag[j_a][k][1].x[2 * i],
          &a_frag[j_a][k][0].x[2 * i + 1],
          &a_frag[j_a][k][1].x[2 * i + 1]
        );
      }
    }
  }

  _complex_matmul_c2r<ALayout, BLayout, out_trans, MATMUL_WARP_WIDTH, output_to_shmem>(a_real_out, sqrt_N, N, a_frag, b_frag, acc_frag_1, acc_frag_half, out_layout);
}

#endif