# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang

import torch
import triton
import triton.language as tl

from fla.ops.utils import prepare_chunk_indices
from fla.utils import IS_AMD, autotune_cache_kwargs, get_multiprocessor_count, input_guard, tensor_cache

NUM_WARPS_AUTOTUNE = [2, 4, 8, 16] if IS_AMD else [2, 4, 8, 16, 32]


def token_shift_ref(
    x: torch.Tensor,
    cu_seqlens: torch.Tensor | None = None,
) -> torch.Tensor:
    if cu_seqlens is not None:
        # Variable length mode with cu_seqlens
        assert x.dim() == 3, "Input must be [B, T, D]"
        B, T, D = x.shape
        assert B == 1, "Batch size must be 1 when using cu_seqlens"

        result = torch.zeros_like(x)
        N = cu_seqlens.shape[0] - 1

        for i in range(N):
            start = cu_seqlens[i].item()
            end = cu_seqlens[i+1].item()
            seq_len = end - start

            if seq_len <= 1:
                # For sequences of length 1 or 0, delta is simply -x
                result[0, start:end] = -x[0, start:end]
            else:
                # For longer sequences, handle padding manually
                shifted = torch.zeros_like(x[0, start:end])
                shifted[1:] = x[0, start:end-1]
                delta = shifted - x[0, start:end]
                result[0, start:end] = delta

        return result
    else:
        time_shift = torch.nn.ZeroPad2d((0, 0, 1, -1))
        shifted = time_shift(x)
        delta = shifted - x
        return delta


@triton.heuristics({
    'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
    'USE_INITIAL_STATE': lambda args: args['cache'] is not None,
})
@triton.autotune(
    configs=[
        triton.Config({}, num_warps=num_warps, num_stages=num_stages)
        for num_warps in NUM_WARPS_AUTOTUNE
        for num_stages in [1, 2, 3]
    ],
    key=['BD'],
    **autotune_cache_kwargs,
)
@triton.jit
def token_shift_fwd_kernel_short(
    x,
    y,
    cu_seqlens,
    cache,
    cache_out,
    T,
    D: tl.constexpr,
    BD: tl.constexpr,
    IS_VARLEN: tl.constexpr,
    USE_INITIAL_STATE: tl.constexpr,
    STORE_FINAL_STATE: tl.constexpr,
    IS_DECODE: tl.constexpr,
):
    i_b, i_t = tl.program_id(0), tl.program_id(1)

    if IS_VARLEN:
        i_n = i_b
        bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
        g_t = i_t + bos

        if g_t >= eos:
            return

        is_first_pos = (i_t == 0)
        is_last_pos = (g_t == eos - 1)
    else:
        g_t = i_t
        is_first_pos = (g_t == 0)
        is_last_pos = (g_t == T - 1)

    o_d = tl.arange(0, BD)
    m_d = o_d < D

    if IS_VARLEN:
        base_offset = g_t * D + o_d
    else:
        base_offset = i_b * T*D + g_t * D + o_d

    b_x = tl.load(x + base_offset, mask=m_d)
    if IS_VARLEN:
        cache_offset = i_n * D + o_d  # i_n is seq index
    else:
        cache_offset = i_b * D + o_d  # i_b is batch index

    if IS_DECODE and USE_INITIAL_STATE:
        b_cache = tl.load(cache + cache_offset, mask=m_d)
        delta = b_cache - b_x
        tl.store(y + base_offset, delta, mask=m_d)
        if STORE_FINAL_STATE:
            tl.store(cache_out + cache_offset, b_x, mask=m_d)
        return

    if is_first_pos:
        # First position in sequence: delta = -hidden_states
        if USE_INITIAL_STATE:
            # cache shape: [N, D]
            b_cache = tl.load(cache + cache_offset, mask=m_d)
            delta = b_cache - b_x
            tl.store(y + base_offset, delta, mask=m_d)
        else:
            tl.store(y + base_offset, -b_x, mask=m_d)
        return

    # Other positions: delta = prev - curr
    if IS_VARLEN:
        prev_offset = (g_t-1) * D + o_d
    else:
        prev_offset = i_b * T*D + (g_t-1) * D + o_d

    prev_values = tl.load(x + prev_offset, mask=m_d)
    delta = prev_values - b_x
    tl.store(y + base_offset, delta, mask=m_d)
    if STORE_FINAL_STATE:
        if is_last_pos:
            tl.store(cache_out + cache_offset, b_x, mask=m_d)


