# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# credit - flat index forward kernel is derived from FBGemm:
# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm

# pyre-unsafe
import functools
import logging

import os
import sys
from typing import Any, Dict, Optional, Tuple

import torch

import triton
import triton.language as tl
from triton import Config as TConfig

from triton.runtime import driver  # @manual

sys.path.append(os.path.dirname(os.path.abspath(__file__)))

from tma_autotuning import (
    _NV_CONFIGS,
    ALIGN_SIZE_M,
    CudaUtils,
    early_config_prune,
    TmaDescriptorHelper,
)


# Configure logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)

# ==============  Start Triton Kernels ===============


@triton.autotune(
    configs=_NV_CONFIGS,
    key=["G", "M_BUCKET", "N", "K"],
    prune_configs_by={"early_config_prune": early_config_prune},
)
@triton.jit
def _kernel_mg_forward_hopper(
    a_desc_ptr,
    b_desc_ptr,
    c_ptr,
    workspace,
    m_sizes,
    # problem sizes
    G: tl.constexpr,
    M_BUCKET: tl.constexpr,
    N: tl.constexpr,
    K: tl.constexpr,
    # config
    NUM_SMS: tl.constexpr,
    TMA_SIZE: tl.constexpr,
    USE_EPILOGUE_SUBTILING: tl.constexpr,
    # tiles
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
) -> None:
    """
    Flat index style forward kernel for Hopper.
    For simplicity, we always use TMA Load and TMA Store
    """
    tbidx = tl.program_id(0)  # thread block index

    c_dtype = c_ptr.dtype.element_ty  # output dtype

    c_desc_ptr = workspace + (tbidx * TMA_SIZE)  # for TMA Store

    M_end = 0
    M_start = 0
    processed_tiles = 0
    # Size of individual weight matrix
    n_size = N // G
    n_start = 0

    for g in range(G):
        # Move down along groups
        # reset to new M offset
        M_start = M_end
        m_size = tl.load(m_sizes + g)
        M_end = M_start + m_size
        n_start = n_size * g

        if m_size > 0:
            # Process this group

            # Acquire hold on c_desc_ptr for TMA Store
            tl.extra.cuda.experimental_device_tensormap_create2d(
                desc_ptr=c_desc_ptr,
                global_address=c_ptr + M_start * n_size,
                load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
                global_size=[m_size, n_size],
                element_ty=c_dtype,
            )
            tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)

            # tiles for this group
            num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
            num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
            group_num_tiles = num_m_tiles * num_n_tiles

            while tbidx >= processed_tiles and tbidx < (
                processed_tiles + group_num_tiles
            ):
                group_index = tbidx - processed_tiles

                # columnwise
                tile_m_index = group_index % num_m_tiles
                tile_n_index = group_index // num_m_tiles

                accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

                m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
                n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
                global_n_offset = (n_start + n_offset).to(tl.int32)

                for k_offset in range(0, K, BLOCK_SIZE_K):
                    # input block [M,K]
                    a = tl._experimental_descriptor_load(
                        a_desc_ptr,
                        [m_offset, k_offset],
                        [BLOCK_SIZE_M, BLOCK_SIZE_K],
                        c_dtype,
                    )
                    # weight block [N, K]
                    b = tl._experimental_descriptor_load(
                        b_desc_ptr,
                        [global_n_offset, k_offset],
                        [BLOCK_SIZE_N, BLOCK_SIZE_K],
                        c_dtype,
                    )

                    accumulator += tl.dot(a, b.T)

                # Store using TMA

                m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)

                if USE_EPILOGUE_SUBTILING:
                    acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2))
                    acc = tl.permute(acc, (0, 2, 1))
                    acc0, acc1 = tl.split(acc)
                    c0 = acc0.to(c_dtype)
                    tl._experimental_descriptor_store(
                        c_desc_ptr, c0, [m_offset, n_offset]
                    )
                    c1 = acc1.to(c_dtype)
                    tl._experimental_descriptor_store(
                        c_desc_ptr, c1, [m_offset, n_offset + BLOCK_SIZE_N // 2]
                    )
                else:
                    tl._experimental_descriptor_store(
                        c_desc_ptr,
                        accumulator.to(c_dtype),
                        [m_offset, n_offset],
                    )
                # move to next tile in group
                tbidx += NUM_SMS
            # Update the total tiles count for the next group
            processed_tiles += group_num_tiles


@triton.autotune(
    configs=_NV_CONFIGS,
    key=["G", "M_BUCKET", "N", "K"],
    prune_configs_by={"early_config_prune": early_config_prune},
)
@triton.jit
def _kernel_mg_forward_tma(
    a_desc_ptr,
    b_desc_ptr,
    c_ptr,
    workspace,
    m_sizes,
    a_scale_ptr,
    b_scale_ptr,
    # problem sizes
    G: tl.constexpr,
    M_BUCKET: tl.constexpr,
    N: tl.constexpr,
    K: tl.constexpr,
    # config
    NUM_SMS: tl.constexpr,
    USE_TMA_LOAD: tl.constexpr,
    USE_TMA_STORE: tl.constexpr,
    TMA_SIZE: tl.constexpr,
    USE_FP8: tl.constexpr,
    # tiles
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
) -> None:
    """
    Flat index style forward kernel.
    For simplicity, we always use TMA Load and TMA Store
    """
    tbidx = tl.program_id(0)  # thread block index

    c_dtype = c_ptr.dtype.element_ty

    c_desc_ptr = workspace + (tbidx * TMA_SIZE)

    M_end = 0
    processed_tiles = 0

    for g in range(G):
        # Move down along groups
        # reset to new M offset
        M_start = M_end
        m_size = tl.load(m_sizes + g)
        M_end = M_start + m_size

        if m_size > 0:
            # Process this group
            n_size = N

            # TMA Store prep
            tl.extra.cuda.experimental_device_tensormap_create2d(
                desc_ptr=c_desc_ptr,
                global_address=c_ptr + M_start * N,
                load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
                global_size=[m_size, n_size],
                element_ty=c_dtype,
            )
            tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)

            # tiles for this group
            num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
            num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
            group_num_tiles = num_m_tiles * num_n_tiles

            while tbidx >= processed_tiles and tbidx < (
                processed_tiles + group_num_tiles
            ):
                group_index = tbidx - processed_tiles

                tile_m_index = group_index % num_m_tiles
                tile_n_index = group_index // num_m_tiles

                accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

                m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
                n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)

                for k_offset in range(0, K, BLOCK_SIZE_K):
                    # input block [M,K]
                    a = tl._experimental_descriptor_load(
                        a_desc_ptr,
                        [m_offset, k_offset],
                        [BLOCK_SIZE_M, BLOCK_SIZE_K],
                        c_dtype,
                    )
                    # weight block [N, K]
                    b = tl._experimental_descriptor_load(
                        b_desc_ptr,
                        [n_offset, k_offset],
                        [BLOCK_SIZE_N, BLOCK_SIZE_K],
                        c_dtype,
                    )

                    accumulator += tl.dot(a, b.T)

                # Store using TMA

                m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
                # n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)

                tl._experimental_descriptor_store(
                    c_desc_ptr,
                    accumulator.to(c_dtype),
                    [m_offset, n_offset],
                )

                # Move to the next tile
                tbidx += NUM_SMS
            # Update the total tiles count for the next group
            processed_tiles += group_num_tiles


