from typing import Optional
import torch
from einops import rearrange
from torch.nn import functional as F

try:
    import flashinfer
    FLASHINFER_AVAILABLE = True
except ImportError:
    FLASHINFER_AVAILABLE = False

try:
    from profiling.timers import timed, log_section, summary as profiling_summary
except ImportError:
    from profiling_stub import timed, log_section, summary as profiling_summary

try:
    from pooling_kernel import attn_with_pooling
    from pooling_kernel_opt import attn_with_pooling_optimized
    from gilbert3d_impl import gilbert3d
except ImportError:
    from anonymous_pooling_kernel import attn_with_pooling
    from anonymous_pooling_kernel_opt import attn_with_pooling_optimized
    from anonymous_gilbert3d import gilbert3d


use_rearrange = False
max_retain_ratio = 0.2
min_retain_ratio = 0.2
width = 80
height = 60
depth = 21
text_length = 0


class BackendConfig:
    ORIGINAL = "original"
    FLASHINFER_VARIABLE = "flashinfer_variable"


BACKEND = BackendConfig.FLASHINFER_VARIABLE
BACKEND_FA_VERSION = "fa3"


@timed("pad_to_multiple")
def pad_to_multiple(x, multiple):
    L = x.size(2)
    r = L % multiple
    if r != 0:
        x = F.pad(x, (0, 0, 0, multiple - r), mode="replicate")
    return x


@timed("random_sample_tokens")
def random_sample_tokens(x, block_size=64, sample_num=8):
    B, H, L, D = x.size()
    num_blocks = L // block_size
    x_blocks = x.view(B, H, num_blocks, block_size, D)
    rand_vals = torch.rand(B, H, 1, block_size, device=x.device)
    _, idx = torch.topk(rand_vals, sample_num, dim=3)
    idx = idx.unsqueeze(-1).expand(-1, -1, num_blocks, -1, D)
    sampled = torch.gather(x_blocks, 3, idx)
    return sampled.view(B, H, num_blocks * sample_num, D)


@timed("efficient_attn_with_pooling")
def efficient_attn_with_pooling(q, k, v, block_size=128, num_keep=8):
    q_ = pad_to_multiple(q, block_size)
    k_ = pad_to_multiple(k, block_size)
    sq = random_sample_tokens(q_, block_size, num_keep)
    sk = random_sample_tokens(k_, block_size, num_keep)
    _, pooling = attn_with_pooling_optimized(
        sq, sk, v, False, 1.0 / (sq.size(-1) ** 0.5), num_keep
    )
    return pooling


def standard_attn(q, k, v):
    return torch.nn.functional.scaled_dot_product_attention(q, k, v)


class GilbertRearranger:
    def __init__(self, w, h, d, tlen):
        self.width = w
        self.height = h
        self.depth = d
        self.total = w * h * d
        self.text_length = tlen
        m = self._curve_map(w, h, d)
        o2c = [0] * self.total
        c2o = [0] * self.total
        for curve_idx, org_idx in m.items():
            o2c[org_idx] = curve_idx
            c2o[curve_idx] = org_idx
        self.o2c = torch.tensor(o2c, dtype=torch.long, device="cuda")
        self.c2o = torch.tensor(c2o, dtype=torch.long, device="cuda")

    def _curve_map(self, w, h, d):
        m = {}
        idx = 0
        def f(x, y, z):
            return x + w * (y + h * z)
        for x, y, z in gilbert3d(w, h, d):
            m[f(x, y, z)] = idx
            idx += 1
        return m

    @timed("curve_rearrange")
    def rearrange(self, q, k, v):
        seq = -2
        tq, vq = q[..., :self.text_length, :], q[..., self.text_length:, :]
        tk, vk = k[..., :self.text_length, :], k[..., self.text_length:, :]
        tv, vv = v[..., :self.text_length, :], v[..., self.text_length:, :]
        rq = vq.index_select(seq, self.o2c)
        rk = vk.index_select(seq, self.o2c)
        rv = vv.index_select(seq, self.o2c)
        return (
            torch.cat((rq, tq), dim=seq),
            torch.cat((rk, tk), dim=seq),
            torch.cat((rv, tv), dim=seq),
        )

    @timed("curve_reverse")
    def reversed_rearrange(self, out):
        seq = -2
        video, text = out[..., :-self.text_length, :], out[..., -self.text_length:, :]
        rev = video.index_select(seq, self.c2o)
        return torch.cat((text, rev), dim=seq)


