#include <torch/extension.h>

#include <vector>
#include <stdio.h>
#include <mma.h>
#include <cuda_fp16.h>
#include <cub/block/block_load.cuh>
#include <cub/block/block_store.cuh>
using namespace nvcuda;

using complex_half_t = typename c10::complex<at::Half>;

#define WMMA_M 16
#define WMMA_N 16
#define WMMA_K 16
// #define TILE_SIZE 4
// #define SHMEM_SIZE 256 * TILE_SIZE
// #define SEQUENCE_SIZE 256
#define WARP_SIZE 32

#ifndef MONARCH_CUDA_H_
#define MONARCH_CUDA_H_

template <typename ALayout, typename BLayout, bool out_trans, int MATMUL_WARP_WIDTH, bool output_to_shmem>
__device__ __forceinline__ void _complex_matmul(
    half *a_real,
    half *a_imag,
    int sqrt_N,
    int N,
    wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, half, ALayout> a_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_K, WMMA_N, half, BLayout> b_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::layout_t out_layout = wmma::mem_row_major)
{
   // #pragma unroll
   for (int j_b = 0; j_b < MATMUL_WARP_WIDTH; j_b++) {
      // #pragma unroll
      for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
         wmma::fill_fragment(acc_frag_1[j_a][j_b][0], __float2half(0.0f));

         // real
         // bd
         // #pragma unroll
         for (int k = 0; k < MATMUL_WARP_WIDTH; k++) {
            wmma::mma_sync(acc_frag_1[j_a][j_b][0], a_frag[j_a][k][1], b_frag[k][j_b][1], acc_frag_1[j_a][j_b][0]);
         }

         // bd -> -bd
         // #pragma unroll
         for (int i = 0; i < acc_frag_1[j_a][j_b][0].num_elements; i++) {
            acc_frag_1[j_a][j_b][0].x[i] = __hneg(acc_frag_1[j_a][j_b][0].x[i]);
         }

         // ac
         // #pragma unroll
         for (int k = 0; k < MATMUL_WARP_WIDTH; k++) {
            wmma::mma_sync(acc_frag_1[j_a][j_b][0], a_frag[j_a][k][0], b_frag[k][j_b][0], acc_frag_1[j_a][j_b][0]);
         }

         wmma::fill_fragment(acc_frag_1[j_a][j_b][1], __float2half(0.0f));

         // imag
         // ad
         // #pragma unroll
         for (int k = 0; k < MATMUL_WARP_WIDTH; k++) {
            wmma::mma_sync(acc_frag_1[j_a][j_b][1], a_frag[j_a][k][0], b_frag[k][j_b][1], acc_frag_1[j_a][j_b][1]);
         }

         // bc
         // #pragma unroll
         for (int k = 0; k < MATMUL_WARP_WIDTH; k++) {
            wmma::mma_sync(acc_frag_1[j_a][j_b][1], a_frag[j_a][k][1], b_frag[k][j_b][0], acc_frag_1[j_a][j_b][1]);
         }

      }
   }

   if (output_to_shmem) {
      // #pragma unroll
      for (int j_b = 0; j_b < MATMUL_WARP_WIDTH; j_b++) {
         // #pragma unroll
         for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
            // does it matter where we put this?
            wmma::store_matrix_sync(
               a_real + (out_trans ?
               j_b * WMMA_M * sqrt_N + j_a * WMMA_N:
               j_a * WMMA_M * sqrt_N + j_b * WMMA_N),
               acc_frag_1[j_a][j_b][0], sqrt_N, out_layout
            );
   
            wmma::store_matrix_sync(
               a_imag + (out_trans ?
               j_b * WMMA_M * sqrt_N + j_a * WMMA_N:
               j_a * WMMA_M * sqrt_N + j_b * WMMA_N),
               acc_frag_1[j_a][j_b][1], sqrt_N, out_layout
            );
         }
      }
   }
}

template <typename ALayout, typename BLayout, bool out_trans, int MATMUL_WARP_WIDTH, bool output_to_shmem>
__device__ __forceinline__ void _complex_matmul_r2c(
    half *a_real,
    half *a_imag,
    int sqrt_N,
    int N,
    wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, half, ALayout> a_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_K, WMMA_N, half, BLayout> b_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::layout_t out_layout = wmma::mem_row_major)
{
   // #pragma unroll
   for (int j_b = 0; j_b < MATMUL_WARP_WIDTH; j_b++) {
      // #pragma unroll
      for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
         wmma::fill_fragment(acc_frag_1[j_a][j_b][0], __float2half(0.0f));

         // real

         // ac
         for (int k = 0; k < MATMUL_WARP_WIDTH; k++) {
            wmma::mma_sync(acc_frag_1[j_a][j_b][0], a_frag[j_a][k][0], b_frag[k][j_b][0], acc_frag_1[j_a][j_b][0]);
         }

         wmma::fill_fragment(acc_frag_1[j_a][j_b][1], __float2half(0.0f));

         // imag
         // ad
         for (int k = 0; k < MATMUL_WARP_WIDTH; k++) {
            wmma::mma_sync(acc_frag_1[j_a][j_b][1], a_frag[j_a][k][0], b_frag[k][j_b][1], acc_frag_1[j_a][j_b][1]);
         }

      }
   }

   if (output_to_shmem) {
      // #pragma unroll
      for (int j_b = 0; j_b < MATMUL_WARP_WIDTH; j_b++) {
         // #pragma unroll
         for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
            // does it matter where we put this?
            wmma::store_matrix_sync(
               a_real + (out_trans ?
               j_b * WMMA_M * sqrt_N + j_a * WMMA_N:
               j_a * WMMA_M * sqrt_N + j_b * WMMA_N),
               acc_frag_1[j_a][j_b][0], sqrt_N, out_layout
            );
   
            wmma::store_matrix_sync(
               a_imag + (out_trans ?
               j_b * WMMA_M * sqrt_N + j_a * WMMA_N:
               j_a * WMMA_M * sqrt_N + j_b * WMMA_N),
               acc_frag_1[j_a][j_b][1], sqrt_N, out_layout
            );
         }
      }
   }
}

