from __future__ import annotations

import math
from typing import Optional

import torch
import torch.nn.functional as F

from torchtitan.kernels.triton.fused_swiglu_ffn import launch_fused_swiglu_ffn_forward


class FusedSwiGLUFFNFunction(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        x: torch.Tensor,  # [B, S, K]
        w1: torch.Tensor,  # [K, H]
        w3: torch.Tensor,  # [K, H]
        w2: torch.Tensor,  # [H, K]
        out_of_place: bool = True,
        block_m: int = 128,
        block_k: int = 64,
        block_h: int = 128,
        block_d: int = 128,
        num_warps: int = 8,
        num_stages: int = 2,
    ) -> torch.Tensor:
        assert x.is_cuda and w1.is_cuda and w2.is_cuda and w3.is_cuda
        dtype = x.dtype
        device = x.device
        assert dtype == torch.bfloat16, "Only bf16 is supported in v0"

        B, S, K = x.shape
        H = w1.shape[1]
        D = w2.shape[1]
        assert D == K, "Output dim must equal input dim in this implementation"

        x2d = x.reshape(B * S, K)
        # ensure contiguous 2D inputs for simple stride math
        x2d = x2d.contiguous()
        w1c = w1.contiguous()
        w3c = w3.contiguous()
        w2c = w2.contiguous()

        y2d = torch.empty((B * S, D), device=device, dtype=dtype)

        launch_fused_swiglu_ffn_forward(
            x2d, w1c, w3c, w2c, y2d,
            BLOCK_M=block_m,
            BLOCK_K=block_k,
            BLOCK_H=block_h,
            BLOCK_D=block_d,
            num_warps=num_warps,
            num_stages=num_stages,
        )

        y = y2d.view(B, S, D)

        # Save for backward minimal tensors; v0 recomputes activations
        ctx.save_for_backward(x, w1, w3, w2)
        ctx.dimensions = (B, S, K, H, D)
        return y

    @staticmethod
    def backward(ctx, grad_out: torch.Tensor):
        x, w1, w3, w2 = ctx.saved_tensors
        B, S, K, H, D = ctx.dimensions

        grad_x = grad_w1 = grad_w3 = grad_w2 = None

        # Flatten for matmuls
        x2d = x.reshape(B * S, K)
        gy2d = grad_out.reshape(B * S, D)

        with torch.cuda.amp.autocast(enabled=False):
            # Cast to fp32 for numerical stability in grads
            x32 = x2d.to(torch.float32)
            w1_32 = w1.to(torch.float32)
            w3_32 = w3.to(torch.float32)
            w2_32 = w2.to(torch.float32)
            gy32 = gy2d.to(torch.float32)

            a = x32 @ w1_32  # [M, H]
            b = x32 @ w3_32  # [M, H]
            sig = torch.sigmoid(a)
            silu_a = a * sig
            c = silu_a * b  # [M, H]

            if ctx.needs_input_grad[3]:  # w2
                grad_w2 = c.t() @ gy32  # [H, D]

            dc = gy32 @ w2_32.t()  # [M, H]
            da = dc * b * (sig * (1.0 + a * (1.0 - sig)))  # silu'(a) * b * dc
            db = dc * silu_a

            if ctx.needs_input_grad[1]:  # w1
                grad_w1 = x32.t() @ da  # [K, H]
            if ctx.needs_input_grad[2]:  # w3
                grad_w3 = x32.t() @ db  # [K, H]
            if ctx.needs_input_grad[0]:  # x
                grad_x = da @ w1_32.t() + db @ w3_32.t()  # [M, K]
                grad_x = grad_x.view(B, S, K)

        # Cast grads back to parameter dtypes
        if grad_w2 is not None:
            grad_w2 = grad_w2.to(w2.dtype)
        if grad_w1 is not None:
            grad_w1 = grad_w1.to(w1.dtype)
        if grad_w3 is not None:
            grad_w3 = grad_w3.to(w3.dtype)
        if grad_x is not None:
            grad_x = grad_x.to(x.dtype)

        # forward inputs: x, w1, w3, w2, out_of_place, block_m, block_k, block_h, block_d, num_warps, num_stages
        return grad_x, grad_w1, grad_w3, grad_w2, None, None, None, None, None, None, None


# @torch.compile
def fused_swiglu_ffn_forward(
    x: torch.Tensor,
    w1: torch.Tensor,
    w3: torch.Tensor,
    w2: torch.Tensor,
    out_of_place: bool = True,
    block_m: int = 64,
    block_k: int = 64,
    block_h: int = 64,
    block_d: int = 64,
    num_warps: int = 4,
    num_stages: int = 1,
) -> torch.Tensor:
    """User-facing fused FFN forward with autograd support.

    Args:
        x: [B, S, K] bf16
        w1: [K, H] bf16
        w3: [K, H] bf16
        w2: [H, K] bf16
        out_of_place: currently unused; always out-of-place in v0
    """
    return FusedSwiGLUFFNFunction.apply(
        x, w1, w3, w2, out_of_place, block_m, block_k, block_h, block_d, num_warps, num_stages
    )