import torch
import torch.utils.checkpoint
import numpy as np

from quick_extend.ref_impl.mask_gen import mask_gen_ref
from quick_extend.ref_impl.sparse_attn import sparse_attn_ref, sparse_attn_flash_ref
from quick_extend.triton_impl.mixed_fwd_bwd import hip_attn, mixed_attn
from quick_extend.triton_impl.ste_backward import ste_hip_attn


def run_toy_gd(
        BATCH_SIZE=1,
        NUM_HEADS=32,
        QUERY_LEN=32768,
        KEY_VALUE_LEN=32768,
        HEAD_DIM=128,
        VALUE_DIM=128,
        BLOCK_SIZE_Q=16,
        BLOCK_SIZE_K=2,
        TOP_K_ELEMS=1024,
        QUERY_OFFSET=0,
        START_SINK_TOKENS=2,
        END_SINK_TOKENS=16,
        seed=0,
        experiment='hip',  # 'hip' or 'soft'
        dtype=torch.float32):
    TOP_K = TOP_K_ELEMS // BLOCK_SIZE_K

    torch.random.manual_seed(seed)
    dev = torch.device("cuda")
    q = torch.randn(BATCH_SIZE, NUM_HEADS, QUERY_LEN, HEAD_DIM, dtype=dtype, device=dev)
    k = torch.randn(BATCH_SIZE, NUM_HEADS, KEY_VALUE_LEN, HEAD_DIM, dtype=dtype, device=dev)
    v = torch.randn(BATCH_SIZE, NUM_HEADS, KEY_VALUE_LEN, VALUE_DIM, dtype=dtype, device=dev)

    q.requires_grad = True
    k.requires_grad = True
    v.requires_grad = True

    grad_noise = torch.randn(BATCH_SIZE, NUM_HEADS, QUERY_LEN, VALUE_DIM, dtype=dtype, device=dev)

    #opt = torch.optim.AdamW([q, k, v], lr=1e-1)
    opt = torch.optim.SGD([q, k, v], lr=1e-1, momentum=0.9)

    losses = []
    q_grad_norms = []
    k_grad_norms = []
    v_grad_norms = []

    for it in range(100):
        print(f"Iteration {it}")
        opt.zero_grad()

        if experiment == 'soft':
            r = ref_softmask(q, k, v, TOP_K, BLOCK_SIZE_Q, BLOCK_SIZE_K, QUERY_OFFSET, START_SINK_TOKENS, END_SINK_TOKENS)
        elif experiment == 'hip':
            r = hip_attn(q, k, v, BLOCK_SIZE_Q, BLOCK_SIZE_K, TOP_K_ELEMS, QUERY_OFFSET, START_SINK_TOKENS, END_SINK_TOKENS)
        elif experiment == 'ste':
            r = ste_hip_attn(q, k, v, BLOCK_SIZE_Q, BLOCK_SIZE_K, TOP_K_ELEMS, QUERY_OFFSET, START_SINK_TOKENS, END_SINK_TOKENS)
        elif experiment == 'mixed':
            r = mixed_attn(q, k, v, BLOCK_SIZE_Q, BLOCK_SIZE_K, TOP_K_ELEMS, QUERY_OFFSET, 0, START_SINK_TOKENS, END_SINK_TOKENS)

        loss = (r - grad_noise).pow(2).mean()

        loss.backward()
        losses.append(loss.item())
        q_grad_norms.append(q.grad.norm().item())
        k_grad_norms.append(k.grad.norm().item())
        v_grad_norms.append(v.grad.norm().item())
        print(f"Loss: {losses[-1]} Grad_norm: {q_grad_norms[-1]} {k_grad_norms[-1]} {v_grad_norms[-1]}")

        torch.nn.utils.clip_grad_norm_([q, k, v], 10.0)
        opt.step()

    return losses, q_grad_norms, k_grad_norms, v_grad_norms


def ref_softmask(q, k, v, TOP_K, BLOCK_SIZE_Q, BLOCK_SIZE_K, QUERY_OFFSET, START_SINK_TOKENS, END_SINK_TOKENS):
    m = torch.utils.checkpoint.checkpoint(
        mask_gen_ref,
        q, k, TOP_K, BLOCK_SIZE_Q, BLOCK_SIZE_K, QUERY_OFFSET,
        START_SINK_TOKENS, END_SINK_TOKENS, True,
        use_reentrant=True,
    )
    r, l = sparse_attn_flash_ref(
        q, k, m, v, BLOCK_SIZE_Q, BLOCK_SIZE_K, 32, 32, QUERY_OFFSET,
        START_SINK_TOKENS, END_SINK_TOKENS
    )
    return r


if __name__ == '__main__':

    # make output directory
    import os
    if not os.path.exists("outputs"):
        os.makedirs("outputs")

    for seed in [0, 1, 2, 3, 4]:
        print(f"Seed: {seed}")
        for experiment in ['soft', 'hip', 'ste']:
            print(f"Experiment: {experiment}")
            losses, q_grad_norms, k_grad_norms, v_grad_norms = run_toy_gd(
                BATCH_SIZE=2,
                NUM_HEADS=2,
                QUERY_LEN=512,
                KEY_VALUE_LEN=512,
                HEAD_DIM=32,
                VALUE_DIM=32,
                BLOCK_SIZE_Q=16,
                BLOCK_SIZE_K=2,
                TOP_K_ELEMS=128,
                QUERY_OFFSET=512 - 512,
                START_SINK_TOKENS=32,
                END_SINK_TOKENS=32,
                seed=seed,
                experiment=experiment,
            )
            np.save(f"outputs/toy_{experiment}_{seed}.npy", np.array([losses, q_grad_norms, k_grad_norms, v_grad_norms]))
