# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


# CREDITS: This is heavily inspired by the Triton dropout tutorial
# https://raw.githubusercontent.com/openai/triton/master/python/tutorials/04-low-memory-dropout.py

import triton
import triton.language as tl

from xformers.triton.k_activations import (
    gelu,
    gelu_grad,
    leaky_relu,
    leaky_relu_grad,
    relu,
    relu_grad,
    smelu,
    smelu_grad,
    squared_relu,
    squared_relu_grad,
)

_configs = [
    triton.Config({}, num_warps=1),
    triton.Config({}, num_warps=2),
    triton.Config({}, num_warps=4),
    triton.Config({}, num_warps=8),
    triton.Config({}, num_warps=16),
]


# fmt: off
@triton.heuristics({"SIZE_RAND_BLOCK": lambda args: args["BLOCK_N"] * args["BLOCK_M"]})
@triton.autotune(
    configs=_configs,
    key=["M", "N", "is_fp16"],
)
@triton.jit
def k_dropout_fw(
    Y, X, BIAS, SEEDS,
    stride,
    M, N,
    p: tl.constexpr,
    is_fp16: tl.constexpr,  # autotune
    ACTIVATION: tl.constexpr,
    # Meta-parameters
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    SIZE_RAND_BLOCK: tl.constexpr,
    USE_BIAS: tl.constexpr,
):
    """
    Apply dropout on an input tensor
    Y : Output  (M, N)
    X : Input   (M, N)
    BIAS        (N,)
    SEEDS       (M,)
    p : dropout probability
    """
    # fmt: on

    row_id = tl.program_id(axis=0)
    rows = row_id * BLOCK_M + tl.arange(0, BLOCK_M)

    col_id = tl.program_id(axis=1)
    cols = col_id * BLOCK_N + tl.arange(0, BLOCK_N)

    # pointers starting point
    x_ptrs = X + rows[:, None] * stride + cols[None, :]
    y_ptrs = Y + rows[:, None] * stride + cols[None, :]

    # good to go, start the layer computations
    col_mask = cols[None, :] < N
    p_scale = 1. / (1. - p)
    if USE_BIAS:
        b_ptrs = BIAS + cols[None, :]
        bias = tl.load(b_ptrs, mask=cols[None, :] < N, other=0.)
    else:
        bias = x_ptrs  # will not be used

    block_mask = (rows[:, None] < M) & col_mask
    x = tl.load(x_ptrs, mask=block_mask, other=0.0)

    # optionally apply a fused bias
    if USE_BIAS:
        x += bias

    # optional: fused activation (while the data is in shared memory)
    if ACTIVATION == 1:
        x = relu(x)
    elif ACTIVATION == 2:
        x = leaky_relu(x)
    elif ACTIVATION == 3:
        x = gelu(x)
    elif ACTIVATION == 4:
        x = squared_relu(x)
    elif ACTIVATION == 5:
        x = smelu(x)

    # get the random keep mask
    rand_offsets = tl.arange(0, SIZE_RAND_BLOCK)
    seed_int = tl.load(SEEDS + col_id)
    r = tl.rand(seed_int, rand_offsets)
    keep_mask = r > p

    # prune and normalize in one go
    keep = tl.view(keep_mask, x.shape)
    output = tl.where(keep, (x * p_scale).to(x.dtype), 0.)

    tl.store(y_ptrs, output, mask=block_mask)  # output


# fmt: off
@triton.heuristics({"SIZE_RAND_BLOCK": lambda args: args["BLOCK_N"] * args["BLOCK_M"]})
@triton.autotune(
    configs=_configs,
    key=["M", "N", "is_fp16"],
)
@triton.jit
def k_dropout_bw(
    GRAD_IN, GRAD_BIAS, GRAD_OUT,
    INPUTS, BIAS, SEEDS,
    stride_grad, stride_inputs,
    M, N,
    p: tl.constexpr,
    is_fp16: tl.constexpr,  # autotune
    ACTIVATION: tl.constexpr,
    # Meta-parameters
    BLOCK_M: tl.constexpr,  # heuristics
    BLOCK_N: tl.constexpr,
    SIZE_RAND_BLOCK: tl.constexpr,
    TRAINABLE_BIAS: tl.constexpr,
    USE_BIAS: tl.constexpr,
):
    """
    Apply dropout on an input tensor
    GRAD_OUT    (M, N)
    GRAD_BIAS   (N,)
    GRAD_IN     (M, N)
    BIAS        (N,)
    SEEDS       (N,)
    p : dropout probability
    """
    # fmt: on

    row_id = tl.program_id(axis=0)
    rows = row_id * BLOCK_M + tl.arange(0, BLOCK_M)

    col_id = tl.program_id(axis=1)
    cols = col_id * BLOCK_N + tl.arange(0, BLOCK_N)

    # pointers starting point
    grad_out_ptrs = GRAD_OUT + rows[:, None] * stride_grad + cols[None, :]
    grad_in_ptrs = GRAD_IN + rows[:, None] * stride_grad + cols[None, :]
    input_ptrs = INPUTS + rows[:, None] * stride_inputs + cols[None, :]

    # now go over the tiles
    grad_bias = tl.zeros((BLOCK_N,), dtype=tl.float32)
    col_mask = cols[None, :] < N
    p_scale = 1. / (1. - p)

    if USE_BIAS:
        b_ptrs = BIAS + cols[None, :]
        bias = tl.load(b_ptrs, mask=col_mask, other=0.)

    block_mask = (rows[:, None] < M) & col_mask
    grad_out = tl.load(grad_out_ptrs, mask=block_mask, other=0.)

    # optional: fused activation (while the data is in shared memory)
    if ACTIVATION:
        inputs = tl.load(input_ptrs, mask=block_mask, other=0.)

        # optionally apply a fused bias
        if USE_BIAS:
            inputs += bias

        if ACTIVATION == 1:
            act_grad = relu_grad(inputs)
        elif ACTIVATION == 2:
            act_grad = leaky_relu_grad(inputs)
        elif ACTIVATION == 3:
            act_grad = gelu_grad(inputs)
        elif ACTIVATION == 4:
            act_grad = squared_relu_grad(inputs)
        elif ACTIVATION == 5:
            act_grad = smelu_grad(inputs)

        grad_out *= act_grad

    # randomly prune (and scale) the resulting buffer, possibly a no-op
    # note that even if we did not save the mask from the FW pass, it is generated
    # from the same seeds, so the same drop mask is applied here
    rand_offsets = tl.arange(0, SIZE_RAND_BLOCK)
    seed_int = tl.load(SEEDS + col_id)
    r = tl.rand(seed_int, rand_offsets)
    r = tl.view(r, grad_out.shape)
    output = tl.where(r > p, (grad_out * p_scale).to(grad_out.dtype), 0.)

    # write-back
    tl.store(grad_in_ptrs, output, mask=block_mask)

    # optionally accumulate the bias gradient
    if TRAINABLE_BIAS:
        grad_bias += tl.sum(output, axis=0)

    if TRAINABLE_BIAS:
        grad_bias_ptr = GRAD_BIAS + row_id * N + cols
        tl.store(grad_bias_ptr, grad_bias, mask=cols < N)