@triton.autotune(
    configs=_NV_CONFIGS,
    key=["G", "M_BUCKET", "N", "K"],
    prune_configs_by={"early_config_prune": early_config_prune},
)
@triton.jit
def _kernel_mg_forward_no_tma(
    a_ptr,
    b_ptr,
    c_ptr,
    workspace,
    m_sizes,
    # problem sizes
    G: tl.constexpr,
    M_BUCKET: tl.constexpr,
    N: tl.constexpr,
    K: tl.constexpr,
    # config
    NUM_SMS: tl.constexpr,
    USE_TMA_LOAD: tl.constexpr,
    USE_TMA_STORE: tl.constexpr,
    TMA_SIZE: tl.constexpr,
    # tiles
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
) -> None:
    """
    Flat index style forward kernel.
    For bc and Ampere, we never use TMA Load and TMA Store
    """
    tbidx = tl.program_id(0)  # thread block index

    c_dtype = c_ptr.dtype.element_ty
    c_desc_ptr = None

    M_end = 0
    processed_tiles = 0

    for g in range(G):
        # Move down along groups
        # reset to new M offset
        M_start = M_end
        m_size = tl.load(m_sizes + g)
        M_end = M_start + m_size

        if m_size > 0:
            # Process this group
            n_size = N

            # tiles for this group
            num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
            num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
            group_num_tiles = num_m_tiles * num_n_tiles

            while tbidx >= processed_tiles and tbidx < (
                processed_tiles + group_num_tiles
            ):
                group_index = tbidx - processed_tiles

                tile_m_index = group_index % num_m_tiles
                tile_n_index = group_index // num_m_tiles

                accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

                m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
                n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)

                offs_am = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
                offs_bn = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
                offs_k = tl.arange(0, BLOCK_SIZE_K)

                a_ptrs = a_ptr + (M_start + offs_am[:, None]) * K + offs_k[None, :]
                b_ptrs = b_ptr + (offs_bn[:, None]) * K + offs_k[None, :]

                for k_offset in range(0, K, BLOCK_SIZE_K):
                    # Load with bounds checking
                    a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)
                    b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)

                    # Main matmul
                    accumulator += tl.dot(a, b.T)

                    # Update pointers for next block
                    a_ptrs += BLOCK_SIZE_K
                    b_ptrs += BLOCK_SIZE_K

                # Store without TMA
                offs_am = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
                offs_bn = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)

                c = accumulator.to(c_dtype)

                tl.store(
                    c_ptr
                    + (M_start + offs_am[:, None]) * N  # Row stride is N
                    + offs_bn[None, :],  # Column offset
                    c,
                    mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,
                )
                # Move to the next tile
                tbidx += NUM_SMS
            # Update the total tiles count for the next group
            processed_tiles += group_num_tiles