@triton.heuristics({
    'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
    'USE_INITIAL_STATE': lambda args: args['cache'] is not None,
})
@triton.autotune(
    configs=[
        triton.Config({}, num_warps=num_warps, num_stages=num_stages)
        for num_warps in NUM_WARPS_AUTOTUNE
        for num_stages in [1, 2, 3]
    ],
    key=['BD', 'NB'],
    **autotune_cache_kwargs,
)
@triton.jit
def token_shift_fwd_kernel_long(
    x,
    y,
    cu_seqlens,
    chunk_indices,
    cache,
    cache_out,
    T,
    D: tl.constexpr,
    BD: tl.constexpr,
    BT: tl.constexpr,
    NB: tl.constexpr,
    IS_VARLEN: tl.constexpr,
    USE_INITIAL_STATE: tl.constexpr,
    STORE_FINAL_STATE: tl.constexpr,
):
    i_d, i_t, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)

    if IS_VARLEN:
        i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), \
            tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
        bos, eos = tl.load(cu_seqlens + i_n), tl.load(cu_seqlens + i_n + 1)
        t_start = i_t * BT
        t_end = tl.minimum(t_start + BT, eos - bos)
    else:
        i_n = i_b
        bos, eos = i_b * T, (i_b + 1) * T
        t_start = i_t * BT
        t_end = tl.minimum(t_start + BT, T)

    o_d = i_d * BD + tl.arange(0, BD)
    m_d = o_d < D

    for t in range(t_start, t_end):
        global_t = bos + t
        offset = global_t * D + o_d
        b_x = tl.load(x + offset, mask=m_d)
        is_first = (global_t == bos)
        if is_first:
            if USE_INITIAL_STATE:
                # cache shape: [N, D]
                cache_off = i_n * D + o_d if IS_VARLEN else i_b * D + o_d
                b_cache = tl.load(cache + cache_off, mask=m_d)
                delta = b_cache - b_x
            else:
                delta = -b_x
        else:
            prev_off = offset - D
            b_prev = tl.load(x + prev_off, mask=m_d)
            delta = b_prev - b_x

        tl.store(y + offset, delta, mask=m_d)

        if STORE_FINAL_STATE:
            if global_t == eos - 1:
                cache_out_off = i_n * D + o_d if IS_VARLEN else i_b * D + o_d
                tl.store(cache_out + cache_out_off, b_x, mask=m_d)


@triton.heuristics({
    'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
    'USE_INITIAL_STATE': lambda args: args['grad_cache_out'] is not None,
    'HAS_DCACHE': lambda args: args['grad_cache_in'] is not None,
})
@triton.autotune(
    configs=[
        triton.Config({}, num_warps=num_warps, num_stages=num_stages)
        for num_warps in NUM_WARPS_AUTOTUNE
        for num_stages in [1, 2, 3]
    ],
    key=['BD'],
    **autotune_cache_kwargs,
)
@triton.jit
def token_shift_bwd_kernel_short(
    dx,
    dy,
    cu_seqlens,
    grad_cache_in,
    grad_cache_out,
    T,
    D: tl.constexpr,
    BD: tl.constexpr,
    IS_VARLEN: tl.constexpr,
    USE_INITIAL_STATE: tl.constexpr,
    HAS_DCACHE: tl.constexpr,
):
    i_b, i_t = tl.program_id(0), tl.program_id(1)

    if IS_VARLEN:
        i_n = i_b
        bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
        g_t = i_t + bos
        if g_t >= eos:
            return
        is_first_pos = (g_t == bos)
        is_last_pos = (g_t == eos - 1)
    else:
        g_t = i_t
        is_first_pos = (g_t == 0)
        is_last_pos = (g_t == T - 1)

    o_d = tl.arange(0, BD)
    m_d = o_d < D

    if IS_VARLEN:
        base_offset = g_t * D + o_d
        # This should not be used for varlen
        cache_off = i_n * D + o_d
    else:
        base_offset = i_b * T * D + g_t * D + o_d
        cache_off = i_b * D + o_d

    b_dy = tl.load(dy + base_offset, mask=m_d)

    if is_last_pos:
        # grad = -grad_delta[t] + grad_cache_in（from next rank）
        if HAS_DCACHE:
            b_dy_cache = tl.load(grad_cache_in + cache_off, mask=m_d)
            b_dx = -b_dy + b_dy_cache
        else:
            b_dx = -b_dy
    else:
        # grad = -grad_delta[t] + grad_delta[t+1]
        if IS_VARLEN:
            next_offset = (g_t + 1) * D + o_d
        else:
            next_offset = i_b * T * D + (g_t + 1) * D + o_d
        b_dx = -b_dy + tl.load(dy + next_offset, mask=m_d)

    tl.store(dx + base_offset, b_dx, mask=m_d)

    if USE_INITIAL_STATE:
        if is_first_pos:
            tl.store(grad_cache_out + cache_off, b_dy, mask=m_d)


