// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONVOLUTION_H
#define EIGEN_CXX11_TENSOR_TENSOR_CONVOLUTION_H

#include "./InternalHeaderCheck.h"

namespace Eigen {

/** \class TensorConvolution
  * \ingroup CXX11_Tensor_Module
  *
  * \brief Tensor convolution class.
  *
  *
  */
namespace internal {

template <typename Index, typename InputDims, int NumKernelDims, int Layout>
class IndexMapper {
 public:
  IndexMapper(const InputDims& input_dims, const array<Index, NumKernelDims>& kernel_dims,
              const array<Index, NumKernelDims>& indices) {

    array<Index, NumDims> dimensions = input_dims;
    for (int i = 0; i < NumKernelDims; ++i) {
      const Index index = indices[i];
      const Index input_dim = input_dims[index];
      const Index kernel_dim = kernel_dims[i];
      const Index result_dim = input_dim - kernel_dim + 1;
      dimensions[index] = result_dim;
    }

    array<Index, NumDims> inputStrides;
    array<Index, NumDims> outputStrides;
    if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
      inputStrides[0] = 1;
      outputStrides[0] = 1;
      for (int i = 1; i < NumDims; ++i) {
        inputStrides[i] = inputStrides[i-1] * input_dims[i-1];
        outputStrides[i] = outputStrides[i-1] * dimensions[i-1];
      }
    } else {
      inputStrides[NumDims - 1] = 1;
      outputStrides[NumDims - 1] = 1;
      for (int i = static_cast<int>(NumDims) - 2; i >= 0; --i) {
        inputStrides[i] = inputStrides[i + 1] * input_dims[i + 1];
        outputStrides[i] = outputStrides[i + 1] * dimensions[i + 1];
      }
    }

    array<Index, NumDims> gpuInputDimensions;
    array<Index, NumDims> gpuOutputDimensions;
    array<Index, NumDims> tmp = dimensions;
    array<Index, NumDims> ordering;
    const size_t offset = static_cast<int>(Layout) == static_cast<int>(ColMajor)
                              ? 0
                              : NumDims - NumKernelDims;
    for (int i = 0; i < NumKernelDims; ++i) {
      const Index index = i + offset;
      ordering[index] = indices[i];
      tmp[indices[i]] = -1;
      gpuInputDimensions[index] = input_dims[indices[i]];
      gpuOutputDimensions[index] = dimensions[indices[i]];
    }

    int written = static_cast<int>(Layout) == static_cast<int>(ColMajor)
                      ? NumKernelDims
                      : 0;
    for (int i = 0; i < NumDims; ++i) {
      if (tmp[i] >= 0) {
        ordering[written] = i;
        gpuInputDimensions[written] = input_dims[i];
        gpuOutputDimensions[written] = dimensions[i];
        ++written;
      }
    }

    for (int i = 0; i < NumDims; ++i) {
      m_inputStrides[i] = inputStrides[ordering[i]];
      m_outputStrides[i] = outputStrides[ordering[i]];
    }

    if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
      for (int i = 0; i < NumDims; ++i) {
        if (i > NumKernelDims) {
          m_gpuInputStrides[i] =
              m_gpuInputStrides[i - 1] * gpuInputDimensions[i - 1];
          m_gpuOutputStrides[i] =
              m_gpuOutputStrides[i - 1] * gpuOutputDimensions[i - 1];
        } else {
          m_gpuInputStrides[i] = 1;
          m_gpuOutputStrides[i] = 1;
        }
      }
    } else {
      for (int i = NumDims - 1; i >= 0; --i) {
        if (static_cast<size_t>(i + 1) < offset) {
          m_gpuInputStrides[i] =
              m_gpuInputStrides[i + 1] * gpuInputDimensions[i + 1];
          m_gpuOutputStrides[i] =
              m_gpuOutputStrides[i + 1] * gpuOutputDimensions[i + 1];
        } else {
          m_gpuInputStrides[i] = 1;
          m_gpuOutputStrides[i] = 1;
        }
      }
    }
  }

  EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Index mapGpuInputPlaneToTensorInputOffset(Index p) const {
    Index inputIndex = 0;
    if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
      for (int d = NumDims - 1; d > NumKernelDims; --d) {
        const Index idx = p / m_gpuInputStrides[d];
        inputIndex += idx * m_inputStrides[d];
        p -= idx * m_gpuInputStrides[d];
      }
      inputIndex += p * m_inputStrides[NumKernelDims];
    } else {
      std::ptrdiff_t limit = 0;
      if (NumKernelDims < NumDims) {
        limit = NumDims - NumKernelDims - 1;
      }
      for (int d = 0; d < limit; ++d) {
        const Index idx = p / m_gpuInputStrides[d];
        inputIndex += idx * m_inputStrides[d];
        p -= idx * m_gpuInputStrides[d];
      }
      inputIndex += p * m_inputStrides[limit];
    }
    return inputIndex;
  }

  EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Index mapGpuOutputPlaneToTensorOutputOffset(Index p) const {
    Index outputIndex = 0;
    if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
      for (int d = NumDims - 1; d > NumKernelDims; --d) {
        const Index idx = p / m_gpuOutputStrides[d];
        outputIndex += idx * m_outputStrides[d];
        p -= idx * m_gpuOutputStrides[d];
      }
      outputIndex += p * m_outputStrides[NumKernelDims];
    } else {
      std::ptrdiff_t limit = 0;
      if (NumKernelDims < NumDims) {
        limit = NumDims - NumKernelDims - 1;
      }
      for (int d = 0; d < limit; ++d) {
        const Index idx = p / m_gpuOutputStrides[d];
        outputIndex += idx * m_outputStrides[d];
        p -= idx * m_gpuOutputStrides[d];
      }
      outputIndex += p * m_outputStrides[limit];
    }
    return outputIndex;
  }

  EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Index mapGpuInputKernelToTensorInputOffset(Index i) const {
    const size_t offset = static_cast<int>(Layout) == static_cast<int>(ColMajor)
                              ? 0
                              : NumDims - NumKernelDims;
    return i * m_inputStrides[offset];
  }

  EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Index mapGpuOutputKernelToTensorOutputOffset(Index i) const {
    const size_t offset = static_cast<int>(Layout) == static_cast<int>(ColMajor)
                              ? 0
                              : NumDims - NumKernelDims;
    return i * m_outputStrides[offset];
  }

  EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Index mapGpuInputKernelToTensorInputOffset(Index i, Index j) const {
    const size_t offset = static_cast<int>(Layout) == static_cast<int>(ColMajor)
                              ? 0
                              : NumDims - NumKernelDims;
    return i * m_inputStrides[offset] + j * m_inputStrides[offset + 1];
  }

  EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Index mapGpuOutputKernelToTensorOutputOffset(Index i, Index j) const {
    const size_t offset = static_cast<int>(Layout) == static_cast<int>(ColMajor)
                              ? 0
                              : NumDims - NumKernelDims;
    return i * m_outputStrides[offset] + j * m_outputStrides[offset + 1];
  }

  EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Index mapGpuInputKernelToTensorInputOffset(Index i, Index j, Index k) const {
    const size_t offset = static_cast<int>(Layout) == static_cast<int>(ColMajor)
                              ? 0
                              : NumDims - NumKernelDims;
    return i * m_inputStrides[offset] + j * m_inputStrides[offset + 1] +
           k * m_inputStrides[offset + 2];
  }

  EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Index mapGpuOutputKernelToTensorOutputOffset(Index i, Index j, Index k) const {
    const size_t offset = static_cast<int>(Layout) == static_cast<int>(ColMajor)
                              ? 0
                              : NumDims - NumKernelDims;
    return i * m_outputStrides[offset] + j * m_outputStrides[offset + 1] +
           k * m_outputStrides[offset + 2];
  }

 private:
  static const int NumDims = internal::array_size<InputDims>::value;
  array<Index, NumDims> m_inputStrides;
  array<Index, NumDims> m_outputStrides;
  array<Index, NumDims> m_gpuInputStrides;
  array<Index, NumDims> m_gpuOutputStrides;
};



