import math
import typing

import flash_attn
import flash_attn.layers.rotary
import huggingface_hub
import omegaconf
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

# Flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)


def bias_dropout_add_scale(
    x: torch.Tensor,
    bias: typing.Optional[torch.Tensor],
    scale: torch.Tensor,
    residual: typing.Optional[torch.Tensor],
    prob: float,
    training: bool,
) -> torch.Tensor:
    if bias is not None:
        out = scale * F.dropout(x + bias, p=prob, training=training)
    else:
        out = scale * F.dropout(x, p=prob, training=training)

    if residual is not None:
        out = residual + out
    return out


def get_bias_dropout_add_scale(training):
    def _bias_dropout_add(x, bias, scale, residual, prob):
        return bias_dropout_add_scale(x, bias, scale, residual, prob, training)

    return _bias_dropout_add


# function overload
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
    return x * (1 + scale) + shift


@torch.jit.script
def bias_dropout_add_scale_fused_train(
    x: torch.Tensor,
    bias: typing.Optional[torch.Tensor],
    scale: torch.Tensor,
    residual: typing.Optional[torch.Tensor],
    prob: float,
) -> torch.Tensor:
    return bias_dropout_add_scale(x, bias, scale, residual, prob, True)


@torch.jit.script
def bias_dropout_add_scale_fused_inference(
    x: torch.Tensor,
    bias: typing.Optional[torch.Tensor],
    scale: torch.Tensor,
    residual: typing.Optional[torch.Tensor],
    prob: float,
) -> torch.Tensor:
    return bias_dropout_add_scale(x, bias, scale, residual, prob, False)


@torch.jit.script
def modulate_fused(
    x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor
) -> torch.Tensor:
    return modulate(x, shift, scale)


class Rotary(torch.nn.Module):
    def __init__(self, dim, base=10_000):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
        self.seq_len_cached = None
        self.cos_cached = None
        self.sin_cached = None

    def forward(self, x, seq_dim=1):
        seq_len = x.shape[seq_dim]
        if seq_len != self.seq_len_cached:
            self.seq_len_cached = seq_len
            t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
            freqs = torch.einsum("i,j->ij", t, self.inv_freq.clone())
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            # dims are: batch, seq_len, qkv, head, dim
            self.cos_cached = emb.cos()[None, :, None, None, :].repeat(1, 1, 3, 1, 1)
            self.sin_cached = emb.sin()[None, :, None, None, :].repeat(1, 1, 3, 1, 1)
            # This makes the transformation on v an identity.
            self.cos_cached[:, :, 2, :, :].fill_(1.0)
            self.sin_cached[:, :, 2, :, :].fill_(0.0)

        return self.cos_cached, self.sin_cached


