// Copyright 2017 The Gemmlowp Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// dispatch_gemm_shape.h: dispatch GEMM calls according to their shape

#ifndef GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_
#define GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_

#include "../internal/kernel_default.h"
#include "../public/map.h"
#include "../public/output_stages.h"
#include "single_thread_gemm.h"

namespace gemmlowp {

template <typename T>
struct TransposeImpl {
  typedef T DstType;
  static T Run(const T& t) { return t; }
};

template <typename T>
using TransposeType = typename TransposeImpl<T>::DstType;

template <typename T>
TransposeType<T> Transpose(const T& t) {
  return TransposeImpl<T>::Run(t);
}

template <MapOrder Order>
struct TransposeMapOrder {
  static constexpr MapOrder Value =
      Order == MapOrder::RowMajor ? MapOrder::ColMajor : MapOrder::RowMajor;
};

template <VectorShape Shape>
struct TransposeVectorShape {
  static constexpr VectorShape Value =
      Shape == VectorShape::Row ? VectorShape::Col : VectorShape::Row;
};

template <typename Scalar, VectorShape Shape>
struct TransposeImpl<VectorMap<Scalar, Shape>> {
  typedef VectorMap<Scalar, Shape> SrcType;
  static constexpr VectorShape TransposedShape =
      TransposeVectorShape<Shape>::Value;
  typedef VectorMap<Scalar, TransposedShape> DstType;
  static DstType Run(const SrcType& src) {
    return DstType(src.data(), src.size());
  }
};

template <typename Scalar, MapOrder Order>
struct TransposeImpl<MatrixMap<Scalar, Order>> {
  typedef MatrixMap<Scalar, Order> SrcType;
  static constexpr MapOrder TransposedOrder = TransposeMapOrder<Order>::Value;
  typedef MatrixMap<Scalar, TransposedOrder> DstType;
  static DstType Run(const SrcType& src) {
    return DstType(src.data(), src.cols(), src.rows(), src.stride());
  }
};

template <VectorShape Shape>
struct TransposeImpl<OutputStageQuantizeDownInt32ToUint8ScalePC<Shape>> {
  typedef OutputStageQuantizeDownInt32ToUint8ScalePC<Shape> SrcType;
  static const VectorShape TransposedShape = TransposeVectorShape<Shape>::Value;
  typedef OutputStageQuantizeDownInt32ToUint8ScalePC<TransposedShape> DstType;
  static DstType Run(const SrcType& src) {
    DstType dst;
    dst.result_shift = src.result_shift;
    dst.result_offset = Transpose(src.result_offset);
    dst.result_mult_int = Transpose(src.result_mult_int);
    return dst;
  }
};

template <VectorShape Shape>
struct TransposeImpl<OutputStageScaleInt32ByFixedPointAndExponentPC<Shape>> {
  typedef OutputStageScaleInt32ByFixedPointAndExponentPC<Shape> SrcType;
  static const VectorShape TransposedShape = TransposeVectorShape<Shape>::Value;
  typedef OutputStageScaleInt32ByFixedPointAndExponentPC<TransposedShape>
      DstType;
  static DstType Run(const SrcType& src) {
    DstType dst;
    dst.result_fixedpoint_multiplier =
        Transpose(src.result_fixedpoint_multiplier);
    dst.result_exponent = Transpose(src.result_exponent);
    dst.result_offset_after_shift = src.result_offset_after_shift;
    return dst;
  }
};

template <typename VectorMapType>
struct TransposeImpl<OutputStageBiasAddition<VectorMapType>> {
  typedef OutputStageBiasAddition<VectorMapType> SrcType;
  typedef TransposeType<VectorMapType> TransposedVectorMapType;
  typedef OutputStageBiasAddition<TransposedVectorMapType> DstType;
  static DstType Run(const SrcType& src) {
    DstType dst;
    dst.bias_vector = Transpose(src.bias_vector);
    return dst;
  }
};

// TODO(benoitjacob) - does anyone understand C++ variadic templates?
// How to use them to implement TransposeTuple? Note: there are lots
// of answers on StackOverflow but they seem to all involve either
// C++14/C++17 (we can only use C++11) or lots of abstract nonsense.
inline std::tuple<> TransposeTuple(const std::tuple<>& t) { return t; }

template <typename T0>
std::tuple<TransposeType<T0>> TransposeTuple(const std::tuple<T0>& t) {
  return std::make_tuple(Transpose(std::get<0>(t)));
}

template <typename T0, typename T1>
std::tuple<TransposeType<T0>, TransposeType<T1>> TransposeTuple(
    const std::tuple<T0, T1>& t) {
  return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)));
}