template<typename Dimensions, typename InputXprType, typename KernelXprType>
struct traits<TensorConvolutionOp<Dimensions, InputXprType, KernelXprType> >
{
  // Type promotion to handle the case where the types of the lhs and the rhs are different.
  typedef typename promote_storage_type<typename InputXprType::Scalar,
                                        typename KernelXprType::Scalar>::ret Scalar;
  typedef typename promote_storage_type<typename traits<InputXprType>::StorageKind,
                                        typename traits<KernelXprType>::StorageKind>::ret StorageKind;
  typedef typename promote_index_type<typename traits<InputXprType>::Index,
                                      typename traits<KernelXprType>::Index>::type Index;
  typedef typename InputXprType::Nested LhsNested;
  typedef typename KernelXprType::Nested RhsNested;
  typedef typename remove_reference<LhsNested>::type _LhsNested;
  typedef typename remove_reference<RhsNested>::type _RhsNested;
  static const int NumDimensions = traits<InputXprType>::NumDimensions;
  static const int Layout = traits<InputXprType>::Layout;
  typedef typename conditional<Pointer_type_promotion<typename InputXprType::Scalar, Scalar>::val,
  typename traits<InputXprType>::PointerType, typename traits<KernelXprType>::PointerType>::type PointerType;

  enum {
    Flags = 0
  };
};

template<typename Dimensions, typename InputXprType, typename KernelXprType>
struct eval<TensorConvolutionOp<Dimensions, InputXprType, KernelXprType>, Eigen::Dense>
{
  typedef const TensorConvolutionOp<Dimensions, InputXprType, KernelXprType>& type;
};

template<typename Dimensions, typename InputXprType, typename KernelXprType>
struct nested<TensorConvolutionOp<Dimensions, InputXprType, KernelXprType>, 1, typename eval<TensorConvolutionOp<Dimensions, InputXprType, KernelXprType> >::type>
{
  typedef TensorConvolutionOp<Dimensions, InputXprType, KernelXprType> type;
};

}  // end namespace internal



template<typename Indices, typename InputXprType, typename KernelXprType>
class TensorConvolutionOp : public TensorBase<TensorConvolutionOp<Indices, InputXprType, KernelXprType>, ReadOnlyAccessors>
{
  public:
  typedef typename Eigen::internal::traits<TensorConvolutionOp>::Scalar Scalar;
  typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
  typedef typename internal::promote_storage_type<typename InputXprType::CoeffReturnType,
                                                  typename KernelXprType::CoeffReturnType>::ret CoeffReturnType;
  typedef typename Eigen::internal::nested<TensorConvolutionOp>::type Nested;
  typedef typename Eigen::internal::traits<TensorConvolutionOp>::StorageKind StorageKind;
  typedef typename Eigen::internal::traits<TensorConvolutionOp>::Index Index;

  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorConvolutionOp(const InputXprType& input, const KernelXprType& kernel, const Indices& dims)
      : m_input_xpr(input), m_kernel_xpr(kernel), m_indices(dims) {}

    EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
    const Indices& indices() const { return m_indices; }

    /** \returns the nested expressions */
    EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
    const typename internal::remove_all<typename InputXprType::Nested>::type&
    inputExpression() const { return m_input_xpr; }

    EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
    const typename internal::remove_all<typename KernelXprType::Nested>::type&
    kernelExpression() const { return m_kernel_xpr; }

  protected:
    typename InputXprType::Nested m_input_xpr;
    typename KernelXprType::Nested m_kernel_xpr;
    const Indices m_indices;
};


template<typename Indices, typename InputArgType, typename KernelArgType, typename Device>
struct TensorEvaluator<const TensorConvolutionOp<Indices, InputArgType, KernelArgType>, Device>
{
  typedef TensorConvolutionOp<Indices, InputArgType, KernelArgType> XprType;

  static const int NumDims = internal::array_size<typename TensorEvaluator<InputArgType, Device>::Dimensions>::value;
  static const int NumKernelDims = internal::array_size<Indices>::value;
  typedef typename XprType::Index Index;
  typedef DSizes<Index, NumDims> Dimensions;

  typedef typename XprType::Scalar Scalar;
  typedef typename XprType::CoeffReturnType CoeffReturnType;
  typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
  static const int PacketSize = PacketType<CoeffReturnType, Device>::size;
  typedef StorageMemory<Scalar, Device> Storage;
  typedef typename Storage::Type EvaluatorPointerType;

  enum {
    IsAligned = int(TensorEvaluator<InputArgType, Device>::IsAligned) & int(TensorEvaluator<KernelArgType, Device>::IsAligned),
    PacketAccess = int(TensorEvaluator<InputArgType, Device>::PacketAccess) & int(TensorEvaluator<KernelArgType, Device>::PacketAccess),
    BlockAccess = false,
    PreferBlockAccess = false,
    Layout = TensorEvaluator<InputArgType, Device>::Layout,
    CoordAccess = false,  // to be implemented
    RawAccess = false
  };

  //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
  typedef internal::TensorBlockNotImplemented TensorBlock;
  //===--------------------------------------------------------------------===//

  EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
      : m_inputImpl(op.inputExpression(), device), m_kernelImpl(op.kernelExpression(), device), m_kernelArg(op.kernelExpression()), m_kernel(NULL), m_local_kernel(false), m_device(device)
  {
    EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<InputArgType, Device>::Layout) == static_cast<int>(TensorEvaluator<KernelArgType, Device>::Layout)), YOU_MADE_A_PROGRAMMING_MISTAKE);

    const typename TensorEvaluator<InputArgType, Device>::Dimensions& input_dims = m_inputImpl.dimensions();
    const typename TensorEvaluator<KernelArgType, Device>::Dimensions& kernel_dims = m_kernelImpl.dimensions();

    if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
      m_inputStride[0] = 1;
      for (int i = 1; i < NumDims; ++i) {
        m_inputStride[i] = m_inputStride[i - 1] * input_dims[i - 1];
      }
    } else {
      m_inputStride[NumDims - 1] = 1;
      for (int i = NumDims - 2; i >= 0; --i) {
        m_inputStride[i] = m_inputStride[i + 1] * input_dims[i + 1];
      }
    }

    m_dimensions = m_inputImpl.dimensions();
    if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
      for (int i = 0; i < NumKernelDims; ++i) {
        const Index index = op.indices()[i];
        const Index input_dim = input_dims[index];
        const Index kernel_dim = kernel_dims[i];
        const Index result_dim = input_dim - kernel_dim + 1;
        m_dimensions[index] = result_dim;
        if (i > 0) {
          m_kernelStride[i] = m_kernelStride[i - 1] * kernel_dims[i - 1];
        } else {
          m_kernelStride[0] = 1;
        }
        m_indexStride[i] = m_inputStride[index];
      }

      m_outputStride[0] = 1;
      for (int i = 1; i < NumDims; ++i) {
        m_outputStride[i] = m_outputStride[i - 1] * m_dimensions[i - 1];
      }
    } else {
      for (int i = NumKernelDims - 1; i >= 0; --i) {
        const Index index = op.indices()[i];
        const Index input_dim = input_dims[index];
        const Index kernel_dim = kernel_dims[i];
        const Index result_dim = input_dim - kernel_dim + 1;
        m_dimensions[index] = result_dim;
        if (i < NumKernelDims - 1) {
          m_kernelStride[i] = m_kernelStride[i + 1] * kernel_dims[i + 1];
        } else {
          m_kernelStride[NumKernelDims - 1] = 1;
        }
        m_indexStride[i] = m_inputStride[index];
      }

      m_outputStride[NumDims - 1] = 1;
      for (int i = NumDims - 2; i >= 0; --i) {
        m_outputStride[i] = m_outputStride[i + 1] * m_dimensions[i + 1];
      }
    }
  }

  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }

  EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar*) {
    m_inputImpl.evalSubExprsIfNeeded(NULL);
    preloadKernel();
    return true;
  }
  EIGEN_STRONG_INLINE void cleanup() {
    m_inputImpl.cleanup();
    if (m_local_kernel) {
      m_device.deallocate((void*)m_kernel);
      m_local_kernel = false;
    }
    m_kernel = NULL;
  }

  void evalTo(typename XprType::Scalar* buffer) {
    evalSubExprsIfNeeded(NULL);
    for (int i = 0; i < dimensions().TotalSize(); ++i) {
      buffer[i] += coeff(i);
    }
    cleanup();
  }

  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
  {
    CoeffReturnType result = CoeffReturnType(0);
    convolve(firstInput(index), 0, NumKernelDims-1, result);
    return result;
  }

  template<int LoadMode>
  EIGEN_DEVICE_FUNC PacketReturnType packet(const Index index) const
  {
    Index indices[2] = {index, index+PacketSize-1};
    Index startInputs[2] = {0, 0};
    if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
      for (int i = NumDims - 1; i > 0; --i) {
        const Index idx0 = indices[0] / m_outputStride[i];
        const Index idx1 = indices[1] / m_outputStride[i];
        startInputs[0] += idx0 * m_inputStride[i];
        startInputs[1] += idx1 * m_inputStride[i];
        indices[0] -= idx0 * m_outputStride[i];
        indices[1] -= idx1 * m_outputStride[i];
      }
    } else {
      for (int i = 0; i < NumDims - 1; ++i) {
        const Index idx0 = indices[0] / m_outputStride[i];
        const Index idx1 = indices[1] / m_outputStride[i];
        startInputs[0] += idx0 * m_inputStride[i];
        startInputs[1] += idx1 * m_inputStride[i];
        indices[0] -= idx0 * m_outputStride[i];
        indices[1] -= idx1 * m_outputStride[i];
      }
    }
    startInputs[0] += indices[0];
    startInputs[1] += indices[1];

    if (startInputs[1]-startInputs[0] == PacketSize-1) {
      PacketReturnType result = internal::pset1<PacketReturnType>(0);
      convolvePacket(startInputs[0], 0, NumKernelDims-1, result);
      return result;
    } else {
      EIGEN_ALIGN_MAX Scalar data[PacketSize];
      data[0] = Scalar(0);
      convolve(startInputs[0], 0, NumKernelDims-1, data[0]);
      for (int i = 1; i < PacketSize-1; ++i) {
        data[i] = Scalar(0);
        convolve(firstInput(index+i), 0, NumKernelDims-1, data[i]);
      }
      data[PacketSize-1] = Scalar(0);
      convolve(startInputs[1], 0, NumKernelDims-1, data[PacketSize-1]);
      return internal::pload<PacketReturnType>(data);
    }
  }

  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
  costPerCoeff(bool vectorized) const {
    const double kernel_size = m_kernelImpl.dimensions().TotalSize();
    // We ignore the use of fused multiply-add.
    const double convolve_compute_cost =
        TensorOpCost::AddCost<Scalar>() + TensorOpCost::MulCost<Scalar>();
    const double firstIndex_compute_cost =
        NumDims *
        (2 * TensorOpCost::AddCost<Index>() + 2 * TensorOpCost::MulCost<Index>() +
         TensorOpCost::DivCost<Index>());
    return TensorOpCost(0, 0, firstIndex_compute_cost, vectorized, PacketSize) +
           kernel_size * (m_inputImpl.costPerCoeff(vectorized) +
                          m_kernelImpl.costPerCoeff(vectorized) +
                          TensorOpCost(0, 0, convolve_compute_cost, vectorized,
                                       PacketSize));
  }

  EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return NULL; }

 private:
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index firstInput(Index index) const {
    Index startInput = 0;
    if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
      for (int i = NumDims - 1; i > 0; --i) {
        const Index idx = index / m_outputStride[i];
        startInput += idx * m_inputStride[i];
        index -= idx * m_outputStride[i];
      }
    } else {
      for (int i = 0; i < NumDims - 1; ++i) {
        const Index idx = index / m_outputStride[i];
        startInput += idx * m_inputStride[i];
        index -= idx * m_outputStride[i];
      }
    }
    startInput += index;
    return startInput;
  }

  EIGEN_DEVICE_FUNC void convolve(Index firstIndex, Index firstKernel, int DimIndex, CoeffReturnType& accum) const {
    for (int j = 0; j < m_kernelImpl.dimensions()[DimIndex]; ++j) {
      const Index input = firstIndex + j * m_indexStride[DimIndex];
      const Index kernel = firstKernel + j * m_kernelStride[DimIndex];
      if (DimIndex > 0) {
        convolve(input, kernel, DimIndex-1, accum);
      } else {
        accum += m_inputImpl.coeff(input) * m_kernel[kernel];
      }
    }
  }

  template <typename Packet>
  EIGEN_DEVICE_FUNC void convolvePacket(Index firstIndex, Index firstKernel, int DimIndex, Packet& accum) const {
    for (int j = 0; j < m_kernelImpl.dimensions()[DimIndex]; ++j) {
      const Index input = firstIndex + j * m_indexStride[DimIndex];
      const Index kernel = firstKernel + j * m_kernelStride[DimIndex];
      if (DimIndex > 0) {
        convolvePacket(input, kernel, DimIndex-1, accum);
      } else {
        accum = internal::pmadd<Packet>(m_inputImpl.template packet<Unaligned>(input), internal::pset1<Packet>(m_kernel[kernel]), accum);
      }
    }
  }

  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void preloadKernel() {
    // Don't make a local copy of the kernel unless we have to (i.e. it's an
    // expression that needs to be evaluated)
    const Scalar* in_place = m_kernelImpl.data();
    if (in_place) {
      m_kernel = in_place;
      m_local_kernel = false;
    } else {
      size_t kernel_sz = m_kernelImpl.dimensions().TotalSize() * sizeof(Scalar);
      Scalar* local = (Scalar*)m_device.allocate_temp(kernel_sz);
      typedef TensorEvalToOp<const KernelArgType> EvalTo;
      EvalTo evalToTmp(local, m_kernelArg);
      const bool Vectorize = internal::IsVectorizable<Device, KernelArgType>::value;
      internal::TensorExecutor<const EvalTo, Device, Vectorize>::run(evalToTmp, m_device);

      m_kernel = local;
      m_local_kernel = true;
    }
  }

  array<Index, NumDims> m_inputStride;
  array<Index, NumDims> m_outputStride;

  array<Index, NumKernelDims> m_indexStride;
  array<Index, NumKernelDims> m_kernelStride;
  TensorEvaluator<InputArgType, Device> m_inputImpl;
  TensorEvaluator<KernelArgType, Device> m_kernelImpl;
  Dimensions m_dimensions;

  KernelArgType m_kernelArg;
  const Scalar* m_kernel;
  bool m_local_kernel;
  const Device EIGEN_DEVICE_REF m_device;
};




