import torch
from torch.profiler import profile, ProfilerActivity, record_function, tensorboard_trace_handler

from attention import naive_attention, attention
from mlp import mlp, naive_mlp
from xentropy import xentropy, naive_xentropy

log_dir = './log'

def wrap(fun, mock, *inputs):
    y1 = fun(*inputs)
    (y1 * mock).sum().backward()

def check_xentropy():
    L, R, D = 16*1024, 16*1024, 256
    p = torch.randn(L, D, requires_grad=True)
    c = torch.randn(R, D, requires_grad=True)
    t = torch.randint(R, (L,))
    mock = torch.randn(R)

    f = xentropy
    
    wrap(f, mock, p, c, t)
    with profile(
        activities=[ProfilerActivity.CPU], 
        profile_memory=True,
        on_trace_ready=tensorboard_trace_handler(log_dir),
    ) as prof:
        with record_function("MonoidReduce"):
            wrap(f, mock, p, c, t)
    print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))


def check_attention():
    B, D = 32*1024, 128
    q = torch.randn(B, D, requires_grad=True)
    k = torch.randn(B, D, requires_grad=True)
    v = torch.randn(B, D, requires_grad=True)

    mock = torch.randn(B, D)

    with profile(
        activities=[ProfilerActivity.CPU], 
        profile_memory=True,
        on_trace_ready=tensorboard_trace_handler(log_dir),
    ) as prof:
        with record_function("MonoidReduce"):
            wrap(naive_attention, mock, q, k, v)
    print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))


def check_mlp():
    B, D, K = 8*1024, 128, 8*1024
    x = torch.randn(B, D, requires_grad=True)
    p = torch.randn(K, D, requires_grad=True)
    q = torch.randn(K, D, requires_grad=True)

    mock = torch.randn(B, D)


    def wrap(fun):
        y1 = fun(x, p, q)
        (y1 * mock).sum().backward()

    wrap(mlp)

    with profile(
        activities=[ProfilerActivity.CPU], 
        profile_memory=True,
        on_trace_ready=tensorboard_trace_handler(log_dir),
    ) as prof:
        with record_function("MonoidReduce"):
            wrap(mlp)
    print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))

if __name__ == '__main__':
    check_attention()
