import torch
import q2t
from eden import linear

BATCH = 8
SEQ = 16
HID = 256
DELAYED_AMAX = False
FourOverSix = False

INPUT = torch.randn((BATCH, SEQ, HID), device='cuda', dtype=torch.bfloat16)
TARGET = torch.randn((BATCH, SEQ, 1), device='cuda')

W1_triton = q2t.Quartet_II_linear(HID, HID, device='cuda', delayed_amax=DELAYED_AMAX, dtype=torch.bfloat16, four_over_six=FourOverSix)
W1_cuda = linear.Quartet_II_linear(HID, HID, device='cuda', dtype=torch.bfloat16, four_over_six=FourOverSix)

with torch.no_grad():
    W1_triton.weight /= (HID**0.5 * W1_triton.weight.std())
    W1_cuda.weight[...] =  W1_triton.weight.clone()
    w_ref = torch.nn.Parameter(W1_triton.weight.clone())
h_triton = W1_triton(INPUT)
h_cuda = W1_cuda(INPUT)
h_ref = INPUT @ w_ref.T


def compare(got, ref, title):
    quad_err = (got - ref).pow(2).mean() / ref.pow(2).mean()
    eff_bitwidth = (-torch.log2(quad_err) / 2).item()
    print(f"{title}: {quad_err:.4f} mrqe; {eff_bitwidth:.2f} bits")

print("FWD")
compare(h_triton, h_ref, "BF16 vs triton")
compare(h_cuda, h_ref, "BF16 vs CUDA  ")
compare(h_cuda, h_triton, "CUDA vs triton")


grad_o = torch.randn_like(h_triton)
h_ref.backward(grad_o)
h_triton.backward(grad_o)
h_cuda.backward(grad_o)

print("BWD")
compare(W1_triton.weight.grad, w_ref.grad, "BF16 vs triton")
compare(W1_cuda.weight.grad, w_ref.grad, "BF16 vs CUDA  ")
compare(W1_cuda.weight.grad, W1_triton.weight.grad, "CUDA vs triton")
