// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Mehdi Goli    Codeplay Software Ltd.
// Ralph Potter  Codeplay Software Ltd.
// Luke Iwanski  Codeplay Software Ltd.
// Contact: <eigen@codeplay.com>
// Copyright (C) 2016 Benoit Steiner <benoit.steiner.goog@gmail.com>

//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONVOLUTION_SYCL_H
#define EIGEN_CXX11_TENSOR_TENSOR_CONVOLUTION_SYCL_H

namespace Eigen {

/** \class TensorConvolution
 * \ingroup CXX11_Tensor_Module
 *
 * \brief Tensor convolution class.
 *
 *
 */

enum class convolution_type { CONV1D, CONV2D, CONV3D };
template <typename Evaluator, typename CoeffReturnType, typename KernelType, typename Index, typename InputDims,
          typename Kernel_accessor, typename Buffer_accessor, convolution_type Conv_Dim>
struct EigenConvolutionKernel;
template <typename Evaluator, typename CoeffReturnType, typename KernelType, typename Index, typename InputDims,
          typename Kernel_accessor, typename Buffer_accessor>
struct EigenConvolutionKernel<Evaluator, CoeffReturnType, KernelType, Index, InputDims, Kernel_accessor,
                              Buffer_accessor, convolution_type::CONV1D> {
  typedef cl::sycl::accessor<CoeffReturnType, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local>
      Local_accessor;
  Local_accessor local_acc;
  Evaluator device_evaluator;
  Kernel_accessor kernel_filter;
  Buffer_accessor buffer_acc;
  internal::IndexMapper<Index, InputDims, 1, Evaluator::Layout> indexMapper;
  const size_t kernelSize;
  const cl::sycl::range<2> input_range;
  EigenConvolutionKernel(Local_accessor local_acc_, Evaluator device_evaluator_, Kernel_accessor kernel_filter_,
                         Buffer_accessor buffer_acc_,
                         internal::IndexMapper<Index, InputDims, 1, Evaluator::Layout> indexMapper_,
                         const size_t kernelSize_, const cl::sycl::range<2> input_range_)
      : local_acc(local_acc_),
        device_evaluator(device_evaluator_),
        kernel_filter(kernel_filter_),
        buffer_acc(buffer_acc_),
        indexMapper(indexMapper_),
        kernelSize(kernelSize_),
        input_range(input_range_) {}

  template <typename BooleanDim2>
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool boundary_check(const BooleanDim2 boolean_check) {
    return (boolean_check[0] && boolean_check[1]);
  }
  void operator()(cl::sycl::nd_item<2> itemID) {
    auto buffer_ptr = buffer_acc.get_pointer();
    auto kernel_ptr = kernel_filter.get_pointer();
    // the required row to be calculated for the for each plane in shered memory
    const size_t num_input = (itemID.get_local_range()[0] + kernelSize - 1);
    const size_t plane_kernel_offset = itemID.get_local_id(1) * num_input;
    const size_t input_offset = itemID.get_group(0) * itemID.get_local_range()[0];
    const size_t plane_tensor_offset = indexMapper.mapGpuInputPlaneToTensorInputOffset(itemID.get_global_id(1));
    /// fill the shared memory
    for (size_t i = itemID.get_local_id(0); i < num_input; i += itemID.get_local_range()[0]) {
      const size_t local_index = i + plane_kernel_offset;
      const size_t tensor_index =
          plane_tensor_offset + indexMapper.mapGpuInputKernelToTensorInputOffset(i + input_offset);

      local_acc[local_index] =
          (((i + input_offset) < (input_range[0] + kernelSize - 1)) && itemID.get_global_id(1) < input_range[1])
              ? device_evaluator.coeff(tensor_index)
              : CoeffReturnType(0);
    }

    itemID.barrier(cl::sycl::access::fence_space::local_space);

    // calculate the convolution // output start x
    const size_t first_output_start = itemID.get_group(0) * (itemID.get_local_range()[0]);
    if (boundary_check(itemID.get_global_id() < input_range)) {
      CoeffReturnType result = static_cast<CoeffReturnType>(0);
      const size_t index = plane_kernel_offset + itemID.get_local_id(0);
      for (size_t k = 0; k < kernelSize; ++k) {
        result += (local_acc[k + index] * kernel_ptr[k]);
      }
      const size_t tensor_index =
          indexMapper.mapGpuOutputPlaneToTensorOutputOffset(itemID.get_global_id(1)) +
          indexMapper.mapGpuOutputKernelToTensorOutputOffset(itemID.get_local_id(0) + first_output_start);
      buffer_ptr[tensor_index] = result;
    }
  }
};

template <typename Evaluator, typename CoeffReturnType, typename KernelType, typename Index, typename InputDims,
          typename Kernel_accessor, typename Buffer_accessor>
struct EigenConvolutionKernel<Evaluator, CoeffReturnType, KernelType, Index, InputDims, Kernel_accessor,
                              Buffer_accessor, convolution_type::CONV2D> {
  typedef cl::sycl::accessor<CoeffReturnType, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local>
      Local_accessor;
  Local_accessor local_acc;
  Evaluator device_evaluator;
  Kernel_accessor kernel_filter;
  Buffer_accessor buffer_acc;
  internal::IndexMapper<Index, InputDims, 2, Evaluator::Layout> indexMapper;
  const cl::sycl::range<2> kernel_size;
  const cl::sycl::range<3> input_range;
  EigenConvolutionKernel(Local_accessor local_acc_, Evaluator device_evaluator_, Kernel_accessor kernel_filter_,
                         Buffer_accessor buffer_acc_,
                         internal::IndexMapper<Index, InputDims, 2, Evaluator::Layout> indexMapper_,
                         const cl::sycl::range<2> kernel_size_, const cl::sycl::range<3> input_range_)
      : local_acc(local_acc_),
        device_evaluator(device_evaluator_),
        kernel_filter(kernel_filter_),
        buffer_acc(buffer_acc_),
        indexMapper(indexMapper_),
        kernel_size(kernel_size_),
        input_range(input_range_) {}
  template <typename BooleanDim3>
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool boundary_check(const BooleanDim3 boolean_check) {
    return (boolean_check[0] && boolean_check[1] && boolean_check[2]);
  }

  void operator()(cl::sycl::nd_item<3> itemID) {
    auto buffer_ptr = buffer_acc.get_pointer();
    auto kernel_ptr = kernel_filter.get_pointer();
    // the required row to be calculated for the for each plane in shered memory
    const auto num_input = cl::sycl::range<2>{
        (cl::sycl::range<2>(itemID.get_local_range()[0], itemID.get_local_range()[1]) + kernel_size - 1)};

    const size_t plane_input_offset = indexMapper.mapGpuInputPlaneToTensorInputOffset(itemID.get_global_id(2));
    const size_t plane_kernel_offset = itemID.get_local_id(2) * num_input[1];

    const auto input_offset = cl::sycl::range<2>{itemID.get_group(0) * itemID.get_local_range()[0],
                                                 itemID.get_group(1) * itemID.get_local_range()[1]};
      
    // fill the local memory
    bool in_range_dim2 = itemID.get_global_id(2) < input_range[2];
    for (size_t j = itemID.get_local_id(1); j < num_input[1]; j += itemID.get_local_range()[1]) {
      const size_t local_input_offset = num_input[0] * (j + plane_kernel_offset);
      bool in_range_dim1 = ((j + input_offset[1]) < (input_range[1] + kernel_size[1] - 1)); 
      for (size_t i = itemID.get_local_id(0); i < num_input[0]; i += itemID.get_local_range()[0]) {
        const size_t local_index = i + local_input_offset;
        const size_t tensor_index = plane_input_offset + indexMapper.mapGpuInputKernelToTensorInputOffset(
                                                             i + input_offset[0], j + input_offset[1]);
        local_acc[local_index] = (((i + input_offset[0]) < (input_range[0] + kernel_size[0] - 1)) &&
                                  in_range_dim1 && in_range_dim2)
                                     ? device_evaluator.coeff(tensor_index)
                                     : CoeffReturnType(0);
      }
    }

    itemID.barrier(cl::sycl::access::fence_space::local_space);

    // output offset start for each thread
    const auto output_offset = cl::sycl::range<2>{itemID.get_group(0) * itemID.get_local_range()[0],
                                                  itemID.get_group(1) * itemID.get_local_range()[1]};

    if (boundary_check(itemID.get_global_id() < input_range)) {
      CoeffReturnType result = static_cast<CoeffReturnType>(0);

      for (size_t j = 0; j < kernel_size[1]; j++) {
        size_t kernel_offset = kernel_size[0] * j;
        const size_t index =
            (num_input[0] * (plane_kernel_offset + j + itemID.get_local_id(1))) + itemID.get_local_id(0);
        for (size_t i = 0; i < kernel_size[0]; i++) {
          result += (local_acc[i + index] * kernel_ptr[i + kernel_offset]);
        }
      }
      const size_t tensor_index =
          indexMapper.mapGpuOutputPlaneToTensorOutputOffset(itemID.get_global_id(2)) +
          indexMapper.mapGpuOutputKernelToTensorOutputOffset(itemID.get_local_id(0) + output_offset[0],
                                                             itemID.get_local_id(1) + output_offset[1]);

      buffer_ptr[tensor_index] = result;
    }
  }
};

template <typename Evaluator, typename CoeffReturnType, typename KernelType, typename Index, typename InputDims,
          typename Kernel_accessor, typename Buffer_accessor>
struct EigenConvolutionKernel<Evaluator, CoeffReturnType, KernelType, Index, InputDims, Kernel_accessor,
                              Buffer_accessor, convolution_type::CONV3D> {
  typedef cl::sycl::accessor<CoeffReturnType, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local>
      Local_accessor;
  Local_accessor local_acc;
  Evaluator device_evaluator;
  Kernel_accessor kernel_filter;
  Buffer_accessor buffer_acc;
  internal::IndexMapper<Index, InputDims, 3, Evaluator::Layout> indexMapper;
  const cl::sycl::range<3> kernel_size;
  const cl::sycl::range<3> input_range;
  const size_t numP;

  EigenConvolutionKernel(Local_accessor local_acc_, Evaluator device_evaluator_, Kernel_accessor kernel_filter_,
                         Buffer_accessor buffer_acc_,
                         internal::IndexMapper<Index, InputDims, 3, Evaluator::Layout> indexMapper_,
                         const cl::sycl::range<3> kernel_size_, const cl::sycl::range<3> input_range_,
                         const size_t numP_)
      : local_acc(local_acc_),
        device_evaluator(device_evaluator_),
        kernel_filter(kernel_filter_),
        buffer_acc(buffer_acc_),
        indexMapper(indexMapper_),
        kernel_size(kernel_size_),
        input_range(input_range_),
        numP(numP_) {}
  template <typename BooleanDim3>
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool boundary_check(const BooleanDim3 boolean_check) {
    return (boolean_check[0] && boolean_check[1] && boolean_check[2]);
  }
  void operator()(cl::sycl::nd_item<3> itemID) {
    auto buffer_ptr = buffer_acc.get_pointer();
    auto kernel_ptr = kernel_filter.get_pointer();
    const auto num_input = cl::sycl::range<3>{itemID.get_local_range() + kernel_size - 1};

    const auto input_offset = cl::sycl::range<3>{itemID.get_group().get_id() * itemID.get_local_range()};

    const auto output_offset =
          cl::sycl::range<3>{itemID.get_group().get_id() * itemID.get_local_range() + itemID.get_local_id()};

    for (size_t p = 0; p < numP; p++) {
      /// fill the shared memory
      const size_t plane_input_offset = indexMapper.mapGpuInputPlaneToTensorInputOffset(p);
      for (size_t k = itemID.get_local_id(2); k < num_input[2]; k += itemID.get_local_range()[2]) {
        size_t local_index_dim2 = num_input[0] * num_input[1] * k;
        bool cond_k_dim = (k + input_offset[2] < (input_range[2] + kernel_size[2] - 1));
        for (size_t j = itemID.get_local_id(1); j < num_input[1]; j += itemID.get_local_range()[1]) {
          bool cond_j_dim = cond_k_dim && (j + input_offset[1] < (input_range[1] + kernel_size[1] - 1));
          size_t local_index_dim1 = (num_input[0] * j)  + local_index_dim2;
          for (size_t i = itemID.get_local_id(0); i < num_input[0]; i += itemID.get_local_range()[0]) {
            bool conds = cond_j_dim && (i + input_offset[0] < (input_range[0] + kernel_size[0] - 1));
            const size_t local_index = local_index_dim1 + i;
            const size_t tensor_index =
                plane_input_offset + indexMapper.mapGpuInputKernelToTensorInputOffset(
                                         i + input_offset[0], j + input_offset[1], k + input_offset[2]);
            local_acc[local_index] = conds ? device_evaluator.coeff(tensor_index) : CoeffReturnType(0);
          }
        }
      }
      itemID.barrier(cl::sycl::access::fence_space::local_space);

      // calculate the convolution

      if (boundary_check(itemID.get_global_id() < input_range)) {
        CoeffReturnType result = static_cast<CoeffReturnType>(0);
        for (size_t k = 0; k < kernel_size[2]; k++) {
          for (size_t j = 0; j < kernel_size[1]; j++) {
            for (size_t i = 0; i < kernel_size[0]; i++) {
              const size_t kernel_index = i + kernel_size[0] * (j + kernel_size[1] * k);
              const size_t local_index =
                  ((i + itemID.get_local_id(0)) +
                   num_input[0] * ((j + itemID.get_local_id(1)) + num_input[1] * (k + itemID.get_local_id(2))));

              result += (local_acc[local_index] * kernel_ptr[kernel_index]);
            }
          }
        }
        const size_t tensor_index =
            indexMapper.mapGpuOutputPlaneToTensorOutputOffset(p) +
            indexMapper.mapGpuOutputKernelToTensorOutputOffset(output_offset[0], output_offset[1], output_offset[2]);
        buffer_ptr[tensor_index] = result;
      }

      itemID.barrier(cl::sycl::access::fence_space::local_space);
    }
  }
};

template <typename Indices, typename InputArgType, typename KernelArgType>
struct TensorEvaluator<const TensorConvolutionOp<Indices, InputArgType, KernelArgType>, Eigen::SyclDevice> {
  typedef TensorConvolutionOp<Indices, InputArgType, KernelArgType> XprType;