template <typename ALayout, typename BLayout, bool out_trans, int MATMUL_WARP_WIDTH, bool output_to_shmem>
__device__ __forceinline__ void _complex_matmul_r2c_load_b(
    half *b_real,
    half *b_imag,
    int sqrt_N,
    int N,
    wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, half, ALayout> a_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_K, WMMA_N, half, BLayout> b_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::layout_t out_layout = wmma::mem_row_major)
{
   // #pragma unroll
   for (int j_b = 0; j_b < MATMUL_WARP_WIDTH; j_b++) {
      // #pragma unroll
      for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
         wmma::fill_fragment(acc_frag_1[j_a][j_b][0], __float2half(0.0f));

         // real
         // ac
         // #pragma unroll
         for (int k = 0; k < MATMUL_WARP_WIDTH; k++) {
            wmma::mma_sync(acc_frag_1[j_a][j_b][0], a_frag[j_a][k][0], b_frag[k][j_b][0], acc_frag_1[j_a][j_b][0]);
         }

         wmma::fill_fragment(acc_frag_1[j_a][j_b][1], __float2half(0.0f));

         // imag
         // bc
         // #pragma unroll
         for (int k = 0; k < MATMUL_WARP_WIDTH; k++) {
            wmma::mma_sync(acc_frag_1[j_a][j_b][1], a_frag[j_a][k][1], b_frag[k][j_b][0], acc_frag_1[j_a][j_b][1]);
         }

      }
   }

   if (output_to_shmem) {
      // #pragma unroll
      for (int j_b = 0; j_b < MATMUL_WARP_WIDTH; j_b++) {
         // #pragma unroll
         for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
            // does it matter where we put this?
            wmma::store_matrix_sync(
               b_real + (out_trans ?
               j_b * WMMA_M * sqrt_N + j_a * WMMA_N:
               j_a * WMMA_M * sqrt_N + j_b * WMMA_N),
               acc_frag_1[j_a][j_b][0], sqrt_N, out_layout
            );
   
            wmma::store_matrix_sync(
               b_imag + (out_trans ?
               j_b * WMMA_M * sqrt_N + j_a * WMMA_N:
               j_a * WMMA_M * sqrt_N + j_b * WMMA_N),
               acc_frag_1[j_a][j_b][1], sqrt_N, out_layout
            );
         }
      }
   }
}

template <typename ALayout, typename BLayout, bool out_trans, int MATMUL_WARP_WIDTH, bool output_to_shmem>
__device__ __forceinline__ void _complex_matmul_r2c_256(
    half *a_real,
    half *a_imag,
    int sqrt_N,
    int N,
    wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, half, ALayout> a_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_K, WMMA_N, half, BLayout> b_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::layout_t out_layout = wmma::mem_row_major)
{
   // #pragma unroll
   for (int j_b = 0; j_b < MATMUL_WARP_WIDTH; j_b++) {
      // #pragma unroll
      for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
         wmma::fill_fragment(acc_frag_1[j_a][j_b][0], __float2half(0.0f));

         // real

         // ac
         for (int k = 0; k < MATMUL_WARP_WIDTH; k++) {
            wmma::mma_sync(acc_frag_1[j_a][j_b][0], a_frag[j_a][k][0], b_frag[k][j_b][0], acc_frag_1[j_a][j_b][0]);
         }

         wmma::fill_fragment(acc_frag_1[j_a][j_b][1], __float2half(0.0f));

         // imag
         // ad
         for (int k = 0; k < MATMUL_WARP_WIDTH; k++) {
            wmma::mma_sync(acc_frag_1[j_a][j_b][1], a_frag[j_a][k][0], b_frag[k][j_b][1], acc_frag_1[j_a][j_b][1]);
         }

      }
   }

   if (output_to_shmem) {
      // #pragma unroll
      for (int j_b = 0; j_b < MATMUL_WARP_WIDTH; j_b++) {
         // #pragma unroll
         for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
            // does it matter where we put this?
            wmma::store_matrix_sync(
               a_real + (out_trans ?
               j_b * WMMA_M * 256 + j_a * WMMA_N:
               j_a * WMMA_M * 256 + j_b * WMMA_N),
               acc_frag_1[j_a][j_b][0], 256, out_layout
            );
   
            wmma::store_matrix_sync(
               a_imag + (out_trans ?
               j_b * WMMA_M * 256 + j_a * WMMA_N:
               j_a * WMMA_M * 256 + j_b * WMMA_N),
               acc_frag_1[j_a][j_b][1], 256, out_layout
            );
         }
      }
   }
}

template <typename ALayout, typename BLayout, bool out_trans, int MATMUL_WARP_WIDTH, bool output_to_shmem>
__device__ __forceinline__ void _complex_matmul_r2c_1024(
    half *a_real,
    half *a_imag,
    int sqrt_N,
    int N,
    wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, half, ALayout> a_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_K, WMMA_N, half, BLayout> b_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::layout_t out_layout = wmma::mem_row_major)
{
   // #pragma unroll
   for (int j_b = 0; j_b < MATMUL_WARP_WIDTH; j_b++) {
      // #pragma unroll
      for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
         wmma::fill_fragment(acc_frag_1[j_a][j_b][0], __float2half(0.0f));

         // real

         // ac
         for (int k = 0; k < MATMUL_WARP_WIDTH; k++) {
            wmma::mma_sync(acc_frag_1[j_a][j_b][0], a_frag[j_a][k][0], b_frag[k][j_b][0], acc_frag_1[j_a][j_b][0]);
         }

         wmma::fill_fragment(acc_frag_1[j_a][j_b][1], __float2half(0.0f));

         // imag
         // ad
         for (int k = 0; k < MATMUL_WARP_WIDTH; k++) {
            wmma::mma_sync(acc_frag_1[j_a][j_b][1], a_frag[j_a][k][0], b_frag[k][j_b][1], acc_frag_1[j_a][j_b][1]);
         }

      }
   }

   if (output_to_shmem) {
      // #pragma unroll
      for (int j_b = 0; j_b < MATMUL_WARP_WIDTH; j_b++) {
         // #pragma unroll
         for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
            // does it matter where we put this?
            wmma::store_matrix_sync(
               a_real + (out_trans ?
               j_b * WMMA_M * 1024 + j_a * WMMA_N:
               j_a * WMMA_M * 1024 + j_b * WMMA_N),
               acc_frag_1[j_a][j_b][0], 1024, out_layout
            );
   
            wmma::store_matrix_sync(
               a_imag + (out_trans ?
               j_b * WMMA_M * 1024 + j_a * WMMA_N:
               j_a * WMMA_M * 1024 + j_b * WMMA_N),
               acc_frag_1[j_a][j_b][1], 1024, out_layout
            );
         }
      }
   }
}

