import math

import torch
import torch.nn.functional as F

from einops import rearrange

from monarch_cpp import monarch_conv_forward, monarch_conv_backward, \
    monarch_conv_forward_16_16_16, monarch_conv_backward_16_16_16, \
    monarch_conv_forward_32_16_16, monarch_conv_backward_32_16_16, \
    monarch_conv_forward_16_32_32, monarch_conv_backward_16_32_32, \
    monarch_conv_forward_32_32_32, monarch_conv_backward_32_32_32, \
    monarch_conv_forward_32_32_32_complex, monarch_conv_backward_32_32_32_complex
from butterfly_cuda import butterfly_forward, butterfly_ifft_forward

def fft_matrix(N):
    n = torch.arange(N)
    k = n.view(-1, 1)
    M = torch.exp(-2j * torch.pi * n * k / N)
    return M

def compute_twiddle_factors_fft(n, m):
    """Compute the twiddle factors of size n x m"""
    # n_a = torch.arange(n).view(-1, 1)
    # m_a = torch.arange(m)
    n_a = torch.arange(n).view(-1, 1)
    m_a = torch.arange(m)
    N = n * m
    M = torch.exp(-2j * torch.pi * n_a * m_a / N)
    return M

def ifft_matrix(N):
    n = torch.arange(N)
    k = n.view(-1, 1)
    M = torch.exp(2j * torch.pi * n * k / N)
    return M

def compute_twiddle_factors_ifft(n, m):
    """Compute the twiddle factors of size n x m"""
    # n_a = torch.arange(n).view(-1, 1)
    # m_a = torch.arange(m)
    n_a = torch.arange(n).view(-1, 1)
    m_a = torch.arange(m)
    N = n * m
    M = torch.exp(2j * torch.pi * n_a * m_a / N)
    return M

def monarch_outer_dft(x, f_sqrt_N_fft, twiddle_factors_fft, sqrt_N):
    x = x.transpose(-1, -2) # 32K, 32
    x = x @ f_sqrt_N_fft    # 32K, 32
    x = x.transpose(-1, -2) # 32, 32K
    # x = (f_sqrt_N_fft.T @ x) * twiddle_factors_fft # (32, 32K) * (32, 32K), pointwise

    return (x * twiddle_factors_fft).contiguous()

def monarch_outer_idft(x, f_sqrt_N_ifft, twiddle_factors_ifft, sqrt_N):
    # x = f_sqrt_N_ifft.T @ (x * twiddle_factors_ifft) # (32, 32K) * (32, 32K), pointwise
    x = x * twiddle_factors_ifft 
    x = x.transpose(-1, -2) # 32K, 32
    x = x @ f_sqrt_N_ifft
    x = x.transpose(-1, -2) # 32, 32K

    return x.contiguous()

class MatrixFFTConv(torch.nn.Module):
    def __init__(self, h):
        super().__init__()
        self.h = h

    def forward(self, x, k, N):
        L = x.shape[-1]
        x_dtype = x.dtype
        x_f = torch.fft.fft(x.float(), n = N)
        k_f = torch.fft.fft(k, n = N)
        x_f = rearrange(x_f, 'b (hh h1 h2) l -> b l hh h1 h2', h1=self.h, h2=self.h)
        k_f = rearrange(k_f, '(hh h1 h2) l -> l hh h1 h2', h1=self.h, h2=self.h)
        y_f = torch.matmul(x_f, k_f)
        y_f = rearrange(y_f, 'b l hh h1 h2 -> b (hh h1 h2) l', h1=self.h, h2=self.h)
        y = torch.fft.ifft(y_f, n = N)[..., :L].real.to(x_dtype)

        return y
    
class PartialFFTConv(torch.nn.Module):
    def __init__(self, N_partial):
        super().__init__()
        self.N_partial = N_partial

    def forward(self, x, k):
        L = x.shape[-1]
        N = 2 * L
        x_dtype = x.dtype
        x_f = torch.fft.rfft(x.float(), n = N)
        k_f = torch.fft.rfft(k[..., :self.N_partial], n = N)
        y_f = x_f * k_f
        y = torch.fft.irfft(y_f, n = N)[..., :L].to(x_dtype)

        return y
    
