
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import triton
import triton.language as tl
from einops import rearrange

    
def ctx_save(ctx, requires_grad: bool, *args, **kwargs):
    if requires_grad:
        ctx.save_for_backward(*args) 
        for key, value in kwargs.items():
            ctx.__setattr__(key, value)

@triton.autotune(
        configs=[
            triton.Config({'BLOCK_SIZE_NHW': 32, 'BLOCK_SIZE_C': 2}, num_warps=4),
            triton.Config({'BLOCK_SIZE_NHW': 32, 'BLOCK_SIZE_C': 4}, num_warps=4),
            triton.Config({'BLOCK_SIZE_NHW': 64, 'BLOCK_SIZE_C': 2}, num_warps=4),
            triton.Config({'BLOCK_SIZE_NHW': 64, 'BLOCK_SIZE_C': 4}, num_warps=8),
            triton.Config({'BLOCK_SIZE_NHW': 128, 'BLOCK_SIZE_C': 2}, num_warps=8),
            triton.Config({'BLOCK_SIZE_NHW': 128, 'BLOCK_SIZE_C': 4}, num_warps=8),

        ],
        key=['NHW', 'C'],
)
@triton.jit
def t_first_triton_forward_kernel(x_seq_ptr: torch.Tensor, # [T+K-1, N, C, H, W]
                    weight_ptr: torch.Tensor, # [C, 1, K]
                    output_ptr: torch.Tensor, # [T, N, C, H, W]
                    NHW, T: tl.constexpr, N, C, H, W, K: tl.constexpr,
                    stride_t, stride_n, stride_c, stride_h, stride_w, 
                    BLOCK_SIZE_NHW: tl.constexpr, BLOCK_SIZE_C: tl.constexpr):
    
    nhw_pid = tl.program_id(0)
    c_pid = tl.program_id(1)

    nhw_offset = nhw_pid * BLOCK_SIZE_NHW + tl.arange(0, BLOCK_SIZE_NHW) # [BLOCK_SIZE_NHW]
    c_offset = c_pid * BLOCK_SIZE_C + tl.arange(0, BLOCK_SIZE_C) # [BLOCK_SIZE_C]
    
    n_idx = nhw_offset // (H*W) 
    hw_idx = nhw_offset % (H*W)
    h_idx = hw_idx // W
    w_idx = hw_idx % W

    weight = tl.load(weight_ptr + c_offset[:, None] * K + tl.arange(0, K)[None, :], mask=c_offset[:, None] < C) # [BLOCK_SIZE_C, K]
    x_ptr_idx = (tl.arange(0, T)[None, None, :, None] + tl.arange(0, K)[None, None, None, :]) * stride_t + n_idx[:, None, None, None] *  stride_n + c_offset[None, :, None, None] * stride_c + h_idx[:, None, None, None] * stride_h + w_idx[:, None, None, None] * stride_w # [BLOCK_SIZE_NHW, BLOCK_SIZE_C, T, K]
    x_val = tl.load(x_seq_ptr + x_ptr_idx, mask=x_ptr_idx < (T+K-1)*N*C*H*W)
    x_val = tl.reshape(x_val, (BLOCK_SIZE_NHW, BLOCK_SIZE_C, T, K))

    output_val = tl.reshape(tl.sum(weight[None, :, None, :] * x_val, axis=3), (BLOCK_SIZE_NHW * BLOCK_SIZE_C * T, ))

    output_ptr_idx = tl.arange(0, T)[None, None, :] * stride_t + n_idx[:, None, None] * stride_n + c_offset[None, :, None] * stride_c + h_idx[:, None,  None] * stride_h + w_idx[:, None, None] * stride_w
    output_ptr_idx = tl.reshape(output_ptr_idx, (BLOCK_SIZE_NHW * BLOCK_SIZE_C * T, ))

    tl.store(output_ptr + output_ptr_idx, output_val, mask=output_ptr_idx < T*N*C*H*W)

