import torch
from tqdm import tqdm, trange
torch.set_float32_matmul_precision('high')
from q2t import *

M = 1024
N = 1024
K = 1024
HADAMARD_DIM = 128

A = torch.randn((M, K), device='cuda')
B = torch.randn((N, K), device='cuda')
ht = get_hadamard_matrix(HADAMARD_DIM, A.dtype, A.device)

with torch.no_grad():
    for acc_steps in tqdm([1, 4, 16, 64, 256, 1024], desc="Iterating steps"):
        accumulator = torch.zeros_like(A @ B.T)
        for i in trange(acc_steps, leave=False):
            ht = rerotate_hadamard(ht)

            A_amax_buffer = (A.view(-1, ht.size(0)) @ ht.T).abs().max()
            Aq, A_amax_buffer = eden_1x16s_fp4_kernel_wrapper(
                A,
                ht,
                1.0,
                16,
                current_amax=A_amax_buffer,
            )

            B_amax_buffer = (B.view(-1, ht.size(0)) @ ht.T).abs().max()
            Bq, B_amax_buffer = eden_1x16s_fp4_kernel_wrapper(
                B,
                ht,
                1.0,
                16,
                current_amax=B_amax_buffer,
            )

            accumulator += Aq @ Bq.T
        accumulator /= acc_steps

        quad_err = (accumulator - A @ B.T).pow(2).mean() / (A @ B.T).pow(2).mean()
        eff_bitwidth = (-torch.log2(quad_err) / 2).item()
        print(f"{acc_steps}: {eff_bitwidth:.2f} bits")


BATCH = 8
SEQ = 16
HID = 256
DELAYED_AMAX = True

INPUT = torch.randn((BATCH, SEQ, HID), device='cuda')
TARGET = torch.randn((BATCH, SEQ, 1), device='cuda')

W1 = Quartet_II_linear(HID, HID, device='cuda', delayed_amax=DELAYED_AMAX)
W2 = Quartet_II_linear(HID, HID, device='cuda', delayed_amax=DELAYED_AMAX)
W3 = Quartet_II_linear(HID, HID, device='cuda', delayed_amax=DELAYED_AMAX)

with torch.no_grad():
    W1.weight /= (HID**0.5 * W1.weight.std())
    W2.weight /= (HID**0.5 * W2.weight.std())
    W3.weight /= (HID**0.5 * W3.weight.std())

head = torch.randn(HID, 1, device='cuda')




W1.weight.grad = None
W2.weight.grad = None
W3.weight.grad = None

W1.disable_backward_quant = True
W2.disable_backward_quant = True
W3.disable_backward_quant = True

hid = W1(INPUT)
hid = torch.nn.functional.relu(hid)
hid = W2(hid)
hid = torch.nn.functional.relu(hid)
hid = W3(hid)
loss = (hid @ head - TARGET).pow(2).mean()
loss.backward()

w1_ref_grad = W1.weight.grad.clone().detach()
w2_ref_grad = W2.weight.grad.clone().detach()
w3_ref_grad = W3.weight.grad.clone().detach()



W1.disable_backward_quant = False
W2.disable_backward_quant = False
W3.disable_backward_quant = False

hid = W1(INPUT)
hid = torch.nn.functional.relu(hid)
hid = W2(hid)
hid = torch.nn.functional.relu(hid)
hid = W3(hid)
loss = (hid @ head - TARGET).pow(2).mean()

