#ifndef _SCALAR_FW_FW_IMPL_H_
#define _SCALAR_FW_FW_IMPL_H_

#include <limits>
#include <iomanip>

#include "fw.h"
#include "../oracle/lmo.h"

//! constructor
template<typename ix_t, typename val_t, class lmo_t>
AwayStepFrankWolfe<ix_t, val_t, lmo_t>::AwayStepFrankWolfe(
   const ix_t              dim,
   fun_t<ix_t, val_t>&     obj_fun,
   ip_fun_t<ix_t, val_t>&  grad_fun,
   lmo_t&                  lmo,
   bool                    verbose,
   val_t                   epsilon
   )
   : m_dim(dim), m_obj_fun(obj_fun), m_grad_fun(grad_fun), m_lmo(lmo), m_lmo_calls(0), m_v(dim), m_as(dim), m_verbose(verbose), m_epsilon(epsilon)
{
}


//! destructor
template<typename ix_t, typename val_t, class lmo_t>
AwayStepFrankWolfe<ix_t, val_t, lmo_t>::~AwayStepFrankWolfe()
{
}

/** main solution method
 *
 *  The solution is available via get_solution() and an optimal vertex via get_vertex().
 *  Returns true if a cutting plane could be produced.
 */
template<typename ix_t, typename val_t, class lmo_t>
bool
AwayStepFrankWolfe<ix_t, val_t, lmo_t>::operator()(
   const vec_t<val_t>& xsep,                //! (fractional) point to be separated
   const vec_t<val_t>& xint,                //! relative interior point
   LC_CLOCK*           fwclock,             //! clock for measuring total FW time
   LC_CLOCK*           oracleclock,         //! clock for measuring oracle time
   AwayStepFrankWolfe<ix_t, val_t, lmo_t>::solution_info_t& result, //! result data
   const ix_t max_iters,                    //! maximal number of iterations
   const val_t primallimit,                 //! stop FW if primal value falls below this limit
   const ix_t print_ln                      //! output frequency
   )
{
   assert( xsep.size() == (unsigned) m_dim );
   assert( xint.size() == (unsigned) m_dim );

   if ( m_verbose )
   {
      std::cout << std::endl;
      std::cout << "-----------------------------------------------------------------------------------------------------------------------------" << std::endl;
      std::cout << std::setw(6) << "Type" <<
         std::setw(14) << "Iteration" <<
         std::setw(15) << "Primal" <<
         std::setw(15) << "Dual" <<
         std::setw(15) << "Dual Gap" <<
         std::setw(15) << "Violation" <<
         std::setw(15) << "Time" <<
         std::setw(15) << "It/Sec" <<
         std::setw(15) << "#ActiveSet" <<
         std::endl;
      std::cout << "-----------------------------------------------------------------------------------------------------------------------------" << std::endl;
   }

   double starttime = LC_clockGetTime(fwclock);
   LC_clockStart(fwclock);

   // initialization
   m_lmo_calls = 0;

   // compute gradient at relative interior point
   vec_t<val_t> gradient(m_dim);
   m_grad_fun(gradient, xint);

   // compute starting point x0: vertex for optimizing the current gradient
   LC_clockStart(oracleclock);
   (void) m_lmo(-gradient, m_v);   // negative gradient because we are minimizing here
   ++m_lmo_calls;
   LC_clockStop(oracleclock);

   // store vertex
   vec_t<val_t> v = m_v;
   assert( v.size() == (unsigned) m_dim );

   // initialize solution with vertex
   m_as.initialize(v); // this is the starting point x0

   // early exit if we are optimal
   SCIP_Real dual_gap = gradient.dot(xint - v);
   assert( dual_gap >= - m_epsilon ); // dual_gap should be nonnegative since v minimizes gradient
   if ( fabs(dual_gap) < m_epsilon )
   {
      LC_clockStop(fwclock);

      // store result
      result.primal = m_obj_fun(v);
      result.dual_gap = dual_gap;
      result.norm_viol = calc_norm_viol(gradient, v, xsep);
      result.exitcode = exit_early;
      result.noraclecalls = (int) m_lmo_calls;
      result.niters = 0;
      assert( m_as.x().size() == (unsigned) m_dim );

      if ( m_verbose )
      {
         std::cout << std::setw(6) << "Last" <<
            std::setw(14) << 0 <<
            std::setw(15) << std::scientific << std::setprecision(6) << result.primal <<
            std::setw(15) << std::scientific << std::setprecision(6) << 0.0 <<
            std::setw(15) << std::scientific << std::setprecision(6) << 0.0 <<
            std::setw(15) << std::scientific << std::setprecision(6) << result.norm_viol <<
            std::setw(15) << std::fixed << std::setprecision(6) << LC_clockGetTime(fwclock) - starttime <<
            std::setw(15) << std::fixed << std::setprecision(6) << 0.0 <<
            std::setw(15) << m_as.size() << std::endl;

         std::cout << "-----------------------------------------------------------------------------------------------------------------------------" << std::endl;
         std::cout << "Early exit." << std::endl;
      }
      return true;
   }

   // init FW paramters
   val_t phi = dual_gap;
   val_t norm_viol = calc_norm_viol(gradient, v, xsep);
   val_t K = 2.0;

   val_t primal;
   val_t dual;

   val_t membership_lhs = 0.0;
   val_t membership_rhs = std::numeric_limits<val_t>::max();

   bool away_step_taken;
   bool fw_step_taken;
   bool lazy_step_taken;
   bool dual_step_taken;

   vec_t<val_t> d(m_dim);
   val_t gamma = 1.0;
   val_t gamma_max;

   ix_t iters = 0;

   // start iterations; current point is m_as.x()
   do
   {
      // compute current gradient
      m_grad_fun(gradient, m_as.x());

      // compute FW step from current iterate m_as.x() with current gradient
      lazy_afw_step(m_as.x(), xsep, gradient, oracleclock, v, phi, norm_viol, d, gamma_max, away_step_taken, fw_step_taken, lazy_step_taken, dual_step_taken, K);

      // separate check in case gradient is exactly 0, which means that the point is exactly on the boundary
      if ( fabs(gradient.norm()) < m_epsilon )
      {
         LC_clockStop(fwclock);

         // store result
         result.primal = 0.0;
         assert( m_obj_fun(m_as.x()) <= 1e-6 );
         result.dual_gap = 0.0;
         result.norm_viol = 0.0;
         result.exitcode = exit_zero_gradient;
         result.noraclecalls = (int) m_lmo_calls;
         result.niters = (int) iters;
         assert( m_as.x().size() == (unsigned) m_dim );

         if ( m_verbose )
         {
            double t = LC_clockGetTime(fwclock) - starttime;
            std::cout << std::setw(6) << "Last" <<
               std::setw(14) << iters <<
               std::setw(15) << std::scientific << std::setprecision(6) << 0.0 <<
               std::setw(15) << std::scientific << std::setprecision(6) << 0.0 <<
               std::setw(15) << std::scientific << std::setprecision(6) << 0.0 <<
               std::setw(15) << std::scientific << std::setprecision(6) << calc_norm_viol(gradient, v, xsep) <<
               std::setw(15) << std::fixed << std::setprecision(6) << t <<
               std::setw(15) << std::fixed << std::setprecision(6) << (iters + 1) / t <<
               std::setw(15) << m_as.size() <<
               std::endl;

            std::cout << "-----------------------------------------------------------------------------------------------------------------------------" << std::endl;
            std::cout << "Exit - gradient is 0." << std::endl;
         }
         return false;
      }

      // recompute early termination criterion
      if ( fw_step_taken || dual_step_taken )
      {
         membership_lhs = - v.dot(gradient);
         membership_rhs = m_obj_fun(m_as.x()) - (m_as.x() - xsep).dot(m_as.x());

         // test termination criterion
         if ( membership_lhs < membership_rhs )
            break;
      }

      // compute gaps
      primal = m_obj_fun(m_as.x());
      dual_gap = phi;
      dual = primal - dual_gap;

      // in addition to the termination criteria below, we stop if the primal gap is below primallimit
      if ( primal < primallimit )
         break;

      if ( fw_step_taken || away_step_taken || lazy_step_taken )
      {
         // perform line search
         assert( fabs(d.squaredNorm()) >= m_epsilon );
         gamma = gradient.dot(d) / d.squaredNorm();
         gamma = std::min(gamma, gamma_max);

         // update current iterate point
         if ( away_step_taken )
         {
            m_as.update(std::make_pair(-gamma, v), true);
         }
         else
         {
            m_as.update(std::make_pair(gamma, v), true);
         }

         // the following is probably not needed, but we leave it for numerical safety for now
         m_as.renormalize(); /*lint !e523*/
         m_as.cleanup();
         m_as.update_x();
      }

      if ( m_verbose && (iters % print_ln == 0) )
      {
         std::string step_kind = "FW";
         if ( away_step_taken )
            step_kind = "A";
         if ( lazy_step_taken )
            step_kind = "L";
         if ( dual_step_taken )
            step_kind = "LD";

         double t = LC_clockGetTime(fwclock) - starttime;
         std::cout << std::setw(6) << step_kind <<
            std::setw(14) << iters <<
            std::setw(15) << std::scientific << std::setprecision(6) << primal <<
            std::setw(15) << std::scientific << std::setprecision(6) << dual <<
            std::setw(15) << std::scientific << std::setprecision(6) << dual_gap <<
            std::setw(15) << std::scientific << std::setprecision(6) << norm_viol <<
            std::setw(15) << std::fixed << std::setprecision(6) << t <<
            std::setw(15) << std::fixed << std::setprecision(6) << (iters + 1) / t <<
            std::setw(15) << m_as.size() <<
            std::endl;
      }

      ++iters;
   }
   while( iters < max_iters );

   // recompute everything for final verification
   m_as.renormalize(); /*lint !e523*/
   m_as.cleanup();
   m_as.update_x();

   m_grad_fun(gradient, m_as.x());

   LC_clockStart(oracleclock);
   (void) m_lmo(-gradient, m_v);   // negative gradient because we are minimizing here
   ++m_lmo_calls;
   LC_clockStop(oracleclock);

   v = m_v;
   primal = m_obj_fun(m_as.x());
   dual_gap = m_as.x().dot(gradient) - v.dot(gradient);
   dual = primal - dual_gap;
   norm_viol = calc_norm_viol(gradient, v, xsep);

   // store result
   result.primal = primal;
   result.dual_gap = dual_gap;
   result.norm_viol = norm_viol;
   if ( membership_lhs < membership_rhs )
      result.exitcode = exit_term_check;
   else if ( primal < primallimit )
      result.exitcode = exit_primal_gap;
   else
   {
      assert( iters >= max_iters );
      result.exitcode = exit_max_iters;
   }
   result.noraclecalls = (int) m_lmo_calls;
   result.niters = (int) iters;
   assert( m_as.x().size() == (unsigned) m_dim );

   LC_clockStop(fwclock);
   if (m_verbose)
   {
      double t = LC_clockGetTime(fwclock) - starttime;
      std::cout << std::setw(6) << "Last" <<
         std::setw(14) << iters <<
         std::setw(15) << std::scientific << std::setprecision(6) << primal <<
         std::setw(15) << std::scientific << std::setprecision(6) << dual <<
         std::setw(15) << std::scientific << std::setprecision(6) << dual_gap <<
         std::setw(15) << std::scientific << std::setprecision(6) << norm_viol <<
         std::setw(15) << std::fixed << std::setprecision(6) << t <<
         std::setw(15) << std::fixed << std::setprecision(6) << iters / t <<
         std::setw(15) << m_as.size() << std::endl;
         std::cout << "-----------------------------------------------------------------------------------------------------------------------------" << std::endl;
   }

   if (iters < max_iters && m_verbose)
      std::cout << "Exit after " << iters << " Iterations." << std::endl;

   return (iters < max_iters && primal >= primallimit );
}

