/* Copyright NVIDIA/apex
   Copyright AlexwellChen
   This kernel is adapted from NVIDIA/apex.
*/
#include <cmath>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>

#include <assert.h>

#include "include/type_shim.h" // Used for DISPATCH
#include "include/multi_tensor_apply.cuh" 
#include "include/fused_softsignsgd_kernel.cuh"

#define BLOCK_SIZE 512
#define ILP 4

using MATH_T = float;

template<typename T>
struct SoftsignsgdFunctor
{
   __device__ __forceinline__ void operator()(
    int chunk_size,
    volatile int* noop_gmem,
    TensorListMetadata<4>& tl,
    const float beta,
    const float lr,
    const float decay,
    const float epsilon,
    const float power
    )
  {
    // I'd like this kernel to propagate infs/nans.
    // if(*noop_gmem == 1)
    //   return;

    int tensor_loc = tl.block_to_tensor[blockIdx.x];

    // potentially use to pass in list of scalar
    // int tensor_num = tl.start_tensor_this_launch + tensor_loc;

    int chunk_idx = tl.block_to_chunk[blockIdx.x];
    int n = tl.sizes[tensor_loc];

    T* p = (T*)tl.addresses[0][tensor_loc];
    p += chunk_idx*chunk_size;

    T* g = (T*)tl.addresses[1][tensor_loc];
    g += chunk_idx*chunk_size;

    T* exp_avg = (T*)tl.addresses[2][tensor_loc];
    exp_avg += chunk_idx*chunk_size;

    T* exp_avg_abs = (T*)tl.addresses[3][tensor_loc];
    exp_avg_abs += chunk_idx*chunk_size;

    n -= chunk_idx*chunk_size;

    for(int i_start = 0;
            i_start < n && i_start < chunk_size;
            i_start += blockDim.x*ILP)
    {
      MATH_T r_p[ILP];
      MATH_T r_g[ILP];
      MATH_T r_exp_avg[ILP];
      MATH_T r_exp_avg_abs[ILP];
#pragma unroll
      for(int ii = 0; ii < ILP; ii++)
      {
        int i = i_start + threadIdx.x + ii*blockDim.x;
        if(i < n && i < chunk_size)
        {
          r_p[ii] = p[i];
          r_g[ii] = g[i];
          r_exp_avg[ii] = exp_avg[i];
          r_exp_avg_abs[ii] = exp_avg_abs[i];
        } else {
          r_p[ii] = MATH_T(0);
          r_g[ii] = MATH_T(0);
          r_exp_avg[ii] = MATH_T(0);
          r_exp_avg_abs[ii] = MATH_T(0);
        }
      }
#pragma unroll
      for(int ii = 0; ii < ILP; ii++)
      {
        MATH_T scale_grad;
        scale_grad = (1 - beta) * r_g[ii];

        MATH_T scale_grad_abs;
        scale_grad_abs = (1 - beta) * pow( abs(r_g[ii]), power);

        r_exp_avg[ii] = beta * r_exp_avg[ii] + scale_grad;
        r_exp_avg_abs[ii] = beta * r_exp_avg_abs[ii] + scale_grad_abs;

        MATH_T numers;
        numers = beta * r_exp_avg[ii] + scale_grad;
        MATH_T denoms;
        denoms = pow( beta * r_exp_avg_abs[ii] + scale_grad_abs, 1.0 / power) + epsilon;
        MATH_T fractions;
        fractions = numers / denoms;

        r_p[ii] = r_p[ii] * (1 - lr * decay);
        r_p[ii] = r_p[ii] - lr * fractions;
      }
#pragma unroll
      for(int ii = 0; ii < ILP; ii++)
      {
        int i = i_start + threadIdx.x + ii*blockDim.x;
        if(i < n && i < chunk_size)
        {
          g[i] = r_g[ii];
          p[i] = r_p[ii];
          exp_avg[i] = r_exp_avg[ii];
          exp_avg_abs[i] = r_exp_avg_abs[ii];
        }
      }
    }
  }
};

void multi_tensor_softsignsgd_cuda(
  int chunk_size,
  at::Tensor noop_flag,
  std::vector<std::vector<at::Tensor>> tensor_lists,
  const float beta,
  const float lr,
  const float decay,
  const float epsilon,
  const float power)
{
  using namespace at;
  TORCH_CHECK(!tensor_lists.empty(), "tensor list cannot be empty")
  if (tensor_lists[0].empty()) {
    return;
  }

  // Assume single type across p,g,m,t now
  DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
    tensor_lists[0][0].scalar_type(), 0, "softsignsgd",
    multi_tensor_apply<4>(
      BLOCK_SIZE,
      chunk_size,
      noop_flag,
      tensor_lists,
      SoftsignsgdFunctor<scalar_t_0>(),
      beta,
      lr,
      decay,
      epsilon,
      power
      ); )

  AT_CUDA_CHECK(cudaGetLastError());

}
