#include <torch/library.h>

#include<torch/all.h>

#include <ATen/autocast_mode.h>
#include <c10/core/DispatchKey.h>

#include <cuda_runtime.h>

#include <ATen/cuda/CUDAContext.h>

#include <nvtx3/nvToolsExt.h>

#include "hybrid_sp.h"

#include "perf_instrumentation.h"



#define PROFILE_TIME 0





using HybridSpPtr = c10::intrusive_ptr<hybrid_sp_t>;



std::tuple<at::Tensor, at::Tensor, at::Tensor, HybridSpPtr, HybridSpPtr, HybridSpPtr> ff_forward_cuda_gated(const at::Tensor& X, const at::Tensor& G, const at::Tensor& K,  const at::Tensor& V, int64_t out_size)  {

    PERF_START("ff_forward_gated_total", 0);



    int m = X.size(0) * X.size(1);

    int n = K.size(0);

    int k = K.size(1);

    auto device = X.get_device();

    auto stream = at::cuda::getCurrentCUDAStream(device);

#if PROFILE_TIME    

    cudaEvent_t start, stop;

    cudaEventCreate(&start);

    cudaEventCreate(&stop);

    cudaEventRecord(start, stream);

#endif

    auto acc_dtype = torch::kInt32;

    auto options = at::TensorOptions().dtype(acc_dtype).device(X.device());

    at::Tensor P_s = at::empty({out_size}, options);

    at::Tensor N = at::zeros({}, options);

    acc_dtype = torch::kFloat32;

    options = at::TensorOptions().dtype(acc_dtype).device(X.device());

    at::Tensor l0 = at::zeros({}, options);

    at::Tensor l1 = at::zeros({}, options);

    at::Tensor L = torch::einsum("bmn,kn->bmk", {X, G});

    pack_sparse(P_s, N, L, m, n, out_size, stream);

    auto P = c10::make_intrusive<hybrid_sp_t>(m, n, P_s.device());

    create_hybrid_sparse(P_s, N, P.get(), l0, l1, m, n, stream);

#if PROFILE_TIME    

    cudaEventRecord(stop, stream);

    cudaEventSynchronize(stop);

    float ms = 0.0f;

    cudaEventElapsedTime(&ms, start, stop);

    printf("cutlass: %f\n", ms);

#endif



#if PROFILE_TIME    

    cudaEventRecord(start, stream);

#endif



    cudaError_t err = cudaMemcpyAsync(

        P->hN.data_ptr<int>(),

        P->overflow_counter(),

        sizeof(int),

        cudaMemcpyDeviceToHost,

        stream

    );

    TORCH_CHECK(err == cudaSuccess, "cudaMemcpyAsync failed: ", cudaGetErrorString(err));

    err = cudaEventCreateWithFlags(&P->_counter_copy_ev, cudaEventDisableTiming);

    TORCH_CHECK(err == cudaSuccess, "cudaEventCreate failed: ", cudaGetErrorString(err));

    err = cudaEventRecord(P->_counter_copy_ev, stream);

    TORCH_CHECK(err == cudaSuccess, "cudaEventRecord failed: ", cudaGetErrorString(err));



    auto options_bf16 = at::TensorOptions().dtype(torch::kBFloat16).device(X.device());

    auto options_fp32 = at::TensorOptions().dtype(torch::kFloat).device(X.device());

    at::Tensor out2 = at::zeros(X.sizes(), options_bf16);



    auto R = c10::make_intrusive<hybrid_sp_t>(*P);

    R->reset_vals();

    at::Tensor acc_init =  at::zeros({}, options_fp32);
    auto T = c10::make_intrusive<hybrid_sp_t>(*P);






    R->_dense_active_rows = TAIL_CAPACITY_ROWS;

    P->_dense_active_rows = TAIL_CAPACITY_ROWS;

    T->_dense_active_rows = TAIL_CAPACITY_ROWS;

    new_product_as_sparse_sma(R.get(), X, K, acc_init, m, n, k, stream);

    sparse_elementwise(T.get(), R.get(), P.get(), m, n, stream);

    sparse_dense_gemm_hybrid_dense(out2, T.get(), V, m, k, n, false, stream);



#if PROFILE_TIME    

    cudaEventRecord(stop, stream);

    cudaEventSynchronize(stop);

    ms = 0.0f;

    cudaEventElapsedTime(&ms, start, stop);

    std::cout<<"sparse: "<<ms<<" " <<l0.item<float>()<<" "<<l0_e.item<float>()<<std::endl;

#endif



    PERF_STOP("ff_forward_gated_total");

    return {out2, l0, l1, P, R, T};

}



