import torch
import torch.nn.functional as F
import math
from utils import get_time, check_close
import socket
import os
from flash_attn import flash_attn_func
import dfpa
import dfpa.serial as df_serial


def flash_attn(q, k, v, causal=False):
    q = q.permute(0, 2, 1, 3)
    k = k.permute(0, 2, 1, 3)
    v = v.permute(0, 2, 1, 3)
    return flash_attn_func(q, k, v, causal=causal).permute(0, 2, 1, 3)


def naive_attn(K_chunk, K_prev, V_prev):
    attn = K_chunk @ K_prev.transpose(-1, -2)/math.sqrt(K_chunk.shape[-1])
    z_intra = torch.logsumexp(attn, dim=-1)
    return torch.softmax(attn,dim=-1)@V_prev, z_intra


def naive_implementation(k, n, d_model):
    B, H, T, D = k.shape
    v = torch.zeros_like(n)
    for t in range(T):
        if t == 0:
            v[:, :, 0] = n[:, :, 0]
        else:
            scores = torch.matmul(k[:, :, :t], k[:, :, t].unsqueeze(-1)).squeeze(-1) / math.sqrt(d_model)
            attn_probs = F.softmax(scores, dim=-1)
            v[:, :, t] = n[:, :, t] - torch.sum(attn_probs.unsqueeze(-1) * v[:, :, :t], dim=-2)
    return v


def _naive_loop(T, w, n, v):
    v[:, :, 0] = n[:, :, 0]
    for t in range(1, T):
        v[:, :, t] = n[:, :, t] - torch.sum(w[:, :, t, :t].unsqueeze(-1) * v[:, :, :t], dim=-2)
    return v


class DebFun(torch.autograd.Function):
    @staticmethod
    def forward(self, x):
        return x

    @staticmethod
    def backward(self, dx):
        print("====== debfun output =======")
        print(dx)
        print(">>>>>>>>>>>>>>>>>>>>>>>>>>")
        return dx


def naive_loop(k, n, d_model):
    a = k @ k.permute(0, 1, 3, 2) / math.sqrt(d_model)
    w = df_serial.naive_tril_softmax(a)
    v = torch.zeros_like(n)
    _naive_loop(n.size(2), w, n, v)
    return v


def pyinv(k, v, d_model):
    B, H, T, D = k.size()
    k = k.flatten(0, 1)
    v = v.flatten(0, 1)
    A = k @ k.transpose(-1, -2) / math.sqrt(d_model)
    A = df_serial.naive_tril_softmax(A)
    # Ti = torch.eye(T).view(1, 1, T, T) + A
    U = torch.linalg.solve_triangular(
        A.float(),
        v.float(),
        upper=False,
        unitriangular=True
    ).to(k.dtype)
    return U.reshape(B, H, T, D)


def blocked_loop(K, N, d_model, C):
    T = K.size(2)
    o = torch.zeros_like(N)
    for t_start in range(0, T, C):
        t_end = t_start + C
        q = K[:, :, t_start:t_end]
        k = K[:, :, :t_end]
        v = N[:, :, :t_end]
        o[:, :, t_start:t_end] = flash_attn(q, k, v)
    return o


def chunked_inv(K, N, d_model, C): 
    B, H, T, D = K.shape
    V = torch.zeros(B, H, T, D)
    chunk_nums = T // C
    for chunk_num in range(chunk_nums):
        start = chunk_num * C
        end = (chunk_num + 1) * C
        K_chunk = K[:, :, start:end]
        N_chunk = N[:, :, start:end]
        A = torch.tril(K_chunk @ K_chunk.transpose(-2, -1), -1) / math.sqrt(d_model)
        A = F.softmax(A, dim=-1)
        Ti = torch.eye(C).view(1, 1, C, C) + A
        # Ti_inverse = torch.inverse(Ti.float()) ##这是个下三角矩阵,可以用前代法，即利用gpu的并行性快速求逆 O(C^3)
        # V[:, :, start:end] = Ti_inverse.to(K.dtype) @ P     # O(C^2D)                     
    return V    #O(T/C * (TCD + C^2D)) = O(T^2D + TCD + TC^2)


