/**
 * Copyright (c) Facebook, Inc. and its affiliates.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

#include <ATen/ATen.h>
#include <c10/cuda/CUDAStream.h>

#include <cuda.h>
#include <cuda_runtime.h>

#include <algorithm>
#include <functional>
#include <iostream>
#include <stdexcept>
#include <utility>
#include <vector>

#include <assert.h>
#include <stdlib.h>

#define SHFL_MASK 0xffffffff

template <int FS, int SB, int padding_l, typename scalar_t>
__global__ void lightconv_forward_kernel(
    const scalar_t* input,
    const scalar_t* filters,
    int minibatch,
    int sequenceLength,
    int numFeatures,
    int numFiltersInBlock,
    scalar_t* output);

template <int FS, int SB, int padding_l, typename scalar_t>
__global__ void lightconv_grad_wrt_input_kernel(
    const scalar_t* input,
    const scalar_t* filters,
    int minibatch,
    int sequenceLength,
    int numFeatures,
    int numFiltersInBlock,
    scalar_t* output);

template <int FS, int SB, int padding_l, typename scalar_t>
__global__ void lightconv_grad_wrt_weights_firstpass_short_kernel(
    const scalar_t* input,
    const scalar_t* gradInput,
    int minibatch,
    int sequenceLength,
    int numFeatures,
    int numFiltersInBlock,
    int numHeads,
    float* output);

template <int FS, int SB, typename scalar_t>
__global__ void lightconv_grad_wrt_weights_secondpass_short_kernel(
    const float* input,
    const int minibatch,
    const int numFiltersInBlock,
    scalar_t* output);

template <int FS, int SB, int padding_l, typename scalar_t>
__global__ void lightconv_grad_wrt_weights_firstpass_kernel(
    const scalar_t* input,
    const scalar_t* gradInput,
    int minibatch,
    int sequenceLength,
    int numFeatures,
    int numFiltersInBlock,
    float* output);

template <int FS, int SB, typename scalar_t>
__global__ void lightconv_grad_wrt_weights_secondpass_kernel(
    const float* input,
    const int minibatch,
    const int numFiltersInBlock,
    scalar_t* output);