std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> ff_backward_cuda_gated(const at::Tensor& X, const at::Tensor& G, const at::Tensor& K, const at::Tensor& V, HybridSpPtr P, HybridSpPtr R, HybridSpPtr T, const at::Tensor& gR_fp,const at::Tensor& gl0,const at::Tensor& gl1) {

    PERF_START("ff_backward_gated_total", 0);

    int m = X.size(0) * X.size(1);

    int n = K.size(0);

    int k = K.size(1);



    auto device = X.get_device();

    auto stream = at::cuda::getCurrentCUDAStream(device);

    TORCH_CHECK(

      X.scalar_type() == at::kBFloat16,

      "X: Expected a Float (bfloat16) tensor but got ", X.scalar_type());

    TORCH_CHECK(

      K.scalar_type() == at::kBFloat16,

      "K: Expected a Float (bfloat16) tensor but got ", K.scalar_type());

    TORCH_CHECK(

      V.scalar_type() == at::kBFloat16,

      "V: Expected a Float (bfloat16) tensor but got ", V.scalar_type());

    TORCH_CHECK(

      gR_fp.scalar_type() == at::kBFloat16,

      "gR: Expected a Float (bfloat16) tensor but got ", gR_fp.scalar_type());

    auto gR = gR_fp.contiguous();

    auto acc_dtype = torch::kInt32;

    auto options = at::TensorOptions().dtype(acc_dtype).device(X.device());



    acc_dtype = torch::kBFloat16;

    options = at::TensorOptions().dtype(acc_dtype).device(X.device());

    at::Tensor dX_r = at::zeros(X.sizes(), options);

    at::Tensor dX_u = at::zeros(X.sizes(), options);

    at::Tensor dG = at::zeros(G.sizes(), options);

    at::Tensor dK = at::zeros(K.sizes(), options);

    at::Tensor dV = at::zeros({n, k}, options);



    cudaEventSynchronize(P->_counter_copy_ev);

    int v = *(P->hN.data_ptr<int>());
    v = (v + 127) & ~int64_t(127);

    R->_dense_active_rows = v;

    P->_dense_active_rows = v;

    T->_dense_active_rows = v;


    hybrid_sp_t T_t(n, m, X.device());


    PERF_START("transpose1_T", stream);

    transpose_hybrid_dense(*T, T_t, m, n, stream);

    PERF_STOP("transpose1_T");


    auto hN = at::empty({1}, at::TensorOptions().device(at::kCPU).dtype(at::kInt).pinned_memory(true));


    cudaError_t err = cudaMemcpyAsync(

        hN.data_ptr<int>(),

        T_t.overflow_counter(),

        sizeof(int),

        cudaMemcpyDeviceToHost,

        stream

    );

    TORCH_CHECK(err == cudaSuccess, "cudaMemcpyAsync failed: ", cudaGetErrorString(err));

    cudaEvent_t ev;

    err = cudaEventCreateWithFlags(&ev, cudaEventDisableTiming);

    TORCH_CHECK(err == cudaSuccess, "cudaEventCreate failed: ", cudaGetErrorString(err));

    err = cudaEventRecord(ev, stream);

    TORCH_CHECK(err == cudaSuccess, "cudaEventRecord failed: ", cudaGetErrorString(err));



    hybrid_sp_t dT(*T);

    dT.reset_vals();


    auto options_fp32 = at::TensorOptions().dtype(torch::kFloat).device(X.device());

    at::Tensor acc_init = at::zeros({}, options_fp32);

    PERF_START("as_sparse_dT", stream);

    new_product_as_sparse_sma(&dT, gR, V, acc_init, m, n, k, stream);

    PERF_STOP("as_sparse_dT");

    hybrid_sp_t dR(*T);



    dR.reset_vals();



    PERF_START("elemwise_dR", stream);

    sparse_elementwise(&dR, &dT, P.get(),  m, n, stream);

    PERF_STOP("elemwise_dR");



    hybrid_sp_t dU(*T);

    dU.reset_vals();

    acc_init =  gl1 * (1.0f / m);
    dU._ell_values = (dT._ell_values * R->_ell_values + acc_init).to(torch::kBFloat16);

    dU._ell_col_indices = P->_ell_col_indices;

    dU._tail_dense = (((dT._tail_dense.narrow(0,0,dU._dense_active_rows) * R->_tail_dense.narrow(0,0,R->_dense_active_rows)).to(torch::kFloat) + acc_init) * (P->_tail_dense.narrow(0,0,P->_dense_active_rows) > 0)).to(torch::kBFloat16);




    hybrid_sp_t dR_t(n, m, X.device());

    PERF_START("transpose2_dR", stream);


    transpose_hybrid_dense(dR, dR_t, m, n, stream,

                          T_t.tail_dense_map(),

                          T_t.tail_dense_map_reverse());



    PERF_STOP("transpose2_dR");


    hybrid_sp_t dU_t(n, m, X.device());

    PERF_START("transpose3_dU", stream);


    transpose_hybrid_dense(dU, dU_t, m, n, stream,

                          T_t.tail_dense_map(),

                          T_t.tail_dense_map_reverse());



    PERF_STOP("transpose3_dU");



    PERF_START("gemm_dR_K", stream);

    sparse_dense_gemm_hybrid_dense(dX_r, &dR, K, m, k, n, true, stream);

    PERF_STOP("gemm_dR_K");

    PERF_START("gemm_dU_G", stream);

    sparse_dense_gemm_hybrid_dense(dX_u, &dU, G, m, k, n, true, stream);

    PERF_STOP("gemm_dU_G");

    auto dX = dX_r + dX_u;




    cudaEventSynchronize(ev);

    v = *hN.data_ptr<int>();

    v = (v + 127) & ~int64_t(127);

    if (v >= TAIL_CAPACITY_ROWS) printf("Exceeding capacity %d!!!\n", v);

    dR_t._dense_active_rows = v;

    dU_t._dense_active_rows = v;

    T_t._dense_active_rows = v;

    PERF_START("gemm_dR_t_X", stream);

    sparse_dense_gemm_hybrid_dense(dK, &dR_t, X, n, k, m, false, stream, X);

    PERF_STOP("gemm_dR_t_X");

    PERF_START("gemm_dU_t_X", stream);

    sparse_dense_gemm_hybrid_dense(dG, &dU_t, X, n, k, m, false, stream, X);

    PERF_STOP("gemm_dU_t_X");

    PERF_START("gemm_T_t_gR", stream);

    sparse_dense_gemm_hybrid_dense(dV, &T_t, gR, n, k, m, false, stream, gR);

    PERF_STOP("gemm_T_t_gR");



    PERF_STOP("ff_backward_gated_total");

    return {dX, dG, dK, dV};

}