for acc_steps in [1, 4, 16, 64, 256, 1024, 4096]:
    print(f"{acc_steps=}:")
    W1.weight.grad = None
    W2.weight.grad = None
    W3.weight.grad = None
    for _ in range(acc_steps):
        loss.backward(retain_graph=True)
    with torch.no_grad():
        W1.weight.grad /= acc_steps
        W2.weight.grad /= acc_steps
        W3.weight.grad /= acc_steps

        quad_err = (W1.weight.grad - w1_ref_grad).pow(2).mean() / w1_ref_grad.pow(2).mean()
        eff_bitwidth = (-torch.log2(quad_err) / 2).item()
        cosine = (W1.weight.grad.flatten() @ w1_ref_grad.flatten()) / (w1_ref_grad.flatten() @ w1_ref_grad.flatten())
        print(f"\tW1 grad err: {eff_bitwidth:.2f} bits, {cosine:.3f} cosine")

        quad_err = (W2.weight.grad - w2_ref_grad).pow(2).mean() / w2_ref_grad.pow(2).mean()
        eff_bitwidth = (-torch.log2(quad_err) / 2).item()
        cosine = (W2.weight.grad.flatten() @ w2_ref_grad.flatten()) / (w2_ref_grad.flatten() @ w2_ref_grad.flatten())
        print(f"\tW2 grad err: {eff_bitwidth:.2f} bits, {cosine:.3f} cosine")

        quad_err = (W3.weight.grad - w3_ref_grad).pow(2).mean() / w3_ref_grad.pow(2).mean()
        eff_bitwidth = (-torch.log2(quad_err) / 2).item()
        cosine = (W3.weight.grad.flatten() @ w3_ref_grad.flatten()) / (w3_ref_grad.flatten() @ w3_ref_grad.flatten())
        print(f"\tW3 grad err: {eff_bitwidth:.2f} bits, {cosine:.3f} cosine")




def bench_shape(
        in_dim, out_dim, batch_size, seq_len,
        weight_dtype=torch.float32, act_dtype=torch.bfloat16, device='cuda',
        warmup=10, rep=100,
        compile=True, compile_kwargs=None,
):
    if compile_kwargs is None:
        compile_kwargs = {'dynamic': False, 'mode': 'reduce-overhead', 'fullgraph': False}

    x = torch.randn(batch_size, seq_len, in_dim, device=device, dtype=act_dtype)

    linear = Quartet_II_linear(in_dim, out_dim, four_over_six=True, device=device, dtype=weight_dtype)
    if compile:
        linear = torch.compile(linear, **compile_kwargs)

    # Forward
    torch.set_grad_enabled(False)

    ms = triton.testing.do_bench(
        lambda: linear(x), warmup=warmup, rep=rep,
    )
    forward_time = ms

    # Forward+Backward
    grad = torch.randn_like(linear(x))
    torch.set_grad_enabled(True)

    def forward_backward(x, grad):
        output = linear(x)
        output.backward(grad)

    if compile:
        compiled_forward_backward = torch.compile(forward_backward, **compile_kwargs)

    ms = triton.testing.do_bench(
        lambda: compiled_forward_backward(x, grad), warmup=warmup, rep=rep,
    )
    total_time = ms

    return {
        "forward_ms": forward_time,
        "total_ms": total_time,
    }


from tqdm.auto import tqdm, trange

BATCH_SIZE=8
SEQ_LEN = 2048

shapes = {
    # Q K V Down Up Gate Down
    # "100M": [(1024 * 3, 1024), (1024, 1024), (2816 * 2, 1024), (1024, 2816)],
    # "800M": [(2048 * 3, 2048), (2048, 2048), (5632 * 2, 2048), (2048, 5632)],
    "3B": [(3072 * 3, 3072), (3072, 3072), (8192 * 2, 3072), (3072, 8192)],
    "7B": [(4096 * 3, 4096), (4096, 4096), (11008 * 2, 4096), (4096, 11008)],
    "22B": [(6144 * 3, 6144), (6144, 6144), (16384 * 2, 6144), (6144, 16384)],
    # "52B": [(8192 * 3, 8192), (8192, 8192), (22016 * 2, 8192), (8192, 22016)],
}

shape_to_result = {}

for size, model_shapes in tqdm(shapes.items(), desc="Iterating model sizes"):
    for shape in tqdm(model_shapes, desc="Iterating shapes", leave=False):
        if shape not in shape_to_result:
            result = bench_shape(
                shape[1], shape[0], BATCH_SIZE, SEQ_LEN,
            )
            shape_to_result[shape] = result




for size, model_shapes in shapes.items():
    forward_latency = sum(shape_to_result[shape]['forward_ms'] for shape in model_shapes)
    total_latency = sum(shape_to_result[shape]['total_ms'] for shape in model_shapes)

    print(f"{size:5}: {forward_latency:6.2f} ms forward, {total_latency:6.2f} ms forward+backward")

