from typing import Optional

import math
import torch
import triton
import triton.language as tl
from . import fa as dffa


BLOCK_SIZE_C = 256


def cum_by_inv(u, w):
    return torch.linalg.solve_triangular(
        w.float(),
        u.float(),
        upper=False,
        unitriangular=True
    ).to(u.dtype)


def cum_by_inv_inplace(u, w):
    u.copy_(cum_by_inv(u, w))


def cum_by_inv_backward_x(do, w):
    return torch.linalg.solve_triangular(
        w.tril(-1).mH.float(),
        do.float(),
        upper=True,
        unitriangular=True
    ).to(do.dtype)


def cum_by_inv_backward(do, w, x):
    du = torch.linalg.solve_triangular(
        w.tril(-1).mH.float(),
        do.float(),
        upper=True,
        unitriangular=True
    ).to(do.dtype)
    dw = torch.bmm(-du, x.mH)
    dw = dw.tril(-1)
    return du, dw


def preattn_chunked_nobw(K, V, C=BLOCK_SIZE_C):
    BS, NH, T, D = K.size()
    K = K.flatten(0, 1)
    V = V.flatten(0, 1)
    U = torch.empty_like(V)
    scale = 1 / math.log(2) / math.sqrt(D)
    for i in range(0, T, C):
        w, _ = dffa.forward(
            K[:, i:i + C, :],
            K[:, :i + C, :],
            V[:, i:i + C, :],
            U[:, :i + C, :],
            scale,
        )
        ui = U[:, i:i + C, :]
        cum_by_inv_inplace(ui, w)
    return U.view(BS, NH, T, D)


class DFPreAttn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, ko, vo):
        BS, NH, T, D = ko.size()
        C = BLOCK_SIZE_C
        k = ko.flatten(0, 1)
        v = vo.flatten(0, 1)
        u = torch.empty_like(v)
        ws = torch.empty(T // C, BS * NH, C, C, device=k.device, dtype=k.dtype)
        lses = torch.empty(BS * NH, T, device=k.device, dtype=torch.float)

        scale = 1 / math.log(2) / math.sqrt(D)
        for i in range(0, T, C):
            w, lse = dffa.forward(
                k[:, i:i + C, :],
                k[:, :i + C, :],
                v[:, i:i + C, :],
                u[:, :i + C, :],
                scale,
            )
            ws[i // C].copy_(w)
            lses[:, i:i + C].copy_(lse)
            ui = u[:, i:i + C, :]
            cum_by_inv_inplace(ui, w)

        ctx.save_for_backward(ko, vo, u, ws, lses)
        return u.view_as(vo)

    @staticmethod
    def backward(ctx, grad_o):
        ko, vo, u, ws, lses = ctx.saved_tensors
        k = ko.flatten(0, 1)
        v = vo.flatten(0, 1)
        grad_o = grad_o.flatten(0, 1)
        C = BLOCK_SIZE_C
        BS, NH, T, D = ko.size()
        grad_v = torch.empty_like(v)
        # grad_u = torch.zeros_like(u)

        qk_scale = 1 / math.sqrt(D)
        fa_scale = qk_scale / math.log(2)
        for i in range(T - C, -1, -C):
            do = grad_o[:, i:i + C, :]
            if i < T - C:
                qi = k[:, i:i + C, :]
                ki = k[:, i + C:, :]
                lse = lses[:, i + C:]
                du = dffa.backward_u_chunk(qi, ki, lse, grad_v[:, i + C:, :], fa_scale)
                # Below is an equivalent low performance impl.
                # a = torch.bmm(qi, ki.transpose(1, 2)) * fa_scale
                # p = torch.exp2(a - lse[:, None, :])
                # du = torch.bmm(p, grad_v[:, i + C:, :])
                do = grad_o[:, i:i + C, :] - du
            else:
                do = grad_o[:, i:i + C, :]
            du = cum_by_inv_backward_x(do, ws[i // C])
            grad_v[:, i:i + C, :].copy_(du)

        grad_k = dffa.backward_k(k, u, lses, grad_v, qk_scale, fa_scale)
        return grad_k.view_as(ko), grad_v.view_as(vo)


def preattn(k: torch.Tensor, v: torch.Tensor, C=BLOCK_SIZE_C, use_cuda_graph=False):
    if use_cuda_graph:
        return preattn_chunked_graph(k, v, C)
    if not k.requires_grad:
        return preattn_chunked_nobw(k, v, C=C)
    return DFPreAttn.apply(k, v)


_chunked_graphs = dict()


def preattn_chunked_graph(K, V, C=BLOCK_SIZE_C):
    BS, NH, T, D = K.size()
    key = (BS, NH, T, D, C)
    if key not in _chunked_graphs:
        module = torch.cuda.make_graphed_callables(preattn_chunked_nobw, (K, V))
        _chunked_graphs[key] = module
    else:
        module = _chunked_graphs[key]
    return module(K, V)