std::tuple<at::Tensor, at::Tensor, at::Tensor, HybridSpPtr, HybridSpPtr, HybridSpPtr> ff_forward_meta_gated(const at::Tensor& X, const at::Tensor& G, const at::Tensor& K,  const at::Tensor& V, int64_t out_size)  {

    int b = X.size(0);

    int s = X.size(1);

    int d = X.size(2);

    

    auto acc_dtype = torch::kInt32;

    auto options = at::TensorOptions().dtype(acc_dtype).device(X.device());



    acc_dtype = torch::kFloat;

    options = at::TensorOptions().dtype(acc_dtype).device(X.device());

    at::Tensor out = at::empty({b, s, d}, options);



    acc_dtype = torch::kFloat;

    options = at::TensorOptions().dtype(acc_dtype).device(X.device());

    at::Tensor l0 = at::empty({}, options);

    at::Tensor l1 = at::empty({}, options);

    auto hybrid = c10::make_intrusive<hybrid_sp_t>(b*s, d, X.device());

    return {out, l0, l1, hybrid, hybrid, hybrid};

}





std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> ff_backward_meta_gated(const at::Tensor& X, const at::Tensor& G, const at::Tensor& K, const at::Tensor& V, HybridSpPtr P, HybridSpPtr R, HybridSpPtr T, const at::Tensor& gR,const at::Tensor& gl0,const at::Tensor& gl1) {

    auto acc_dtype = torch::kFloat;

    auto options = at::TensorOptions().dtype(acc_dtype).device(X.device());

    at::Tensor dx = at::empty(X.sizes(), options);

    at::Tensor dg = at::empty(G.sizes(), options);

    at::Tensor dk = at::empty(K.sizes(), options);

    at::Tensor dv = at::empty(V.sizes(), options);



    return {dx, dg, dk, dv};

}



