import torch
import pytest
from eden.quant import transform_had128


def reference_transform(x, h, transpose=False):
    if transpose:
        x = x.T
    r = x.reshape(-1, 128) @ h.T
    return r.reshape(x.shape)


@pytest.mark.parametrize("shape", [(128, 128), (256, 128), (128, 1024), (1024, 640)])
@pytest.mark.parametrize("transpose", [False, True])
def test_transform(shape, transpose):
    x = torch.randn(*shape, device="cuda", dtype=torch.bfloat16)
    h = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16)
    r = reference_transform(x, h, transpose=transpose)
    r2 = transform_had128(h=h, x=x, transpose=transpose)
    assert torch.allclose(r, r2)
