/***************************************************************************************************
 * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: BSD-3-Clause
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * 1. Redistributions of source code must retain the above copyright notice, this
 * list of conditions and the following disclaimer.
 *
 * 2. Redistributions in binary form must reproduce the above copyright notice,
 * this list of conditions and the following disclaimer in the documentation
 * and/or other materials provided with the distribution.
 *
 * 3. Neither the name of the copyright holder nor the names of its
 * contributors may be used to endorse or promote products derived from
 * this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
/*! \file
  
    \brief Unit tests for epilogues
*/
#pragma once

#include <fstream>

#include "../../common/cutlass_unit_test.h"

#include "cutlass/aligned_buffer.h"
#include "cutlass/half.h"
#include "cutlass/complex.h"

#include "cutlass/epilogue/thread/linear_combination.h"

#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"

/////////////////////////////////////////////////////////////////////////////////////////////////

namespace test {
namespace kernel {

template <typename Epilogue>
__global__ void epilogue_with_reduction_threadblock(
  typename Epilogue::ElementVector *ptr_Reduction,
  typename Epilogue::OutputTileIterator::Params params_D,
  typename Epilogue::OutputTileIterator::Element *ptr_D,
  typename Epilogue::OutputTileIterator::Params params_C,
  typename Epilogue::OutputTileIterator::Element *ptr_C,
  typename Epilogue::TensorTileIterator::Params params_Tensor,
  typename Epilogue::TensorTileIterator::Element *ptr_Tensor,
  typename Epilogue::OutputOp::Params params_output_op,
  cutlass::MatrixCoord problem_size,
  cutlass::TensorRef<
    typename Epilogue::WarpMmaOperator::ElementC, 
    typename Epilogue::WarpMmaOperator::LayoutC> accumulator_ref,
  int epilogue_count = 1) {

  __shared__ typename Epilogue::SharedStorage shared_storage;

  int thread_idx = threadIdx.x;
  int warp_idx = threadIdx.x / 32;
  int lane_idx = threadIdx.x % 32;

  //
  // Construct the epilogue
  //

  // Tile iterator writing to output tile
  typename Epilogue::OutputTileIterator iterator_D(
    params_D,
    ptr_D,
    problem_size,
    thread_idx
  );

  // Tile iterator writing to output tile
  typename Epilogue::OutputTileIterator iterator_C(
    params_C,
    ptr_C,
    problem_size,
    thread_idx
  );

  // Tile iterator writing to output tile
  typename Epilogue::TensorTileIterator iterator_T(
    params_Tensor,
    ptr_Tensor,
    problem_size,
    thread_idx
  );

  // Epilogue operator
  Epilogue epilogue(
    shared_storage, 
    thread_idx, 
    warp_idx, 
    lane_idx);

  //
  // Initialize the accumulators
  //

  int warp_mn = warp_idx % (Epilogue::WarpCount::kM * Epilogue::WarpCount::kN);
  int warp_m = warp_mn % Epilogue::WarpCount::kM;
  int warp_n = warp_mn / Epilogue::WarpCount::kM;

  accumulator_ref.add_coord_offset({
    warp_m * Epilogue::WarpMmaOperator::Shape::kM, 
    warp_n * Epilogue::WarpMmaOperator::Shape::kN});

  typename Epilogue::WarpMmaOperator::IteratorC accumulator_iterator(accumulator_ref, lane_idx);
  
  typename Epilogue::AccumulatorTile accumulators;

  accumulators.clear();
  accumulator_iterator.load(accumulators);

#if 0
  // For debugging, enable this block of code to fill each accumulator element with its
  // source thread ID.
  CUTLASS_PRAGMA_UNROLL
  for (int i = 0; i < accumulators.size(); ++i) {
    typename Epilogue::WarpMmaOperator::ElementC x(threadIdx.x);
    //typename Epilogue::WarpMmaOperator::ElementC x(i);
    accumulators[i] = x;
  }

  /*
  #pragma unroll 1
  for (int tid = 0; tid < 32; ++tid) {
    if (tid == thread_idx) {
      printf("\nT%d: ", thread_idx);
      CUTLASS_PRAGMA_UNROLL
      for (int i = 0; i < accumulators.size(); ++i) {
        printf("%d ", int(accumulators[i]));
      }  
    }
  }

  if (thread_idx == 0) {
    printf("\n\n");  
  }
  */

  __syncthreads();

#endif

  //
  // Perform the epilogue operation
  //

  typename Epilogue::OutputOp output_op(params_output_op);

  // Place the epilogue in a loop
  for (int iter = 0; iter < epilogue_count; ++iter) {
    epilogue(output_op, ptr_Reduction, iterator_D, accumulators, iterator_C, iterator_T);
  }
}

} // namespace kernel
} // namespace test