TORCH_LIBRARY(sparse_ops, m) {

    m.class_<hybrid_sp_t>("HybridSp");

    m.def("ff_forward_gated(Tensor X, Tensor G, Tensor K, Tensor V, int out_size) -> (Tensor, Tensor, Tensor,  __torch__.torch.classes.sparse_ops.HybridSp, __torch__.torch.classes.sparse_ops.HybridSp, __torch__.torch.classes.sparse_ops.HybridSp)");

    m.def("ff_backward_gated(Tensor X, Tensor G, Tensor K, Tensor V, __torch__.torch.classes.sparse_ops.HybridSp P, __torch__.torch.classes.sparse_ops.HybridSp R, __torch__.torch.classes.sparse_ops.HybridSp T, Tensor gR, Tensor gl0, Tensor gl1) -> (Tensor, Tensor, Tensor, Tensor)");

}



TORCH_LIBRARY_IMPL(sparse_ops, CUDA, m) {

    m.impl("ff_forward_gated", &ff_forward_cuda_gated);

    m.impl("ff_backward_gated", &ff_backward_cuda_gated);

}



TORCH_LIBRARY_IMPL(sparse_ops, Meta, m) {

    m.impl("ff_forward_gated", &ff_forward_meta_gated);

    m.impl("ff_backward_gated", &ff_backward_meta_gated);

}






class FFSparseGated : public torch::autograd::Function<FFSparseGated> {

public:

  static torch::autograd::variable_list forward(

      torch::autograd::AutogradContext* ctx,

      const at::Tensor& X, const at::Tensor& G, const at::Tensor& K, const at::Tensor& V, int64_t out_size) {

    at::AutoDispatchBelowADInplaceOrView guard;

    static auto ff_forward_op = torch::Dispatcher::singleton()

      .findSchemaOrThrow("sparse_ops::ff_forward_gated", "")

      .typed<decltype(ff_forward_cuda_gated)>();



    auto result = ff_forward_op.call(X, G, K, V, out_size);

    ctx->save_for_backward({X, G, K, V});

    ctx->saved_data["P"] = std::get<3>(result);

    ctx->saved_data["R"] = std::get<4>(result);

    ctx->saved_data["T"] = std::get<5>(result);


    return {std::get<0>(result), std::get<1>(result), std::get<2>(result)};

  }