class LowPassFFTConv(torch.nn.Module):
    def __init__(self, N_partial):
        super().__init__()
        self.N_partial = N_partial

    def forward(self, x, k):
        L = x.shape[-1]
        N = 2 * L
        x_dtype = x.dtype
        x_f = torch.fft.rfft(x.float(), n = N)
        k_f = torch.fft.rfft(k, n = N)
        k_f[..., self.N_partial // 2:] = 0
        y_f = x_f * k_f
        y = torch.fft.irfft(y_f, n = N)[..., :L].to(x_dtype)

        return y

class FlashFFTConv(torch.nn.Module):
    def __init__(self, seqlen, dtype=torch.float16):
        super().__init__()
        assert dtype == torch.bfloat16 or dtype == torch.float16
        self.seqlen = seqlen
        self.dtype = dtype
        if seqlen in [256, 1024]:
            N = seqlen
            sqrt_N = int(math.sqrt(seqlen))
            self.N = N
            self.sqrt_N = sqrt_N
            f_sqrt_N_fft = torch.view_as_real(fft_matrix(sqrt_N)).to(dtype)
            f_sqrt_N_ifft = torch.view_as_real(ifft_matrix(sqrt_N)).to(dtype)

            twiddle_factors_fft = torch.view_as_real(compute_twiddle_factors_fft(sqrt_N, sqrt_N) / N).to(dtype)
            twiddle_factors_ifft = torch.view_as_real(compute_twiddle_factors_ifft(sqrt_N, sqrt_N)).to(dtype)

            self.register_buffer('f_sqrt_N_fft', f_sqrt_N_fft)
            self.register_buffer('f_sqrt_N_ifft', f_sqrt_N_ifft)
            self.register_buffer('twiddle_factors_fft', twiddle_factors_fft)
            self.register_buffer('twiddle_factors_ifft', twiddle_factors_ifft)
        elif seqlen == 4096:
            N = seqlen
            sqrt_N = 16
            sqrt_N_256 = 256
            self.N = N
            self.sqrt_N = sqrt_N
            self.sqrt_N_256 = sqrt_N_256
            f_sqrt_N_fft = torch.view_as_real(fft_matrix(sqrt_N)).to(dtype)
            f_sqrt_N_ifft = torch.view_as_real(ifft_matrix(sqrt_N)).to(dtype)

            twiddle_factors_fft_16_16 = torch.view_as_real(compute_twiddle_factors_fft(sqrt_N, sqrt_N)).to(dtype)
            twiddle_factors_ifft_16_16 = torch.view_as_real(compute_twiddle_factors_ifft(sqrt_N, sqrt_N)).to(dtype)
            twiddle_factors_fft_16_256 = torch.view_as_real(compute_twiddle_factors_fft(sqrt_N, sqrt_N_256) / N).to(dtype)
            twiddle_factors_ifft_16_256 = torch.view_as_real(compute_twiddle_factors_ifft(sqrt_N, sqrt_N_256)).to(dtype)

            self.register_buffer('f_sqrt_N_fft', f_sqrt_N_fft)
            self.register_buffer('f_sqrt_N_ifft', f_sqrt_N_ifft)
            self.register_buffer('twiddle_factors_fft_16_16', twiddle_factors_fft_16_16)
            self.register_buffer('twiddle_factors_ifft_16_16', twiddle_factors_ifft_16_16)
            self.register_buffer('twiddle_factors_fft_16_256', twiddle_factors_fft_16_256)
            self.register_buffer('twiddle_factors_ifft_16_256', twiddle_factors_ifft_16_256)
        elif seqlen == 8192:
            N = seqlen
            N1 = 32
            N2 = 16
            self.N = N
            self.N1 = N1
            self.N2 = N2
            f_32_fft = torch.view_as_real(fft_matrix(32)).to(dtype)
            f_32_ifft = torch.view_as_real(ifft_matrix(32)).to(dtype)
            f_16_fft = torch.view_as_real(fft_matrix(16)).to(dtype)
            f_16_ifft = torch.view_as_real(ifft_matrix(16)).to(dtype)

            twiddle_factors_fft_16_16 = torch.view_as_real(compute_twiddle_factors_fft(16, 16)).to(dtype)
            twiddle_factors_ifft_16_16 = torch.view_as_real(compute_twiddle_factors_ifft(16, 16)).to(dtype)
            twiddle_factors_fft_32_256 = torch.view_as_real(compute_twiddle_factors_fft(32, 256) / N).to(dtype)
            twiddle_factors_ifft_32_256 = torch.view_as_real(compute_twiddle_factors_ifft(32, 256)).to(dtype)

            self.register_buffer('f_32_fft', f_32_fft)
            self.register_buffer('f_32_ifft', f_32_ifft)
            self.register_buffer('f_16_fft', f_16_fft)
            self.register_buffer('f_16_ifft', f_16_ifft)
            self.register_buffer('twiddle_factors_fft_16_16', twiddle_factors_fft_16_16)
            self.register_buffer('twiddle_factors_ifft_16_16', twiddle_factors_ifft_16_16)
            self.register_buffer('twiddle_factors_fft_32_256', twiddle_factors_fft_32_256)
            self.register_buffer('twiddle_factors_ifft_32_256', twiddle_factors_ifft_32_256)
        elif seqlen == 16384:
            N = seqlen
            N1 = 16
            N2 = 32
            self.N = N
            self.N1 = N1
            self.N2 = N2
            f_16_fft = torch.view_as_real(fft_matrix(16)).to(dtype)
            f_16_ifft = torch.view_as_real(ifft_matrix(16)).to(dtype)
            f_32_fft = torch.view_as_real(fft_matrix(32)).to(dtype)
            f_32_ifft = torch.view_as_real(ifft_matrix(32)).to(dtype)

            twiddle_factors_fft_32_32 = torch.view_as_real(compute_twiddle_factors_fft(32, 32)).to(dtype)
            twiddle_factors_ifft_32_32 = torch.view_as_real(compute_twiddle_factors_ifft(32, 32)).to(dtype)
            twiddle_factors_fft_16_1K = torch.view_as_real(compute_twiddle_factors_fft(16, 1024) / N).to(dtype)
            twiddle_factors_ifft_16_1K = torch.view_as_real(compute_twiddle_factors_ifft(16, 1024)).to(dtype)

            self.register_buffer('f_16_fft', f_16_fft)
            self.register_buffer('f_16_ifft', f_16_ifft)
            self.register_buffer('f_32_fft', f_32_fft)
            self.register_buffer('f_32_ifft', f_32_ifft)
            self.register_buffer('twiddle_factors_fft_32_32', twiddle_factors_fft_32_32)
            self.register_buffer('twiddle_factors_ifft_32_32', twiddle_factors_ifft_32_32)
            self.register_buffer('twiddle_factors_fft_16_1K', twiddle_factors_fft_16_1K)
            self.register_buffer('twiddle_factors_ifft_16_1K', twiddle_factors_ifft_16_1K)
        elif seqlen == 32768:
            N = seqlen
            N1 = 32
            N2 = 32
            self.N = N
            self.N1 = N1
            self.N2 = N2
            f_32_fft = torch.view_as_real(fft_matrix(32)).to(dtype)
            f_32_ifft = torch.view_as_real(ifft_matrix(32)).to(dtype)

            twiddle_factors_fft_32_32 = torch.view_as_real(compute_twiddle_factors_fft(32, 32)).to(dtype)
            twiddle_factors_ifft_32_32 = torch.view_as_real(compute_twiddle_factors_ifft(32, 32)).to(dtype)
            twiddle_factors_fft_32_1K = torch.view_as_real(compute_twiddle_factors_fft(32, 1024) / N).to(dtype)
            twiddle_factors_ifft_32_1K = torch.view_as_real(compute_twiddle_factors_ifft(32, 1024)).to(dtype)

            self.register_buffer('f_32_fft', f_32_fft)
            self.register_buffer('f_32_ifft', f_32_ifft)
            self.register_buffer('twiddle_factors_fft_32_32', twiddle_factors_fft_32_32)
            self.register_buffer('twiddle_factors_ifft_32_32', twiddle_factors_ifft_32_32)
            self.register_buffer('twiddle_factors_fft_32_1K', twiddle_factors_fft_32_1K)
            self.register_buffer('twiddle_factors_ifft_32_1K', twiddle_factors_ifft_32_1K)
        elif seqlen == 32 * 32768: #1M
            N = seqlen
            self.N = N
            f_32_fft = torch.view_as_real(fft_matrix(32)).to(dtype)
            f_32_ifft = torch.view_as_real(ifft_matrix(32)).to(dtype)

            twiddle_factors_fft_32_32 = torch.view_as_real(compute_twiddle_factors_fft(32, 32)).to(dtype)
            twiddle_factors_ifft_32_32 = torch.view_as_real(compute_twiddle_factors_ifft(32, 32)).to(dtype)
            twiddle_factors_fft_32_1K = torch.view_as_real(compute_twiddle_factors_fft(32, 1024) / 32768).to(dtype)
            twiddle_factors_ifft_32_1K = torch.view_as_real(compute_twiddle_factors_ifft(32, 1024)).to(dtype)

            twiddle_factors_fft = compute_twiddle_factors_fft(32, 32768) / 32
            twiddle_factors_ifft = compute_twiddle_factors_ifft(32, 32768)

            self.register_buffer('f_32_fft', f_32_fft)
            self.register_buffer('f_32_ifft', f_32_ifft)
            self.register_buffer('twiddle_factors_fft_32_32', twiddle_factors_fft_32_32)
            self.register_buffer('twiddle_factors_ifft_32_32', twiddle_factors_ifft_32_32)
            self.register_buffer('twiddle_factors_fft_32_1K', twiddle_factors_fft_32_1K)
            self.register_buffer('twiddle_factors_ifft_32_1K', twiddle_factors_ifft_32_1K)
            self.register_buffer('twiddle_factors_fft_real', twiddle_factors_fft.real.to(dtype).contiguous())
            self.register_buffer('twiddle_factors_ifft_real', twiddle_factors_ifft.real.to(dtype).contiguous())
            self.register_buffer('twiddle_factors_fft_imag', twiddle_factors_fft.imag.to(dtype).contiguous())
            self.register_buffer('twiddle_factors_ifft_imag', twiddle_factors_ifft.imag.to(dtype).contiguous())
        elif seqlen == 64 * 32768: #2M
            N = seqlen
            self.N = N
            f_32_fft = torch.view_as_real(fft_matrix(32)).to(dtype)
            f_32_ifft = torch.view_as_real(ifft_matrix(32)).to(dtype)
            f_64_fft = torch.view_as_real(fft_matrix(64)).to(dtype)
            f_64_ifft = torch.view_as_real(ifft_matrix(64)).to(dtype)

            twiddle_factors_fft_32_32 = torch.view_as_real(compute_twiddle_factors_fft(32, 32)).to(dtype)
            twiddle_factors_ifft_32_32 = torch.view_as_real(compute_twiddle_factors_ifft(32, 32)).to(dtype)
            twiddle_factors_fft_32_1K = torch.view_as_real(compute_twiddle_factors_fft(32, 1024) / 32768).to(dtype)
            twiddle_factors_ifft_32_1K = torch.view_as_real(compute_twiddle_factors_ifft(32, 1024)).to(dtype)

            twiddle_factors_fft = compute_twiddle_factors_fft(64, 32768) / 64
            twiddle_factors_ifft = compute_twiddle_factors_ifft(64, 32768)

            self.register_buffer('f_32_fft', f_32_fft)
            self.register_buffer('f_32_ifft', f_32_ifft)
            self.register_buffer('f_64_fft', f_64_fft)
            self.register_buffer('f_64_ifft', f_64_ifft)
            self.register_buffer('twiddle_factors_fft_32_32', twiddle_factors_fft_32_32)
            self.register_buffer('twiddle_factors_ifft_32_32', twiddle_factors_ifft_32_32)
            self.register_buffer('twiddle_factors_fft_32_1K', twiddle_factors_fft_32_1K)
            self.register_buffer('twiddle_factors_ifft_32_1K', twiddle_factors_ifft_32_1K)
            self.register_buffer('twiddle_factors_fft_real', twiddle_factors_fft.real.to(dtype).contiguous())
            self.register_buffer('twiddle_factors_ifft_real', twiddle_factors_ifft.real.to(dtype).contiguous())
            self.register_buffer('twiddle_factors_fft_imag', twiddle_factors_fft.imag.to(dtype).contiguous())
            self.register_buffer('twiddle_factors_ifft_imag', twiddle_factors_ifft.imag.to(dtype).contiguous())
        else:
            raise NotImplementedError(f'seqlen {seqlen} not supported')
    
    def forward(self, u, k):
        # orig_dtype = u.dtype
        # if (u.dtype != self.dtype):
        #     u = u.to(self.dtype).contiguous()
        return FlashFFTConvFunc.apply(u, k, self)


class FlashFFTConvFunc(torch.autograd.Function):

    @staticmethod
    def forward(ctx, u, k, fftconv_data):
        # assert(u.dtype == fftconv_data.dtype)

        B, H, L = u.shape

        # replace this with a kernel
        k_f = torch.fft.fft(k, n=fftconv_data.seqlen)

        ctx.fftconv_data = fftconv_data

        if fftconv_data.seqlen in [256, 1024]:
            N = fftconv_data.N
            sqrt_N = fftconv_data.sqrt_N

            # assert(L == N)
            k_f_permuted = torch.view_as_real(k_f.reshape(H, sqrt_N, sqrt_N).transpose(-1, -2).reshape(H, N)).to(fftconv_data.dtype).contiguous()

            ctx.save_for_backward(u, k_f_permuted)

            return monarch_conv_forward(
                u, k_f_permuted,
                fftconv_data.f_sqrt_N_fft, fftconv_data.twiddle_factors_fft,
                fftconv_data.f_sqrt_N_ifft, fftconv_data.twiddle_factors_ifft,
                N, L, sqrt_N
            )
        elif fftconv_data.seqlen == 4096:
            N = fftconv_data.N
            sqrt_N = fftconv_data.sqrt_N
            sqrt_N_256 = fftconv_data.sqrt_N_256

            # assert(L == N)
            k_f_permuted = torch.view_as_real(k_f.reshape(H, sqrt_N_256, sqrt_N).transpose(-1, -2).reshape(H, sqrt_N, sqrt_N, sqrt_N).transpose(-1, -2).reshape(H, N)).to(fftconv_data.dtype).contiguous()

            ctx.save_for_backward(u, k_f_permuted)

            out = monarch_conv_forward_16_16_16(
                u, k_f_permuted,
                fftconv_data.f_sqrt_N_fft,
                fftconv_data.twiddle_factors_fft_16_256, fftconv_data.twiddle_factors_fft_16_16,
                fftconv_data.f_sqrt_N_ifft,
                fftconv_data.twiddle_factors_ifft_16_256, fftconv_data.twiddle_factors_ifft_16_16,
                N, L, sqrt_N_256, sqrt_N
            )

            return out
        elif fftconv_data.seqlen == 8192:
            N = fftconv_data.N

            # assert(L == N)
            k_f_permuted = torch.view_as_real(k_f.reshape(H, 256, 32).transpose(-1, -2).reshape(H, 32, 16, 16).transpose(-1, -2).reshape(H, N)).to(fftconv_data.dtype).contiguous()

            ctx.save_for_backward(u, k_f_permuted)

            return monarch_conv_forward_32_16_16(
                u, k_f_permuted,
                fftconv_data.f_32_fft, fftconv_data.f_16_fft,
                fftconv_data.twiddle_factors_fft_32_256, fftconv_data.twiddle_factors_fft_16_16,
                fftconv_data.f_32_ifft, fftconv_data.f_16_ifft,
                fftconv_data.twiddle_factors_ifft_32_256, fftconv_data.twiddle_factors_ifft_16_16,
                N, L
            )
        elif fftconv_data.seqlen == 16384:
            N = fftconv_data.N

            # assert(L == N)
            k_f_permuted = torch.view_as_real(k_f.reshape(H, 1024, 16).transpose(-1, -2).reshape(H, 16, 32, 32).transpose(-1, -2).reshape(H, N)).to(fftconv_data.dtype).contiguous()

            ctx.save_for_backward(u, k_f_permuted)

            return monarch_conv_forward_16_32_32(
                u, k_f_permuted,
                fftconv_data.f_16_fft, fftconv_data.f_32_fft,
                fftconv_data.twiddle_factors_fft_16_1K, fftconv_data.twiddle_factors_fft_32_32,
                fftconv_data.f_16_ifft, fftconv_data.f_32_ifft,
                fftconv_data.twiddle_factors_ifft_16_1K, fftconv_data.twiddle_factors_ifft_32_32,
                N, L
            )
        elif fftconv_data.seqlen == 32768:
            N = fftconv_data.N

            # assert(L == N)
            k_f_permuted = torch.view_as_real(k_f.reshape(H, 1024, 32).transpose(-1, -2).reshape(H, 32, 32, 32).transpose(-1, -2).reshape(H, N)).to(fftconv_data.dtype).contiguous()

            ctx.save_for_backward(u, k_f_permuted)
            return monarch_conv_forward_32_32_32(
                u, k_f_permuted,
                fftconv_data.f_32_fft,
                fftconv_data.twiddle_factors_fft_32_1K,
                fftconv_data.twiddle_factors_fft_32_32,
                fftconv_data.f_32_ifft,
                fftconv_data.twiddle_factors_ifft_32_1K,
                fftconv_data.twiddle_factors_ifft_32_32,
                N, L
            )
        elif fftconv_data.seqlen == 32 * 32768:
            N = fftconv_data.N

            # assert(N == L)
            if L < N:
                pad_shape = (0, N - L)
                u = F.pad(u, pad_shape, 'constant', 0)

            k_f_permuted = k_f.reshape(H, 32768, 32).transpose(-1, -2).reshape(H, N)
            k_f_double_permuted = torch.view_as_real(k_f_permuted.reshape(H, 32, 1024, 32).transpose(-1, -2).reshape(H, 32, 32, 32, 32).transpose(-1, -2).reshape(H * 32, 32768)).contiguous().to(fftconv_data.dtype)
            
            if fftconv_data.training:
                ctx.save_for_backward(u, k_f_double_permuted)

            x = u.reshape(B, H, 32, 32768)
            x_half_real, x_half_imag = butterfly_forward(
                x,
                fftconv_data.f_32_fft,
                fftconv_data.twiddle_factors_fft_real,
                fftconv_data.twiddle_factors_fft_imag
            )

            x_half_real = x_half_real.reshape(B, H * 32, 32768)
            x_half_imag = x_half_imag.reshape(B, H * 32, 32768)

            out_half_real, out_half_imag = monarch_conv_forward_32_32_32_complex(
                x_half_real, x_half_imag, k_f_double_permuted,
                fftconv_data.f_32_fft,
                fftconv_data.twiddle_factors_fft_32_1K,
                fftconv_data.twiddle_factors_fft_32_32,
                fftconv_data.f_32_ifft,
                fftconv_data.twiddle_factors_ifft_32_1K,
                fftconv_data.twiddle_factors_ifft_32_32,
                32768, 32768
            )

            out_half_real = out_half_real.reshape(B, H, 32, 32768)
            out_half_imag = out_half_imag.reshape(B, H, 32, 32768)

            out_half = butterfly_ifft_forward(
                out_half_real, out_half_imag,
                fftconv_data.f_32_ifft,
                fftconv_data.twiddle_factors_ifft_real,
                fftconv_data.twiddle_factors_ifft_imag
            )

            x = out_half.reshape(B, H, N)

            return x[..., :L]
        
        elif fftconv_data.seqlen == 64 * 32768:
            N = fftconv_data.N

            # assert(N == L)
            if L < N:
                pad_shape = (0, N - L)
                u = F.pad(u, pad_shape, 'constant', 0)

            k_f_permuted = k_f.reshape(H, 32768, 64).transpose(-1, -2).reshape(H, N)
            k_f_double_permuted = torch.view_as_real(k_f_permuted.reshape(H, 64, 1024, 32).transpose(-1, -2).reshape(H, 64, 32, 32, 32).transpose(-1, -2).reshape(H * 64, 32768)).contiguous().to(fftconv_data.dtype)

            ctx.save_for_backward(u, k_f_double_permuted)

            x = u.reshape(B, H, 64, 32768)
            x_half_real, x_half_imag = butterfly_forward(
                x,
                fftconv_data.f_64_fft,
                fftconv_data.twiddle_factors_fft_real,
                fftconv_data.twiddle_factors_fft_imag
            )

            x_half_real = x_half_real.reshape(B, H * 64, 32768)
            x_half_imag = x_half_imag.reshape(B, H * 64, 32768)

            out_half_real, out_half_imag = monarch_conv_forward_32_32_32_complex(
                x_half_real, x_half_imag, k_f_double_permuted,
                fftconv_data.f_32_fft,
                fftconv_data.twiddle_factors_fft_32_1K,
                fftconv_data.twiddle_factors_fft_32_32,
                fftconv_data.f_32_ifft,
                fftconv_data.twiddle_factors_ifft_32_1K,
                fftconv_data.twiddle_factors_ifft_32_32,
                32768, 32768
            )

            out_half_real = out_half_real.reshape(B, H, 64, 32768)
            out_half_imag = out_half_imag.reshape(B, H, 64, 32768)

            out_half = butterfly_ifft_forward(
                out_half_real, out_half_imag,
                fftconv_data.f_64_ifft,
                fftconv_data.twiddle_factors_ifft_real,
                fftconv_data.twiddle_factors_ifft_imag
            )

            x = out_half.reshape(B, H, N)

            return x[..., :L]

        else:
            raise NotImplementedError(f'seqlen {fftconv_data.seqlen} not supported for FlashFFTConv fwd')

    @staticmethod
    def backward(ctx, dout):
        fftconv_data = ctx.fftconv_data
        # assert(dout.dtype == fftconv_data.dtype)

        B, H, L = dout.shape
        dout = dout.contiguous()

        u, k_f_permuted = ctx.saved_tensors

        if fftconv_data.seqlen in [256, 1024]:
            N = fftconv_data.N
            sqrt_N = fftconv_data.sqrt_N

            du, dk_f_permuted = monarch_conv_backward(
                dout, u, k_f_permuted,
                fftconv_data.f_sqrt_N_fft, fftconv_data.twiddle_factors_fft,
                fftconv_data.f_sqrt_N_ifft, fftconv_data.twiddle_factors_ifft,
                N, L, sqrt_N
            )
            dk_f = torch.fft.ifft(
                torch.view_as_complex(dk_f_permuted.to(torch.float32)).reshape(H, sqrt_N, sqrt_N).transpose(-1, -2).reshape(H, N),
                norm='forward', n=N
            ).real[..., :L]

            return du, dk_f, None
        elif fftconv_data.seqlen == 4096:
            N = fftconv_data.N
            sqrt_N = fftconv_data.sqrt_N
            sqrt_N_256 = fftconv_data.sqrt_N_256

            du, dk_f_permuted = monarch_conv_backward_16_16_16(
                dout, u, k_f_permuted,
                fftconv_data.f_sqrt_N_fft,
                fftconv_data.twiddle_factors_fft_16_256, fftconv_data.twiddle_factors_fft_16_16,
                fftconv_data.f_sqrt_N_ifft,
                fftconv_data.twiddle_factors_ifft_16_256, fftconv_data.twiddle_factors_ifft_16_16,
                N, L, sqrt_N_256, sqrt_N
            )

            dk_f = torch.fft.ifft(
                torch.view_as_complex(dk_f_permuted.to(torch.float32)).reshape(H, sqrt_N, sqrt_N, sqrt_N).transpose(-1, -2).reshape(H, sqrt_N, sqrt_N_256).transpose(-1, -2).reshape(H, N),
                norm='forward', n=N
            ).real[..., :L]

            return du, dk_f, None
        elif fftconv_data.seqlen == 8192:
            N = fftconv_data.N

            # assert(L == N)

            du, dk_f_permuted = monarch_conv_backward_32_16_16(
                dout, u, k_f_permuted,
                fftconv_data.f_32_fft, fftconv_data.f_16_fft,
                fftconv_data.twiddle_factors_fft_32_256, fftconv_data.twiddle_factors_fft_16_16,
                fftconv_data.f_32_ifft, fftconv_data.f_16_ifft,
                fftconv_data.twiddle_factors_ifft_32_256, fftconv_data.twiddle_factors_ifft_16_16,
                N, L
            )

            dk_f = torch.fft.ifft(
                torch.view_as_complex(dk_f_permuted.to(torch.float32)).reshape(H, 32, 16, 16).transpose(-1, -2).reshape(H, 32, 256).transpose(-1, -2).reshape(H, N),
                norm='forward', n=N
            ).real[..., :L]

            return du, dk_f, None
        elif fftconv_data.seqlen == 16384:
            N = fftconv_data.N

            # assert(L == N)

            du, dk_f_permuted = monarch_conv_backward_16_32_32(
                dout, u, k_f_permuted,
                fftconv_data.f_16_fft, fftconv_data.f_32_fft,
                fftconv_data.twiddle_factors_fft_16_1K, fftconv_data.twiddle_factors_fft_32_32,
                fftconv_data.f_16_ifft, fftconv_data.f_32_ifft,
                fftconv_data.twiddle_factors_ifft_16_1K, fftconv_data.twiddle_factors_ifft_32_32,
                N, L
            )

            dk_f = torch.fft.ifft(
                torch.view_as_complex(dk_f_permuted.to(torch.float32)).reshape(H, 16, 32, 32).transpose(-1, -2).reshape(H, 16, 1024).transpose(-1, -2).reshape(H, N),
                norm='forward', n=N
            ).real[..., :L]

            return du, dk_f, None
        elif fftconv_data.seqlen == 32768:
            N = fftconv_data.N

            # assert(L == N)

            du, dk_f_permuted = monarch_conv_backward_32_32_32(
                dout, u, k_f_permuted,
                fftconv_data.f_32_fft,
                fftconv_data.twiddle_factors_fft_32_1K,
                fftconv_data.twiddle_factors_fft_32_32,
                fftconv_data.f_32_ifft,
                fftconv_data.twiddle_factors_ifft_32_1K,
                fftconv_data.twiddle_factors_ifft_32_32,
                N, L
            )

            dk_f = torch.fft.ifft(
                torch.view_as_complex(dk_f_permuted.to(torch.float32)).reshape(H, 32, 32, 32).transpose(-1, -2).reshape(H, 32, 1024).transpose(-1, -2).reshape(H, N),
                norm='forward', n=N
            ).real[..., :L]

            return du, dk_f, None
        elif fftconv_data.seqlen == 32 * 32768:
            N = fftconv_data.N

            assert(N == L)
            # breakpoint()

            x = u.reshape(B, H, 32, 32768)
            dout = dout.reshape(B, H, 32, 32768)

            x_half_real, x_half_imag = butterfly_forward(
                x,
                fftconv_data.f_32_fft,
                fftconv_data.twiddle_factors_fft_real,
                fftconv_data.twiddle_factors_fft_imag
            )
            x_half_real = x_half_real.reshape(B, H * 32, 32768)
            x_half_imag = x_half_imag.reshape(B, H * 32, 32768)

            dout_half_real, dout_half_imag = butterfly_forward(
                dout,
                fftconv_data.f_32_fft,
                fftconv_data.twiddle_factors_fft_real,
                fftconv_data.twiddle_factors_fft_imag
            )
            dout_half_real = dout_half_real.reshape(B, H * 32, 32768)
            dout_half_imag = dout_half_imag.reshape(B, H * 32, 32768)

            dx_half_real, dx_half_imag, dk_f_permuted = monarch_conv_backward_32_32_32_complex(
                dout_half_real, dout_half_imag,
                x_half_real, x_half_imag, k_f_permuted,
                fftconv_data.f_32_fft,
                fftconv_data.twiddle_factors_fft_32_1K,
                fftconv_data.twiddle_factors_fft_32_32,
                fftconv_data.f_32_ifft,
                fftconv_data.twiddle_factors_ifft_32_1K,
                fftconv_data.twiddle_factors_ifft_32_32,
                32768, 32768
            )

            dx_half_real = dx_half_real.reshape(B, H, 32, 32768)
            dx_half_imag = dx_half_imag.reshape(B, H, 32, 32768)

            dx_half = butterfly_ifft_forward(
                dx_half_real, dx_half_imag,
                fftconv_data.f_32_ifft,
                fftconv_data.twiddle_factors_ifft_real,
                fftconv_data.twiddle_factors_ifft_imag
            )

            dx = dx_half.reshape(B, H, N)

            dk_f = torch.fft.ifft(
                torch.view_as_complex(dk_f_permuted.to(torch.float32)).reshape(H, 32, 32, 32, 32).transpose(-1, -2).reshape(H, 32, 32, 1024).transpose(-1, -2).reshape(H, 32, 32768).transpose(-1, -2).reshape(H, N) * 32,
                norm='forward', n=N
            ).real[..., :L]

            return dx, dk_f, None
        elif fftconv_data.seqlen == 64 * 32768:
            N = fftconv_data.N

            assert(N == L)

            x = u.reshape(B, H, 64, 32768)
            dout = dout.reshape(B, H, 64, 32768)

            x_half_real, x_half_imag = butterfly_forward(
                x,
                fftconv_data.f_64_fft,
                fftconv_data.twiddle_factors_fft_real,
                fftconv_data.twiddle_factors_fft_imag
            )
            x_half_real = x_half_real.reshape(B, H * 64, 32768)
            x_half_imag = x_half_imag.reshape(B, H * 64, 32768)

            dout_half_real, dout_half_imag = butterfly_forward(
                dout,
                fftconv_data.f_64_fft,
                fftconv_data.twiddle_factors_fft_real,
                fftconv_data.twiddle_factors_fft_imag
            )
            dout_half_real = dout_half_real.reshape(B, H * 64, 32768)
            dout_half_imag = dout_half_imag.reshape(B, H * 64, 32768)

            dx_half_real, dx_half_imag, dk_f_permuted = monarch_conv_backward_32_32_32_complex(
                dout_half_real, dout_half_imag,
                x_half_real, x_half_imag, k_f_permuted,
                fftconv_data.f_32_fft,
                fftconv_data.twiddle_factors_fft_32_1K,
                fftconv_data.twiddle_factors_fft_32_32,
                fftconv_data.f_32_ifft,
                fftconv_data.twiddle_factors_ifft_32_1K,
                fftconv_data.twiddle_factors_ifft_32_32,
                32768, 32768
            )

            dx_half_real = dx_half_real.reshape(B, H, 64, 32768)
            dx_half_imag = dx_half_imag.reshape(B, H, 64, 32768)

            dx_half = butterfly_ifft_forward(
                dx_half_real, dx_half_imag,
                fftconv_data.f_64_ifft,
                fftconv_data.twiddle_factors_ifft_real,
                fftconv_data.twiddle_factors_ifft_imag
            )

            dx = dx_half.reshape(B, H, N)

            dk_f = torch.fft.ifft(
                torch.view_as_complex(dk_f_permuted.to(torch.float32)).reshape(H, 64, 32, 32, 32).transpose(-1, -2).reshape(H, 64, 32, 1024).transpose(-1, -2).reshape(H, 64, 32768).transpose(-1, -2).reshape(H, N) * 64,
                norm='forward', n=N
            ).real[..., :L]

            return dx, dk_f, None
        else:
            raise NotImplementedError(f'seqlen {fftconv_data.seqlen} not supported for FlashFFTConv bwd')