"""
Backward pass for grouped GEMM with Triton, where grouping is M*G
We compute gradients with respect to both input (`grad_x`) and weights (`grad_w`).
"""


# ---- dx flat linear indexed ----
@triton.autotune(
    configs=_NV_CONFIGS,
    key=["G", "M_BUCKET", "N", "K"],
    prune_configs_by={"early_config_prune": early_config_prune},
)
@triton.jit
def _kernel_mg_dx_tma(
    grad_output_desc_ptr,  # [MG, N]
    w_desc_ptr,  # [N, K]
    grad_input_ptr,  # output grad_x [MG, K]
    workspace,  # for TMA store
    m_sizes,  # group sizes [G]
    # problem sizes
    G: tl.constexpr,
    M_BUCKET: tl.constexpr,
    N: tl.constexpr,
    K: tl.constexpr,
    # config
    NUM_SMS: tl.constexpr,
    USE_TMA_LOAD: tl.constexpr,
    USE_TMA_STORE: tl.constexpr,
    TMA_SIZE: tl.constexpr,
    # tiles
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
) -> None:
    """
    TMA-optimized kernel for computing gradients with respect to input (dx).
    For the forward pass Y = X @ W.T, the backward for input is:
    grad_X = grad_Y @ W

    This maps to [MG, N] @ [N, K] -> [MG, K]

    Key differences from forward:
    1. W is used directly and not transposed
    2. The reduction dimension is now N (not K)
    3. Output is [M, K] instead of [M, N]
    """
    tbidx = tl.program_id(0)  # thread block index

    c_dtype = grad_input_ptr.dtype.element_ty
    c_desc_ptr = workspace + (tbidx * TMA_SIZE)

    M_end = 0
    processed_tiles = 0

    for g in range(G):
        # Move down along groups - same as forward
        M_start = M_end
        m_size = tl.load(m_sizes + g)
        M_end = M_start + m_size

        if m_size > 0:
            # Process this group
            # tiles for this group - now producing [M, K] output
            num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
            num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
            group_num_tiles = num_m_tiles * num_k_tiles

            # TMA Store prep for [M, K] output
            tl.extra.cuda.experimental_device_tensormap_create2d(
                desc_ptr=c_desc_ptr,
                global_address=grad_input_ptr + M_start * K,
                load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K],
                global_size=[m_size, K],
                element_ty=c_dtype,
            )
            tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)

            while tbidx >= processed_tiles and tbidx < (
                processed_tiles + group_num_tiles
            ):
                group_index = tbidx - processed_tiles

                # Different tiling scheme for [M, K] output
                tile_m_index = group_index % num_m_tiles
                tile_k_index = group_index // num_m_tiles

                # for grad_input block [M, K]
                accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)

                # Position in full matrix
                m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
                k_offset = (tile_k_index * BLOCK_SIZE_K).to(tl.int32)

                # reduce along N dimension (instead of K in forward)
                for n_offset in range(0, N, BLOCK_SIZE_N):
                    # grad_output block [M, N]
                    grad_output = tl._experimental_descriptor_load(
                        grad_output_desc_ptr,
                        [m_offset, n_offset],
                        [BLOCK_SIZE_M, BLOCK_SIZE_N],
                        c_dtype,
                    )

                    # weight block [N, K] - no transpose needed
                    w = tl._experimental_descriptor_load(
                        w_desc_ptr,
                        [n_offset, k_offset],
                        [BLOCK_SIZE_N, BLOCK_SIZE_K],
                        c_dtype,
                    )

                    # grad_x = grad_output @ w
                    # reducing along N dimension
                    accumulator += tl.dot(grad_output, w)

                # Store using TMA
                m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
                # k_offset = (tile_k_index * BLOCK_SIZE_K).to(tl.int32)

                tl._experimental_descriptor_store(
                    c_desc_ptr,
                    accumulator.to(c_dtype),
                    [m_offset, k_offset],
                )

                # Move to the next tile
                tbidx += NUM_SMS

            # Update the total tiles count for the next group
            processed_tiles += group_num_tiles


# ---- dw flat linear indexed ----