  static const int NumDims =
      internal::array_size<typename TensorEvaluator<InputArgType, Eigen::SyclDevice>::Dimensions>::value;
  static const int NumKernelDims = internal::array_size<Indices>::value;
  typedef typename XprType::Index Index;
  typedef DSizes<Index, NumDims> Dimensions;
  typedef typename TensorEvaluator<KernelArgType, Eigen::SyclDevice>::Dimensions KernelDimensions;
  typedef const Eigen::SyclDevice Device;
  typedef typename XprType::CoeffReturnType CoeffReturnType;
  typedef typename PacketType<CoeffReturnType, Eigen::SyclDevice>::type PacketReturnType;
  typedef typename InputArgType::Scalar Scalar;
  static const int PacketSize = PacketType<CoeffReturnType, Device>::size;
  typedef StorageMemory<CoeffReturnType, Eigen::SyclDevice> Storage;
  typedef typename Storage::Type EvaluatorPointerType;
  typedef StorageMemory<const CoeffReturnType, Eigen::SyclDevice> KernelStorage;

  enum {
    IsAligned = TensorEvaluator<InputArgType, Eigen::SyclDevice>::IsAligned &
                TensorEvaluator<KernelArgType, Eigen::SyclDevice>::IsAligned,
    PacketAccess = false,
    BlockAccess = false,
    PreferBlockAccess = false,
    Layout = TensorEvaluator<InputArgType, Eigen::SyclDevice>::Layout,
    CoordAccess = false,  // to be implemented
    RawAccess = false
  };