template <typename T0, typename T1, typename T2>
std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>>
TransposeTuple(const std::tuple<T0, T1, T2>& t) {
  return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
                         Transpose(std::get<2>(t)));
}

template <typename T0, typename T1, typename T2, typename T3>
std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>,
           TransposeType<T3>>
TransposeTuple(const std::tuple<T0, T1, T2, T3>& t) {
  return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
                         Transpose(std::get<2>(t)), Transpose(std::get<3>(t)));
}

template <typename T0, typename T1, typename T2, typename T3, typename T4>
std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>,
           TransposeType<T3>, TransposeType<T4>>
TransposeTuple(const std::tuple<T0, T1, T2, T3, T4>& t) {
  return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
                         Transpose(std::get<2>(t)), Transpose(std::get<3>(t)),
                         Transpose(std::get<4>(t)));
}

template <typename T0, typename T1, typename T2, typename T3, typename T4,
          typename T5>
std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>,
           TransposeType<T3>, TransposeType<T4>, TransposeType<T5>>
TransposeTuple(const std::tuple<T0, T1, T2, T3, T4, T5>& t) {
  return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
                         Transpose(std::get<2>(t)), Transpose(std::get<3>(t)),
                         Transpose(std::get<4>(t)), Transpose(std::get<5>(t)));
}

template <typename InputScalar, typename OutputScalar, typename BitDepthParams,
          MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder,
          typename LhsOffset, typename RhsOffset, typename OutputPipelineType,
          typename GemmContextType>
void DispatchGemmShape(GemmContextType* context,
                       const MatrixMap<const InputScalar, LhsOrder>& lhs,
                       const MatrixMap<const InputScalar, RhsOrder>& rhs,
                       MatrixMap<OutputScalar, ResultOrder>* result,
                       const LhsOffset& lhs_offset, const RhsOffset& rhs_offset,
                       const OutputPipelineType& output_pipeline) {
  assert(lhs.cols() == rhs.rows());

  int rows = result->rows();
  int cols = result->cols();
  int depth = lhs.cols();

  if (rows == 0 || cols == 0 || depth == 0) {
    // Vacuous GEMM, return early to avoid having to deal with
    // zero sizes below.
    return;
  }

  if (rows < cols) {
    auto transposed_result_map = Transpose(*result);
    return DispatchGemmShape<InputScalar, OutputScalar, BitDepthParams>(
        context, Transpose(rhs), Transpose(lhs), &transposed_result_map,
        Transpose(rhs_offset), Transpose(lhs_offset),
        TransposeTuple(output_pipeline));
  }

  typedef DefaultKernel<BitDepthParams> Kernel;

  //MultiThreadGemm<typename Kernel::Format, InputScalar, OutputScalar,
  //                BitDepthParams>(context, Kernel(), lhs, rhs, result,
  //                                lhs_offset, rhs_offset, output_pipeline);
  SingleThreadGemm<typename Kernel::Format, InputScalar, OutputScalar,
                  BitDepthParams>(context, Kernel(), lhs, rhs, result,
                                  lhs_offset, rhs_offset, output_pipeline);
}

}  // end namespace gemmlowp

#endif  // GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_