// Use an optimized implementation of the evaluation code for GPUs whenever possible.
#if defined(EIGEN_USE_GPU) && defined(EIGEN_GPUCC)

template <int StaticKernelSize>
struct GetKernelSize {
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator() (const int /*kernelSize*/) const {
    return StaticKernelSize;
  }
};
template <>
struct GetKernelSize<Dynamic> {
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator() (const int kernelSize) const {
    return kernelSize;
  }
};

template <typename InputEvaluator, typename Index, typename InputDims,
          int StaticKernelSize>
__global__ EIGEN_HIP_LAUNCH_BOUNDS_1024 void EigenConvolutionKernel1D(
    InputEvaluator eval,
    const internal::IndexMapper<Index, InputDims, 1, InputEvaluator::Layout>
        indexMapper,
    const float* __restrict kernel, const int numPlanes, const int numX,
    const int maxX, const int kernelSize, float* buffer) {
#if defined(EIGEN_HIPCC)
  HIP_DYNAMIC_SHARED(float, s)
#else
  extern __shared__ float s[];
#endif

  const int first_x = blockIdx.x * maxX;
  const int last_x = (first_x + maxX < numX ? first_x + maxX : numX) - 1;
  const int num_x_input = last_x - first_x + GetKernelSize<StaticKernelSize>()(kernelSize);
  const int num_x_output = last_x - first_x + 1;

  const int first_plane = blockIdx.y * blockDim.y;
  const int plane_stride = blockDim.y * gridDim.y;

  for (int p = first_plane + threadIdx.y; p < numPlanes; p += plane_stride) {
    // Load inputs to shared memory
    const int plane_input_offset = indexMapper.mapGpuInputPlaneToTensorInputOffset(p);
    const int plane_kernel_offset = threadIdx.y * num_x_input;
    #pragma unroll
    for (int i = threadIdx.x; i < num_x_input; i += blockDim.x) {
      const int tensor_index = plane_input_offset + indexMapper.mapGpuInputKernelToTensorInputOffset(i+first_x);
      s[i + plane_kernel_offset] = eval.coeff(tensor_index);
    }

    __syncthreads();

    // Compute the convolution
    const int plane_output_offset = indexMapper.mapGpuOutputPlaneToTensorOutputOffset(p);

    #pragma unroll
    for (int i = threadIdx.x; i < num_x_output; i += blockDim.x) {
      const int kernel_offset = plane_kernel_offset + i;
      float result = 0.0f;
      #pragma unroll
      for (int k = 0; k < GetKernelSize<StaticKernelSize>()(kernelSize); ++k) {
        result += s[k + kernel_offset] * kernel[k];
      }
      const int tensor_index = plane_output_offset + indexMapper.mapGpuOutputKernelToTensorOutputOffset(i+first_x);
      buffer[tensor_index] = result;
    }
    __syncthreads();
  }
};

template <typename InputEvaluator, typename Index, typename InputDims,
          int StaticKernelSizeX, int StaticKernelSizeY>