@triton.heuristics({
    'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
    'USE_INITIAL_STATE': lambda args: args['grad_cache_out'] is not None,
    'HAS_DCACHE': lambda args: args['grad_cache_in'] is not None,
})
@triton.autotune(
    configs=[
        triton.Config({}, num_warps=num_warps, num_stages=num_stages)
        for num_warps in NUM_WARPS_AUTOTUNE
        for num_stages in [1, 2, 3]
    ],
    key=['BD', 'NB'],
    **autotune_cache_kwargs,
)
@triton.jit
def token_shift_bwd_kernel_long(
    dx,
    dy,
    cu_seqlens,
    chunk_indices,
    grad_cache_in,
    grad_cache_out,
    T,
    D: tl.constexpr,
    BD: tl.constexpr,
    BT: tl.constexpr,
    NB: tl.constexpr,
    IS_VARLEN: tl.constexpr,
    USE_INITIAL_STATE: tl.constexpr,
    HAS_DCACHE: tl.constexpr,
):
    i_d, i_t_blk, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)

    if IS_VARLEN:
        i_n, i_t_blk = tl.load(chunk_indices + i_t_blk * 2).to(tl.int32), \
            tl.load(chunk_indices + i_t_blk * 2 + 1).to(tl.int32)
        bos, eos = tl.load(cu_seqlens + i_n), tl.load(cu_seqlens + i_n + 1)
        t_start = i_t_blk * BT
        t_end = tl.minimum(t_start + BT, eos - bos)
    else:
        bos, eos = i_b * T, (i_b + 1) * T
        t_start = i_t_blk * BT
        t_end = tl.minimum(t_start + BT, T)

    o_d = i_d * BD + tl.arange(0, BD)
    m_d = o_d < D
    cache_off = i_n * D + o_d if IS_VARLEN else i_b * D + o_d

    for t in range(t_start, t_end):
        global_t = bos + t
        offset = global_t * D + o_d
        b_dy = tl.load(dy + offset, mask=m_d)

        if global_t == eos - 1:
            if HAS_DCACHE:
                b_dy_cache = tl.load(grad_cache_in + cache_off, mask=m_d)
                b_dx = -b_dy + b_dy_cache
            else:
                b_dx = -b_dy
        else:
            next_off = offset + D
            b_dx = -b_dy + tl.load(dy + next_off, mask=m_d)

        tl.store(dx + offset, b_dx, mask=m_d)

        if USE_INITIAL_STATE:
            if global_t == bos:
                tl.store(grad_cache_out + cache_off, b_dy, mask=m_d)


@tensor_cache
def prepare_maxlens(cu_seqlens: torch.LongTensor) -> int:
    return torch.max(cu_seqlens.diff()).item()


def token_shift_fwd(
    x: torch.Tensor,
    cu_seqlens: torch.Tensor | None = None,
    cache: torch.Tensor | None = None,
    output_cache: bool = False,
    chunk_indices: torch.LongTensor | None = None,
) -> torch.Tensor:
    B, T, D = x.shape
    y = torch.empty_like(x)
    use_short_kernel = T <= 4096

    if cu_seqlens is not None:
        T = prepare_maxlens(cu_seqlens)
        N = len(cu_seqlens) - 1
    else:
        N = B

    if output_cache:
        cache_out = torch.empty((N, D), device=x.device, dtype=x.dtype)
    else:
        cache_out = None

    if use_short_kernel:
        if cu_seqlens is not None:
            N = len(cu_seqlens) - 1
        else:
            N = B
        BD = triton.next_power_of_2(D)
        grid = (N, T)
        IS_DECODE = T == 1 or (B == 1 and T == N)
        token_shift_fwd_kernel_short[grid](
            x=x,
            y=y,
            cu_seqlens=cu_seqlens,
            cache=cache,
            cache_out=cache_out,
            T=T,
            D=D,
            BD=BD,
            STORE_FINAL_STATE=output_cache,
            IS_DECODE=IS_DECODE,
        )
    else:
        BT = min(64, triton.next_power_of_2(triton.cdiv(max(16, B*T), get_multiprocessor_count(x.device.index))))
        if chunk_indices is None and cu_seqlens is not None:
            chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
        NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)

        BD = triton.next_power_of_2(D)
        NB = triton.cdiv(B*T, 1024)

        def grid(meta): return (triton.cdiv(D, meta['BD']), NT, N)
        token_shift_fwd_kernel_long[grid](
            x,
            y,
            cu_seqlens,
            chunk_indices,
            cache,
            cache_out,
            T,
            D=D,
            BD=BD,
            BT=BT,
            NB=NB,
            STORE_FINAL_STATE=output_cache,
        )

    return y, N, T, use_short_kernel, cache_out