  //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
  typedef internal::TensorBlockNotImplemented TensorBlock;
  //===--------------------------------------------------------------------===//

  EIGEN_DEVICE_FUNC TensorEvaluator(const XprType &op, const Eigen::SyclDevice &device)
      : m_inputImpl(op.inputExpression(), device),
        m_kernelArg(op.kernelExpression()),
        m_kernelImpl(op.kernelExpression(), device),
        m_indices(op.indices()),
        m_buf(NULL),
        m_kernel(NULL),
        m_local_kernel(false),
        m_device(device) {
    EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<InputArgType, Eigen::SyclDevice>::Layout) ==
                         static_cast<int>(TensorEvaluator<KernelArgType, Eigen::SyclDevice>::Layout)),
                        YOU_MADE_A_PROGRAMMING_MISTAKE);

    const typename TensorEvaluator<InputArgType, Eigen::SyclDevice>::Dimensions &input_dims = m_inputImpl.dimensions();
    const typename TensorEvaluator<KernelArgType, Eigen::SyclDevice>::Dimensions &kernel_dims =
        m_kernelImpl.dimensions();

    m_dimensions = m_inputImpl.dimensions();
    for (int i = 0; i < NumKernelDims; ++i) {
      const Index index = op.indices()[i];
      const Index input_dim = input_dims[index];
      const Index kernel_dim = kernel_dims[i];
      const Index result_dim = input_dim - kernel_dim + 1;
      m_dimensions[index] = result_dim;
    }
  }

  EIGEN_DEVICE_FUNC const Dimensions &dimensions() const { return m_dimensions; }

  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType data) {
    preloadKernel();
    m_inputImpl.evalSubExprsIfNeeded(NULL);
    if (data) {
      executeEval(data);
      return false;
    } else {
      m_buf = (EvaluatorPointerType)m_device.get(
          (Scalar *)m_device.allocate_temp(dimensions().TotalSize() * sizeof(Scalar)));
      executeEval(m_buf);
      return true;
    }
  }

  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
    m_inputImpl.cleanup();
    if (m_buf) {
      m_device.deallocate_temp(m_buf);
      m_buf = NULL;
    }
    if (m_local_kernel) {
      m_device.deallocate_temp(m_kernel);
      m_local_kernel = false;
    }
    m_kernel = NULL;
  }
  /// used by sycl in order to build the sycl buffer
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Device &device() const { return m_device; }
  /// used by sycl in order to build the sycl buffer
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EvaluatorPointerType data() const { return m_buf; }

  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void preloadKernel() {
    // Don't make a local copy of the kernel unless we have to (i.e. it's an
    // expression that needs to be evaluated)
    typename KernelStorage::Type in_place = m_kernelImpl.data();
    if (in_place) {
      m_kernel = in_place;
      m_local_kernel = false;
    } else {
      ptrdiff_t kernel_sz = m_kernelImpl.dimensions().TotalSize() * sizeof(Scalar);
      EvaluatorPointerType local = (EvaluatorPointerType)m_device.get((Scalar *)m_device.allocate_temp(kernel_sz));
      typedef TensorEvalToOp<const KernelArgType> EvalTo;
      EvalTo evalToTmp(m_device.get(local), m_kernelArg);
      const bool PacketAccess = internal::IsVectorizable<Eigen::SyclDevice, KernelArgType>::value;
      internal::TensorExecutor<const EvalTo, Eigen::SyclDevice, PacketAccess>::run(evalToTmp, m_device);
      m_kernel = local;
      m_local_kernel = true;
    }
  }

  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void executeEval(EvaluatorPointerType data) const {
    typedef TensorEvaluator<InputArgType, Eigen::SyclDevice> InputEvaluator;
    typedef typename InputEvaluator::Dimensions InputDims;
    switch (NumKernelDims) {
      case 1: {
        const size_t numX = dimensions()[m_indices[0]];
        const size_t numP = dimensions().TotalSize() / numX;
        const auto input_dim = std::array<size_t, 2>{numX, numP};
        auto global_range = cl::sycl::range<2>{};
        auto local_range = cl::sycl::range<2>{};
        const size_t kernel_size = m_kernelImpl.dimensions().TotalSize();

        m_device.parallel_for_setup(input_dim, global_range, local_range);
        const size_t local_memory_size = (local_range[0] + kernel_size - 1) * (local_range[1]);
        gpu_assert(static_cast<unsigned long>(local_memory_size) <= m_device.sharedMemPerBlock());
        const array<Index, 1> indices{{m_indices[0]}};
        const array<Index, 1> kernel_dims{{m_kernelImpl.dimensions()[0]}};
        internal::IndexMapper<Index, InputDims, 1, Layout> indexMapper(m_inputImpl.dimensions(), kernel_dims, indices);

        typedef EigenConvolutionKernel<InputEvaluator, CoeffReturnType, Scalar, Index, InputDims,
                                       typename KernelStorage::Type, EvaluatorPointerType, convolution_type::CONV1D>
            ConvKernel;

        m_device.template binary_kernel_launcher<CoeffReturnType, ConvKernel>(
            m_inputImpl, m_kernel, data, cl::sycl::nd_range<2>(global_range, local_range), local_memory_size,
            indexMapper, kernel_size, cl::sycl::range<2>(input_dim[0], input_dim[1]));
        break;
      }

      case 2: {
        auto kernel_index = std::array<size_t, 2>{static_cast<int>(Layout) == static_cast<int>(ColMajor) ? 0 : 1,
                                                  static_cast<int>(Layout) == static_cast<int>(ColMajor) ? 1 : 0};
        auto kernel_size = cl::sycl::range<2>{(size_t)m_kernelImpl.dimensions()[kernel_index[0]],
                                              (size_t)m_kernelImpl.dimensions()[kernel_index[1]]};
        const size_t numX = dimensions()[m_indices[kernel_index[0]]];
        const size_t numY = dimensions()[m_indices[kernel_index[1]]];
        const size_t numP = dimensions().TotalSize() / (numX * numY);
        auto input_dim = std::array<size_t, 3>{numX, numY, numP};

        auto global_range = cl::sycl::range<3>{};
        auto local_range = cl::sycl::range<3>{};

        m_device.parallel_for_setup(input_dim, global_range, local_range);

        const size_t local_memory_size =
            (local_range[0] + kernel_size[0] - 1) * (local_range[1] + kernel_size[1] - 1) * local_range[2];
        gpu_assert(static_cast<unsigned long>(local_memory_size) <= m_device.sharedMemPerBlock());
        const array<Index, 2> indices{{m_indices[kernel_index[0]], m_indices[kernel_index[1]]}};
        const array<Index, 2> kernel_dims{
            {m_kernelImpl.dimensions()[kernel_index[0]], m_kernelImpl.dimensions()[kernel_index[1]]}};
        internal::IndexMapper<Index, InputDims, 2, Layout> indexMapper(m_inputImpl.dimensions(), kernel_dims, indices);
        typedef EigenConvolutionKernel<InputEvaluator, CoeffReturnType, Scalar, Index, InputDims,
                                       typename KernelStorage::Type, EvaluatorPointerType, convolution_type::CONV2D>
            ConvKernel;
        m_device.template binary_kernel_launcher<CoeffReturnType, ConvKernel>(
            m_inputImpl, m_kernel, data, cl::sycl::nd_range<3>(global_range, local_range), local_memory_size,
            indexMapper, kernel_size, cl::sycl::range<3>{input_dim[0], input_dim[1], input_dim[2]});
        break;
      }

      case 3: {
        auto kernel_index = std::array<size_t, 3>{static_cast<int>(Layout) == static_cast<int>(ColMajor) ? 0 : 2,
                                                  static_cast<int>(Layout) == static_cast<int>(ColMajor) ? 1 : 1,
                                                  static_cast<int>(Layout) == static_cast<int>(ColMajor) ? 2 : 0};

        auto kernel_size = cl::sycl::range<3>{(size_t)m_kernelImpl.dimensions()[kernel_index[0]],
                                              (size_t)m_kernelImpl.dimensions()[kernel_index[1]],
                                              (size_t)m_kernelImpl.dimensions()[kernel_index[2]]};

        const size_t numX = dimensions()[m_indices[kernel_index[0]]];
        const size_t numY = dimensions()[m_indices[kernel_index[1]]];
        const size_t numZ = dimensions()[m_indices[kernel_index[2]]];
        auto input_dim = std::array<size_t, 3>{numX, numY, numZ};
        const size_t numP = dimensions().TotalSize() / (numX * numY * numZ);

        const array<Index, 3> indices{
            {m_indices[kernel_index[0]], m_indices[kernel_index[1]], m_indices[kernel_index[2]]}};
        const array<Index, 3> kernel_dims{{m_kernelImpl.dimensions()[kernel_index[0]],
                                           m_kernelImpl.dimensions()[kernel_index[1]],
                                           m_kernelImpl.dimensions()[kernel_index[2]]}};

        internal::IndexMapper<Index, InputDims, 3, Layout> indexMapper(m_inputImpl.dimensions(), kernel_dims, indices);

        auto global_range = cl::sycl::range<3>{};
        auto local_range = cl::sycl::range<3>{};

        m_device.parallel_for_setup(input_dim, global_range, local_range);
        auto local_memory_range = (local_range + kernel_size - 1);
        const size_t local_memory_size = local_memory_range[0] * local_memory_range[1] * local_memory_range[2];

        gpu_assert(static_cast<unsigned long>(local_memory_size) <= m_device.sharedMemPerBlock());
        typedef EigenConvolutionKernel<InputEvaluator, CoeffReturnType, Scalar, Index, InputDims,
                                       typename KernelStorage::Type, EvaluatorPointerType, convolution_type::CONV3D>
            ConvKernel;
        m_device.template binary_kernel_launcher<CoeffReturnType, ConvKernel>(
            m_inputImpl, m_kernel, data, cl::sycl::nd_range<3>(global_range, local_range), local_memory_size,
            indexMapper, kernel_size, cl::sycl::range<3>(input_dim[0], input_dim[1], input_dim[2]), numP);
        break;
      }

      default: {
        EIGEN_STATIC_ASSERT((NumKernelDims >= 1 && NumKernelDims <= 3),
                            THIS_METHOD_IS_ONLY_FOR_OBJECTS_OF_A_SPECIFIC_SIZE);
      }
    }
  }

  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
    eigen_assert(m_buf != NULL);
    eigen_assert(index < m_dimensions.TotalSize());
    return m_buf[index];
  }

  template <int LoadMode>
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(const Index index) const {
    eigen_assert(m_buf != NULL);
    eigen_assert(index < m_dimensions.TotalSize());
    return internal::ploadt<PacketReturnType, LoadMode>(m_buf + index);
  }

  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
    // TODO(rmlarsen): FIXME: For now, this is just a copy of the CPU cost
    // model.
    const double kernel_size = m_kernelImpl.dimensions().TotalSize();
    // We ignore the use of fused multiply-add.
    const double convolve_compute_cost = TensorOpCost::AddCost<Scalar>() + TensorOpCost::MulCost<Scalar>();
    const double firstIndex_compute_cost =
        NumDims *
        (2 * TensorOpCost::AddCost<Index>() + 2 * TensorOpCost::MulCost<Index>() + TensorOpCost::DivCost<Index>());
    return TensorOpCost(0, 0, firstIndex_compute_cost, vectorized, PacketSize) +
           kernel_size * (m_inputImpl.costPerCoeff(vectorized) + m_kernelImpl.costPerCoeff(vectorized) +
                          TensorOpCost(0, 0, convolve_compute_cost, vectorized, PacketSize));
  }
  // binding placeholder accessors to a command group handler for SYCL
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const {
    m_kernelImpl.bind(cgh);
    m_inputImpl.bind(cgh);
    m_buf.bind(cgh);
    m_kernel.bind(cgh);
  }

 private:
  // No assignment (copies are needed by the kernels)
  TensorEvaluator &operator=(const TensorEvaluator &);
  TensorEvaluator<InputArgType, Eigen::SyclDevice> m_inputImpl;
  KernelArgType m_kernelArg;
  TensorEvaluator<KernelArgType, Eigen::SyclDevice> m_kernelImpl;
  Indices m_indices;
  Dimensions m_dimensions;
  EvaluatorPointerType m_buf;
  typename KernelStorage::Type m_kernel;
  bool m_local_kernel;
  const Eigen::SyclDevice EIGEN_DEVICE_REF m_device;
};  // namespace Eigen

}  // end namespace Eigen

#endif  // EIGEN_CXX11_TENSOR_TENSOR_CONVOLUTION_H
