// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2008-2014 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

#ifndef EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
#define EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H

namespace Eigen {

namespace internal {


// perform a pseudo in-place sparse * sparse product assuming all matrices are col major
template<typename Lhs, typename Rhs, typename ResultType>
static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res, const typename ResultType::RealScalar& tolerance)
{
  // return sparse_sparse_product_with_pruning_impl2(lhs,rhs,res);

  typedef typename remove_all<Rhs>::type::Scalar RhsScalar;
  typedef typename remove_all<ResultType>::type::Scalar ResScalar;
  typedef typename remove_all<Lhs>::type::StorageIndex StorageIndex;

  // make sure to call innerSize/outerSize since we fake the storage order.
  Index rows = lhs.innerSize();
  Index cols = rhs.outerSize();
  //Index size = lhs.outerSize();
  eigen_assert(lhs.outerSize() == rhs.innerSize());

  // allocate a temporary buffer
  AmbiVector<ResScalar,StorageIndex> tempVector(rows);

  // mimics a resizeByInnerOuter:
  if(ResultType::IsRowMajor)
    res.resize(cols, rows);
  else
    res.resize(rows, cols);

  evaluator<Lhs> lhsEval(lhs);
  evaluator<Rhs> rhsEval(rhs);

  // estimate the number of non zero entries
  // given a rhs column containing Y non zeros, we assume that the respective Y columns
  // of the lhs differs in average of one non zeros, thus the number of non zeros for
  // the product of a rhs column with the lhs is X+Y where X is the average number of non zero
  // per column of the lhs.
  // Therefore, we have nnz(lhs*rhs) = nnz(lhs) + nnz(rhs)
  Index estimated_nnz_prod = lhsEval.nonZerosEstimate() + rhsEval.nonZerosEstimate();

  res.reserve(estimated_nnz_prod);
  double ratioColRes = double(estimated_nnz_prod)/(double(lhs.rows())*double(rhs.cols()));
  for (Index j=0; j<cols; ++j)
  {
    // FIXME:
    //double ratioColRes = (double(rhs.innerVector(j).nonZeros()) + double(lhs.nonZeros())/double(lhs.cols()))/double(lhs.rows());
    // let's do a more accurate determination of the nnz ratio for the current column j of res
    tempVector.init(ratioColRes);
    tempVector.setZero();
    for (typename evaluator<Rhs>::InnerIterator rhsIt(rhsEval, j); rhsIt; ++rhsIt)
    {
      // FIXME should be written like this: tmp += rhsIt.value() * lhs.col(rhsIt.index())
      tempVector.restart();
      RhsScalar x = rhsIt.value();
      for (typename evaluator<Lhs>::InnerIterator lhsIt(lhsEval, rhsIt.index()); lhsIt; ++lhsIt)
      {
        tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x;
      }
    }
    res.startVec(j);
    for (typename AmbiVector<ResScalar,StorageIndex>::Iterator it(tempVector,tolerance); it; ++it)
      res.insertBackByOuterInner(j,it.index()) = it.value();
  }
  res.finalize();
}

template<typename Lhs, typename Rhs, typename ResultType,
  int LhsStorageOrder = traits<Lhs>::Flags&RowMajorBit,
  int RhsStorageOrder = traits<Rhs>::Flags&RowMajorBit,
  int ResStorageOrder = traits<ResultType>::Flags&RowMajorBit>
struct sparse_sparse_product_with_pruning_selector;

template<typename Lhs, typename Rhs, typename ResultType>
struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor>
{
  typedef typename ResultType::RealScalar RealScalar;

  static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
  {
    typename remove_all<ResultType>::type _res(res.rows(), res.cols());
    internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,ResultType>(lhs, rhs, _res, tolerance);
    res.swap(_res);
  }
};

template<typename Lhs, typename Rhs, typename ResultType>
struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,RowMajor>
{
  typedef typename ResultType::RealScalar RealScalar;
  static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
  {
    // we need a col-major matrix to hold the result
    typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> SparseTemporaryType;
    SparseTemporaryType _res(res.rows(), res.cols());
    internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,SparseTemporaryType>(lhs, rhs, _res, tolerance);
    res = _res;
  }
};

template<typename Lhs, typename Rhs, typename ResultType>
struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,RowMajor>
{
  typedef typename ResultType::RealScalar RealScalar;
  static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
  {
    // let's transpose the product to get a column x column product
    typename remove_all<ResultType>::type _res(res.rows(), res.cols());
    internal::sparse_sparse_product_with_pruning_impl<Rhs,Lhs,ResultType>(rhs, lhs, _res, tolerance);
    res.swap(_res);
  }
};

template<typename Lhs, typename Rhs, typename ResultType>
struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor>
{
  typedef typename ResultType::RealScalar RealScalar;
  static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
  {
    typedef SparseMatrix<typename Lhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixLhs;
    typedef SparseMatrix<typename Rhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixRhs;
    ColMajorMatrixLhs colLhs(lhs);
    ColMajorMatrixRhs colRhs(rhs);
    internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,ColMajorMatrixRhs,ResultType>(colLhs, colRhs, res, tolerance);

    // let's transpose the product to get a column x column product
//     typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType;
//     SparseTemporaryType _res(res.cols(), res.rows());
//     sparse_sparse_product_with_pruning_impl<Rhs,Lhs,SparseTemporaryType>(rhs, lhs, _res);
//     res = _res.transpose();
  }
};

template<typename Lhs, typename Rhs, typename ResultType>
struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,RowMajor>
{
  typedef typename ResultType::RealScalar RealScalar;
  static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
  {
    typedef SparseMatrix<typename Lhs::Scalar,RowMajor,typename Lhs::StorageIndex> RowMajorMatrixLhs;
    RowMajorMatrixLhs rowLhs(lhs);
    sparse_sparse_product_with_pruning_selector<RowMajorMatrixLhs,Rhs,ResultType,RowMajor,RowMajor>(rowLhs,rhs,res,tolerance);
  }
};

template<typename Lhs, typename Rhs, typename ResultType>
struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,RowMajor>
{
  typedef typename ResultType::RealScalar RealScalar;
  static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
  {
    typedef SparseMatrix<typename Rhs::Scalar,RowMajor,typename Lhs::StorageIndex> RowMajorMatrixRhs;
    RowMajorMatrixRhs rowRhs(rhs);
    sparse_sparse_product_with_pruning_selector<Lhs,RowMajorMatrixRhs,ResultType,RowMajor,RowMajor,RowMajor>(lhs,rowRhs,res,tolerance);
  }
};

template<typename Lhs, typename Rhs, typename ResultType>
struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,ColMajor>
{
  typedef typename ResultType::RealScalar RealScalar;
  static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
  {
    typedef SparseMatrix<typename Rhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixRhs;
    ColMajorMatrixRhs colRhs(rhs);
    internal::sparse_sparse_product_with_pruning_impl<Lhs,ColMajorMatrixRhs,ResultType>(lhs, colRhs, res, tolerance);
  }
};

template<typename Lhs, typename Rhs, typename ResultType>
struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,ColMajor>
{
  typedef typename ResultType::RealScalar RealScalar;
  static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
  {
    typedef SparseMatrix<typename Lhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixLhs;
    ColMajorMatrixLhs colLhs(lhs);
    internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,Rhs,ResultType>(colLhs, rhs, res, tolerance);
  }
};

} // end namespace internal

} // end namespace Eigen

#endif // EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