template <typename ALayout, typename BLayout, bool out_trans, int MATMUL_WARP_WIDTH, bool output_to_shmem>
__device__ __forceinline__ void _complex_matmul_c2r(
    half *a_real_out,
    int sqrt_N,
    int N,
    wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, half, ALayout> a_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_K, WMMA_N, half, BLayout> b_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::layout_t out_layout = wmma::mem_row_major)
{
   #pragma unroll
   for (int j_b = 0; j_b < MATMUL_WARP_WIDTH; j_b++) {
      // #pragma unroll
      for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
         wmma::fill_fragment(acc_frag_1[j_a][j_b][0], __float2half(0.0f));

         // real
         // bd
         for (int k = 0; k < MATMUL_WARP_WIDTH; k++) {
            wmma::mma_sync(acc_frag_1[j_a][j_b][0], a_frag[j_a][k][1], b_frag[k][j_b][1], acc_frag_1[j_a][j_b][0]);
         }

         // bd -> -bd
         for (int i = 0; i < acc_frag_1[j_a][j_b][0].num_elements; i++) {
            acc_frag_1[j_a][j_b][0].x[i] = __hneg(acc_frag_1[j_a][j_b][0].x[i]);
         }

         // ac
         for (int k = 0; k < MATMUL_WARP_WIDTH; k++) {
            wmma::mma_sync(acc_frag_1[j_a][j_b][0], a_frag[j_a][k][0], b_frag[k][j_b][0], acc_frag_1[j_a][j_b][0]);
         }

      }
   }

   if (output_to_shmem) {
      // #pragma unroll
      for (int j_b = 0; j_b < MATMUL_WARP_WIDTH; j_b++) {
         // #pragma unroll
         for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
            // does it matter where we put this?
            wmma::store_matrix_sync(
               a_real_out + (out_trans ?
               j_b * WMMA_M * sqrt_N + j_a * WMMA_N:
               j_a * WMMA_M * sqrt_N + j_b * WMMA_N),
               acc_frag_1[j_a][j_b][0], sqrt_N, out_layout
            );
         }
      }
   }
}

template <typename ALayout, typename BLayout, bool out_trans, int MATMUL_WARP_WIDTH, bool output_to_shmem>
__device__ __forceinline__ void _complex_matmul_c2r_256(
    half *a_real_out,
    int sqrt_N,
    int N,
    wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, half, ALayout> a_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_K, WMMA_N, half, BLayout> b_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::layout_t out_layout = wmma::mem_row_major)
{
   #pragma unroll
   for (int j_b = 0; j_b < MATMUL_WARP_WIDTH; j_b++) {
      // #pragma unroll
      for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
         wmma::fill_fragment(acc_frag_1[j_a][j_b][0], __float2half(0.0f));

         // real
         // bd
         for (int k = 0; k < MATMUL_WARP_WIDTH; k++) {
            wmma::mma_sync(acc_frag_1[j_a][j_b][0], a_frag[j_a][k][1], b_frag[k][j_b][1], acc_frag_1[j_a][j_b][0]);
         }

         // bd -> -bd
         for (int i = 0; i < acc_frag_1[j_a][j_b][0].num_elements; i++) {
            acc_frag_1[j_a][j_b][0].x[i] = __hneg(acc_frag_1[j_a][j_b][0].x[i]);
         }

         // ac
         for (int k = 0; k < MATMUL_WARP_WIDTH; k++) {
            wmma::mma_sync(acc_frag_1[j_a][j_b][0], a_frag[j_a][k][0], b_frag[k][j_b][0], acc_frag_1[j_a][j_b][0]);
         }

      }
   }

   if (output_to_shmem) {
      // #pragma unroll
      for (int j_b = 0; j_b < MATMUL_WARP_WIDTH; j_b++) {
         // #pragma unroll
         for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
            // does it matter where we put this?
            wmma::store_matrix_sync(
               a_real_out + (out_trans ?
               j_b * WMMA_M * 256 + j_a * WMMA_N:
               j_a * WMMA_M * 256 + j_b * WMMA_N),
               acc_frag_1[j_a][j_b][0], 256, out_layout
            );
         }
      }
   }
}

template <typename ALayout, typename BLayout, bool out_trans, int MATMUL_WARP_WIDTH, bool output_to_shmem>
__device__ __forceinline__ void _complex_matmul_c2r_1024(
    half *a_real_out,
    int sqrt_N,
    int N,
    wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, half, ALayout> a_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_K, WMMA_N, half, BLayout> b_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::layout_t out_layout = wmma::mem_row_major)
{
   #pragma unroll
   for (int j_b = 0; j_b < MATMUL_WARP_WIDTH; j_b++) {
      // #pragma unroll
      for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
         wmma::fill_fragment(acc_frag_1[j_a][j_b][0], __float2half(0.0f));

         // real
         // bd
         for (int k = 0; k < MATMUL_WARP_WIDTH; k++) {
            wmma::mma_sync(acc_frag_1[j_a][j_b][0], a_frag[j_a][k][1], b_frag[k][j_b][1], acc_frag_1[j_a][j_b][0]);
         }

         // bd -> -bd
         for (int i = 0; i < acc_frag_1[j_a][j_b][0].num_elements; i++) {
            acc_frag_1[j_a][j_b][0].x[i] = __hneg(acc_frag_1[j_a][j_b][0].x[i]);
         }

         // ac
         for (int k = 0; k < MATMUL_WARP_WIDTH; k++) {
            wmma::mma_sync(acc_frag_1[j_a][j_b][0], a_frag[j_a][k][0], b_frag[k][j_b][0], acc_frag_1[j_a][j_b][0]);
         }

      }
   }

   if (output_to_shmem) {
      // #pragma unroll
      for (int j_b = 0; j_b < MATMUL_WARP_WIDTH; j_b++) {
         // #pragma unroll
         for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
            // does it matter where we put this?
            wmma::store_matrix_sync(
               a_real_out + (out_trans ?
               j_b * WMMA_M * 1024 + j_a * WMMA_N:
               j_a * WMMA_M * 1024 + j_b * WMMA_N),
               acc_frag_1[j_a][j_b][0], 1024, out_layout
            );
         }
      }
   }
}

