import torch
import triton
import triton.language as tl
from typing import Optional, Tuple, Union
import os
os.environ['TRITON_PRINT_AUTOTUNING'] = '1'

@triton.jit
def silu(x):
    """SiLU activation function: x * sigmoid(x)"""
    return x * tl.sigmoid(x)

norm_configs=[
    triton.Config({"BLOCK_SIZE": 32}, num_warps=1),
    triton.Config({"BLOCK_SIZE": 64}, num_warps=1),
    triton.Config({"BLOCK_SIZE": 128}, num_warps=1),
    triton.Config({"BLOCK_SIZE": 256}, num_warps=1),
    triton.Config({"BLOCK_SIZE": 512}, num_warps=1),
    triton.Config({"BLOCK_SIZE": 32}, num_warps=2),
    triton.Config({"BLOCK_SIZE": 64}, num_warps=2),
    triton.Config({"BLOCK_SIZE": 128}, num_warps=2),
    triton.Config({"BLOCK_SIZE": 256}, num_warps=2),
    triton.Config({"BLOCK_SIZE": 512}, num_warps=2),
    triton.Config({"BLOCK_SIZE": 32}, num_warps=4),
    triton.Config({"BLOCK_SIZE": 64}, num_warps=4),
    triton.Config({"BLOCK_SIZE": 128}, num_warps=4),
    triton.Config({"BLOCK_SIZE": 256}, num_warps=4),
    triton.Config({"BLOCK_SIZE": 512}, num_warps=4),
]
@triton.autotune(
    configs=norm_configs,
    key=["hidden_dim"]  # N=input_dim
)
@triton.jit
def swiglu_norm_kernel(
    a_ptr,                # First input tensor pointer
    b_ptr,                # Second input tensor pointer
    c_ptr,                # Output tensor pointer for activated values
    norm_squared_ptr,     # Pointer to store squared norm
    hidden_dim: tl.constexpr,  # Full hidden dimension size
    BLOCK_SIZE: tl.constexpr   # Block size for processing
):
    # Get block index
    block_idx = tl.program_id(0)
    
    # Calculate start offset for this block
    start_offset = block_idx * BLOCK_SIZE
    
    # Create offsets and mask for this block
    offsets = start_offset + tl.arange(0, BLOCK_SIZE)
    mask = offsets < hidden_dim
    
    # Load input values - since input is (1,1,hidden_dim), we can directly use offsets
    a_vals = tl.load(a_ptr + offsets, mask=mask, other=0).to(tl.float32)
    b_vals = tl.load(b_ptr + offsets, mask=mask, other=0)
    
    # Compute SwiGLU activation: SiLU(a) * b
    silu_a = silu(a_vals)
    c_vals = silu_a.to(b_vals.dtype) * b_vals
    
    # Store activated values
    tl.store(c_ptr + offsets, c_vals, mask=mask)
    
    # Compute squared values
    c_squared = c_vals.to(tl.float32) * c_vals.to(tl.float32)
    
    # Sum squared values for this block
    block_sum = tl.sum(c_squared * mask.to(tl.float32))
    
    # Atomic add to global norm squared
    tl.atomic_add(norm_squared_ptr, block_sum)


def swiglu_norm(a: torch.Tensor, b: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Optimized SwiGLU with norm calculation for (1,1,hidden_dim) shaped inputs
    
    Args:
        a: First input tensor of shape (1,1,hidden_dim)
        b: Second input tensor of shape (1,1,hidden_dim)
        
    Returns:
        Tuple of (activated_tensor, norm_squared)
    """
    # Verify input shapes
    assert a.shape == b.shape, f"Input shapes must match: {a.shape} vs {b.shape}"
    assert a.shape[0] == 1 and a.shape[1] == 1, f"Input must have shape (1,1,hidden_dim), got {a.shape}"
    
    # Get hidden dimension
    hidden_dim = a.shape[2]
    
    # Create output tensor
    c = torch.empty_like(a)
    
    # Create tensor to store squared norm
    norm_squared = torch.zeros(1, dtype=torch.float32, device=a.device)
    
    grid = lambda META: (
            triton.cdiv(hidden_dim, META["BLOCK_SIZE"]),
    )

    # Launch kernel with one block per BLOCK_SIZE elements
    swiglu_norm_kernel[grid](
        a,
        b,
        c,
        norm_squared,
        hidden_dim=hidden_dim,
    )
    
    # Reshape output to original shape
    
    return c, norm_squared