//! return number of calls to linear optimization oracle
template<typename ix_t, typename val_t, class lmo_t>
   ix_t
   AwayStepFrankWolfe<ix_t, val_t, lmo_t>::lmo_calls()
{
   return m_lmo_calls;
}

//! get best vertex of last run
template<typename ix_t, typename val_t, class lmo_t>
const vec_t<val_t>&
AwayStepFrankWolfe<ix_t, val_t, lmo_t>::get_vertex()
{
   return m_v;
}

//! return solution
template<typename ix_t, typename val_t, class lmo_t>
const vec_t<val_t>&
AwayStepFrankWolfe<ix_t, val_t, lmo_t>::get_solution()
{
   return m_as.x();
}

//! perform main step
template<typename ix_t, typename val_t, class lmo_t>
   void
   AwayStepFrankWolfe<ix_t, val_t, lmo_t>::lazy_afw_step(
      const vec_t<val_t>& x,                // current iterate
      const vec_t<val_t>& xsep,             // point to be separated
      const vec_t<val_t>& gradient,         // current gradient
      LC_CLOCK*           oracleclock,      //!< clock for measuring oracle time
      vec_t<val_t>&       v,                // current vertex for minimizing gradient
      val_t&              phi,              // phi value
      val_t&              norm_viol,        // violation of point
      vec_t<val_t>&       d,                // new direction
      val_t&              gamma_max,        // steplength
      bool&               away_step_taken,  // whether an away step was taken
      bool&               fw_step_taken,    // whether a FW step was taken
      bool&               lazy_step_taken,  // whether a lazy step was taken
      bool&               dual_step_taken,  // whether a dual step was taken
      const val_t K
      )
{
   using entry_t = typename ActiveSet<ix_t, val_t>::entry_t;

   std::pair<entry_t, entry_t> minmaxpair = m_as.argminmax(gradient);

   const vec_t<val_t> lv = minmaxpair.first.second;

   const val_t a_lambda = minmaxpair.second.first;
   assert( a_lambda > 0.0 );
   const vec_t<val_t> a = minmaxpair.second.second;

   const val_t grad_dot_lazy_fw_vertex = lv.dot(gradient);
   const val_t grad_dot_x = x.dot(gradient);
   const val_t grad_dot_a = a.dot(gradient);

#if 0
   std::cout << "phi: " << phi << std::endl;
   std::cout << "gradient: " << gradient.transpose() << std::endl;
   std::cout << "x: " << x.transpose() << std::endl;
   std::cout << "lv: " << lv.transpose() << std::endl;
   std::cout << "a: " << a.transpose() << std::endl;
#endif

   away_step_taken = false;
   fw_step_taken = false;
   lazy_step_taken = false;
   dual_step_taken = false;

   if ( (grad_dot_x - grad_dot_lazy_fw_vertex >= grad_dot_a - grad_dot_x) && (grad_dot_x - grad_dot_lazy_fw_vertex >= phi / K) )
   {
      // do a lazy step
      gamma_max = 1.0;
      d = x - lv;
      v = lv;
      lazy_step_taken = true;
      // norm_viol unchanged
      // do not change phi
   }
   else
   {
      // do away step as it promises enough benefit
      if ( (grad_dot_a - grad_dot_x > grad_dot_x - grad_dot_lazy_fw_vertex) && (grad_dot_a - grad_dot_x >= phi / K) )
      {
         gamma_max = a_lambda / (1.0 - a_lambda);
         d = a - x;
         v = a;
         away_step_taken = true;
         // norm_viol unchanged
         // do not change phi
      }
      else
      {
         // update termination values
         LC_clockStart(oracleclock);
         (void) m_lmo(-gradient, m_v);   // negative gradient because we are minimizing here
         ++m_lmo_calls;
         LC_clockStop(oracleclock);

         // save current vertex
         v = m_v;
         assert( v.size() == (unsigned) m_dim );

         // check whether dual gap promises enough progress
         const val_t grad_dot_fw_vertex = v.dot(gradient);
         const val_t dual_gap = grad_dot_x - grad_dot_fw_vertex;

         if ( dual_gap >= phi / K )
         {
            gamma_max = 1.0;
            d = x - v;
            fw_step_taken = true;
            norm_viol = calc_norm_viol(gradient, v, xsep);
            // do not change phi
         }
         else
         {
            phi = std::min(dual_gap, phi / 2.0);
            gamma_max = 0.0;
            d = vec_t<val_t>::Zero(x.rows());
            dual_step_taken = true;
            norm_viol = calc_norm_viol(gradient, v, xsep);
         }
      }
   }
}

//! compute normalized violation of cut
template<typename ix_t, typename val_t, class lmo_t>
val_t
AwayStepFrankWolfe<ix_t, val_t, lmo_t>::calc_norm_viol(
   const vec_t<val_t>& gradient,        //! gradient, i.e., normal vector of cut
   const vec_t<val_t>& v,               //! optimal vertex
   const vec_t<val_t>& xsep             //! point to be separated
   )
{
   /* Euclidean norm by default. */
   val_t gnorm = gradient.norm();
   if ( gnorm == 0.0 )
      return 0.0;
   return - (xsep.dot(gradient) - v.dot(gradient)) / gnorm;
}

#endif // _SCALAR_FW_FW_IMPL_H_