def optimized_chunked_implementation(K, N, d_model, C): 
    B, H, T, D = K.shape
    V = torch.zeros(B, H, T, D)
    chunk_nums = T // C
    mask = torch.tril(torch.ones(C, C),diagonal=-1).unsqueeze(0).unsqueeze(0).to(K.device)
    for chunk_num in range(chunk_nums):
        start = chunk_num * C
        end = (chunk_num + 1) * C
        K_chunk = K[:, :, start:end]
        N_chunk = N[:, :, start:end]
        if chunk_num > 0:
            intra_output, Z_intra = naive_attn(K_chunk, K[:, :, :start], V[:, :, :start])#O(TCD)
            A = (K_chunk @ K_chunk.transpose(-2, -1)).masked_fill(mask[:, :, :C, :C] == 0, float("-inf"))  / math.sqrt(d_model)#O(C^2D)
            Z_inter = torch.logsumexp(A, dim=-1)
            P = N_chunk - intra_output * (1/(1 + torch.exp((Z_inter-Z_intra).unsqueeze(-1))))
            A = F.softmax(A, dim=-1) * (1/(1 + torch.exp((Z_intra-Z_inter).unsqueeze(-1))))
            A[:,:,0,:] = 0
        else:
            A = (K_chunk @ K_chunk.transpose(-2, -1)).masked_fill(mask[:, :, :C, :C] == 0, float("-inf"))  / math.sqrt(d_model)
            A = F.softmax(A, dim=-1)
            A[:,:,0,:] = 0
            P = N_chunk
        Ti = torch.eye(C).unsqueeze(0).unsqueeze(0).unsqueeze(0).to(K.device) + A
        Ti_inverse = torch.inverse(Ti.float()) ##这是个下三角矩阵,可以用前代法，即利用gpu的并行性快速求逆 O(C^3)
        V[:, :, start:end] = Ti_inverse.to(K.dtype) @ P     # O(C^2D)                     
    return V    #O(T/C * (TCD + C^2D)) = O(T^2D + TCD + TC^2)


def verify_equivalence(device='cuda', dtype=torch.bfloat16, profile=False, check=True):
    torch.set_default_device(device)
    torch.set_default_dtype(dtype)
    B = 2
    H = 32
    T = 8192
    D = 128
    C = 32
    d_model = D
    K = torch.randn(B, H, T, D)
    N = torch.randn(B, H, T, D)
    if check:
        naive_output = naive_implementation(K, N, d_model)
        optimized_output = naive_loop(K, N, d_model)
        # optimized_output = optimized_chunked_implementation(K, N, d_model, C)
        check_close(naive_output, optimized_output)

    print(f'Config: B {B} H {H} T {T} D {D} C {C}')
    gen_input = lambda: [torch.randn(B, H, T, D) for _ in range(2)]
    # t1 = get_time(lambda kn: naive_loop(kn[0], kn[1], d_model), pre=gen_input)
    # print(f'Naive   time {t1 * 1e3:.3f} ms')
    t = get_time(lambda kn: flash_attn(kn[0], kn[1], kn[1], causal=True), pre=gen_input)
    print(f'Ref     time {t * 1e3:.3f} ms')
    for C in [32, 64, 128, 256, 512, 1024]:
        # t = get_time(lambda kn: optimized_chunked_implementation(kn[0], kn[1], d_model, C), pre=gen_input)
        # print(f'C {C}  time {t * 1e3:.3f} ms')
        t = get_time(lambda kn: blocked_loop(kn[0], kn[1], d_model, C), pre=gen_input)
        print(f'bC {C}  time {t * 1e3:.3f} ms')
        t = get_time(lambda kn: chunked_inv(kn[0], kn[1], d_model, C), pre=gen_input)
        print(f'cC {C}  time {t * 1e3:.3f} ms')
    if not profile:
        return
    with torch.profiler.profile(
            on_trace_ready=torch.profiler.tensorboard_trace_handler(
                f"./profile/step3",
                worker_name=f"{socket.gethostname()}_{os.getpid()}_rank{0}",
                use_gzip=True,
            ),
            record_shapes=True,
            with_stack=True,
            profile_memory=True,
        ):
        y = naive_implementation(K, N, d_model)
        y = optimized_chunked_implementation(K, N, d_model, C)