template <typename ALayout, bool a_trans, int MATMUL_WARP_WIDTH, bool a_frag_from_acc>
__device__ __forceinline__ void load_a_frag(
    half *a_real,
    half *a_imag,
    int sqrt_N,
    int N,
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, half, ALayout> a_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2])
{
   int a_idx;

   if (a_frag_from_acc) {
      // load up a_frag's from acc_frag_1
      // #pragma unroll
      for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
         // #pragma unroll
         for (int j_b = 0; j_b < MATMUL_WARP_WIDTH; j_b++) {
            // #pragma unroll
            for (int k = 0; k < 2; k++) {
               for (int i = 0; i < acc_frag_1[j_a][j_b][0].num_elements; i++) {
                  a_frag[j_a][j_b][k].x[i] = acc_frag_1[j_a][j_b][k].x[i];
                  a_frag[j_a][j_b][k].x[i + acc_frag_1[j_a][j_b][k].num_elements] = acc_frag_1[j_a][j_b][k].x[i];
               }
            }
         }
      }
   } else {
      // #pragma unroll
      for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
         // #pragma unroll
         for (int k = 0; k < MATMUL_WARP_WIDTH; k++) {
            a_idx = a_trans ? k * WMMA_K * sqrt_N + j_a * WMMA_K : j_a * WMMA_K * sqrt_N + k * WMMA_K;
            wmma::load_matrix_sync(a_frag[j_a][k][0], a_real + a_idx, sqrt_N);
            wmma::load_matrix_sync(a_frag[j_a][k][1], a_imag + a_idx, sqrt_N);
         }
      }  
   }
}

template <typename ALayout, bool a_trans, int MATMUL_WARP_WIDTH, bool a_frag_from_acc>
__device__ __forceinline__ void load_a_frag_256(
    half *a_real,
    half *a_imag,
    int sqrt_N,
    int N,
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, half, ALayout> a_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2])
{
   int a_idx;

   if (a_frag_from_acc) {
      // load up a_frag's from acc_frag_1
      // #pragma unroll
      for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
         // #pragma unroll
         for (int j_b = 0; j_b < MATMUL_WARP_WIDTH; j_b++) {
            // #pragma unroll
            for (int k = 0; k < 2; k++) {
               for (int i = 0; i < acc_frag_1[j_a][j_b][0].num_elements; i++) {
                  a_frag[j_a][j_b][k].x[i] = acc_frag_1[j_a][j_b][k].x[i];
                  a_frag[j_a][j_b][k].x[i + acc_frag_1[j_a][j_b][k].num_elements] = acc_frag_1[j_a][j_b][k].x[i];
               }
            }
         }
      }
   } else {
      // #pragma unroll
      for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
         // #pragma unroll
         for (int k = 0; k < MATMUL_WARP_WIDTH; k++) {
            a_idx = a_trans ? k * WMMA_K * 256 + j_a * WMMA_K : j_a * WMMA_K * 256 + k * WMMA_K;
            wmma::load_matrix_sync(a_frag[j_a][k][0], a_real + a_idx, 256);
            wmma::load_matrix_sync(a_frag[j_a][k][1], a_imag + a_idx, 256);
         }
      }  
   }
}

template <typename ALayout, bool a_trans, int MATMUL_WARP_WIDTH, bool a_frag_from_acc>
__device__ __forceinline__ void load_a_frag_1024(
    half *a_real,
    half *a_imag,
    int sqrt_N,
    int N,
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, half, ALayout> a_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2])
{
   int a_idx;

   if (a_frag_from_acc) {
      // load up a_frag's from acc_frag_1
      // #pragma unroll
      for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
         // #pragma unroll
         for (int j_b = 0; j_b < MATMUL_WARP_WIDTH; j_b++) {
            // #pragma unroll
            for (int k = 0; k < 2; k++) {
               for (int i = 0; i < acc_frag_1[j_a][j_b][0].num_elements; i++) {
                  a_frag[j_a][j_b][k].x[i] = acc_frag_1[j_a][j_b][k].x[i];
                  a_frag[j_a][j_b][k].x[i + acc_frag_1[j_a][j_b][k].num_elements] = acc_frag_1[j_a][j_b][k].x[i];
               }
            }
         }
      }
   } else {
      // #pragma unroll
      for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
         // #pragma unroll
         for (int k = 0; k < MATMUL_WARP_WIDTH; k++) {
            a_idx = a_trans ? k * WMMA_K * 1024 + j_a * WMMA_K : j_a * WMMA_K * 1024 + k * WMMA_K;
            wmma::load_matrix_sync(a_frag[j_a][k][0], a_real + a_idx, 1024);
            wmma::load_matrix_sync(a_frag[j_a][k][1], a_imag + a_idx, 1024);
         }
      }  
   }
}

template <typename ALayout, bool a_trans, int MATMUL_WARP_WIDTH, bool a_frag_from_acc>
__device__ __forceinline__ void load_b_frag_r2c(
    const half *b_real,
    int sqrt_N,
    int N,
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_K, WMMA_N, half, ALayout> b_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2])
{
   int b_idx;
   // #pragma unroll
   for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
      // #pragma unroll
      for (int k = 0; k < MATMUL_WARP_WIDTH; k++) {
         b_idx = a_trans ? k * WMMA_K * sqrt_N + j_a * WMMA_K : j_a * WMMA_K * sqrt_N + k * WMMA_K;
         wmma::load_matrix_sync(b_frag[j_a][k][0], b_real + b_idx, sqrt_N);
      }
   }  
}

template <typename ALayout, bool a_trans, int MATMUL_WARP_WIDTH, bool a_frag_from_acc>
__device__ __forceinline__ void load_b_frag(
    half *b_real,
    half *b_imag,
    int sqrt_N,
    int N,
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_K, WMMA_N, half, ALayout> b_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2])
{
   int b_idx;
   // #pragma unroll
   for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
      // #pragma unroll
      for (int k = 0; k < MATMUL_WARP_WIDTH; k++) {
         b_idx = a_trans ? k * WMMA_K * sqrt_N + j_a * WMMA_K : j_a * WMMA_K * sqrt_N + k * WMMA_K;
         wmma::load_matrix_sync(b_frag[j_a][k][0], b_real + b_idx, sqrt_N);
         wmma::load_matrix_sync(b_frag[j_a][k][1], b_imag + b_idx, sqrt_N);
      }
   }  
}

template <typename ALayout, bool a_trans, int MATMUL_WARP_WIDTH, bool a_frag_from_acc>
__device__ __forceinline__ void load_a_frag_r2c(
    const half *a_real,
    int sqrt_N,
    int N,
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, half, ALayout> a_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2])
{
   int a_idx;

   if (a_frag_from_acc) {
      // load up a_frag's from acc_frag_1
      // #pragma unroll
      for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
         // #pragma unroll
         for (int j_b = 0; j_b < MATMUL_WARP_WIDTH; j_b++) {
            // #pragma unroll
            for (int k = 0; k < 1; k++) {
               // #pragma unroll
               for (int i = 0; i < acc_frag_1[j_a][j_b][0].num_elements; i++) {
                  a_frag[j_a][j_b][k].x[i] = acc_frag_1[j_a][j_b][k].x[i];
                  a_frag[j_a][j_b][k].x[i + acc_frag_1[j_a][j_b][k].num_elements] = acc_frag_1[j_a][j_b][k].x[i];
               }
            }
         }
      }
   } else {
      // #pragma unroll
      for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
         // #pragma unroll
         for (int k = 0; k < MATMUL_WARP_WIDTH; k++) {
            a_idx = a_trans ? k * WMMA_K * sqrt_N + j_a * WMMA_K : j_a * WMMA_K * sqrt_N + k * WMMA_K;
            wmma::load_matrix_sync(a_frag[j_a][k][0], a_real + a_idx, sqrt_N);
         }
      }  
   }
}

