/* Copyright 2021 The LightSeq Team
   Copyright NVIDIA/apex
   Copyright AlexwellChen
   This kernel is adapted from NVIDIA/apex and LightSeq Team
*/
#include <cuda.h>
#include <cuda_runtime.h>
#include <stdio.h>

#include <cmath>

#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/TensorUtils.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/Exceptions.h>
#include "include/type_shim.h"
#include "include/fused_softsignsgd_kernel.cuh"


template <typename T, typename GRAD_T>
__global__ void softsignsgd_cuda_kernel(
    T* __restrict__ p,
    GRAD_T* __restrict__ p_copy,  // For mixed precision training, pass NULL if
                                  // not needed
    GRAD_T* __restrict__ g, T* __restrict__ exp_avg, T* __restrict__ exp_avg_abs,
    const float b, const float lr, const float decay, const float eps, const float power, const size_t total_size
    ){
    int global_id = blockIdx.x * blockDim.x + threadIdx.x;

    if (global_id >= total_size) return;

    GRAD_T scale_grad, scale_grad_abs, numers, denoms, fractions;

    scale_grad = (1 - b) * g[global_id];
    scale_grad_abs = (1 - b) * pow( abs(g[global_id]), power );

    exp_avg[global_id] = b * exp_avg[global_id] + scale_grad;

    exp_avg_abs[global_id] = b * exp_avg_abs[global_id] + scale_grad_abs;

    numers = b * exp_avg[global_id] + scale_grad;
    denoms = pow( b * exp_avg_abs[global_id] + scale_grad_abs, 1.0 / power) + eps;
    fractions = numers / denoms;

    p[global_id] = p[global_id] * (1 - lr * decay)
        - lr * fractions;

    if (p_copy != NULL) p_copy[global_id] = (GRAD_T)p[global_id];
}

template <>
__global__ void softsignsgd_cuda_kernel<float, float>(
    float* __restrict__ p,
    float* __restrict__ p_copy,  // For mixed precision training, pass NULL if
                                  // not needed
    float* __restrict__ g, float* __restrict__ exp_avg, float* __restrict__ exp_avg_abs, 
    const float b, const float lr, const float decay, const float eps, const float power, const size_t total_size){

        int global_id = blockIdx.x * blockDim.x + threadIdx.x;

        if (global_id * 4 >= total_size) return;

        float4* p4_ptr = reinterpret_cast<float4*>(p);
        float4* g4_ptr = reinterpret_cast<float4*>(g);
        float4* exp_avg4_ptr = reinterpret_cast<float4*>(exp_avg);
        float4* exp_avg_abs4_ptr = reinterpret_cast<float4*>(exp_avg_abs);

        
        float4 p4 = p4_ptr[global_id];
        float4 g4 = g4_ptr[global_id];
        float4 exp_avg4 = exp_avg4_ptr[global_id];
        float4 exp_avg_abs4 = exp_avg_abs4_ptr[global_id];

        float4 new_p4;
        float4 new_exp_avg4;
        float4 new_exp_avg_abs4;

        float scale_grad1 = (1 - b) * g4.x;
        float scale_grad2 = (1 - b) * g4.y;
        float scale_grad3 = (1 - b) * g4.z;
        float scale_grad4 = (1 - b) * g4.w;

        float scale_grad_abs1 = (1 - b) * pow( abs(g4.x), power);
        float scale_grad_abs2 = (1 - b) * pow( abs(g4.y), power);
        float scale_grad_abs3 = (1 - b) * pow( abs(g4.z), power);
        float scale_grad_abs4 = (1 - b) * pow( abs(g4.w), power);

        new_exp_avg4.x = b * exp_avg4.x + scale_grad1;
        new_exp_avg4.y = b * exp_avg4.y + scale_grad2;
        new_exp_avg4.z = b * exp_avg4.z + scale_grad3;
        new_exp_avg4.w = b * exp_avg4.w + scale_grad4;

        new_exp_avg_abs4.x = b * exp_avg_abs4.x + scale_grad_abs1;
        new_exp_avg_abs4.y = b * exp_avg_abs4.y + scale_grad_abs2;
        new_exp_avg_abs4.z = b * exp_avg_abs4.z + scale_grad_abs3;
        new_exp_avg_abs4.w = b * exp_avg_abs4.w + scale_grad_abs4;

        float numers1 = b * new_exp_avg4.x + scale_grad1;
        float numers2 = b * new_exp_avg4.y + scale_grad2;
        float numers3 = b * new_exp_avg4.z + scale_grad3;
        float numers4 = b * new_exp_avg4.w + scale_grad4;

        float denoms1 = pow( b * new_exp_avg_abs4.x + scale_grad_abs1, 1.0 / power) + eps;
        float denoms2 = pow( b * new_exp_avg_abs4.y + scale_grad_abs2, 1.0 / power) + eps;
        float denoms3 = pow( b * new_exp_avg_abs4.z + scale_grad_abs3, 1.0 / power) + eps;
        float denoms4 = pow( b * new_exp_avg_abs4.w + scale_grad_abs4, 1.0 / power) + eps;

        float fractions1 = numers1 / denoms1;
        float fractions2 = numers2 / denoms2;
        float fractions3 = numers3 / denoms3;
        float fractions4 = numers4 / denoms4;

        new_p4.x = p4.x * (1 - lr * decay) - lr * fractions1;
        new_p4.y = p4.y * (1 - lr * decay) - lr * fractions2;
        new_p4.z = p4.z * (1 - lr * decay) - lr * fractions3;
        new_p4.w = p4.w * (1 - lr * decay) - lr * fractions4;

        g4_ptr[global_id] = g4;
        p4_ptr[global_id] = new_p4;
        exp_avg4_ptr[global_id] = new_exp_avg4;
        exp_avg_abs4_ptr[global_id] = new_exp_avg_abs4;
}