@triton.autotune(
        configs=[
            triton.Config({'BLOCK_SIZE_NHW': 32, 'BLOCK_SIZE_C': 2}, num_warps=4),
            triton.Config({'BLOCK_SIZE_NHW': 32, 'BLOCK_SIZE_C': 4}, num_warps=4),
            triton.Config({'BLOCK_SIZE_NHW': 64, 'BLOCK_SIZE_C': 2}, num_warps=4),
            triton.Config({'BLOCK_SIZE_NHW': 64, 'BLOCK_SIZE_C': 4}, num_warps=8),
            triton.Config({'BLOCK_SIZE_NHW': 128, 'BLOCK_SIZE_C': 2}, num_warps=8),
            triton.Config({'BLOCK_SIZE_NHW': 128, 'BLOCK_SIZE_C': 4}, num_warps=8),
        ],
        key=['NHW', 'C'],
        reset_to_zero=['grad_w_ptr']
)
@triton.jit
def t_first_triton_gradw_kernel(grad_output_ptr, # [T, N, C, H, W]
                        x_ptr, # [T+K-1, N, C, H, W],
                        grad_w_ptr, # [C, 1, K]
                        NHW, T: tl.constexpr, N, C, H, W, K: tl.constexpr, 
                        stride_t, stride_n, stride_c, stride_h, stride_w,
                        BLOCK_SIZE_NHW: tl.constexpr, BLOCK_SIZE_C: tl.constexpr):
    
    nhw_pid = tl.program_id(0)
    c_pid = tl.program_id(1)

    nhw_offset = nhw_pid * BLOCK_SIZE_NHW + tl.arange(0, BLOCK_SIZE_NHW)
    c_offset = c_pid * BLOCK_SIZE_C + tl.arange(0, BLOCK_SIZE_C)

    n_idx = nhw_offset // (H*W)
    hw_idx = nhw_offset % (H*W)
    h_idx = hw_idx // W
    w_idx = hw_idx % W
   
    grad_output_ptr_idx = tl.arange(0, T)[None, None, :]  * stride_t + n_idx[:, None, None] * stride_n + c_offset[None, :, None] * stride_c + h_idx[:, None, None] * stride_h + w_idx[:, None, None] * stride_w # [BLOCK_SIZE_NHW, BLOCK_SIZE_C, T]
    grad_output_val = tl.load(grad_output_ptr + grad_output_ptr_idx, mask=grad_output_ptr_idx < T*N*C*H*W)
    grad_output_val = tl.reshape(grad_output_val, (BLOCK_SIZE_NHW, BLOCK_SIZE_C, T))
   
    x_ptr_idx = (tl.arange(0, T)[None, None, :, None] + tl.arange(0, K)[None, None, None, :]) * stride_t + n_idx[:, None, None, None] * stride_n + c_offset[None, :, None, None] * stride_c + h_idx[:, None, None, None] * stride_h + w_idx[:, None, None, None] * stride_w # [BLOCK_SIZE_NHW, BLOCK_SIZE_C, T, K]

    x_val = tl.load(x_ptr + x_ptr_idx, mask=x_ptr_idx < (T+K-1)*N*C*H*W)

    grad_w_val = tl.sum(grad_output_val[:, :, :, None] * x_val, axis=0)
    grad_w_val = tl.sum(grad_w_val, axis=1) # [BLOCK_SIZE_C, K]
    grad_w_val = tl.reshape(grad_w_val, (BLOCK_SIZE_C * K, ))

    grad_w_ptr_idx = tl.reshape(c_offset[:, None] * K + tl.arange(0, K)[None, :], (BLOCK_SIZE_C * K, ))
    tl.atomic_add(grad_w_ptr + grad_w_ptr_idx, grad_w_val, mask=grad_w_ptr_idx < C * K)

    