template <typename ALayout, bool a_trans, int MATMUL_WARP_WIDTH, bool a_frag_from_acc>
__device__ __forceinline__ void load_a_frag_r2c_256(
    const half *a_real,
    int sqrt_N,
    int N,
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, half, ALayout> a_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2])
{
   int a_idx;

   if (a_frag_from_acc) {
      // load up a_frag's from acc_frag_1
      // #pragma unroll
      for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
         // #pragma unroll
         for (int j_b = 0; j_b < MATMUL_WARP_WIDTH; j_b++) {
            // #pragma unroll
            for (int k = 0; k < 1; k++) {
               // #pragma unroll
               for (int i = 0; i < acc_frag_1[j_a][j_b][0].num_elements; i++) {
                  a_frag[j_a][j_b][k].x[i] = acc_frag_1[j_a][j_b][k].x[i];
                  a_frag[j_a][j_b][k].x[i + acc_frag_1[j_a][j_b][k].num_elements] = acc_frag_1[j_a][j_b][k].x[i];
               }
            }
         }
      }
   } else {
      // #pragma unroll
      for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
         // #pragma unroll
         for (int k = 0; k < MATMUL_WARP_WIDTH; k++) {
            a_idx = a_trans ? k * WMMA_K * 256 + j_a * WMMA_K : j_a * WMMA_K * 256 + k * WMMA_K;
            wmma::load_matrix_sync(a_frag[j_a][k][0], a_real + a_idx, 256);
         }
      }  
   }
}

template <typename ALayout, bool a_trans, int MATMUL_WARP_WIDTH, bool a_frag_from_acc>
__device__ __forceinline__ void load_a_frag_r2c_1024(
    const half *a_real,
    int sqrt_N,
    int N,
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, half, ALayout> a_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2])
{
   int a_idx;

   if (a_frag_from_acc) {
      // load up a_frag's from acc_frag_1
      // #pragma unroll
      for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
         // #pragma unroll
         for (int j_b = 0; j_b < MATMUL_WARP_WIDTH; j_b++) {
            // #pragma unroll
            for (int k = 0; k < 1; k++) {
               // #pragma unroll
               for (int i = 0; i < acc_frag_1[j_a][j_b][0].num_elements; i++) {
                  a_frag[j_a][j_b][k].x[i] = acc_frag_1[j_a][j_b][k].x[i];
                  a_frag[j_a][j_b][k].x[i + acc_frag_1[j_a][j_b][k].num_elements] = acc_frag_1[j_a][j_b][k].x[i];
               }
            }
         }
      }
   } else {
      // #pragma unroll
      for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
         // #pragma unroll
         for (int k = 0; k < MATMUL_WARP_WIDTH; k++) {
            a_idx = a_trans ? k * WMMA_K * 1024 + j_a * WMMA_K : j_a * WMMA_K * 1024 + k * WMMA_K;
            wmma::load_matrix_sync(a_frag[j_a][k][0], a_real + a_idx, 1024);
         }
      }  
   }
}

template <typename ALayout, typename BLayout, bool a_trans, bool out_trans, int MATMUL_WARP_WIDTH, bool a_frag_from_acc, bool output_to_shmem>
__device__ __forceinline__ void complex_matmul(
    half *a_real,
    half *a_imag,
    int sqrt_N,
    int N,
    wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_K, WMMA_N, half, BLayout> b_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, half, ALayout> k_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::layout_t out_layout = wmma::mem_row_major)
{

   wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, half, ALayout> a_frag [MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2];
   load_a_frag<ALayout, a_trans, MATMUL_WARP_WIDTH, a_frag_from_acc>(a_real, a_imag, sqrt_N, N, acc_frag_1, a_frag);

   // __syncthreads();
   // multiply a_frag by k_frag
   for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
      for (int k = 0; k < MATMUL_WARP_WIDTH; k++) {
         for (int i = 0; i < acc_frag_1[j_a][k][0].num_elements / 2; i++) {
            complex_mul_half2(
               __half2(a_frag[j_a][k][0].x[2 * i], a_frag[j_a][k][0].x[2 * i + 1]),
               __half2(a_frag[j_a][k][1].x[2 * i], a_frag[j_a][k][1].x[2 * i + 1]),
               __half2(k_frag[j_a][k][0].x[2 * i], k_frag[j_a][k][0].x[2 * i + 1]),
               __half2(k_frag[j_a][k][1].x[2 * i], k_frag[j_a][k][1].x[2 * i + 1]),
               &a_frag[j_a][k][0].x[2 * i], 
               &a_frag[j_a][k][1].x[2 * i],
               &a_frag[j_a][k][0].x[2 * i + 1],
               &a_frag[j_a][k][1].x[2 * i + 1]
            );
         }
      }
   }

   _complex_matmul<ALayout, BLayout, out_trans, MATMUL_WARP_WIDTH, output_to_shmem>(a_real, a_imag, sqrt_N, N, a_frag, b_frag, acc_frag_1, out_layout);
}

template <typename ALayout, typename BLayout, bool a_trans, bool out_trans, int MATMUL_WARP_WIDTH, bool a_frag_from_acc, bool output_to_shmem>
__device__ __forceinline__ void complex_matmul(
    half *a_real,
    half *a_imag,
    int sqrt_N,
    int N,
    wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_K, WMMA_N, half, BLayout> b_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::layout_t out_layout = wmma::mem_row_major)
{
   wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, half, ALayout> a_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2];
   load_a_frag<ALayout, a_trans, MATMUL_WARP_WIDTH, a_frag_from_acc>(a_real, a_imag, sqrt_N, N, acc_frag_1, a_frag);

   // __syncthreads();
   _complex_matmul<ALayout, BLayout, out_trans, MATMUL_WARP_WIDTH, output_to_shmem>(a_real, a_imag, sqrt_N, N, a_frag, b_frag, acc_frag_1, out_layout);
}