/////////////////////////////////////////////////////////////////////////////////////////////////

template <
  typename Epilogue_
>
class EpilogueWithReductionTestbed {
public:

  using Epilogue = Epilogue_;
  using ElementAccumulator = typename Epilogue::ElementAccumulator;
  using ElementCompute = typename Epilogue::OutputOp::ElementCompute;
  using ElementTensor = typename Epilogue::TensorTileIterator::Element;
  using ElementOutput = typename Epilogue::ElementOutput;
  using OutputOpParams = typename Epilogue::OutputOp::Params;

public:

  //
  // Data members
  //

  cutlass::MatrixCoord quantized_size;
  cutlass::HostTensor<ElementAccumulator, cutlass::layout::RowMajor> accumulator_tensor;
  cutlass::HostTensor<ElementOutput, cutlass::layout::RowMajor> source_tensor;
  cutlass::HostTensor<ElementOutput, cutlass::layout::RowMajor> output_tensor;
  cutlass::HostTensor<ElementTensor, cutlass::layout::RowMajor> additional_tensor;
  cutlass::HostTensor<ElementAccumulator, cutlass::layout::RowMajor> reduction_tensor;


public:

  //
  // Methods
  //

  EpilogueWithReductionTestbed(): 
    quantized_size(Epilogue::Shape::kM, Epilogue::Shape::kN),
    accumulator_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}),
    source_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}),
    output_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}),
    additional_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}),
    reduction_tensor({1, Epilogue::Shape::kN}) {

    //
    // Initialize problem space
    //

    uint64_t seed = 2019;

    cutlass::reference::host::TensorFillRandomUniform(
      accumulator_tensor.host_view(), 
      seed, 
      20, 
      -20, 
      0);

    cutlass::reference::host::TensorFillRandomUniform(
      source_tensor.host_view(),
      seed + 2018, 
      20, 
      -20, 
      0);

    cutlass::reference::host::TensorFill(additional_tensor.host_view(), ElementTensor(1));
  }

  bool run_all() {
   
    /*
    double alpha_values[] = {1, 0, 2.25};
    double beta_values[] = {0, 1, -1.25};

    // Test runtime explodes if we tried to test every case exhaustively. This tests the full
    // output tile and several smaller sizes to stress predication.
    for (int m_idx = 0; m_idx < 3; ++m_idx) {
      for (int n_idx = 0; n_idx < 3; ++n_idx) {

        int m = quantized_size.row() - m_idx * 3;
        int n = quantized_size.column() - n_idx * Epilogue::kElementsPerAccess;

        for (double const &alpha : alpha_values) {
          for (double const &beta : beta_values) {

            bool passed = run({m, n}, {cutlass::from_real<ElementCompute>(alpha), cutlass::from_real<ElementCompute>(beta)});

            if (!passed) {
              return false;
            }
          }
        }
      }
    }
    return true;
    */

    double alpha = 1;
    double beta = 0;

    return run(
      {quantized_size.row(), quantized_size.column()},
      {cutlass::from_real<ElementCompute>(alpha), cutlass::from_real<ElementCompute>(beta)});
  }

  /// Runs the test
  bool run(
    cutlass::MatrixCoord problem_size,
    OutputOpParams output_params) { 

    //
    // Initialize problem space
    //

    ElementOutput default_output = ElementOutput(-127);
    ElementAccumulator default_reduction = ElementAccumulator();

    cutlass::reference::host::TensorFill(output_tensor.host_view(), default_output);
    cutlass::reference::host::TensorFill(reduction_tensor.host_view(), default_reduction);

    accumulator_tensor.sync_device();
    output_tensor.sync_device();
    source_tensor.sync_device();
    additional_tensor.sync_device();
    reduction_tensor.sync_device();

    //
    // Initialize epilogue parameters
    //

    typename Epilogue::OutputTileIterator::Params params_D(output_tensor.device_ref().layout());
    typename Epilogue::OutputTileIterator::Params params_C(source_tensor.device_ref().layout());
    typename Epilogue::TensorTileIterator::Params params_T(additional_tensor.device_ref().layout());

    //
    // Launch kernel
    //

    dim3 grid(1, 1);
    dim3 block(Epilogue::WarpCount::kCount * 32, 1);

    test::kernel::epilogue_with_reduction_threadblock<Epilogue><<< grid, block >>>(
      reduction_tensor.device_data(),
      params_D,
      output_tensor.device_data(),
      params_C,
      source_tensor.device_data(),
      params_T,
      additional_tensor.device_data(),
      output_params,
      problem_size, 
      accumulator_tensor.device_view());

    cudaError_t result = cudaDeviceSynchronize();

    if (result != cudaSuccess) {
      std::cerr << "Kernel error: " << cudaGetErrorString(result) << std::endl;
      return false;
    }

    //
    // Verify results
    //
    output_tensor.sync_host();
    reduction_tensor.sync_host();

    int errors = 0;
    int const kMaxErrors = 5;

    //
    // The output has two parts:
    //  - GEMM tensor epilogue in canonical layout
    //  - partial reduction in canonical row-major layout
    //

    // Verify the GEMM tensor output
    for (int r = 0; errors < kMaxErrors && r < quantized_size.row(); ++r) {
      for (int c = 0; errors < kMaxErrors && c < quantized_size.column(); ++c) {

        cutlass::MatrixCoord coord{r, c};
        ElementOutput got = output_tensor.at(coord);
        
        ElementOutput expected;
        if (coord.row() < problem_size.row() && coord.column() < problem_size.column()) {

          expected = ElementOutput(output_params.alpha * ElementCompute(accumulator_tensor.at(coord)) + 
            output_params.beta * ElementCompute(source_tensor.at(coord)));
        }
        else {
          expected = default_output;
        }

        if (expected != got) {

          using OutputIO = cutlass::ScalarIO<ElementOutput>;

          EXPECT_TRUE(false)
            << "-------\n"
            << "Error - output element (" << coord << ") - expected: " 
            << OutputIO(expected) 
            << ",  got: " << OutputIO(got) << std::endl;

          ++errors;
        }
      }
    }

    // Verify the partial reduction
    for (int c = 0; c < quantized_size.column(); ++c) {

      ElementAccumulator reduction_acc = ElementAccumulator();

      for (int r = 0; r < quantized_size.row(); ++r) {
        reduction_acc += accumulator_tensor.at({r, c});
      }

      ElementAccumulator expected = default_reduction;
      ElementAccumulator got = reduction_tensor.at({0, c});

      if (c < problem_size.column()) {
        expected = reduction_acc;
      }
      else {
        expected = default_reduction;
      }

      if (expected != got) {
        
        using OutputIO = cutlass::ScalarIO<ElementAccumulator>;

        EXPECT_TRUE(false)
          << "-------\n"
          << "Error - reduction element (" << c << ") - expected: " 
          << OutputIO(expected) 
          << ", got: " << OutputIO(got) << std::endl;
      }
    }

    //
    // Report results on error
    //

    if (errors) {
      std::stringstream ss;
      ss 
        << "output_tensor_op_" << Epilogue::Shape::kM << "x" << Epilogue::Shape::kN << "_" 
        << Epilogue::WarpTileIterator::WarpShape::kM << "x" 
        << Epilogue::WarpTileIterator::WarpShape::kN 
        << "_slice_" << Epilogue::WarpCount::kK << ".csv"; 

      std::ofstream output_file(ss.str()); 
      output_file << output_tensor.host_view(); 
    }

    return !errors;
  }
};

/////////////////////////////////////////////////////////////////////////////////////////////////