__global__ EIGEN_HIP_LAUNCH_BOUNDS_1024 void EigenConvolutionKernel2D(
    InputEvaluator eval,
    const internal::IndexMapper<Index, InputDims, 2, InputEvaluator::Layout>
        indexMapper,
    const float* __restrict kernel, const int numPlanes, const int numX,
    const int maxX, const int numY, const int maxY, const int kernelSizeX,
    const int kernelSizeY, float* buffer) {
#if defined(EIGEN_HIPCC)
  HIP_DYNAMIC_SHARED(float, s)
#else
  extern __shared__ float s[];
#endif

  const int first_x = blockIdx.x * maxX;
  const int last_x = (first_x + maxX < numX ? first_x + maxX : numX) - 1;
  const int num_x_input = last_x - first_x + GetKernelSize<StaticKernelSizeX>()(kernelSizeX);
  const int num_x_output = last_x - first_x + 1;

  const int first_y = blockIdx.y * maxY;
  const int last_y = (first_y + maxY < numY ? first_y + maxY : numY) - 1;
  const int num_y_input = last_y - first_y + GetKernelSize<StaticKernelSizeY>()(kernelSizeY);
  const int num_y_output = last_y - first_y + 1;

  const int first_plane = blockIdx.z * blockDim.z;
  const int plane_stride = blockDim.z * gridDim.z;

  for (int p = first_plane + threadIdx.z; p < numPlanes; p += plane_stride) {

    const int plane_input_offset = indexMapper.mapGpuInputPlaneToTensorInputOffset(p);
    const int plane_kernel_offset = threadIdx.z * num_y_input;

    // Load inputs to shared memory
    #pragma unroll
    for (int j = threadIdx.y; j < num_y_input; j += blockDim.y) {
      const int input_offset = num_x_input * (j + plane_kernel_offset);
      #pragma unroll
      for (int i = threadIdx.x; i < num_x_input; i += blockDim.x) {
        const int tensor_index = plane_input_offset + indexMapper.mapGpuInputKernelToTensorInputOffset(i+first_x, j+first_y);
        s[i + input_offset] = eval.coeff(tensor_index);
      }
    }

    __syncthreads();

    // Convolution
    const int plane_output_offset = indexMapper.mapGpuOutputPlaneToTensorOutputOffset(p);

    #pragma unroll
    for (int j = threadIdx.y; j < num_y_output; j += blockDim.y) {
      #pragma unroll
      for (int i = threadIdx.x; i < num_x_output; i += blockDim.x) {
        float result = 0.0f;
        #pragma unroll
        for (int l = 0; l < GetKernelSize<StaticKernelSizeY>()(kernelSizeY); ++l) {
          const int kernel_offset = kernelSizeX * l;
          const int input_offset = i + num_x_input * (j + l + plane_kernel_offset);
          #pragma unroll
          for (int k = 0; k < GetKernelSize<StaticKernelSizeX>()(kernelSizeX); ++k) {
            result += s[k + input_offset] * kernel[k + kernel_offset];
          }
        }
        const int tensor_index = plane_output_offset + indexMapper.mapGpuOutputKernelToTensorOutputOffset(i+first_x, j+first_y);
        buffer[tensor_index] = result;
      }
    }

    __syncthreads();
  }
};

template <typename InputEvaluator, typename Index, typename InputDims>
__global__ EIGEN_HIP_LAUNCH_BOUNDS_1024 void EigenConvolutionKernel3D(
    InputEvaluator eval,
    const internal::IndexMapper<Index, InputDims, 3, InputEvaluator::Layout>
        indexMapper,
    const float* __restrict kernel, const size_t numPlanes, const size_t numX,
    const size_t maxX, const size_t numY, const size_t maxY, const size_t numZ,
    const size_t maxZ, const size_t kernelSizeX, const size_t kernelSizeY,
    const size_t kernelSizeZ, float* buffer) {
#if defined(EIGEN_HIPCC)
  HIP_DYNAMIC_SHARED(float, s)
#else
  extern __shared__ float s[];
#endif

  // Load inputs to shared memory
  const int first_x = blockIdx.x * maxX;
  const int last_x = (first_x + maxX < numX ? first_x + maxX : numX) - 1;
  const int num_x_input = last_x - first_x + kernelSizeX;

  const int first_y = blockIdx.y * maxY;
  const int last_y = (first_y + maxY < numY ? first_y + maxY : numY) - 1;
  const int num_y_input = last_y - first_y + kernelSizeY;

  const int first_z = blockIdx.z * maxZ;
  const int last_z = (first_z + maxZ < numZ ? first_z + maxZ : numZ) - 1;
  const int num_z_input = last_z - first_z + kernelSizeZ;

  for (int p = 0; p < numPlanes; ++p) {

    const int plane_input_offset = indexMapper.mapGpuInputPlaneToTensorInputOffset(p);
    const int plane_kernel_offset = 0;

    for (int k = threadIdx.z; k < num_z_input; k += blockDim.z) {
      for (int j = threadIdx.y; j < num_y_input; j += blockDim.y) {
        for (int i = threadIdx.x; i < num_x_input; i += blockDim.x) {
          const int tensor_index = plane_input_offset + indexMapper.mapGpuInputKernelToTensorInputOffset(i+first_x, j+first_y, k+first_z);
          s[i + num_x_input * (j + num_y_input * (k + plane_kernel_offset))] = eval.coeff(tensor_index);
        }
      }
    }

    __syncthreads();

    // Convolution
    const int num_z_output = last_z - first_z + 1;
    const int num_y_output = last_y - first_y + 1;
    const int num_x_output = last_x - first_x + 1;
    const int plane_output_offset = indexMapper.mapGpuOutputPlaneToTensorOutputOffset(p);

    for (int k = threadIdx.z; k < num_z_output; k += blockDim.z) {
      for (int j = threadIdx.y; j < num_y_output; j += blockDim.y) {
        for (int i = threadIdx.x; i < num_x_output; i += blockDim.x) {
          float result = 0.0f;
          for (int n = 0; n < kernelSizeZ; ++n) {
            for (int m = 0; m < kernelSizeY; ++m) {
              for (int l = 0; l < kernelSizeX; ++l) {
                result += s[i + l + num_x_input * (j + m + num_y_input * (k + n + plane_kernel_offset))] * kernel[l + kernelSizeX * (m + kernelSizeY * n)];
              }
            }
          }
          const int tensor_index = plane_output_offset + indexMapper.mapGpuOutputKernelToTensorOutputOffset(i+first_x, j+first_y, k+first_z);
          buffer[tensor_index] = result;
        }
      }
    }
    __syncthreads();
  }
};



template<typename Indices, typename InputArgType, typename KernelArgType>
struct TensorEvaluator<const TensorConvolutionOp<Indices, InputArgType, KernelArgType>, GpuDevice>
{
  typedef TensorConvolutionOp<Indices, InputArgType, KernelArgType> XprType;

