import torch
import pytest

from torchtitan.modules.fused_ffn import fused_swiglu_ffn_forward


def reference_ffn(x, w1, w3, w2):
    return (torch.nn.functional.silu(x @ w1) * (x @ w3)) @ w2


@pytest.mark.cuda
@torch.no_grad()
def test_fused_ffn_matches_reference():
    torch.manual_seed(0)
    B, S, K = 2, 16, 128
    H = 256

    x = torch.randn(B, S, K, device="cuda", dtype=torch.bfloat16)
    w1 = torch.randn(K, H, device="cuda", dtype=torch.bfloat16)
    w3 = torch.randn(K, H, device="cuda", dtype=torch.bfloat16)
    w2 = torch.randn(H, K, device="cuda", dtype=torch.bfloat16)

    y_ref = reference_ffn(x.reshape(-1, K), w1, w3, w2).reshape(B, S, K)

    y = fused_swiglu_ffn_forward(x, w1, w3, w2)

    max_abs = (y - y_ref).abs().max().item()
    assert max_abs < 2e-2, f"max_abs={max_abs} too large"