template <typename ALayout, typename BLayout, bool b_trans, bool out_trans, int MATMUL_WARP_WIDTH, bool b_frag_from_acc, bool output_to_shmem>
__device__ __forceinline__ void complex_matmul_load_b(
    half *b_real,
    half *b_imag,
    int sqrt_N,
    int N,
    wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, half, ALayout> a_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::layout_t out_layout = wmma::mem_row_major)
{
   wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_K, WMMA_N, half, BLayout> b_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2];
   load_b_frag<BLayout, b_trans, MATMUL_WARP_WIDTH, b_frag_from_acc>(b_real, b_imag, sqrt_N, N, acc_frag_1, b_frag);

   // __syncthreads();
   _complex_matmul<ALayout, BLayout, out_trans, MATMUL_WARP_WIDTH, output_to_shmem>(b_real, b_imag, sqrt_N, N, a_frag, b_frag, acc_frag_1, out_layout);
}

template <typename ALayout, typename BLayout, bool b_trans, bool out_trans, int MATMUL_WARP_WIDTH, bool b_frag_from_acc, bool output_to_shmem>
__device__ __forceinline__ void complex_matmul_load_b(
    half *b_real,
    half *b_imag,
    int sqrt_N,
    int N,
    wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, half, ALayout> a_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_K, WMMA_N, half, BLayout> k_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::layout_t out_layout = wmma::mem_row_major)
{
   wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_K, WMMA_N, half, BLayout> b_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2];
   load_b_frag<BLayout, b_trans, MATMUL_WARP_WIDTH, b_frag_from_acc>(b_real, b_imag, sqrt_N, N, acc_frag_1, b_frag);

   // __syncthreads();
   // multiply b_frag by k_frag
   for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
      for (int k = 0; k < MATMUL_WARP_WIDTH; k++) {
         for (int i = 0; i < acc_frag_1[j_a][k][0].num_elements / 2; i++) {
            complex_mul_half2(
               __half2(b_frag[j_a][k][0].x[2 * i], b_frag[j_a][k][0].x[2 * i + 1]),
               __half2(b_frag[j_a][k][1].x[2 * i], b_frag[j_a][k][1].x[2 * i + 1]),
               __half2(k_frag[j_a][k][0].x[2 * i], k_frag[j_a][k][0].x[2 * i + 1]),
               __half2(k_frag[j_a][k][1].x[2 * i], k_frag[j_a][k][1].x[2 * i + 1]),
               &b_frag[j_a][k][0].x[2 * i], 
               &b_frag[j_a][k][1].x[2 * i],
               &b_frag[j_a][k][0].x[2 * i + 1],
               &b_frag[j_a][k][1].x[2 * i + 1]
            );
         }
      }
   }

   // __syncthreads();
   _complex_matmul<ALayout, BLayout, out_trans, MATMUL_WARP_WIDTH, output_to_shmem>(b_real, b_imag, sqrt_N, N, a_frag, b_frag, acc_frag_1, out_layout);
}

template <typename ALayout, typename BLayout, bool a_trans, bool out_trans, int MATMUL_WARP_WIDTH, bool a_frag_from_acc, bool output_to_shmem>
__device__ __forceinline__ void complex_matmul_r2c(
    const half *a_real_input,
    half *a_real,
    half *a_imag,
    int sqrt_N,
    int N,
    wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_K, WMMA_N, half, BLayout> b_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::layout_t out_layout = wmma::mem_row_major)
{
   wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, half, ALayout> a_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2];
   load_a_frag_r2c<ALayout, a_trans, MATMUL_WARP_WIDTH, a_frag_from_acc>(a_real_input, sqrt_N, N, acc_frag_1, a_frag);

   _complex_matmul_r2c<ALayout, BLayout, out_trans, MATMUL_WARP_WIDTH, output_to_shmem>(a_real, a_imag, sqrt_N, N, a_frag, b_frag, acc_frag_1, out_layout);
}

template <typename ALayout, typename BLayout, bool b_trans, bool out_trans, int MATMUL_WARP_WIDTH, bool b_frag_from_acc, bool output_to_shmem>
__device__ __forceinline__ void complex_matmul_r2c_load_b(
    const half *b_real_input,
    half *b_real,
    half *b_imag,
    int sqrt_N,
    int N,
    wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, half, ALayout> a_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::layout_t out_layout = wmma::mem_row_major)
{
   wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_K, WMMA_N, half, BLayout> b_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2];
   load_b_frag_r2c<BLayout, b_trans, MATMUL_WARP_WIDTH, b_frag_from_acc>(b_real_input, sqrt_N, N, acc_frag_1, b_frag);

   _complex_matmul_r2c_load_b<ALayout, BLayout, out_trans, MATMUL_WARP_WIDTH, output_to_shmem>(b_real, b_imag, sqrt_N, N, a_frag, b_frag, acc_frag_1, out_layout);
}

template <typename ALayout, typename BLayout, bool a_trans, bool out_trans, int MATMUL_WARP_WIDTH, bool a_frag_from_acc, bool output_to_shmem>
__device__ __forceinline__ void complex_matmul_r2c_256(
    const half *a_real_input,
    half *a_real,
    half *a_imag,
    int sqrt_N,
    int N,
    wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_K, WMMA_N, half, BLayout> b_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::layout_t out_layout = wmma::mem_row_major)
{
   wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, half, ALayout> a_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2];
   load_a_frag_r2c_256<ALayout, a_trans, MATMUL_WARP_WIDTH, a_frag_from_acc>(a_real_input, sqrt_N, N, acc_frag_1, a_frag);

   // __syncthreads();

   _complex_matmul_r2c_256<ALayout, BLayout, out_trans, MATMUL_WARP_WIDTH, output_to_shmem>(a_real, a_imag, sqrt_N, N, a_frag, b_frag, acc_frag_1, out_layout);
}

template <typename ALayout, typename BLayout, bool a_trans, bool out_trans, int MATMUL_WARP_WIDTH, bool a_frag_from_acc, bool output_to_shmem>
__device__ __forceinline__ void complex_matmul_r2c_1024(
    const half *a_real_input,
    half *a_real,
    half *a_imag,
    int sqrt_N,
    int N,
    wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_K, WMMA_N, half, BLayout> b_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::layout_t out_layout = wmma::mem_row_major)
{
   wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, half, ALayout> a_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2];
   load_a_frag_r2c_1024<ALayout, a_trans, MATMUL_WARP_WIDTH, a_frag_from_acc>(a_real_input, sqrt_N, N, acc_frag_1, a_frag);

   // __syncthreads();

   _complex_matmul_r2c_1024<ALayout, BLayout, out_trans, MATMUL_WARP_WIDTH, output_to_shmem>(a_real, a_imag, sqrt_N, N, a_frag, b_frag, acc_frag_1, out_layout);
}