  static const int NumDims =  internal::array_size<typename TensorEvaluator<InputArgType, GpuDevice>::Dimensions>::value;
  static const int NumKernelDims = internal::array_size<Indices>::value;
  typedef typename XprType::Index Index;
  typedef DSizes<Index, NumDims> Dimensions;
  typedef typename TensorEvaluator<KernelArgType, GpuDevice>::Dimensions KernelDimensions;

  enum {
    IsAligned = TensorEvaluator<InputArgType, GpuDevice>::IsAligned & TensorEvaluator<KernelArgType, GpuDevice>::IsAligned,
    PacketAccess = false,
    BlockAccess = false,
    PreferBlockAccess = false,
    Layout = TensorEvaluator<InputArgType, GpuDevice>::Layout,
    CoordAccess = false,  // to be implemented
    RawAccess = false
  };

  //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
  typedef internal::TensorBlockNotImplemented TensorBlock;
  //===--------------------------------------------------------------------===//

  TensorEvaluator(const XprType& op, const GpuDevice& device)
      : m_inputImpl(op.inputExpression(), device), m_kernelImpl(op.kernelExpression(), device), m_kernelArg(op.kernelExpression()), m_indices(op.indices()), m_buf(NULL), m_kernel(NULL), m_local_kernel(false), m_device(device)
  {
    EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<InputArgType, GpuDevice>::Layout) == static_cast<int>(TensorEvaluator<KernelArgType, GpuDevice>::Layout)), YOU_MADE_A_PROGRAMMING_MISTAKE);

    const typename TensorEvaluator<InputArgType, GpuDevice>::Dimensions& input_dims = m_inputImpl.dimensions();
    const typename TensorEvaluator<KernelArgType, GpuDevice>::Dimensions& kernel_dims = m_kernelImpl.dimensions();

    m_dimensions = m_inputImpl.dimensions();
    for (int i = 0; i < NumKernelDims; ++i) {
      const Index index = op.indices()[i];
      const Index input_dim = input_dims[index];
      const Index kernel_dim = kernel_dims[i];
      const Index result_dim = input_dim - kernel_dim + 1;
      m_dimensions[index] = result_dim;
    }
  }

  typedef typename XprType::CoeffReturnType CoeffReturnType;
  typedef typename PacketType<CoeffReturnType, GpuDevice>::type PacketReturnType;
  typedef typename InputArgType::Scalar Scalar;
  static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;

  EIGEN_DEVICE_FUNC const Dimensions& dimensions() const { return m_dimensions; }

  EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) {
    preloadKernel();
    m_inputImpl.evalSubExprsIfNeeded(NULL);
    if (data) {
      executeEval(data);
      return false;
    } else {
      m_buf = (Scalar*)m_device.allocate(dimensions().TotalSize() * sizeof(Scalar));
      executeEval(m_buf);
      return true;
    }
  }

  EIGEN_STRONG_INLINE void cleanup() {
    m_inputImpl.cleanup();
    if (m_buf) {
      m_device.deallocate(m_buf);
      m_buf = NULL;
    }
    if (m_local_kernel) {
      m_device.deallocate((void*)m_kernel);
      m_local_kernel = false;
    }
    m_kernel = NULL;
  }

  EIGEN_STRONG_INLINE void preloadKernel() {
    // Don't make a local copy of the kernel unless we have to (i.e. it's an
    // expression that needs to be evaluated)
    const Scalar* in_place = m_kernelImpl.data();
    if (in_place) {
      m_kernel = in_place;
      m_local_kernel = false;
    } else {
      size_t kernel_sz = m_kernelImpl.dimensions().TotalSize() * sizeof(Scalar);
      Scalar* local = (Scalar*)m_device.allocate(kernel_sz);
      typedef TensorEvalToOp<const KernelArgType> EvalTo;
      EvalTo evalToTmp(local, m_kernelArg);
      const bool PacketAccess = internal::IsVectorizable<GpuDevice, KernelArgType>::value;
      internal::TensorExecutor<const EvalTo, GpuDevice, PacketAccess>::run(evalToTmp, m_device);

      m_kernel = local;
      m_local_kernel = true;
    }
  }

  static unsigned int ceil(unsigned int num, unsigned int denom) {
    const unsigned int rounded_toward_zero = num / denom;
    if (num > rounded_toward_zero * denom) {
      return rounded_toward_zero + 1;
    }
    return rounded_toward_zero;
  }

  void executeEval(Scalar* data) const {
    typedef typename TensorEvaluator<InputArgType, GpuDevice>::Dimensions InputDims;

    const int maxSharedMem = m_device.sharedMemPerBlock();
    const int maxThreadsPerBlock = m_device.maxGpuThreadsPerBlock();
    const int maxBlocksPerProcessor = m_device.maxGpuThreadsPerMultiProcessor() / maxThreadsPerBlock;
    const int numMultiProcessors = m_device.getNumGpuMultiProcessors();
    const int warpSize = 32;

    switch (NumKernelDims) {
      case 1: {
        const int kernel_size = m_kernelImpl.dimensions().TotalSize();

        const int numX = dimensions()[m_indices[0]];
        const int numP = dimensions().TotalSize() / numX;
        int maxX;
        dim3 block_size;

        const int single_stride_dim =
            static_cast<int>(Layout) == static_cast<int>(ColMajor)
                ? 0
                : m_inputImpl.dimensions().rank() - 1;
        if (m_indices[0] == single_stride_dim) {
          // Maximum the reuse
          const int inner_dim = ((maxSharedMem / (sizeof(Scalar)) - kernel_size + 1 + 31) / 32) * 32;
          maxX = numext::mini<int>(inner_dim, numX);
          const int maxP = numext::mini<int>(maxSharedMem / ((kernel_size - 1 + maxX) * sizeof(Scalar)), numP);
          block_size.x = numext::mini(maxThreadsPerBlock, maxX);
          block_size.y = numext::mini<int>(maxThreadsPerBlock / block_size.x, maxP);
        }
        else {
          // Read as much as possible alongside the inner most dimension, that is the plane
          const int inner_dim = maxSharedMem / ((warpSize + kernel_size) * sizeof(Scalar));
          const int maxP = numext::mini<int>(inner_dim, numP);
          maxX = numext::mini<int>(maxSharedMem / (inner_dim * sizeof(Scalar)) - kernel_size + 1, numX);

          block_size.x = numext::mini(warpSize, maxX);
          block_size.y = numext::mini<int>(maxThreadsPerBlock/block_size.x, maxP);
        }

        const int shared_mem = block_size.y * (maxX + kernel_size - 1) * sizeof(Scalar);
        gpu_assert(shared_mem <= maxSharedMem);

        const int num_x_blocks = ceil(numX, maxX);
        const int blocksPerProcessor = numext::mini(maxBlocksPerProcessor, maxSharedMem / shared_mem);
        const int num_y_blocks = ceil(numMultiProcessors * blocksPerProcessor, num_x_blocks);

        dim3 num_blocks(num_x_blocks, numext::mini<int>(num_y_blocks, ceil(numP, block_size.y)));


        //cout << "launching 1D kernel with block_size.x: " << block_size.x << " block_size.y: " << block_size.y << " num_blocks.x: " << num_blocks.x << " num_blocks.y: " << num_blocks.y << " maxX: " << maxX << " shared_mem: " << shared_mem << " in stream " << m_device.stream() << endl;

        const array<Index, 1> indices(m_indices[0]);
        const array<Index, 1> kernel_dims(m_kernelImpl.dimensions()[0]);
        internal::IndexMapper<Index, InputDims, 1, Layout> indexMapper(
            m_inputImpl.dimensions(), kernel_dims, indices);
        switch(kernel_size) {
          case 4: {
            LAUNCH_GPU_KERNEL((EigenConvolutionKernel1D<TensorEvaluator<InputArgType, GpuDevice>, Index, InputDims, 4>), num_blocks, block_size, shared_mem, m_device, m_inputImpl, indexMapper, m_kernel, numP, numX, maxX, 4, data);
            break;
          }
          case 7: {
            LAUNCH_GPU_KERNEL((EigenConvolutionKernel1D<TensorEvaluator<InputArgType, GpuDevice>, Index, InputDims, 7>), num_blocks, block_size, shared_mem, m_device, m_inputImpl, indexMapper, m_kernel, numP, numX, maxX, 7, data);
            break;
          }
          default: {
            LAUNCH_GPU_KERNEL((EigenConvolutionKernel1D<TensorEvaluator<InputArgType, GpuDevice>, Index, InputDims, Dynamic>), num_blocks, block_size, shared_mem, m_device, m_inputImpl, indexMapper, m_kernel, numP, numX, maxX, kernel_size, data);
          }
        }
        break;
      }

      case 2: {
        const int idxX =
            static_cast<int>(Layout) == static_cast<int>(ColMajor) ? 0 : 1;
        const int idxY =
            static_cast<int>(Layout) == static_cast<int>(ColMajor) ? 1 : 0;
        const int kernel_size_x = m_kernelImpl.dimensions()[idxX];
        const int kernel_size_y = m_kernelImpl.dimensions()[idxY];

        const int numX = dimensions()[m_indices[idxX]];
        const int numY = dimensions()[m_indices[idxY]];
        const int numP = dimensions().TotalSize() / (numX*numY);

        const float scaling_factor = sqrtf(static_cast<float>(maxSharedMem) / (sizeof(Scalar) * kernel_size_y * kernel_size_x));

        // Snap maxX to warp size
        int inner_dim = ((static_cast<int>(scaling_factor * kernel_size_x) - kernel_size_x + 1 + 32) / 32) * 32;
        const int maxX = numext::mini<int>(inner_dim, numX);
        const int maxY = numext::mini<int>(maxSharedMem / (sizeof(Scalar) * (maxX + kernel_size_x - 1)) - kernel_size_y + 1, numY);
        const int maxP = numext::mini<int>(maxSharedMem / ((kernel_size_x - 1 + maxX) * (kernel_size_y - 1 + maxY) * sizeof(Scalar)), numP);

        dim3 block_size;
        block_size.x = numext::mini(1024, maxX);
        block_size.y = numext::mini<int>(1024/block_size.x, maxY);
        block_size.z = numext::mini<int>(1024/(block_size.x*block_size.y), maxP);

        const int shared_mem = block_size.z * (maxX + kernel_size_x - 1) * (maxY + kernel_size_y - 1) * sizeof(Scalar);
        gpu_assert(shared_mem <= maxSharedMem);

        const int num_x_blocks = ceil(numX, maxX);
        const int num_y_blocks = ceil(numY, maxY);
        const int blocksPerProcessor = numext::mini(maxBlocksPerProcessor, maxSharedMem / shared_mem);
        const int num_z_blocks = ceil(numMultiProcessors * blocksPerProcessor, num_x_blocks * num_y_blocks);

        dim3 num_blocks(num_x_blocks, num_y_blocks, numext::mini<int>(num_z_blocks, ceil(numP, block_size.z)));


        //cout << "launching 2D kernel with block_size.x: " << block_size.x << " block_size.y: " << block_size.y  << " block_size.z: " << block_size.z << " num_blocks.x: " << num_blocks.x << " num_blocks.y: " << num_blocks.y << " num_blocks.z: " << num_blocks.z << " maxX: " << maxX << " maxY: " << maxY << " maxP: " << maxP << " shared_mem: " << shared_mem << " in stream " << m_device.stream() << endl;

        const array<Index, 2> indices(m_indices[idxX], m_indices[idxY]);
        const array<Index, 2> kernel_dims(m_kernelImpl.dimensions()[idxX],
                                          m_kernelImpl.dimensions()[idxY]);
        internal::IndexMapper<Index, InputDims, 2, Layout> indexMapper(
            m_inputImpl.dimensions(), kernel_dims, indices);
        switch (kernel_size_x) {
          case 4: {
            switch (kernel_size_y) {
              case 7: {
                LAUNCH_GPU_KERNEL((EigenConvolutionKernel2D<TensorEvaluator<InputArgType, GpuDevice>, Index, InputDims, 4, 7>), num_blocks, block_size, shared_mem, m_device, m_inputImpl, indexMapper, m_kernel, numP, numX, maxX, numY, maxY, 4, 7, data);
                break;
              }
              default: {
                LAUNCH_GPU_KERNEL((EigenConvolutionKernel2D<TensorEvaluator<InputArgType, GpuDevice>, Index, InputDims, 4, Dynamic>), num_blocks, block_size, shared_mem, m_device, m_inputImpl, indexMapper, m_kernel, numP, numX, maxX, numY, maxY, 4, kernel_size_y, data);
                break;
              }
            }
            break;
          }
          case 7: {
            switch (kernel_size_y) {
              case 4: {
                LAUNCH_GPU_KERNEL((EigenConvolutionKernel2D<TensorEvaluator<InputArgType, GpuDevice>, Index, InputDims, 7, 4>), num_blocks, block_size, shared_mem, m_device, m_inputImpl, indexMapper, m_kernel, numP, numX, maxX, numY, maxY, 7, 4, data);
                break;
              }
              default: {
                LAUNCH_GPU_KERNEL((EigenConvolutionKernel2D<TensorEvaluator<InputArgType, GpuDevice>, Index, InputDims, 7, Dynamic>), num_blocks, block_size, shared_mem, m_device, m_inputImpl, indexMapper, m_kernel, numP, numX, maxX, numY, maxY, 7, kernel_size_y, data);
                break;
              }
            }
            break;
          }
          default: {
            LAUNCH_GPU_KERNEL((EigenConvolutionKernel2D<TensorEvaluator<InputArgType, GpuDevice>, Index, InputDims, Dynamic, Dynamic>), num_blocks, block_size, shared_mem, m_device, m_inputImpl, indexMapper, m_kernel, numP, numX, maxX, numY, maxY, kernel_size_x, kernel_size_y, data);
            break;
          }
        }
        break;
      }

      case 3: {
        const int idxX =
            static_cast<int>(Layout) == static_cast<int>(ColMajor) ? 0 : 2;
        const int idxY =
            static_cast<int>(Layout) == static_cast<int>(ColMajor) ? 1 : 1;
        const int idxZ =
            static_cast<int>(Layout) == static_cast<int>(ColMajor) ? 2 : 0;

        const int kernel_size_x = m_kernelImpl.dimensions()[idxX];
        const int kernel_size_y = m_kernelImpl.dimensions()[idxY];
        const int kernel_size_z = m_kernelImpl.dimensions()[idxZ];

        const int numX = dimensions()[m_indices[idxX]];
        const int numY = dimensions()[m_indices[idxY]];
        const int numZ = dimensions()[m_indices[idxZ]];
        const int numP = dimensions().TotalSize() / (numX*numY*numZ);

        const int maxX = numext::mini<int>(128, numext::mini<int>(maxSharedMem / (sizeof(Scalar) * kernel_size_y * kernel_size_z) - kernel_size_x + 1, numX));
        const int maxY = numext::mini<int>(128, numext::mini<int>(maxSharedMem / (sizeof(Scalar) * (maxX + kernel_size_x - 1) * kernel_size_z) - kernel_size_y + 1, numY));
        const int maxZ = numext::mini<int>(128, numext::mini<int>(maxSharedMem / (sizeof(Scalar) * (maxX + kernel_size_x - 1) * (maxY + kernel_size_y - 1)) - kernel_size_z + 1, numZ));

        dim3 block_size;
        block_size.x = numext::mini(32, maxX);
        block_size.y = numext::mini(32, maxY);
        block_size.z = numext::mini<int>(1024/(block_size.x*block_size.y), maxZ);
        dim3 num_blocks(ceil(numX, maxX), ceil(numY, maxY), ceil(numZ, maxZ));

        const int shared_mem = (maxX + kernel_size_x - 1) * (maxY + kernel_size_y - 1) * (maxZ + kernel_size_z - 1) * sizeof(Scalar);
        gpu_assert(shared_mem <= maxSharedMem);

        //cout << "launching 3D kernel with block_size.x: " << block_size.x << " block_size.y: " << block_size.y  << " block_size.z: " << block_size.z << " num_blocks.x: " << num_blocks.x << " num_blocks.y: " << num_blocks.y << " num_blocks.z: " << num_blocks.z  << " shared_mem: " << shared_mem << " in stream " << m_device.stream() << endl;
        const array<Index, 3> indices(m_indices[idxX], m_indices[idxY],
                                      m_indices[idxZ]);
        const array<Index, 3> kernel_dims(m_kernelImpl.dimensions()[idxX],
                                          m_kernelImpl.dimensions()[idxY],
                                          m_kernelImpl.dimensions()[idxZ]);
        internal::IndexMapper<Index, InputDims, 3, Layout> indexMapper(
            m_inputImpl.dimensions(), kernel_dims, indices);

        LAUNCH_GPU_KERNEL((EigenConvolutionKernel3D<TensorEvaluator<InputArgType, GpuDevice>, Index, InputDims>), num_blocks, block_size, shared_mem, m_device, m_inputImpl, indexMapper, m_kernel, numP, numX, maxX, numY, maxY, numZ, maxZ, kernel_size_x, kernel_size_y, kernel_size_z, data);
        break;
      }

      default: {
        EIGEN_STATIC_ASSERT((NumKernelDims >= 1 && NumKernelDims <= 3), THIS_METHOD_IS_ONLY_FOR_OBJECTS_OF_A_SPECIFIC_SIZE);
      }
    }
  }

  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
  {
    eigen_assert(m_buf);
    eigen_assert(index < m_dimensions.TotalSize());
    return m_buf[index];
  }

  template<int LoadMode>
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(const Index index) const
  {
    eigen_assert(m_buf);
    eigen_assert(index < m_dimensions.TotalSize());
    return internal::ploadt<PacketReturnType, LoadMode>(m_buf+index);
  }

  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
  costPerCoeff(bool vectorized) const {
    // TODO(rmlarsen): FIXME: For now, this is just a copy of the CPU cost
    // model.
    const double kernel_size = m_kernelImpl.dimensions().TotalSize();
    // We ignore the use of fused multiply-add.
    const double convolve_compute_cost =
        TensorOpCost::AddCost<Scalar>() + TensorOpCost::MulCost<Scalar>();
    const double firstIndex_compute_cost =
        NumDims *
        (2 * TensorOpCost::AddCost<Index>() + 2 * TensorOpCost::MulCost<Index>() +
         TensorOpCost::DivCost<Index>());
    return TensorOpCost(0, 0, firstIndex_compute_cost, vectorized, PacketSize) +
           kernel_size * (m_inputImpl.costPerCoeff(vectorized) +
                          m_kernelImpl.costPerCoeff(vectorized) +
                          TensorOpCost(0, 0, convolve_compute_cost, vectorized,
                                       PacketSize));
  }

 private:
  // No assignment (copies are needed by the kernels)
  TensorEvaluator& operator = (const TensorEvaluator&);

  TensorEvaluator<InputArgType, GpuDevice> m_inputImpl;
  TensorEvaluator<KernelArgType, GpuDevice> m_kernelImpl;
  KernelArgType m_kernelArg;
  Indices m_indices;
  Dimensions m_dimensions;
  Scalar* m_buf;
  const Scalar* m_kernel;
  bool m_local_kernel;

  const GpuDevice& m_device;
};
#endif


} // end namespace Eigen

#endif // EIGEN_CXX11_TENSOR_TENSOR_CONVOLUTION_H