@triton.autotune(
    configs=_NV_CONFIGS,
    key=["G", "M_BUCKET", "N", "K"],
    prune_configs_by={"early_config_prune": early_config_prune},
)
@triton.jit
def _kernel_mg_dw_tma(
    x_desc_ptr,  # input descriptor [M_total, K]
    grad_output_desc_ptr,  # grad_output descriptor [M_total, N]
    grad_weight_ptr,  # output grad_w [N, K]
    workspace,  # workspace for TMA store
    m_sizes,  # group sizes [G]
    # problem sizes
    G: tl.constexpr,
    M_BUCKET: tl.constexpr,
    N: tl.constexpr,
    K: tl.constexpr,
    # config
    NUM_SMS: tl.constexpr,
    USE_TMA_LOAD: tl.constexpr,
    USE_TMA_STORE: tl.constexpr,
    TMA_SIZE: tl.constexpr,
    # tiles
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    BLOCK_SIZE_M: tl.constexpr,  # block size for reduction dimension
) -> None:
    """
    Improved TMA-optimized kernel for computing gradients with respect to weights (dw).
    Uses flat index structure similar to forward.

    For the forward pass Y = X @ W.T,
    the backward for weights is:
    grad_W = grad_Y.T @ X

    Where:
    - grad_Y is [MG, N]
    - X is [MG, K]
    - grad_W is [N, K]
    - we return [N,K]
    """
    # Get thread block index l
    tbidx = tl.program_id(0)

    # Get output data type
    c_dtype = grad_weight_ptr.dtype.element_ty

    # Calculate number of output tiles
    num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
    num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
    total_output_tiles = num_n_tiles * num_k_tiles

    # Process tiles in strided manner across SMs
    for tile_idx in range(tbidx, total_output_tiles, NUM_SMS):
        # Calculate tile indices
        tile_n_idx = tile_idx % num_n_tiles
        tile_k_idx = tile_idx // num_n_tiles

        # Calculate global offsets
        n_offset = tile_n_idx * BLOCK_SIZE_N
        k_offset = tile_k_idx * BLOCK_SIZE_K

        # Initialize accumulator for this output tile [N, K]
        accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32)

        # Process each group
        M_end = 0
        for g in range(G):
            # Get group boundaries
            M_start = M_end
            m_size = tl.load(m_sizes + g)
            M_end = M_start + m_size

            # Only process if group is non-empty
            if m_size > 0:
                # Process this group in chunks along the M dimension
                for m_offset in range(0, m_size, BLOCK_SIZE_M):
                    # Calculate actual block size (handling boundary)
                    m_block_size = tl.minimum(BLOCK_SIZE_M, m_size - m_offset)

                    # Only process if we have actual work to do
                    if m_block_size > 0:
                        # Global offset for this chunk
                        m_global_offset = M_start + m_offset

                        if USE_TMA_LOAD:
                            # Load input chunk [M_chunk, K] using TMA
                            x_block = tl._experimental_descriptor_load(
                                x_desc_ptr,
                                [m_global_offset, k_offset],
                                [BLOCK_SIZE_M, BLOCK_SIZE_K],
                                c_dtype,
                            )

                            # Load grad_output chunk [M_chunk, N] using TMA
                            grad_output_block = tl._experimental_descriptor_load(
                                grad_output_desc_ptr,
                                [m_global_offset, n_offset],
                                [BLOCK_SIZE_M, BLOCK_SIZE_N],
                                c_dtype,
                            )

                            # Apply masks for valid regions
                            offs_m = tl.arange(0, BLOCK_SIZE_M)
                            m_mask = offs_m < m_block_size

                            # Zero out invalid elements
                            x_block = tl.where(m_mask[:, None], x_block, 0.0)
                            grad_output_block = tl.where(
                                m_mask[:, None], grad_output_block, 0.0
                            )
                        else:
                            # Manual load with bounds checking
                            offs_m = tl.arange(0, BLOCK_SIZE_M)
                            offs_n = tl.arange(0, BLOCK_SIZE_N)
                            offs_k = tl.arange(0, BLOCK_SIZE_K)

                            # Create masks
                            m_mask = offs_m < m_block_size
                            n_mask = offs_n < N - n_offset
                            k_mask = offs_k < K - k_offset

                            # Combined masks
                            mk_mask = m_mask[:, None] & k_mask[None, :]
                            mn_mask = m_mask[:, None] & n_mask[None, :]

                            # Global offsets for loading
                            m_global_offs = m_global_offset + offs_m

                            # Load x block [M_chunk, K]
                            x_block = tl.load(
                                x_desc_ptr
                                + m_global_offs[:, None] * K
                                + (k_offset + offs_k)[None, :],
                                mask=mk_mask,
                                other=0.0,
                            )

                            # Load grad_output block [M_chunk, N]
                            grad_output_block = tl.load(
                                grad_output_desc_ptr
                                + m_global_offs[:, None] * N
                                + (n_offset + offs_n)[None, :],
                                mask=mn_mask,
                                other=0.0,
                            )

                        # Compute partial contribution: grad_W += grad_Y.T @ X
                        # transpose grad_output for the matmul
                        contribution = tl.dot(
                            grad_output_block.to(tl.float32).T,  # [N, M_chunk]
                            x_block.to(tl.float32),  # [M_chunk, K]
                        )

                        # Accumulate
                        accumulator += contribution

        # Store the result
        if USE_TMA_STORE:
            # Store using TMA
            tl._experimental_descriptor_store(
                workspace,  # TMA store descriptor
                accumulator.to(c_dtype),
                [n_offset, k_offset],
            )
        else:
            # Manual store with bounds checking
            offs_n = tl.arange(0, BLOCK_SIZE_N)
            offs_k = tl.arange(0, BLOCK_SIZE_K)

            # Create masks for bounds checking
            n_mask = offs_n < N - n_offset
            k_mask = offs_k < K - k_offset
            output_mask = n_mask[:, None] & k_mask[None, :]

            # Store the result
            tl.store(
                grad_weight_ptr
                + (n_offset + offs_n)[:, None] * K
                + (k_offset + offs_k)[None, :],
                accumulator.to(c_dtype),
                mask=output_mask,
            )