def check_attn(device='cuda', dtype=torch.bfloat16):
    torch.set_default_device(device)
    torch.set_default_dtype(dtype)
    torch.cuda.manual_seed(10)
    torch.manual_seed(10)
    B = 1
    H = 1
    T = 1024
    D = 128 
    K = torch.randn(B, H, T, D)
    N = torch.randn(B, H, T, D)
    out_naive = naive_loop(K, N, D)
    # out_inv = pyinv(K, N, D)
    # check_close(out_naive, out_inv, print_details=True, name='out inv')
    # out_delt = deltattn.deltattn_full(K, N)
    out_deltc = dfpa.preattn(K, N)
    out_deltcg = dfpa.preattn(K, N, use_cuda_graph=True)
    # check_close(out_naive, out_delt, print_details=True, name='out delt lin')
    check_close(out_naive, out_deltc, print_details=False, name='out delt cnk')
    check_close(out_deltcg, out_deltc, print_details=False, name='out delt cudagraph cnk')


def clone_with_grad(w):
    with torch.no_grad():
        w1 = w.clone()
    w1.requires_grad = True
    return w1


def check_backward(device='cuda', dtype=torch.bfloat16):
    torch.set_default_device(device)
    torch.set_default_dtype(dtype)
    torch.manual_seed(2)
    torch.cuda.manual_seed(2)
    B, H, T, D = 1, 1, 256, 128
    B, H, T, D = 2, 8, 1024, 128
    with torch.no_grad():
        if True:
            K = torch.randn(B, H, T, D)
            N = torch.randn(B, H, T, D)
            grad_y = torch.rand_like(K)
        else:
            K = torch.randn(B, H, T, D)
            N = torch.ones(B, H, T, D)
            grad_y = torch.rand_like(K)

    k1, v1 = clone_with_grad(K), clone_with_grad(N)
    k2, v2 = clone_with_grad(K), clone_with_grad(N)

    out_dfpa = dfpa.preattn(k1, v1)
    out_dfpa.backward(grad_y)

    out_inv = pyinv(k2, v2, D)
    out_inv.backward(grad_y)

    check_close(k2.grad, k1.grad, name='dk inv vs cnk', print_details=False)
    check_close(v2.grad, v1.grad, name='dv inv vs cnk', print_details=False)
    # print(v2.grad[0, 0, :, 0])
    # check_close(out_naive, out_inv, print_details=True, name='out inv')
    # out_deltc = deltattn.deltattn_chunked(K, N)
    # out_deltcg = deltattn.deltattn_chunked_graph(K, N)
    # check_close(out_naive, out_delt, print_details=True, name='out delt lin')
    # check_close(out_naive, out_deltc, print_details=True, name='out delt cnk')
    # check_close(out_deltcg, out_deltc, print_details=True, name='out delt cudagraph cnk')