template <typename ALayout, typename BLayout, bool a_trans, bool out_trans, int MATMUL_WARP_WIDTH, bool a_frag_from_acc, bool output_to_shmem>
__device__ __forceinline__ void complex_matmul_c2r(
    half *a_real,
    half *a_imag,
    half *a_real_out,
    int sqrt_N,
    int N,
    wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_K, WMMA_N, half, BLayout> b_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::layout_t out_layout = wmma::mem_row_major)
{
   wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, half, ALayout> a_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2];
   load_a_frag<ALayout, a_trans, MATMUL_WARP_WIDTH, a_frag_from_acc>(a_real, a_imag, sqrt_N, N, acc_frag_1, a_frag);

   _complex_matmul_c2r<ALayout, BLayout, out_trans, MATMUL_WARP_WIDTH, output_to_shmem>(a_real_out, sqrt_N, N, a_frag, b_frag, acc_frag_1, out_layout);
}

template <typename ALayout, typename BLayout, bool a_trans, bool out_trans, int MATMUL_WARP_WIDTH, bool a_frag_from_acc, bool output_to_shmem>
__device__ __forceinline__ void complex_matmul_c2r_256(
    half *a_real,
    half *a_imag,
    half *a_real_out,
    int sqrt_N,
    int N,
    wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_K, WMMA_N, half, BLayout> b_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::layout_t out_layout = wmma::mem_row_major)
{
   wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, half, ALayout> a_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2];
   load_a_frag_256<ALayout, a_trans, MATMUL_WARP_WIDTH, a_frag_from_acc>(a_real, a_imag, sqrt_N, N, acc_frag_1, a_frag);
   // __syncthreads();

   _complex_matmul_c2r_256<ALayout, BLayout, out_trans, MATMUL_WARP_WIDTH, output_to_shmem>(a_real_out, sqrt_N, N, a_frag, b_frag, acc_frag_1, out_layout);
}

template <typename ALayout, typename BLayout, bool a_trans, bool out_trans, int MATMUL_WARP_WIDTH, bool a_frag_from_acc, bool output_to_shmem>
__device__ __forceinline__ void complex_matmul_c2r_256(
    half *a_real,
    half *a_imag,
    half *a_real_out,
    int sqrt_N,
    int N,
    wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_K, WMMA_N, half, BLayout> b_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, half, ALayout> k_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::layout_t out_layout = wmma::mem_row_major)
{
   wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, half, ALayout> a_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2];
   load_a_frag_256<ALayout, a_trans, MATMUL_WARP_WIDTH, a_frag_from_acc>(a_real, a_imag, sqrt_N, N, acc_frag_1, a_frag);
   // __syncthreads();

   // multiply a_frag by k_frag
   for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
      for (int k = 0; k < MATMUL_WARP_WIDTH; k++) {
         for (int i = 0; i < acc_frag_1[j_a][k][0].num_elements / 2; i++) {
            complex_mul_half2(
               __half2(a_frag[j_a][k][0].x[2 * i], a_frag[j_a][k][0].x[2 * i + 1]),
               __half2(a_frag[j_a][k][1].x[2 * i], a_frag[j_a][k][1].x[2 * i + 1]),
               __half2(k_frag[j_a][k][0].x[2 * i], k_frag[j_a][k][0].x[2 * i + 1]),
               __half2(k_frag[j_a][k][1].x[2 * i], k_frag[j_a][k][1].x[2 * i + 1]),
               &a_frag[j_a][k][0].x[2 * i], 
               &a_frag[j_a][k][1].x[2 * i],
               &a_frag[j_a][k][0].x[2 * i + 1],
               &a_frag[j_a][k][1].x[2 * i + 1]
            );
         }
      }
   }

   _complex_matmul_c2r_256<ALayout, BLayout, out_trans, MATMUL_WARP_WIDTH, output_to_shmem>(a_real_out, sqrt_N, N, a_frag, b_frag, acc_frag_1, out_layout);
}

template <typename ALayout, typename BLayout, bool a_trans, bool out_trans, int MATMUL_WARP_WIDTH, bool a_frag_from_acc, bool output_to_shmem>
__device__ __forceinline__ void complex_matmul_c2r_1024(
    half *a_real,
    half *a_imag,
    half *a_real_out,
    int sqrt_N,
    int N,
    wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_K, WMMA_N, half, BLayout> b_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, half, ALayout> k_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::layout_t out_layout = wmma::mem_row_major)
{
   wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, half, ALayout> a_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2];
   load_a_frag_1024<ALayout, a_trans, MATMUL_WARP_WIDTH, a_frag_from_acc>(a_real, a_imag, sqrt_N, N, acc_frag_1, a_frag);
   // __syncthreads();

   // multiply a_frag by k_frag
   for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
      for (int k = 0; k < MATMUL_WARP_WIDTH; k++) {
         for (int i = 0; i < acc_frag_1[j_a][k][0].num_elements / 2; i++) {
            complex_mul_half2(
               __half2(a_frag[j_a][k][0].x[2 * i], a_frag[j_a][k][0].x[2 * i + 1]),
               __half2(a_frag[j_a][k][1].x[2 * i], a_frag[j_a][k][1].x[2 * i + 1]),
               __half2(k_frag[j_a][k][0].x[2 * i], k_frag[j_a][k][0].x[2 * i + 1]),
               __half2(k_frag[j_a][k][1].x[2 * i], k_frag[j_a][k][1].x[2 * i + 1]),
               &a_frag[j_a][k][0].x[2 * i], 
               &a_frag[j_a][k][1].x[2 * i],
               &a_frag[j_a][k][0].x[2 * i + 1],
               &a_frag[j_a][k][1].x[2 * i + 1]
            );
         }
      }
   }

   _complex_matmul_c2r_1024<ALayout, BLayout, out_trans, MATMUL_WARP_WIDTH, output_to_shmem>(a_real_out, sqrt_N, N, a_frag, b_frag, acc_frag_1, out_layout);
}