@timed("transfer_mask_energy")
def transfer_attn_to_mask(attn, mode="energy", 
                          init_k=None, max_retain_ratio=0.7, 
                          min_retain_ratio=0.1, energy_threshold=0.95):

    b, h, seq, _ = attn.shape
    device = attn.device
    mask = torch.zeros_like(attn, dtype=torch.bool)
    sorted_attn, idx = torch.sort(attn, dim=-1, descending=True)
    cum = torch.cumsum(sorted_attn, dim=-1)
    total = cum[..., -1:]
    thr = cum >= energy_threshold * total
    k_idx = torch.argmax(thr.int(), dim=-1)
    minr = torch.clamp((seq * min_retain_ratio), min=1).to(torch.int)
    maxr = torch.clamp((seq * max_retain_ratio), min=1).to(torch.int)
    k_idx = torch.minimum(torch.maximum(k_idx, minr), maxr)
    pos = torch.arange(seq, device=device).view(1, 1, seq)
    keep = pos < k_idx.unsqueeze(-1)
    mask.scatter_(-1, idx, keep)
    mask[..., -2:] = True
    mask[:, :, -2:, :] = True
    return mask


@timed("flashinfer_variable")
def block_sparse_attn_flashinfer(q, k, v, block_mask, block_size=128, backend="fa3"):
    if not FLASHINFER_AVAILABLE:
        raise RuntimeError("flashinfer not installed")

    B, H, L, D = q.shape
    bh = B * H
    q2 = q.reshape(bh, L, D)
    k2 = k.reshape(bh, L, D)
    v2 = v.reshape(bh, L, D)
    bm = block_mask.reshape(bh, block_mask.size(-2), block_mask.size(-1))
    br = torch.full((bh, bm.size(1)), block_size, dtype=torch.int32, device=q.device)
    bc = torch.full((bh, bm.size(2)), block_size, dtype=torch.int32, device=q.device)

    float_ws = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=q.device)
    int_ws = torch.empty(512 * 1024 * 1024, dtype=torch.uint8, device=q.device)
    idx_ws = torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device=q.device)
    ptr_ws = torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device=q.device)

    wrapper = flashinfer.VariableBlockSparseAttentionWrapper(float_ws, backend=backend)
    wrapper.reset_workspace_buffer(float_ws, int_ws, idx_ws, ptr_ws)

    wrapper.plan(
        block_mask_map=bm,
        block_row_sz=br,
        block_col_sz=bc,
        num_qo_heads=bh,
        num_kv_heads=bh,
        head_dim=D,
        causal=False,
        use_fp16_qk_reduction=True,
        non_blocking=True,
        q_data_type=q.dtype,
        kv_data_type=q.dtype,
    )

    out = wrapper.run(q2, k2, v2)
    return out.reshape(B, H, L, D)


@timed("adaptive_block_sparse")
def adaptive_block_sparse_attn(q, k, v, use_flashinfer=None):
    global max_retain_ratio, min_retain_ratio, BACKEND, BACKEND_FA_VERSION

    block_size = 128
    L = q.size(2)

    if L % block_size != 0:
        q = pad_to_multiple(q.transpose(1, 2), block_size).transpose(1, 2)
        k = pad_to_multiple(k.transpose(1, 2), block_size).transpose(1, 2)
        v = pad_to_multiple(v.transpose(1, 2), block_size).transpose(1, 2)
        Lp = q.size(2)
    else:
        Lp = L

    with torch.no_grad():
        pooling = efficient_attn_with_pooling(q, k, v)
        mask = transfer_attn_to_mask(
            pooling,
            mode="energy",
            max_retain_ratio=max_retain_ratio,
            min_retain_ratio=min_retain_ratio,
            energy_threshold=0.7,
        )
        sparsity = 1 - mask.float().mean()

    out = block_sparse_attn_flashinfer(
        q, k, v, mask, block_size=128, backend=BACKEND_FA_VERSION
    )

    if Lp != L:
        out = out[:, :, :L, :]

    return out, float(sparsity)


class AdaptiveBlockSparseAttnTrain(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.g = GilbertRearranger(width, height, depth, text_length)
        self.use_rearrange = use_rearrange
        self.last_sparsity = None

    def forward(self, q, k, v):
        if self.use_rearrange:
            q2, k2, v2 = self.g.rearrange(q, k, v)
        else:
            q2, k2, v2 = q, k, v
        out, sp = adaptive_block_sparse_attn(q2, k2, v2)
        self.last_sparsity = sp
        if self.use_rearrange:
            return self.g.reversed_rearrange(out)
        return out


def attention_profile(_: Optional[float] = None):
    stats = profiling_summary()
    return {
        "stats": [
            {"name": s.name, "calls": s.calls, "total_ms": s.total_ms, "avg_ms": s.avg_ms}
            for s in stats
        ]
    }
