// AFW implementation
#include "types.h"
#include "fw.impl.h"
#include "active_set.impl.h"
#include "../oracle/lmo.impl.h"

// external interface
#include "fwknapcut.h"
#include "../timer.hpp"

#include <stdexcept>


//! run away-step Frank Wolfe algorithm using knapsack oracles
FWTuple* scalar_afw_knap_gen_cut(
   SCIP*                        scip,        // scip instance
   const int                    dim,         // dimension of knapsack
   const SCIP_Longint*          weights,     // integer weights of knapsack
   const SCIP_Longint           capacity,    // capacity of knapsack
   const double*                xsep,        //! point to be separated -(size dim)
   const double*                xint,        //! relative interior point (size dim)
   LC_CLOCK*                    fwclock,     //! clock for measuring total FW time
   LC_CLOCK*                    oracleclock, //! clock for measuring oracle time
   double                       primallimit, //! stop FW if primal value falls below this limit
   uint64_t                     max_iters    //! maximal number of iterations
   )
{
   // parameters for AFW
   using ix_t = uint64_t;
   using val_t = double;
   using lmo_t = KnapsackOracle<ix_t, val_t>;

   const ix_t step_print = 100;

   // copy points
   vec_t<SCIP_Longint> Weights(dim);
   vec_t<val_t> Xsep(dim);
   vec_t<val_t> Xint(dim);
   for (int i = 0; i < dim; ++i)
   {
      Weights[i] = weights[i];
      Xsep[i] = xsep[i];
      Xint[i] = xint[i];
   }

   // set up objective function
   fun_t<ix_t, val_t> lc_obj_fun = [&Xsep](const vec_t<val_t>& x)
   {
      return 0.5 * (x - Xsep).squaredNorm();
   };

   // set up gradient of objective function
   ip_fun_t<ix_t, val_t> lc_gradient = [&Xsep](vec_t<val_t>& gradient, const vec_t<val_t>& x)
   {
      gradient = x - Xsep;
   };

   // allocate FWTuple
   FWTuple* tpl = (FWTuple*) malloc(sizeof(FWTuple));
   tpl->a = (val_t*) malloc(dim * sizeof(val_t));  /*lint !e737*/

   tpl->success = 0;
   tpl->frac_viol = 0.0;
   tpl->noraclecalls = 0;
   tpl->niters = 0;
   tpl->exitcode = exit_not_run;

   // set up algorithms
   lmo_t lmo(scip, dim, Weights, capacity);
   AwayStepFrankWolfe<ix_t, val_t, lmo_t> afw(dim, lc_obj_fun, lc_gradient, lmo, false); /*lint !e732*/

   // run FW
   typename AwayStepFrankWolfe<ix_t, val_t, lmo_t>::solution_info_t solution;
#ifndef NDEBUG
   bool afw_success = afw(Xsep, Xint, fwclock, oracleclock, solution, max_iters, primallimit, step_print);
#else
   (void) afw(Xsep, Xint, fwclock, oracleclock, solution, max_iters, primallimit, step_print);
#endif

   if ( solution.norm_viol > 0.0 )
   {
      vec_t<val_t> gradient(dim);
      const vec_t<val_t>& x = afw.get_solution();
      lc_gradient(gradient, x);

      // 1. cut is ax <= b -> violation is ax - b (LHS > RHS for violation)
      // 2. cut evaluated at fractional point xsep
      vec_t<val_t> a = -gradient;
      const vec_t<val_t> v = afw.get_vertex();
      val_t b = a.dot(v);

      val_t frac_viol = a.dot(Xsep) - b;

#ifndef NDEBUG
      // the following is the rhs from the stopping criterion; b should be smaller, because we move towards the polytope
      val_t termcritrhs = lc_obj_fun(x) - gradient.dot(x) + static_cast<val_t>(1e-6);
      assert( ! afw_success || b <= termcritrhs );

      // double check feasibility:
      // argmax_v < \nabla f(x_t), x_t - v> = argmin < \nabla f(x_t), v>
      // we solve min \nabla f(x_t) with the lmo as lmo(\nabla f(x_t))
      vec_t<val_t> min_dir = gradient;
      vec_t<val_t> V(gradient.rows());
      val_t optval = - lmo(-min_dir, V); // V = argmin; "-" because we are minimizing here

      // we should get the same vertex
      assert( v == V );

      // cut evaluated at v
      const val_t lhs = a.dot(V);
      const val_t lhs_xtilde = a.dot(Xsep);

      assert( REALABS(-optval - b) <= 1e-6 );
      assert( REALABS(lhs - b) <= 1e-6 );

      // V check: cut has to be valid for all points in the subpoltyope, in particular for the maximum (here:
      // minimum due to sign change)
      assert( lhs <= termcritrhs || solution.exitcode == exit_max_iters );

      // Xtilde check: Xtilde has to be cut off
      assert( ! (lhs_xtilde <= lhs) || solution.exitcode == exit_max_iters );
#endif

      // return data
      for (int i = 0; i < dim; ++i)
         tpl->a[i] = a[i];
      tpl->b = b;
      tpl->frac_viol = frac_viol;
      if ( frac_viol > 0.0 )
         tpl->success = 1;
   }
   tpl->noraclecalls = solution.noraclecalls;
   tpl->niters = solution.niters;
   tpl->exitcode = solution.exitcode;

   return tpl;
}
