/* Copyright 2021 The LightSeq Team
   Copyright NVIDIA/apex
   Copyright AlexwellChen
   This kernel is adapted from NVIDIA/apex and LightSeq Team
*/
#include <ATen/ATen.h>
#include <torch/extension.h>

// CUDA forward declaration
void fused_softsignsgd_cuda(
    at::Tensor& p, at::Tensor& p_copy, at::Tensor& g, 
    at::Tensor& exp_avg, at::Tensor& exp_avg_abs,
    float beta, float lr, float decay, float eps, float power);

void multi_tensor_softsignsgd_cuda(
    int chunk_size,
    at::Tensor noop_flag,
    std::vector<std::vector<at::Tensor>> tensor_lists,
    const float beta,
    const float lr,
    const float decay,
    const float epsilon,
    const float power);