import torch
from einops import rearrange
import pytest


class FFTConvFuncv2(torch.autograd.Function):
    @staticmethod
    def forward(ctx, u, k):
        seqlen = u.shape[-1]
        if len(u.shape) > 3:
            k = k.unsqueeze(1)
        fft_size = 2 * seqlen
        k_f = torch.fft.rfft(k, n=fft_size) / fft_size
        u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)
        y = torch.fft.irfft(u_f * k_f, n=fft_size, norm="forward")[..., :seqlen]
        ctx.save_for_backward(u_f, k_f)
        return y

    @staticmethod
    def backward(ctx, dout):
        u_f, k_f = ctx.saved_tensors
        seqlen = dout.shape[-1]
        fft_size = 2 * seqlen

        dout_f = torch.fft.rfft(dout, n=fft_size)
        du = torch.fft.irfft(dout_f * k_f.conj(), n=fft_size, norm="forward")[
            ..., :seqlen
        ]
        dk = torch.fft.irfft(dout_f * u_f.conj(), n=fft_size, norm="forward")[
            ..., :seqlen
        ]
        return du, dk.squeeze()


def fftconv_ref(u, k, k_rev=None):
    seqlen = u.shape[-1]
    fft_size = 2 * seqlen
    k_f = torch.fft.rfft(k, n=fft_size) / fft_size
    if k_rev is not None:
        k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size
        k_f = k_f + k_rev_f.conj()
    u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)

    if len(u.shape) > 3:
        k_f = k_f.unsqueeze(1)

    y = torch.fft.irfft(u_f * k_f, n=fft_size, norm="forward")[..., :seqlen]

    return y


def test_fftconv_bwd():
    bsz = 1
    hdim = 1
    L = 1024

    u = torch.randn(bsz, hdim, L, requires_grad=True)
    k = torch.randn(1, hdim, L)

    y = fftconv_ref(u, k)
    du_ad = torch.autograd.grad(y[0, 0, -1].mean(), u)[0]
    analytical_err_ad = torch.norm(du_ad[0, 0] - k[0, 0].flip(0), p=2)

    y = FFTConvFuncv2.apply(u, k)
    du = torch.autograd.grad(y[:, 0, -1].sum(), u)[0]
    analytical_err_custom = torch.norm(du[0, 0] - k[0, 0].flip(0), p=2)
    error_ad_custom = torch.norm(du[0, 0] - du_ad[0, 0], p=2)

    assert analytical_err_ad <= 1e-5
    assert analytical_err_custom <= 1e-5
    assert error_ad_custom <= 1e-5