void fused_softsignsgd_cuda(at::Tensor& p, at::Tensor& p_copy, at::Tensor& g, 
          at::Tensor& exp_avg, at::Tensor& exp_avg_abs, 
          float beta, float lr, float decay, float eps, float power){
    // Get tensor size
    int total_size = p.numel();
    AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p),
              "parameter tensor is too large to be indexed with int32");
    
    cudaStream_t stream = at::cuda::getCurrentCUDAStream();

    if (g.scalar_type() == at::ScalarType::Half) {
        const int block_dim = 1024;
        int grid_dim = ((total_size + block_dim - 1) / block_dim);
        const dim3 blocks(grid_dim);
        // all other values should be fp32 for half gradients
        AT_ASSERTM(p.scalar_type() == at::ScalarType::Float,
                  "expected parameter to be of float type");
        // dispatch is done on the gradient type
        using namespace at;  // prevents "toString is undefined" errors
        DISPATCH_FLOAT_AND_HALF(
            g.scalar_type(), 0, "softsignsgd_cuda_kernel",
            using accscalar_t = at::acc_type<scalar_t_0, true>;
            softsignsgd_cuda_kernel<accscalar_t, scalar_t_0>
            <<<blocks, block_dim, 0, stream>>>(
                p.data_ptr<accscalar_t>(),
                p_copy.numel() ? p_copy.data_ptr<scalar_t_0>() : NULL,
                g.data_ptr<scalar_t_0>(), exp_avg.data_ptr<accscalar_t>(), exp_avg_abs.data_ptr<accscalar_t>(),
                beta, lr, decay, eps, power, total_size
                );
            );
    } else {
        using namespace at;
        const int block_dim = 1024;
        int grid_dim = ((total_size + block_dim - 1) / block_dim) >> 2;
        if (grid_dim == 0) grid_dim = 1;
        const dim3 blocks(grid_dim);
        DISPATCH_DOUBLE_AND_FLOAT(
            g.scalar_type(), 0, "softsignsgd_cuda_kernel",
            softsignsgd_cuda_kernel<scalar_t_0, scalar_t_0>
            <<<blocks, block_dim, 0, stream>>>(
                p.data_ptr<scalar_t_0>(),
                NULL,
                g.data_ptr<scalar_t_0>(), exp_avg.data_ptr<scalar_t_0>(), exp_avg_abs.data_ptr<scalar_t_0>(),
                beta, lr, decay, eps, power, total_size
            );
        );
    }
    AT_CUDA_CHECK(cudaGetLastError());
}