template <typename ALayout, typename BLayout, bool a_trans, bool out_trans, int MATMUL_WARP_WIDTH, bool a_frag_from_acc, bool output_to_shmem>
__device__ __forceinline__ void complex_matmul_c2r(
    half *a_real,
    half *a_imag,
    half *a_real_out,
    int sqrt_N,
    int N,
    wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_K, WMMA_N, half, BLayout> b_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_K, WMMA_N, half> acc_frag_1[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, half, ALayout> k_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2],
    wmma::layout_t out_layout = wmma::mem_row_major)
{
   wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_K, WMMA_N, half, ALayout> a_frag[MATMUL_WARP_WIDTH][MATMUL_WARP_WIDTH][2];
   load_a_frag<ALayout, a_trans, MATMUL_WARP_WIDTH, a_frag_from_acc>(a_real, a_imag, sqrt_N, N, acc_frag_1, a_frag);
   // __syncthreads();

   // multiply a_frag by k_frag
   for (int j_a = 0; j_a < MATMUL_WARP_WIDTH; j_a++) {
      for (int k = 0; k < MATMUL_WARP_WIDTH; k++) {
         for (int i = 0; i < acc_frag_1[j_a][k][0].num_elements / 2; i++) {
            complex_mul_half2(
               __half2(a_frag[j_a][k][0].x[2 * i], a_frag[j_a][k][0].x[2 * i + 1]),
               __half2(a_frag[j_a][k][1].x[2 * i], a_frag[j_a][k][1].x[2 * i + 1]),
               __half2(k_frag[j_a][k][0].x[2 * i], k_frag[j_a][k][0].x[2 * i + 1]),
               __half2(k_frag[j_a][k][1].x[2 * i], k_frag[j_a][k][1].x[2 * i + 1]),
               &a_frag[j_a][k][0].x[2 * i], 
               &a_frag[j_a][k][1].x[2 * i],
               &a_frag[j_a][k][0].x[2 * i + 1],
               &a_frag[j_a][k][1].x[2 * i + 1]
            );
         }
      }
   }

   _complex_matmul_c2r<ALayout, BLayout, out_trans, MATMUL_WARP_WIDTH, output_to_shmem>(a_real_out, sqrt_N, N, a_frag, b_frag, acc_frag_1, out_layout);
}

__device__ __forceinline__ void complex_mul(at::Half a_real, at::Half a_imag, at::Half b_real, at::Half b_imag, at::Half *c_real, at::Half *c_imag) {
   __half temp_x, temp_y;
   // temp_x = __hsub(__hmul(a_real, b_real), __hmul(a_imag, b_imag));
   // temp_y = __hadd(__hmul(a_imag, b_real), __hmul(a_real, b_imag));
   temp_x = __half(a_real * b_real - a_imag * b_imag);
   temp_y = __hfma(__half(a_imag), __half(b_real), __half(a_real * b_imag));
   *c_real = temp_x;
   *c_imag = temp_y;
}

__device__ __forceinline__ void complex_mul_float_half(float a_real, float a_imag, at::Half b_real, at::Half b_imag, at::Half *c_real, at::Half *c_imag) {
   __half temp_x, temp_y;
   // temp_x = __hsub(__hmul(a_real, b_real), __hmul(a_imag, b_imag));
   // temp_y = __hadd(__hmul(a_imag, b_real), __hmul(a_real, b_imag));
   temp_x = __half(at::Half(a_real) * b_real - at::Half(a_imag) * b_imag);
   temp_y = __hfma(__half(at::Half(a_imag)), __half(b_real), __half(at::Half(a_real) * b_imag));
   *c_real = temp_x;
   *c_imag = temp_y;
}

__device__ __forceinline__ void complex_mul_half2(__half2 a_real, __half2 a_imag, __half2 b_real, __half2 b_imag, __half2 *c_real, __half2 *c_imag) {
   __half2 temp_x, temp_y;

   temp_x = __hsub2(__hmul2(a_real, b_real), __hmul2(a_imag, b_imag));
   temp_y = __hfma2(a_imag, b_real, __hmul2(a_real, b_imag));
   *c_real = temp_x;
   *c_imag = temp_y;
}

__device__ __forceinline__ void complex_mul_half2(__half2 a_real, __half2 a_imag, __half2 b_real, __half2 b_imag, __half *c_real_0, __half *c_imag_0, __half *c_real_1, __half *c_imag_1) {
   __half2 temp_x, temp_y;

   temp_x = __hsub2(__hmul2(a_real, b_real), __hmul2(a_imag, b_imag));
   temp_y = __hfma2(a_imag, b_real, __hmul2(a_real, b_imag));
   *c_real_0 = temp_x.x;
   *c_imag_0 = temp_y.x;
   *c_real_1 = temp_x.y;
   *c_imag_1 = temp_y.y;
}

// negates b_imag
__device__ __forceinline__ void complex_mul_conj_half2(__half2 a_real, __half2 a_imag, __half2 b_real, __half2 b_imag, c10::complex<__half> *c_0, c10::complex<__half> *c_1) {
   __half2 temp_x, temp_y;

   temp_x = __hfma2(a_real, b_real, __hmul2(a_imag, b_imag));
   // temp_x = __hsub2(__hmul2(a_real, b_real), __hmul2(a_imag, b_imag));
   temp_y = __hsub2(__hmul2(a_imag, b_real), __hmul2(a_real, b_imag));
   // temp_y = __hfma2(a_imag, b_real, __hmul2(a_real, b_imag));
   *c_0 = c10::complex<__half>(temp_x.x, temp_y.x);
   *c_1 = c10::complex<__half>(temp_x.y, temp_y.y);
}

// negates b_imag
__device__ __forceinline__ void complex_mul_conj_half2(__half2 a_real, __half2 a_imag, c10::complex<__half> b_0, c10::complex<__half> b_1, c10::complex<__half> *c_0, c10::complex<__half> *c_1) {
   __half2 b_real_h2, b_imag_h2;

   b_real_h2 = __half2(b_0.real(), b_1.real());
   b_imag_h2 = __half2(b_0.imag(), b_1.imag());
   complex_mul_conj_half2(a_real, a_imag, b_real_h2, b_imag_h2, c_0, c_1);
}

__device__ __forceinline__ void complex_mul_conj_half2(__half2 a_real, __half2 a_imag, __half2 b_real, __half2 b_imag, __half2 *c_real, __half2 *c_imag) {
   __half2 temp_x, temp_y;

   temp_x = __hfma2(a_real, b_real, __hmul2(a_imag, b_imag));
   // temp_x = __hsub2(__hmul2(a_real, b_real), __hmul2(a_imag, b_imag));
   temp_y = __hsub2(__hmul2(a_imag, b_real), __hmul2(a_real, b_imag));
   // temp_y = __hfma2(a_imag, b_real, __hmul2(a_real, b_imag));
   *c_real = temp_x;
   *c_imag = temp_y;
}

#endif