class T_First_Triton_ATGF(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x_seq: torch.Tensor, weight: torch.Tensor, padding: int, bias: torch.Tensor=None):
        shape = x_seq.shape 
        if shape.__len__() == 4:
            x_seq = x_seq.unsqueeze(-1)
        elif shape.__len__() == 3:
            x_seq = x_seq.unsqueeze(-1).unsqueeze(-1)


        T, N, C, H, W = x_seq.shape
        NHW = N * H * W
        K = weight.shape[2]

        output = torch.zeros(x_seq.shape, device=x_seq.device, dtype=x_seq.dtype)
        x_seq_pad = F.pad(x_seq, (0, 0, 0, 0, 0, 0, 0, 0, padding, 0))
        
        stride = x_seq.stride()

        grid = lambda META: (triton.cdiv(NHW, META['BLOCK_SIZE_NHW']), triton.cdiv(C, META['BLOCK_SIZE_C']))

        t_first_triton_forward_kernel[grid](x_seq_pad, weight, output, NHW, T, N, C, H, W, K, stride[0], stride[1], stride[2], stride[3], stride[4])

        if bias is not None:
            output += bias.view(1, 1, -1, 1, 1)

        ctx_save(ctx, x_seq.requires_grad, x_seq_pad, weight, bias, T=T, H=H, W=W,  N=N, C=C, K=K, P=padding, grid=grid, shape=shape)
        return output.view(shape)

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        x_seq, weight, bias = ctx.saved_tensors
        weight = weight.flip(2)
        T, N, C, H, W, K, P, grid = ctx.T, ctx.N, ctx.C, ctx.H, ctx.W, ctx.K, ctx.P, ctx.grid
        shape = ctx.shape

        if grad_output.shape.__len__() == 4:
            grad_output = grad_output.unsqueeze(-1)
        elif grad_output.shape.__len__() == 3:
            grad_output = grad_output.unsqueeze(-1).unsqueeze(-1)

        grad_bias = None
        grad_output_pad = F.pad(grad_output, (0, 0, 0, 0, 0, 0, 0, 0, 0, P))
        grad_weight = torch.zeros([C, 1, K], device=grad_output.device, dtype=torch.float32)
        grad_input = torch.zeros((T, N, C, H, W),  device=grad_output.device, dtype=grad_output.dtype)

        stride = grad_output_pad.stride()
        t_first_triton_forward_kernel[grid](grad_output_pad, weight, grad_input, N*H*W, T, N, C, H, W, K, stride[0], stride[1], stride[2], stride[3], stride[4])
        t_first_triton_gradw_kernel[grid](grad_output, x_seq, grad_weight, N*H*W, T, N, C, H, W, K, stride[0], stride[1], stride[2], stride[3], stride[4])


        if bias is not None:
            dims_to_sum = [i for i in range(grad_output.dim()) if i != 2]
            grad_bias = grad_output.sum(dim=dims_to_sum)
        
        grad_input = grad_input.view(shape)
        return grad_input, grad_weight, None, grad_bias


@triton.autotune(
        configs=[
            triton.Config({'BLOCK_SIZE_NHW': 32, 'BLOCK_SIZE_C': 2}, num_warps=4),
            triton.Config({'BLOCK_SIZE_NHW': 32, 'BLOCK_SIZE_C': 4}, num_warps=4),
            triton.Config({'BLOCK_SIZE_NHW': 64, 'BLOCK_SIZE_C': 2}, num_warps=4),
            triton.Config({'BLOCK_SIZE_NHW': 64, 'BLOCK_SIZE_C': 4}, num_warps=8),
            triton.Config({'BLOCK_SIZE_NHW': 128, 'BLOCK_SIZE_C': 2}, num_warps=8),
            triton.Config({'BLOCK_SIZE_NHW': 128, 'BLOCK_SIZE_C': 4}, num_warps=8),
        ], 
        key=['NHW', 'C']
)
@triton.jit
def t_last_triton_forward_kernel(x_seq_ptr: torch.Tensor, # [N, C, H, W, T+K-1]
                weight_ptr: torch.Tensor, # [C, 1, K]
                output_ptr: torch.Tensor, # [N, C, H, W, T]
                NHW, N, C, H, W, T: tl.constexpr, K: tl.constexpr,
                stride_x_seq_n, stride_x_seq_c, stride_x_seq_h, stride_x_seq_w, stride_x_seq_t, 
                stride_output_n, stride_output_c, stride_output_h, stride_output_w, stride_output_t,
                BLOCK_SIZE_NHW: tl.constexpr, BLOCK_SIZE_C: tl.constexpr):
    
    nhw_pid = tl.program_id(0)
    c_pid = tl.program_id(1)

    nhw_offset = nhw_pid * BLOCK_SIZE_NHW + tl.arange(0, BLOCK_SIZE_NHW) # [BLOCK_SIZE_NHW]
    c_offset = c_pid * BLOCK_SIZE_C + tl.arange(0, BLOCK_SIZE_C) # [BLOCK_SIZE_C]


    n_idx = nhw_offset // (H*W)
    hw_idx = nhw_offset % (H*W)
    h_idx = hw_idx // W
    w_idx = hw_idx % W

    weight = tl.load(weight_ptr + c_offset[:, None] * K + tl.arange(0, K)[None, :], mask=c_offset[:, None] < C) # [BLOCK_SIZE_C, K]

    x_ptr_idx = (n_idx * stride_x_seq_n + h_idx * stride_x_seq_h + w_idx * stride_x_seq_w)[:, None, None, None] + c_offset[None, :, None, None] * stride_x_seq_c + (tl.arange(0, T)[None, None, :, None] + tl.arange(0, K)[None, None, None, :]) * stride_x_seq_t # [BLOCK_SIZE_NHW, BLOCK_SIZE_C, T, K]

    x_val = tl.load(x_seq_ptr + x_ptr_idx, mask=x_ptr_idx < N*C*H*W*(T+K-1))

    output_val = weight[None, :, None, :] * x_val # [BLOCK_SIZE_NHW, BLOCK_SIZE_C, T, K]
    output_val = tl.sum(output_val, axis=3) # [BLOCK_SIZE_NHW, BLOCK_SIZE_C, T]
    output_val = tl.reshape(output_val, (BLOCK_SIZE_NHW * BLOCK_SIZE_C * T, ))

    output_ptr_idx = (n_idx * stride_output_n + h_idx * stride_output_h + w_idx * stride_output_w)[:, None, None] + c_offset[None, :, None] * stride_output_c + tl.arange(0, T)[None, None, :] * stride_output_t # [BLOCK_SIZE_NHW, BLOCK_SIZE_C, T]
    output_ptr_idx = tl.reshape(output_ptr_idx, (BLOCK_SIZE_NHW * BLOCK_SIZE_C * T, ))

    tl.store(output_ptr + output_ptr_idx, output_val, mask=output_ptr_idx < N*C*H*W*T)