def rotate_half(x):
    x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(qkv, cos, sin):
    cos = cos[0, :, 0, 0, : cos.shape[-1] // 2]
    sin = sin[0, :, 0, 0, : sin.shape[-1] // 2]
    return flash_attn.layers.rotary.apply_rotary_emb_qkv_(qkv, cos, sin)


# function overload
def modulate(x, shift, scale):
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)


#################################################################################
#                                  Layers                                       #
#################################################################################
class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.weight = nn.Parameter(torch.ones([dim]))
        self.dim = dim

    def forward(self, x):
        with torch.amp.autocast("cuda", enabled=False):
            x = F.layer_norm(x.float(), [self.dim])
        return x * self.weight[None, None, :]


def residual_linear(x, W, x_skip, residual_scale):
    """x_skip + residual_scale * W @ x"""
    dim_out, dim_in = W.shape[0], W.shape[1]
    return torch.addmm(
        x_skip.view(-1, dim_out), x.view(-1, dim_in), W.T, alpha=residual_scale
    ).view(*x.shape[:-1], dim_out)


#################################################################################
#               Embedding Layers for Timesteps and Class Labels                 #
#################################################################################
class TimestepEmbedder(nn.Module):
    """
    Embeds scalar timesteps into vector representations.
    """

    def __init__(self, hidden_size, frequency_embedding_size=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size, bias=True),
        )
        self.frequency_embedding_size = frequency_embedding_size

    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        """
        Create sinusoidal timestep embeddings.
        :param t: a 1-D Tensor of N indices, one per batch element.
                          These may be fractional.
        :param dim: the dimension of the output.
        :param max_period: controls the minimum frequency of the embeddings.
        :return: an (N, D) Tensor of positional embeddings.
        """
        # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period)
            * torch.arange(start=0, end=half, dtype=torch.float32)
            / half
        ).to(device=t.device)
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat(
                [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
            )
        return embedding

    def forward(self, t):
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
        t_emb = self.mlp(t_freq)
        return t_emb


class LabelEmbedder(nn.Module):
    """Embeds class labels into vector representations.

    Also handles label dropout for classifier-free guidance.
    """

    def __init__(self, num_classes, cond_size):
        super().__init__()
        self.embedding_table = nn.Embedding(num_classes + 1, cond_size)
        self.num_classes = num_classes

        # TODO think of initializing with 0.02 std deviation like in original DiT paper

    def forward(self, labels):
        embeddings = self.embedding_table(labels)
        return embeddings


#################################################################################
#                                 Core Model                                    #
#################################################################################


class DDiTBlock(nn.Module):
    def __init__(self, dim, n_heads, cond_dim, mlp_ratio=4, dropout=0.1):
        super().__init__()
        self.n_heads = n_heads

        self.norm1 = LayerNorm(dim)
        self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False)
        self.attn_out = nn.Linear(dim, dim, bias=False)
        self.dropout1 = nn.Dropout(dropout)

        self.norm2 = LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_ratio * dim, bias=True),
            nn.GELU(approximate="tanh"),
            nn.Linear(mlp_ratio * dim, dim, bias=True),
        )
        self.dropout2 = nn.Dropout(dropout)
        self.dropout = dropout

        self.adaLN_modulation = nn.Linear(cond_dim, 6 * dim, bias=True)
        self.adaLN_modulation.weight.data.zero_()
        self.adaLN_modulation.bias.data.zero_()

        self.cached_x = None
        self.cached_qkv = None

    def _get_bias_dropout_scale(self):
        if self.training:
            return bias_dropout_add_scale_fused_train
        else:
            return bias_dropout_add_scale_fused_inference

    def forward(self, x, rotary_cos_sin, c, seqlens=None, exclude_last_token=False,
            update_indices=None, use_cache=False, cache_xqkv=False, seq_len=None):
        # batch_size, seq_len = x.shape[0], x.shape[1]
        batch_size = x.shape[0]
        if seq_len is None:
            seq_len = x.shape[1]

        bias_dropout_scale_fn = self._get_bias_dropout_scale()

        (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp) = (
            self.adaLN_modulation(c)[:, None].chunk(6, dim=2)
        )

        # attention operation
        x_skip = x.clone()
        x = modulate_fused(self.norm1(x), shift_msa, scale_msa)
        qkv = self.attn_qkv(x)

        if cache_xqkv:
            self.cached_x = x.clone()
            self.cached_qkv = qkv.clone()
        if (update_indices is not None) and use_cache:
            # input x: updated indices only
            batch_indices = batch_indices_from_uind(update_indices)
            self.cached_x[batch_indices, update_indices] = x
            x = self.cached_x
            self.cached_qkv[batch_indices, update_indices] = qkv
            qkv = self.cached_qkv

            self.cached_x = x.clone()
            self.cached_qkv = qkv.clone()

        qkv = rearrange(
            qkv, "b s (three h d) -> b s three h d", three=3, h=self.n_heads
        )
        with torch.amp.autocast("cuda", enabled=False):
            cos, sin = rotary_cos_sin
            if exclude_last_token: # for di4c
                qkv[:,:-1] = apply_rotary_pos_emb(
                    qkv[:,:-1].clone(), cos.to(qkv.dtype), sin.to(qkv.dtype))
            else:
                qkv = apply_rotary_pos_emb(qkv, cos.to(qkv.dtype), sin.to(qkv.dtype))
        
        if update_indices is None:
            qkv = rearrange(qkv, "b s ... -> (b s) ...")
            if seqlens is None:
                cu_seqlens = torch.arange(
                    0,
                    (batch_size + 1) * seq_len,
                    step=seq_len,
                    dtype=torch.int32,
                    device=qkv.device,
                )
            else:
                cu_seqlens = seqlens.cumsum(-1)
            x = flash_attn.flash_attn_interface.flash_attn_varlen_qkvpacked_func(
            # x = flash_attn_varlen_qkvpacked_func(
                qkv, cu_seqlens, seq_len, 0.0, causal=False
            )

            x = rearrange(x, "(b s) h d -> b s (h d)", b=batch_size)
        else:
            cu_seqlens_k = torch.arange(
                0,
                (batch_size + 1) * seq_len,
                step=seq_len,
                dtype=torch.int32,
                device=qkv.device,
            )
            
            x = custom_flash_attn_with_update_indices_direct(
                qkv, 
                cu_seqlens=cu_seqlens_k,
                seq_len=seq_len,
                update_indices=update_indices
            )
            # x = x.reshape(batch_size, update_indices.size(1), -1)
            x = rearrange(x, "(b s) h d -> b s (h d)", b=batch_size)

        x = bias_dropout_scale_fn(
            self.attn_out(x), None, gate_msa, x_skip, self.dropout
        )

        # mlp operation
        x = bias_dropout_scale_fn(
            self.mlp(modulate_fused(self.norm2(x), shift_mlp, scale_mlp)),
            None,
            gate_mlp,
            x,
            self.dropout,
        )
        return x