  static torch::autograd::variable_list backward(

      torch::autograd::AutogradContext* ctx,

      torch::autograd::variable_list grad_output) {

    auto saved_tensors = ctx->get_saved_variables();

    static auto ff_backward_op = torch::Dispatcher::singleton()

      .findSchemaOrThrow("sparse_ops::ff_backward_gated", "")

      .typed<decltype(ff_backward_cuda_gated)>();

    auto P = ctx->saved_data["P"].toCustomClass<hybrid_sp_t>();

    auto R = ctx->saved_data["R"].toCustomClass<hybrid_sp_t>();

    auto T = ctx->saved_data["T"].toCustomClass<hybrid_sp_t>();

    auto result = ff_backward_op.call(saved_tensors[0], saved_tensors[1], saved_tensors[2], saved_tensors[3], P, R, T, grad_output[0], grad_output[1], grad_output[2]);

    at::Tensor undef;

    return {std::get<0>(result), std::get<1>(result), std::get<2>(result), std::get<3>(result), undef}; 

  }

};





std::tuple<at::Tensor, at::Tensor, at::Tensor, HybridSpPtr, HybridSpPtr, HybridSpPtr> ff_forward_autograd_gated(const at::Tensor& X, const at::Tensor& G, const at::Tensor& K,  const at::Tensor& V, int64_t out_size) {

   auto result = FFSparseGated::apply(X, G, K, V, out_size);


   auto hybrid = c10::make_intrusive<hybrid_sp_t>();
   return {result[0], result[1], result[2], hybrid, hybrid, hybrid};

}



TORCH_LIBRARY_IMPL(sparse_ops, AutogradCUDA, m) {

    m.impl("ff_forward_gated", &ff_forward_autograd_gated);

}



std::tuple<at::Tensor, at::Tensor, at::Tensor, HybridSpPtr, HybridSpPtr, HybridSpPtr> ff_forward_ac_gated(c10::DispatchKeySet ks, const at::Tensor& X, const at::Tensor& G, const at::Tensor& K,  const at::Tensor& V, int64_t out_size)  {


    c10::impl::ExcludeDispatchKeyGuard guard(c10::DispatchKey::Autocast);


    c10::DispatchKeySet modified_ks = ks.remove(c10::DispatchKey::AutocastCUDA);

    auto target_dtype = at::autocast::get_autocast_dtype(at::kCUDA);




    auto Xc = at::autocast::cached_cast(target_dtype, X);

    auto Gc = at::autocast::cached_cast(target_dtype, G);

    auto Kc = at::autocast::cached_cast(target_dtype, K);

    auto Vc = at::autocast::cached_cast(target_dtype, V);





    static auto op = torch::Dispatcher::singleton()

      .findSchemaOrThrow("sparse_ops::ff_forward_gated", "")

      .typed<decltype(ff_forward_cuda_gated)>();



    return op.redispatch(modified_ks, Xc, Gc, Kc, Vc, out_size);

}



TORCH_LIBRARY_IMPL(sparse_ops, AutocastCUDA, m) {

    m.impl("ff_forward_gated", &ff_forward_ac_gated);

}




void print_perf_report() {

#if ENABLE_PERF_PROFILING

    fprintf(stderr, "[DEBUG] ENABLE_PERF_PROFILING is 1, profiling is enabled\n");

#else

    fprintf(stderr, "[DEBUG] ENABLE_PERF_PROFILING is 0, profiling is DISABLED\n");

#endif

    fprintf(stderr, "[DEBUG] Calling PERF_REPORT()\n");

    PERF_REPORT();

    fprintf(stderr, "[DEBUG] PERF_REPORT() returned\n");

}



void reset_perf_stats() {

    PERF_RESET();

}



TORCH_LIBRARY(sparse_ops_perf, m) {

    m.def("print_report() -> ()");

    m.def("reset_stats() -> ()");

}



TORCH_LIBRARY_IMPL(sparse_ops_perf, CompositeExplicitAutograd, m) {

    m.impl("print_report", &print_perf_report);

    m.impl("reset_stats", &reset_perf_stats);

}