def compare_attn(device='cuda', dtype=torch.bfloat16):
    torch.set_default_device(device)
    torch.set_default_dtype(dtype)
    B = 2
    H = 32
    T = 8192
    D = 128
    d_model = D

    def gen_input():
        a, b = [torch.randn(B, H, T, D) for _ in range(2)]
        a.requires_grad = True
        b.requires_grad = True
        return a, b

    k, n = gen_input()
    dy = torch.rand_like(k)

    t = get_time(lambda kn: naive_attn(kn[0], kn[1], kn[1]), pre=gen_input)
    print(f'Naive Normal     time {t * 1e3:.3f} ms')

    # t = get_time(lambda kn: naive_loop(kn[0], kn[1], D), pre=gen_input)
    # print(f'Loop    time {t * 1e3:.3f} ms')
    # t = get_time(lambda kn: optimized_chunked_implementation(kn[0], kn[1], D, 128), pre=gen_input)
    # print(f'Chunked Inverse  time {t * 1e3:.3f} ms')
    # t = get_time(lambda kn: optimized_chunked_implementation(kn[0], kn[1], D, 256), pre=gen_input)
    # print(f'Chunked Inverse2 time {t * 1e3:.3f} ms')

    t = get_time(lambda kn: pyinv(kn[0], kn[1], D), pre=gen_input)
    print(f'Torch trsv       time {t * 1e3:.3f} ms')

    t = get_time(lambda kn: df_serial.preattn(kn[0], kn[1]), pre=gen_input)
    print(f'Serial           time {t * 1e3:.3f} ms')

    t = get_time(lambda kn: dfpa.preattn(kn[0], kn[1]), pre=gen_input)
    print(f'DFPA             time {t * 1e3:.3f} ms')

    a = torch.randn(B, T, H * D)
    w = torch.randn(H * D, H * D * 4)
    t = get_time(lambda _: a @ w)
    print(f'Matmul         ref time {t * 1e3:.3f} ms')

    t = get_time(lambda kn: flash_attn(kn[0], kn[1], kn[1], causal=True), pre=gen_input)
    print(f'FlashAttention ref time {t * 1e3:.3f} ms')

    t = get_time(lambda y: y.backward(dy), pre=lambda: dfpa.preattn(k, n))
    print(f'DFPA       backward time {t * 1e3:.3f} ms')

    t = get_time(lambda y: y.backward(dy), pre=lambda: flash_attn(k, n, n, causal=True))
    print(f'Flashattn  backward time {t * 1e3:.3f} ms')

    t = get_time(lambda y: y.backward(dy), pre=lambda: pyinv(k, n, D))
    print(f'Torch trsv backward time {t * 1e3:.3f} ms')


def profile_attn(device='cuda', dtype=torch.bfloat16):
    torch.set_default_device(device)
    torch.set_default_dtype(dtype)
    B = 2
    H = 32
    T = 8192
    D = 128
    w1 = torch.randn(D, D * 8)
    w2 = torch.randn(D * 8, D)
    def gen_input():
        a = torch.randn(B, H, T, D)
        b = torch.randn(B, H, T, D)
        for _ in range(10):
            a = a @ w1 @ w2
            b = b @ w1 @ w2
        a.requires_grad = True
        b.requires_grad = True
        return a, b

    for c in [128, 256, 512]:
        t = get_time(lambda kn: dfpa.preattn(kn[0], kn[1], C=c), pre=gen_input)
        print(f'chunksize {c}  time {t * 1e3:.3f} ms')
    t = get_time(lambda kn: dfpa.preattn(kn[0], kn[1], use_cuda_graph=True), pre=gen_input)
    print(f'cuda graph (c=256) time {t * 1e3:.3f} ms')

    a = torch.randn(B, H, T, D)
    b = torch.randn(B, H, T, D)
    a.requires_grad = True
    b.requires_grad = True
    y = dfpa.preattn(a, b)
    dy = torch.rand_like(y)
    t = get_time(lambda y: y.backward(dy), pre=lambda: dfpa.preattn(a, b))
    print(f'backward time {t * 1e3:.3f} ms')
 
    with torch.profiler.profile(
            on_trace_ready=torch.profiler.tensorboard_trace_handler(
                f"./profile/trial-0",
                worker_name=f"{socket.gethostname()}_{os.getpid()}_rank{0}",
                use_gzip=True,
            ),
            record_shapes=True,
            with_stack=True,
            profile_memory=True,
        ):
        y0 = dfpa.preattn(*gen_input(), C=256)
        # y1 = dfpa.preattn_chunked_graph(*gen_input(), C=256)
        y0.backward(dy)


if __name__ == "__main__":
    check_attn()
    check_backward()
    # compare_attn()
    compare_attn()
    profile_attn()
    # verify_equivalence(profile=False)
 