# ======== End Triton kernels ========

# ======== Triton wrapper functions ========

# ----- main forward pass wrapper -----


def grouped_gemm_forward(
    x: torch.Tensor,
    w: torch.Tensor,
    m_sizes: torch.Tensor,
    tma_size: int = 128,
) -> torch.Tensor:
    """
    M*G style grouped GEMM with TMA and Float8 support.
    # Removed for now - FP8 support is triggered by passing x_scale and w_scale tensors.

    """
    if not CudaUtils.verify_tma():
        raise NotImplementedError("Grouped GEMM without TMA is not supported yet")

    G = m_sizes.shape[0]

    assert x.is_contiguous()
    assert w.is_contiguous()
    assert m_sizes.is_contiguous()

    # Total input size is now [M_total, K] where M_total is the sum of all group sizes
    M_total, K = x.shape
    N = w.shape[0]  # N is now the same for all groups

    assert K == w.shape[1], f"Input K ({K}) must match weight K ({w.shape[1]})"

    # Verify that all group sizes are multiples of ALIGN_SIZE_M
    # This check is commented out because it will involve a GPU-CPU sync
    # assert torch.remainder(m_sizes, ALIGN_SIZE_M).max() == 0, "Group sizes must be a multiple of ALIGN_SIZE_M"

    # Create output tensor with correct shape [M_total, N]
    y = torch.empty((M_total, N // G), device=x.device, dtype=x.dtype)

    if M_total == 0:
        return y

    NUM_SMS = CudaUtils.get_num_sms()
    USE_TMA_LOAD = True
    USE_TMA_STORE = True
    USE_EPILOGUE_SUBTILING = False

    # TMA descriptor helper
    desc_helper = None
    desc_x = x
    desc_w = w
    workspace = None

    if USE_TMA_LOAD:
        desc_helper = TmaDescriptorHelper(tma_size=tma_size)
        desc_helper.init_tma_descriptor("x")
        desc_helper.init_tma_descriptor("w")
        desc_x = desc_helper.get_tma_descriptor_kernel_param("x")
        desc_w = desc_helper.get_tma_descriptor_kernel_param("w")

    if USE_TMA_STORE:
        workspace = torch.empty(
            NUM_SMS * desc_helper.tma_size,
            device=x.device,
            dtype=torch.uint8,
        )

    def grid(META):
        if USE_TMA_LOAD:
            nonlocal desc_helper
            desc_helper.fill_2d_tma_descriptor(
                "x",
                x.data_ptr(),
                M_total,
                K,
                META["BLOCK_SIZE_M"],
                META["BLOCK_SIZE_K"],
                x.element_size(),
            )

            desc_helper.fill_2d_tma_descriptor(
                "w",
                w.data_ptr(),
                N,
                K,
                META["BLOCK_SIZE_N"],
                META["BLOCK_SIZE_K"],
                w.element_size(),
            )
        return (NUM_SMS,)

    M_BUCKET = triton.next_power_of_2(M_total)

    _kernel_mg_forward_hopper[grid](
        desc_x,
        desc_w,
        y,
        workspace,
        m_sizes,
        G,
        M_BUCKET,
        N,
        K,
        NUM_SMS,
        TMA_SIZE=tma_size,
        USE_EPILOGUE_SUBTILING=USE_EPILOGUE_SUBTILING,
    )

    return y


# ======== Improved Backward =============
def grouped_gemm_backward(
    grad_output: torch.Tensor,
    x: torch.Tensor,
    w: torch.Tensor,
    m_sizes: torch.Tensor,
    use_tma: bool = True,
    tma_size: int = 128,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Unified backward pass for grouped GeMM with M*G grouping.
    Uses optimized TMA-based implementations for both dx and dw when available.

    Args:
        grad_output: Gradient of output, shape [M_total, N]
        x: Input tensor from forward pass, shape [M_total, K]
        w: Weight tensor from forward pass, shape [N, K]
        m_sizes: Group sizes tensor, shape [G]
        use_tma: Whether to try using TMA acceleration (if available)
        tma_size: Size of TMA descriptor in bytes


    Returns:
        Tuple of gradients with respect to x and w: (grad_x, grad_w)
    """
    logging.info("Starting unified grouped_gemm_backward")

    # do this once, seems expensive
    NUM_SMS = CudaUtils.get_num_sms()

    # Basic validation
    G = m_sizes.shape[0]
    M_total, K_x = x.shape
    M_grad, N = grad_output.shape
    N_w, K_w = w.shape

    # Check dimensions
    if K_x != K_w:
        raise ValueError(f"K dimension mismatch: x has K={K_x}, w has K={K_w}")
    if M_total != M_grad:
        raise ValueError(
            f"M dimension mismatch: x has M={M_total}, grad_output has M={M_grad}"
        )

    # Check total M matches sum of group sizes
    sum_m_sizes = m_sizes.sum().item()
    if M_total != sum_m_sizes:
        raise ValueError(
            f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
        )

    # Make sure inputs are contiguous
    grad_output = grad_output.contiguous()
    x = x.contiguous()
    w = w.contiguous()
    m_sizes = m_sizes.contiguous()

    # Check TMA support
    can_use_tma = use_tma and CudaUtils.verify_tma()
    if use_tma and not can_use_tma:
        logging.info("TMA requested but not supported on this device")
        use_tma = False

    # Compute grad_x using flat linear implementation
    try:
        logging.info(f"Computing grad_x with flat linear kernel")

        # Use TMA-optimized implementation
        grad_x = grouped_gemm_dx_tma(
            grad_output=grad_output,
            w=w,
            m_sizes=m_sizes,
            num_sms=NUM_SMS,
            tma_size=tma_size,
        )

    except Exception as e:
        logging.error(f"Error in grad_x computation: {e}")
        raise

    # Compute grad_w using flat linear style implementation
    try:
        logging.info(f"Computing grad_w with flat linear kernel")

        grad_w = grouped_gemm_dw_tma(
            x, grad_output, m_sizes, num_sms=NUM_SMS, tma_size=tma_size
        )
    except Exception as e:
        logging.error(f"Error in grad_w computation: {e}")
        raise

    return grad_x, grad_w


# ----- dx backward pass wrapper -----


def grouped_gemm_dx_tma(
    grad_output: torch.Tensor,
    w: torch.Tensor,
    m_sizes: torch.Tensor,
    num_sms: int = 132,
    tma_size: int = 128,
) -> torch.Tensor:
    """
    Optimized backward pass wrapper for computing gradient with respect to input (dx)
    using TMA patterns similar to the forward pass.

    Args:
        grad_output: Gradient of output, shape [M_total, N]
        w: Weight tensor, shape [N, K]
        m_sizes: Group sizes tensor, shape [G]
        tma_size: Size of TMA descriptor
        # using_fp8: Whether to use FP8 quantization
        # grad_output_scale: Scale for grad_output in FP8 mode
        # w_scale: Scale for w in FP8 mode

    Returns:
        grad_x: Gradient with respect to x, shape [M_total, K]
    """
    """
    Optimized backward pass for computing gradient with respect to input (dx)
    using TMA patterns similar to the forward pass.

    Args:
        grad_output: Gradient of output, shape [M_total, N]
        w: Weight tensor, shape [N, K]
        m_sizes: Group sizes tensor, shape [G]
        tma_size: Size of TMA descriptor
        using_fp8: Whether to use FP8 quantization
        # grad_output_scale: Scale for grad_output in FP8 mode
        # w_scale: Scale for w in FP8 mode

    Returns:
        grad_x: Gradient with respect to x, shape [M_total, K]
    """
    if not CudaUtils.verify_tma():
        raise NotImplementedError("Optimized dx computation requires TMA support")

    G = m_sizes.shape[0]

    assert grad_output.is_contiguous()
    assert w.is_contiguous()
    assert m_sizes.is_contiguous()

    M_total, N_grad = grad_output.shape
    N_w, K = w.shape

    # Check dimensions
    assert N_grad == N_w, f"Grad_output N ({N_grad}) must match weight N ({N_w})"

    # Verify that the sum of m_sizes matches M_total
    sum_m_sizes = m_sizes.sum().item()
    assert (
        M_total == sum_m_sizes
    ), f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"

    # Create output tensor (grad_x) with shape [M_total, K]
    grad_x = torch.empty(
        (M_total, K), device=grad_output.device, dtype=grad_output.dtype
    )

    NUM_SMS = num_sms  # CudaUtils.get_num_sms()
    USE_TMA_LOAD = True
    USE_TMA_STORE = True

    # Set up TMA descriptors
    desc_helper = TmaDescriptorHelper(tma_size=tma_size)
    desc_helper.init_tma_descriptor("grad_output")
    desc_helper.init_tma_descriptor("w")
    desc_grad_output = desc_helper.get_tma_descriptor_kernel_param("grad_output")
    desc_w = desc_helper.get_tma_descriptor_kernel_param("w")

    # Allocate workspace for TMA store
    workspace = torch.empty(
        NUM_SMS * desc_helper.tma_size,
        device=grad_output.device,
        dtype=torch.uint8,
    )

    def grid(META):
        # Fill TMA descriptors with appropriate dimensions
        desc_helper.fill_2d_tma_descriptor(
            "grad_output",
            grad_output.data_ptr(),
            M_total,
            N_grad,
            META["BLOCK_SIZE_M"],
            META["BLOCK_SIZE_N"],
            grad_output.element_size(),
        )

        desc_helper.fill_2d_tma_descriptor(
            "w",
            w.data_ptr(),
            N_w,
            K,
            META["BLOCK_SIZE_N"],
            META["BLOCK_SIZE_K"],
            w.element_size(),
        )
        return (NUM_SMS,)

    M_BUCKET = triton.next_power_of_2(M_total)

    # Launch the flat linear kernel for computing grad_x
    _kernel_mg_dx_tma[grid](
        desc_grad_output,
        desc_w,
        grad_x,
        workspace,
        m_sizes,
        G,
        M_BUCKET,
        N_grad,  # N dimension is now the reduction dimension
        K,
        NUM_SMS,
        USE_TMA_LOAD,
        USE_TMA_STORE,
        TMA_SIZE=tma_size,
    )

    return grad_x


# ======== dw wrapper function ==========


def grouped_gemm_dw_tma(
    x: torch.Tensor,
    grad_output: torch.Tensor,
    m_sizes: torch.Tensor,
    num_sms: int = 132,
    tma_size: int = 128,
) -> torch.Tensor:
    """
    Optimized flat linear kernel computation of gradients with respect to weights (dw) using TMA.
    For the forward pass Y = X @ W.T, the backward for weights is:
    grad_W = grad_Y.T @ X

    Args:
        x: Input tensor, shape [M_total, K]
        grad_output: Gradient of output, shape [M_total, N]
        m_sizes: Group sizes tensor, shape [G]
        tma_size: Size of TMA descriptor in bytes


    Returns:
        grad_w: Gradient with respect to weights, shape [N, K]
    """
    # Check TMA support
    has_tma_support = CudaUtils.verify_tma()

    # Get group count
    G = m_sizes.shape[0]

    # Ensure contiguous tensors
    x = x.contiguous()
    grad_output = grad_output.contiguous()
    m_sizes = m_sizes.contiguous()

    # Get dimensions
    M_total, K_x = x.shape
    M_grad, N = grad_output.shape

    # Check dimensions
    assert M_total == M_grad, f"x M ({M_total}) must match grad_output M ({M_grad})"

    # Verify that the sum of m_sizes matches M_total
    sum_m_sizes = m_sizes.sum().item()
    assert (
        sum_m_sizes == M_total
    ), f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"

    # Create output tensor (grad_w) with shape [N, K]
    grad_w = torch.zeros((N, K_x), device=x.device, dtype=x.dtype)

    NUM_SMS = num_sms

    # TODO  - hardcoded for now...but should set TMA flags based on hardware support
    USE_TMA_LOAD = True  # has_tma_support
    USE_TMA_STORE = True  # has_tma_support

    # Set up TMA descriptors or direct pointers
    if USE_TMA_LOAD or USE_TMA_STORE:
        desc_helper = TmaDescriptorHelper(tma_size=tma_size)

        if USE_TMA_LOAD:
            desc_helper.init_tma_descriptor("x")
            desc_helper.init_tma_descriptor("grad_output")
            x_desc = desc_helper.get_tma_descriptor_kernel_param("x")
            grad_output_desc = desc_helper.get_tma_descriptor_kernel_param(
                "grad_output"
            )
        else:
            x_desc = x
            grad_output_desc = grad_output

        if USE_TMA_STORE:
            desc_helper.init_tma_descriptor("grad_w")
            workspace = desc_helper.get_tma_descriptor_kernel_param("grad_w")
        else:
            workspace = torch.empty(1, device=x.device, dtype=torch.uint8)
    else:
        # If not using TMA, just use the tensors directly
        x_desc = x
        grad_output_desc = grad_output
        workspace = torch.empty(1, device=x.device, dtype=torch.uint8)

    # M_BUCKET for grid size
    M_BUCKET = triton.next_power_of_2(M_total)

    # Define grid for kernel launch
    def grid(META):
        if USE_TMA_LOAD or USE_TMA_STORE:

            if USE_TMA_LOAD:
                desc_helper.fill_2d_tma_descriptor(
                    "x",
                    x.data_ptr(),
                    M_total,
                    K_x,
                    META["BLOCK_SIZE_M"],
                    META["BLOCK_SIZE_K"],
                    x.element_size(),
                )

                desc_helper.fill_2d_tma_descriptor(
                    "grad_output",
                    grad_output.data_ptr(),
                    M_total,
                    N,
                    META["BLOCK_SIZE_M"],
                    META["BLOCK_SIZE_N"],
                    grad_output.element_size(),
                )

            if USE_TMA_STORE:
                desc_helper.fill_2d_tma_descriptor(
                    "grad_w",
                    grad_w.data_ptr(),
                    N,
                    K_x,
                    META["BLOCK_SIZE_N"],
                    META["BLOCK_SIZE_K"],
                    grad_w.element_size(),
                )

        # Return grid size - one block per SM for balanced work distribution
        return (NUM_SMS,)

    # Launch the optimized kernel
    _kernel_mg_dw_tma[grid](
        x_desc,
        grad_output_desc,
        grad_w,
        workspace,
        m_sizes,
        G,
        M_BUCKET,
        N,
        K_x,
        NUM_SMS,
        USE_TMA_LOAD,
        USE_TMA_STORE,
        TMA_SIZE=tma_size,
    )

    return grad_w


# ======== End Backwards Wrapper Functions =============

# ======== PyTorch wrapper functions ========


class GroupedGEMM_mg(torch.autograd.Function):
    """
    Autograd function for GroupedGEMM with M*G grouping.
    Supports both standard and FP8 quantized operations.
    """

    @staticmethod
    def forward(ctx, x, w, m_sizes, use_tma=True, tma_size=128):
        """
        Forward pass of GroupedGEMM.

        Args:
            x: Input tensor, shape [M_total, K]
            w: Weight tensor, shape [N, K]
            m_sizes: Tensor of shape [G] containing the size of each group
            use_tma: Whether to try using TMA acceleration (if available)
            tma_size: Size of TMA descriptor in bytes
            using_fp8: Whether to use FP8 quantization

        Returns:
            Output tensor, shape [M_total, N]
        """

        # Use regular forward without quantization
        output = grouped_gemm_forward(
            x=x, w=w, m_sizes=m_sizes, tma_size=tma_size, using_fp8=False
        )

        # Save inputs and parameters for backward pass
        ctx.save_for_backward(x, w, m_sizes)
        ctx.use_tma = use_tma
        ctx.tma_size = tma_size

        ctx.save_for_backward(x, w, m_sizes)

        return output

    @staticmethod
    def backward(ctx, grad_output):
        """
        Backward pass of M*G GroupedGEMM.

        Args:
            grad_output: Gradient of output, shape [M_total, N]

        Returns:
            Tuple of gradients:
                - grad_x: Gradient with respect to x, shape [M_total, K]
                - grad_w: Gradient with respect to w, shape [N, K]
                - None: Gradient with respect to m_sizes (not differentiable)
                - None: Gradient with respect to use_tma (not differentiable)
                - None: Gradient with respect to tma_size (not differentiable)

        """
        # Retrieve saved tensors and parameters

        x, w, m_sizes = ctx.saved_tensors

        use_tma = ctx.use_tma
        tma_size = ctx.tma_size

        # Compute gradients using the unified implementation
        grad_x, grad_w = grouped_gemm_backward(
            grad_output=grad_output,
            x=x,
            w=w,
            m_sizes=m_sizes,
            use_tma=use_tma,
            tma_size=tma_size,
        )

        # Return gradients for all inputs (None for non-differentiable parameters)
        return grad_x, grad_w, None, None


def mg_grouped_gemm(
    x: torch.Tensor,
    w: torch.Tensor,
    m_sizes: torch.Tensor,
    use_tma: bool = True,
    tma_size: int = 128,
    using_fp8: bool = False,
) -> torch.Tensor:
    """
    Unified differentiable grouped GEMM operation for M*G grouped GEMM.
    Supports both standard precision and FP8 quantized operations.

    Args:
        x: Input tensor, shape [M_total, K]
        w: Weight tensor, shape [N, K]
        m_sizes: Tensor of shape [G] containing the size of each group
        use_tma: Whether to try using TMA acceleration (if available)
        tma_size: Size of TMA descriptor in bytes
        using_fp8: Whether to use FP8 quantization

    Returns:
        Output tensor, shape [M_total, N]
    """
    return GroupedGEMM_mg.apply(x, w, m_sizes, use_tma, tma_size, using_fp8)