def token_shift_bwd(
    dy: torch.Tensor,
    N: int,
    T: int,
    dcache: torch.Tensor | None = None,
    cu_seqlens: torch.Tensor | None = None,
    use_short_kernel: bool = True,
    has_init_cache: bool = False,
    chunk_indices: torch.LongTensor | None = None,
) -> torch.Tensor:
    D = dy.shape[2]
    BD = triton.next_power_of_2(D)
    dx = torch.empty_like(dy)
    if has_init_cache:
        grad_cache_out = torch.empty((N, D), device=dy.device, dtype=dy.dtype)
    else:
        grad_cache_out = None
    if use_short_kernel:
        grid = (N, T)
        token_shift_bwd_kernel_short[grid](
            dy=dy,
            dx=dx,
            cu_seqlens=cu_seqlens,
            grad_cache_in=dcache,
            grad_cache_out=grad_cache_out,
            T=T,
            D=D,
            BD=BD,
        )
    else:
        BT = min(64, triton.next_power_of_2(triton.cdiv(max(16, dy.numel() // D),
                                                        get_multiprocessor_count(dy.device.index))))
        if chunk_indices is None and cu_seqlens is not None:
            chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
        NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
        NB = triton.cdiv(N * dy.shape[1], 1024)
        BD = triton.next_power_of_2(D)

        def grid(meta): return (triton.cdiv(D, meta['BD']), NT, N)
        token_shift_bwd_kernel_long[grid](
            dx,
            dy,
            cu_seqlens,
            chunk_indices,
            dcache,
            grad_cache_out,
            T,
            D=D,
            BD=BD,
            BT=BT,
            NB=NB,
        )
    return dx, grad_cache_out


class TokenShift(torch.autograd.Function):

    @staticmethod
    @input_guard
    def forward(ctx, x: torch.Tensor, cu_seqlens: torch.Tensor | None = None,
                cache: torch.Tensor | None = None, output_cache: bool = False,
                chunk_indices: torch.LongTensor | None = None):
        output, N, T, use_short_kernel, cache_out = token_shift_fwd(x, cu_seqlens, cache, output_cache, chunk_indices)
        ctx.cu_seqlens = cu_seqlens
        ctx.chunk_indices = chunk_indices
        ctx.N = N
        ctx.T = T
        ctx.use_short_kernel = use_short_kernel
        ctx.has_cache = cache is not None
        return output, cache_out

    @staticmethod
    @input_guard
    def backward(ctx, dy: torch.Tensor, dcache: torch.Tensor | None = None):
        dx, grad_cache = token_shift_bwd(dy, ctx.N, ctx.T, dcache, ctx.cu_seqlens,
                                         ctx.use_short_kernel, ctx.has_cache, ctx.chunk_indices)
        return dx, None, grad_cache, None, None


def token_shift(
    x: torch.Tensor,
    cu_seqlens: torch.LongTensor | None = None,
    cache: torch.Tensor | None = None,
    output_cache: bool = False,
    chunk_indices: torch.LongTensor | None = None,
):
    """
    Token-shift operation implemented with Triton kernels.

    Args:
        x: Input tensor of shape [B, T, D] (or [1, T, D] when `cu_seqlens` is supplied).
        cu_seqlens: Optional cumulative sequence lengths of shape [B + 1].
                    When supplied, `x.shape[0]` must be 1 and `x.dim()` must be 3.
        cache: Optional cache tensor of shape [N, D] that holds the last token
               from the previous call.
        output_cache: Whether to return the updated cache alongside the output.
                      In previous versions this parameter did not exist and the
                      cache was always dropped; to preserve backward compatibility
                      the default is False.

    Returns:
        output: Tensor of shape [B, T, D] after applying the token-shift.

        cache_out: Tensor of shape [B, 1, D] containing the last token that
                   should be fed as `cache` in the next call.  Only returned
                   when `output_cache=True`.
    """
    if cu_seqlens is not None:
        assert x.dim() == 3, "Input must be [B, T, D]"
        assert x.shape[0] == 1, "Batch size must be 1 when using cu_seqlens"

    output, cache_out = TokenShift.apply(x, cu_seqlens, cache, output_cache, chunk_indices)
    if output_cache:
        return output, cache_out
    else:
        return output