@triton.autotune(
        configs=[
            triton.Config({'BLOCK_SIZE_NHW': 32, 'BLOCK_SIZE_C': 2}, num_warps=4),
            triton.Config({'BLOCK_SIZE_NHW': 32, 'BLOCK_SIZE_C': 4}, num_warps=4),
            triton.Config({'BLOCK_SIZE_NHW': 64, 'BLOCK_SIZE_C': 2}, num_warps=4),
            triton.Config({'BLOCK_SIZE_NHW': 64, 'BLOCK_SIZE_C': 4}, num_warps=8),
            triton.Config({'BLOCK_SIZE_NHW': 128, 'BLOCK_SIZE_C': 2}, num_warps=8),
            triton.Config({'BLOCK_SIZE_NHW': 128, 'BLOCK_SIZE_C': 4}, num_warps=8),
        ], 
        key=['NHW', 'C'],
        reset_to_zero=['grad_w_ptr']
)
@triton.jit
def t_last_triton_gradw_kernel(x_ptr, # [N, C, H, W, T+K-1],
                        grad_output_ptr, # [N, C, H, W, T]
                        grad_w_ptr, # [C, 1, K]
                        NHW, N, C, H, W, T: tl.constexpr, K: tl.constexpr, 
                        stride_grad_output_n, stride_grad_output_c, stride_grad_output_h, stride_grad_output_w, stride_grad_output_t,
                        stride_x_n, stride_x_c, stride_x_h, stride_x_w, stride_x_t,
                        BLOCK_SIZE_NHW: tl.constexpr, BLOCK_SIZE_C: tl.constexpr):
    
    nhw_pid = tl.program_id(0)
    c_pid = tl.program_id(1)

    nhw_offset = nhw_pid * BLOCK_SIZE_NHW + tl.arange(0, BLOCK_SIZE_NHW) # [BLOCK_SIZE_NHW]
    c_offset = c_pid * BLOCK_SIZE_C + tl.arange(0, BLOCK_SIZE_C) # [BLOCK_SIZE_C]

    n_idx = nhw_offset // (H*W)
    hw_idx = nhw_offset % (H*W)
    h_idx = hw_idx // W
    w_idx = hw_idx % W

    grad_output_ptr_idx = (n_idx * stride_grad_output_n + h_idx * stride_grad_output_h + w_idx * stride_grad_output_w)[:, None, None] + c_offset[None, :, None] * stride_grad_output_c + tl.arange(0, T)[None, None, :] * stride_grad_output_t # [BLOCK_SIZE_NHW, BLOCK_SIZE_C, T]
   
    grad_output_val = tl.load(grad_output_ptr + grad_output_ptr_idx, mask=grad_output_ptr_idx < N*C*H*W*T)

    x_ptr_idx = (n_idx * stride_x_n + h_idx * stride_x_h + w_idx * stride_x_w)[:, None, None, None] + c_offset[None, :, None, None] * stride_x_c + (tl.arange(0, T)[None, None, :, None] + tl.arange(0, K)[None, None, None, :]) * stride_x_t # [BLOCK_SIZE_NHW, BLOCK_SIZE_C, T, K]

    x_val = tl.load(x_ptr + x_ptr_idx, mask=x_ptr_idx < N*C*H*W*(T+K-1))

    grad_w_val = grad_output_val[:, :, :, None] * x_val
    grad_w_val = tl.sum(grad_w_val, axis=2) # [BLOCK_SIZE_NHW, BLOCK_SIZE_C, K]
    grad_w_val = tl.sum(grad_w_val, axis=0) # [BLOCK_SIZE_C, K]
    grad_w_val = tl.reshape(grad_w_val, (BLOCK_SIZE_C * K, ))

    grad_w_ptr_idx = tl.reshape(c_offset[:, None] * K + tl.arange(0, K)[None, :], (BLOCK_SIZE_C * K, ))
    tl.atomic_add(grad_w_ptr + grad_w_ptr_idx, grad_w_val, mask=grad_w_ptr_idx < C * K)