class EmbeddingLayer(nn.Module):
    def __init__(self, dim, vocab_dim):
        super().__init__()
        self.embedding = nn.Parameter(torch.empty((vocab_dim, dim)))
        torch.nn.init.kaiming_uniform_(self.embedding, a=math.sqrt(5))

    def forward(self, x):
        return self.embedding[x]


class DDitFinalLayer(nn.Module):
    def __init__(self, hidden_size, out_channels, cond_dim):
        super().__init__()
        self.norm_final = LayerNorm(hidden_size)
        self.linear = nn.Linear(hidden_size, out_channels)
        self.linear.weight.data.zero_()
        self.linear.bias.data.zero_()

        self.adaLN_modulation = nn.Linear(cond_dim, 2 * hidden_size, bias=True)
        self.adaLN_modulation.weight.data.zero_()
        self.adaLN_modulation.bias.data.zero_()

    def forward(self, x, c):
        shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2)
        x = modulate_fused(self.norm_final(x), shift, scale)
        x = self.linear(x)
        return x


class DIT(nn.Module, huggingface_hub.PyTorchModelHubMixin):
    def __init__(self, config, vocab_size: int):
        super().__init__()
        if type(config) == dict:
            config = omegaconf.OmegaConf.create(config)

        self.config = config
        self.vocab_size = vocab_size

        self.vocab_embed = EmbeddingLayer(config.model.hidden_size, vocab_size)
        self.sigma_map = TimestepEmbedder(config.model.cond_dim)
        self.rotary_emb = Rotary(config.model.hidden_size // config.model.n_heads)

        blocks = []
        for _ in range(config.model.n_blocks):
            blocks.append(
                DDiTBlock(
                    config.model.hidden_size,
                    config.model.n_heads,
                    config.model.cond_dim,
                    dropout=config.model.dropout,
                )
            )
        self.blocks = nn.ModuleList(blocks)

        self.output_layer = DDitFinalLayer(
            config.model.hidden_size, vocab_size, config.model.cond_dim
        )
        self.scale_by_sigma = config.model.scale_by_sigma
        if "is_di4c" in config:
            self.is_di4c = config.is_di4c
        else:
            self.is_di4c = config.is_di4c = False
        if self.is_di4c:
            # Added for Di4C:
            self.latent_feature_dim = 128
            self.latent_projection = nn.Sequential(
                nn.Linear(in_features=self.latent_feature_dim, out_features=self.latent_feature_dim*4),
                nn.GELU(),
                nn.Linear(self.latent_feature_dim*4, config.model.hidden_size)
            )

    def _get_bias_dropout_scale(self):
        if self.training:
            return bias_dropout_add_scale_fused_train
        else:
            return bias_dropout_add_scale_fused_inference

    def forward(self, indices, sigma, update_indices=None, use_cache=False, cache_xqkv=False):
        x = self.vocab_embed(indices)
        c = F.silu(self.sigma_map(sigma))

        rotary_cos_sin = self.rotary_emb(x)
        seq_len = indices.size(1)

        if self.is_di4c: # Di4C
            z = torch.rand(x.size(0)).to(x.device)
            z_emb = transformer_timestep_embedding(
                z.view(-1) * 1000, self.latent_feature_dim
            )
            z_emb = self.latent_projection(z_emb)
            x = torch.cat([x, z_emb[:,None,:]], dim=1)
            seq_len += 1
        
        if update_indices is not None:
            batch_indices = batch_indices_from_uind(update_indices)
            x = x[batch_indices, update_indices]

        with torch.amp.autocast("cuda", dtype=torch.bfloat16):
            for i in range(len(self.blocks)):
                x = self.blocks[i](x, rotary_cos_sin, c, seqlens=None, exclude_last_token=self.is_di4c,
                        update_indices=update_indices, use_cache=use_cache, cache_xqkv=cache_xqkv, seq_len=seq_len)
            x = self.output_layer(x, c)

        if update_indices is not None:
            return x[:, :update_indices.size(1), :]
        return x[:, :indices.size(1), :]

# From https://github.com/yang-song/score_sde_pytorch/ which is from
#  https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py
def transformer_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
    assert len(timesteps.shape) == 1  # and timesteps.dtype == tf.int32
    half_dim = embedding_dim // 2
    # magic number 10000 is from transformers
    emb = math.log(max_positions) / (half_dim - 1)
    # emb = math.log(2.) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
    # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
    # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
    emb = timesteps.float()[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    if embedding_dim % 2 == 1:  # zero pad
        emb = F.pad(emb, (0, 1), mode='constant')
    assert emb.shape == (timesteps.shape[0], embedding_dim)
    return emb


### Below is taken and modified from https://github.com/Dao-AILab/flash-attention/tree/4d3d2ff2163ac011bce1b16a2eb2ca90a75f9628
### We put the copyright notice of the original code:

'''
BSD 3-Clause License

Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
  list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice,
  this list of conditions and the following disclaimer in the documentation
  and/or other materials provided with the distribution.

* Neither the name of the copyright holder nor the names of its
  contributors may be used to endorse or promote products derived from
  this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
'''

from typing import Optional, Sequence, Tuple, Union
import flash_attn_2_cuda as flash_attn_gpu

# def flash_attn_varlen_qkvpacked_func(
#     qkv,
#     cu_seqlens,
#     max_seqlen,
#     dropout_p=0.0,
#     softmax_scale=None,
#     causal=False,
#     window_size=(-1, -1),  # -1 means infinite context window
#     softcap=0.0, # 0.0 means deactivated
#     alibi_slopes=None,
#     deterministic=False,
#     return_attn_probs=False,
#     cu_seqlens_q=None,
# ):
#     return FlashAttnVarlenQKVPackedFunc.apply(
#         qkv,
#         cu_seqlens,
#         max_seqlen,
#         dropout_p,
#         softmax_scale,
#         causal,
#         window_size,
#         softcap,
#         alibi_slopes,
#         deterministic,
#         return_attn_probs,
#         torch.is_grad_enabled(),
#         cu_seqlens_q,
#     )

# class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
#     @staticmethod
#     def forward(
#         ctx,
#         qkv,
#         cu_seqlens,
#         max_seqlen,
#         dropout_p,
#         softmax_scale,
#         causal,
#         window_size,
#         softcap,
#         alibi_slopes,
#         deterministic,
#         return_softmax,
#         is_grad_enabled,
#         cu_seqlens_q=None,
#     ):
#         is_grad = is_grad_enabled and qkv.requires_grad
#         if softmax_scale is None:
#             softmax_scale = qkv.shape[-1] ** (-0.5)
#         if cu_seqlens_q is None:
#             cu_seqlens_q = cu_seqlens
#         q, k, v = qkv[:, 0].detach(), qkv[:, 1].detach(), qkv[:, 2].detach()
#         head_size_og = q.size(2)
#         if head_size_og % 8 != 0:
#             q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
#             k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
#             v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
#         out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(
#             q,
#             k,
#             v,
#             cu_seqlens_q,
#             cu_seqlens,
#             max_seqlen,
#             max_seqlen,
#             dropout_p,
#             softmax_scale,
#             causal=causal,
#             window_size_left=window_size[0],
#             window_size_right=window_size[1],
#             softcap=softcap,
#             alibi_slopes=alibi_slopes,
#             return_softmax=return_softmax and dropout_p > 0,
#             block_table=None,
#         )
#         if is_grad:
#             ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
#             ctx.dropout_p = dropout_p
#             ctx.max_seqlen = max_seqlen
#             ctx.softmax_scale = softmax_scale
#             ctx.causal = causal
#             ctx.window_size = window_size
#             ctx.softcap = softcap
#             ctx.alibi_slopes = alibi_slopes
#             ctx.deterministic = deterministic
#         out = out_padded[..., :head_size_og]
#         return out if not return_softmax else (out, softmax_lse, S_dmask)
#
# def _wrapped_flash_attn_varlen_forward( # original: _flash_attn_forward
#     q: torch.Tensor,
#     k: torch.Tensor,
#     v: torch.Tensor,
#     cu_seqlens_q: torch.Tensor,
#     cu_seqlens_k: torch.Tensor,
#     max_seqlen_q: int,
#     max_seqlen_k: int,
#     dropout_p: float,
#     softmax_scale: float,
#     causal: bool,
#     window_size_left: int = -1,
#     window_size_right: int = -1,
#     softcap: float = 0.0,
#     alibi_slopes: Optional[torch.Tensor] = None,
#     return_softmax: bool = False,
#     block_table: Optional[torch.Tensor] = None,
#     leftpad_k: Optional[torch.Tensor] = None,
#     seqused_k: Optional[torch.Tensor] = None,
#     zero_tensors: bool = False,
# ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
#     q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
#     out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd(
#         q,
#         k,
#         v,
#         None,
#         cu_seqlens_q,
#         cu_seqlens_k,
#         seqused_k,
#         leftpad_k,
#         block_table,
#         alibi_slopes,
#         max_seqlen_q,
#         max_seqlen_k,
#         dropout_p,
#         softmax_scale,
#         zero_tensors,
#         causal,
#         window_size_left,
#         window_size_right,
#         softcap,
#         return_softmax,
#         None,
#     )
#     # if out.isnan().any() or softmax_lse.isnan().any():
#     #     breakpoint()
#     return out, softmax_lse, S_dmask, rng_state

def maybe_contiguous(x):
    return x.contiguous() if x is not None and x.stride(-1) != 1 else x

def batch_indices_from_uind(uind):
    B, M = uind.shape
    batch_indices = torch.arange(B, device=uind.device).unsqueeze(1).expand(-1, M)
    return batch_indices


def custom_flash_attn_with_update_indices_direct(qkv, cu_seqlens, seq_len, update_indices, cu_seqlens_q=None):
    """
    flash_attn_gpu.varlen_fwdを直接呼び出して、update_indicesに対応するqだけを使用するFlash Attention実装
    
    Args:
        qkv: 入力テンソル (batch_size, seq_len, 3, n_heads, head_dim)
        cu_seqlens: 累積シーケンス長
        seq_len: シーケンス長
        update_indices: 更新するインデックス
        cu_seqlens_q: クエリの累積シーケンス長（Noneの場合は新たに計算）
    """
    # update_indicesに対応するqだけを抽出
    batch_indices = batch_indices_from_uind(update_indices)
    # q = qkv[batch_indices, update_indices, 0].detach().reshape(-1, qkv.size(-2), qkv.size(-1))
    # k = qkv[:, :, 1].detach().reshape(-1, qkv.size(-2), qkv.size(-1))  # すべてのkを使用
    # v = qkv[:, :, 2].detach().reshape(-1, qkv.size(-2), qkv.size(-1))  # すべてのvを使用

    # print(qkv[batch_indices, update_indices, 0].shape)

    q = rearrange(qkv[batch_indices, update_indices, 0].detach(), "b s ... -> (b s) ...")
    k = rearrange(qkv[:, :, 1].detach(), "b s ... -> (b s) ...")
    v = rearrange(qkv[:, :, 2].detach(), "b s ... -> (b s) ...")
    
    # head_sizeが8の倍数でない場合はパディング
    head_size_og = q.size(-1)
    if head_size_og % 8 != 0:
        q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
        k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
        v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
    
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
    
    # クエリの累積シーケンス長を計算
    if cu_seqlens_q is None:
        batch_size = qkv.size(0)
        update_seq_len = update_indices.size(1)
        cu_seqlens_q = torch.arange(
            0,
            (batch_size + 1) * update_seq_len,
            step=update_seq_len,
            dtype=torch.int32,
            device=qkv.device,
        )
    
    # デフォルトパラメータの設定
    dropout_p = 0.0
    softmax_scale = q.shape[-1] ** (-0.5)
    causal = False
    window_size_left = -1
    window_size_right = -1
    softcap = 0.0
    alibi_slopes = None
    return_softmax = False
    zero_tensors = False
    
    # flash_attn_gpu.varlen_fwdを直接呼び出し
    out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd(
        q,                  # クエリ
        k,                  # キー
        v,                  # バリュー
        None,               # out（Noneの場合は新しく割り当て）
        cu_seqlens_q,       # クエリの累積シーケンス長
        cu_seqlens,         # キー/バリューの累積シーケンス長
        None,               # seqused_k
        None,               # leftpad_k
        None,               # block_table
        alibi_slopes,       # alibi_slopes
        update_indices.size(1),  # max_seqlen_q
        seq_len,            # max_seqlen_k
        dropout_p,          # dropout確率
        softmax_scale,      # スケーリング係数
        zero_tensors,       # ゼロテンソルフラグ
        causal,             # 因果マスク
        window_size_left,   # 左ウィンドウサイズ
        window_size_right,  # 右ウィンドウサイズ
        softcap,            # ソフトキャップ
        return_softmax,     # ソフトマックス返却フラグ
        None,               # rng_state
    )
    
    # 元のhead_sizeに戻す
    out = out[..., :head_size_og]
    return out