class T_Last_Triton_ATGF(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x_seq: torch.Tensor, weight: torch.Tensor, padding: int, bias: torch.Tensor=None):
        shape = x_seq.shape
        if shape.__len__() == 4:
            x_seq = x_seq.unsqueeze(-2)
        elif shape.__len__() == 3:
            x_seq = x_seq.unsqueeze(-2).unsqueeze(-2)


        N, C, H, W, T = x_seq.shape
        NHW = N * H * W
        K = weight.shape[2]

        output = torch.zeros(x_seq.shape, device=x_seq.device, dtype=x_seq.dtype)
        x_seq_pad = F.pad(x_seq, (padding, 0))

        stride_output = output.stride()
        stride_x_seq_pad = x_seq_pad.stride()
        
        grid = lambda META: (triton.cdiv(NHW, META['BLOCK_SIZE_NHW']), triton.cdiv(C, META['BLOCK_SIZE_C']))

        t_last_triton_forward_kernel[grid](x_seq_pad, weight, output, NHW, N, C, H, W, T, K, stride_x_seq_pad[0], stride_x_seq_pad[1], stride_x_seq_pad[2], stride_x_seq_pad[3], stride_x_seq_pad[4], stride_output[0], stride_output[1], stride_output[2], stride_output[3], stride_output[4])            

        if bias is not None:
            output += bias.view(1, -1, 1, 1, 1)

        ctx_save(ctx, x_seq.requires_grad, x_seq_pad, weight, bias, N=N, C=C, H=H, W=W, T=T, K=K, P=padding, grid=grid, shape=shape)
        return output.view(shape)

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        x_seq, weight, bias = ctx.saved_tensors
        weight = weight.flip(2)
        N, C, H, W, T, K, P, grid = ctx.N, ctx.C, ctx.H, ctx.W, ctx.T, ctx.K, ctx.P, ctx.grid
        shape = ctx.shape

        if grad_output.shape.__len__() == 4:
            grad_output = grad_output.unsqueeze(-2)
        elif grad_output.shape.__len__() == 3:
            grad_output = grad_output.unsqueeze(-2).unsqueeze(-2)

        grad_bias = None
        grad_output_pad = F.pad(grad_output, (0, P))
        grad_weight = torch.zeros([C, 1, K], device=grad_output.device, dtype=grad_output.dtype)
        grad_input = torch.zeros((N, C, H, W, T),  device=grad_output.device, dtype=grad_output.dtype)

        stride_x_seq  = x_seq.stride()
        stride_grad_input = grad_input.stride()
        stride_grad_output_pad = grad_output_pad.stride()
        stride_grad_output = grad_output.stride()

        NHW = N * H * W

        t_last_triton_forward_kernel[grid](grad_output_pad, weight, grad_input, NHW, N, C, H, W, T, K, stride_grad_output_pad[0], stride_grad_output_pad[1], stride_grad_output_pad[2], stride_grad_output_pad[3], stride_grad_output_pad[4], stride_grad_input[0], stride_grad_input[1], stride_grad_input[2], stride_grad_input[3], stride_grad_input[4])
        t_last_triton_gradw_kernel[grid](x_seq, grad_output, grad_weight, NHW, N, C, H, W, T, K, stride_grad_output[0], stride_grad_output[1], stride_grad_output[2], stride_grad_output[3], stride_grad_output[4], stride_x_seq[0], stride_x_seq[1], stride_x_seq[2], stride_x_seq[3], stride_x_seq[4])

        if bias is not None:
            dims_to_sum = [i for i in range(grad_output.dim()) if i != 1]
            grad_bias = grad_output.sum(dim=dims_to_sum)

        grad_input = grad_input.view(shape)

        return grad_input, grad_weight, None, grad_bias

def t_last_triton_impl(x_seq, weight, padding, bias):
    return T_Last_Triton_ATGF.apply(x_seq, weight, padding, bias)

def t_frist_triton_impl(x_seq, weight, padding, bias):
    return T_First_Triton_ATGF.apply(x_seq, weight, padding, bias)