#ifndef _GLOBAL_TRAJ_OPT_H_
#define _GLOBAL_TRAJ_OPT_H_

#include <Eigen/Eigen>
#include <ros/ros.h>

#include "poly_traj_utils.hpp"
#include "lbfgsme.hpp"
#include <plan_env/global_map.hpp>
namespace global_traj
{

  using namespace std;
  template <typename gridmap>

  class GlobalTrajOptimizer
  {

    private:
    int  piecenum;
    double gslar_ = 0.0;
    // double esdfWei = 1000000.0;
    // double penaWei = 100000.0;
    double esdfWei = 10000.0;
    double penaWei = 1000.0;
    int traj_res = 10;

    double safeMargin = 0.2;

    double length_per_piece = 1.5;


    gridmap* map_itf_;
    Eigen::MatrixXd headState_, tailState_;
    double collision_cost =  0.0;

  public:
    double vmax = 3.0;
    double amax = 3.0;
    poly_traj::MinJerkOpt jerkOpt_;
    double wei_time_ = 1000.0;
    
    GlobalTrajOptimizer() {}
    ~GlobalTrajOptimizer() {}
    void init(ros::NodeHandle& nh, gridmap* env, bool isglobal = false){
      nh.param("manager/max_vel", vmax, 3.0);
      nh.param("manager/max_acc", amax, 3.0);
      nh.param("grid_map/obstacles_inflation", safeMargin, 0.0);
      nh.param("optimization/weight_time", wei_time_, 500.0);
      nh.param("manager/polyTraj_piece_length",length_per_piece,1.0);
      map_itf_ =  env;
      safeMargin += 0.05;
    }

    /* main planning API */
    inline int OptimizeGlobalTrajectory(
            const Eigen::MatrixXd &headState, 
            const Eigen::MatrixXd &tailState, 
            std::vector<Eigen::Vector3d> initial_path){
      headState_ = headState;
      tailState_ = tailState;
      double totalLength = 0.0;
      for(int i = 0; i < initial_path.size()-1; i++){
        totalLength += (initial_path[i+1]-initial_path[i]).norm();
      }
      piecenum = std::max(int(totalLength/length_per_piece), 2);
      double tmpl = totalLength / piecenum;
      Eigen::MatrixXd wps;
      wps.resize(3, piecenum-1);
      double curArc = 0.0;
      int index = 0;
      for(int i = 0; i < initial_path.size()-1; i++){
        curArc += (initial_path[i+1]-initial_path[i]).norm();
        if(curArc > (index+1)*tmpl){
          wps.col(index) = initial_path[i];
          index++;
          if(index>=piecenum-1){
            break;
          }
        }
      }
      Eigen::VectorXd initialT;
      initialT.resize(piecenum);  initialT.setConstant(tmpl/(vmax/1.2));
      jerkOpt_.reset(headState_, tailState_, piecenum);
      int variable_num_ = 3 * (piecenum - 1) + piecenum;

      Eigen::VectorXd x;
      x.resize(variable_num_);
      int offset = 0;
      memcpy(x.data()+offset, wps.data(), wps.size() * sizeof(x[0]));
      offset += wps.size();
      Eigen::Map<Eigen::VectorXd> Vt(x.data()+offset, initialT.size());
      RealT2VirtualT(initialT, Vt);
      offset += initialT.size();

      Eigen::VectorXd initialT0 = initialT;
      Eigen::MatrixXd wps0 = wps;
    std::cout << "before " << initialT0.transpose() << std::endl;


    lbfgs_me::lbfgs_parameter_t lbfgs_params;
    lbfgs_params.mem_size = 256;//128
    lbfgs_params.past = 3; //3 
    lbfgs_params.g_epsilon = 1.0e-4;
    lbfgs_params.min_step = 1.0e-32;
    lbfgs_params.delta = 1.0e-4;
    lbfgs_params.max_iterations = 10000;
    int result;
    double final_cost;
    double t1 = ros::Time::now().toSec();
    result = lbfgs_me::lbfgs_optimize(
      x,
      final_cost,
      GlobalTrajOptimizer::costFunctionCallback,
      NULL,
      NULL,
      this,
      lbfgs_params);
    double t2 = ros::Time::now().toSec();
    ROS_WARN_STREAM("dftpav planning time: "<<1000.0*(t2-t1)<<" ms");


    /* ---------- get result and check collision ---------- */
    if (result == lbfgs_me::LBFGS_CONVERGENCE ||
        result == lbfgs_me::LBFGS_CANCELED ||
        result == lbfgs_me::LBFGS_STOP||result == lbfgs_me::LBFGSERR_MAXIMUMITERATION)
    {
      ROS_WARN_STREAM("dif planner worked cost:"<<final_cost);
      ROS_WARN_STREAM("collision cost:"<<collision_cost);
      
    } 
    else if (result == lbfgs_me::LBFGSERR_MAXIMUMLINESEARCH){
      ROS_WARN_STREAM("dif planner worked cost:"<<final_cost);
      ROS_WARN_STREAM("collision cost:"<<collision_cost);
      ROS_WARN("Lbfgs: The line-search routine reaches the maximum number of evaluations.");
      ros::Duration(1000.0).sleep();
    }
    else
    {
      ROS_WARN("Solver error. Return = %d, %s. Skip this planning.", result, lbfgs_me::lbfgs_strerror(result));
    }

    std::cout << "after optimized: "<<jerkOpt_.get_T1().transpose() << std::endl;
    
    // debugVis();
    // ROS_WARN_STREAM("dftpav energy: " << debugEnergy);
    return true;
  }

    inline int OptimizeLocalTrajectory(
            const Eigen::MatrixXd &headState, 
            const Eigen::MatrixXd &tailState, 
            Eigen::MatrixXd init_wps,
            Eigen::VectorXd init_rts){
      headState_ = headState;
      tailState_ = tailState;
      double totalLength = 0.0;
      piecenum = init_rts.size();
      jerkOpt_.reset(headState_, tailState_, piecenum);
      int variable_num_ = 3 * (piecenum - 1) + piecenum;

      Eigen::VectorXd x;
      x.resize(variable_num_);
      int offset = 0;
      memcpy(x.data()+offset, init_wps.data(), init_wps.size() * sizeof(x[0]));
      offset += init_wps.size();
      Eigen::Map<Eigen::VectorXd> Vt(x.data()+offset, init_rts.size());
      RealT2VirtualT(init_rts, Vt);
      offset += init_rts.size();

      Eigen::VectorXd initialT0 = init_rts;
      Eigen::MatrixXd wps0 = init_wps;



    lbfgs_me::lbfgs_parameter_t lbfgs_params;
    lbfgs_params.mem_size = 256;//128
    lbfgs_params.past = 3; //3 
    lbfgs_params.g_epsilon = 1.0e-4;
    lbfgs_params.min_step = 1.0e-32;
    lbfgs_params.delta = 1.0e-4;
    lbfgs_params.max_iterations = 10000;
    int result;
    double final_cost;
    double t1 = ros::Time::now().toSec();
    result = lbfgs_me::lbfgs_optimize(
      x,
      final_cost,
      GlobalTrajOptimizer::costFunctionCallback,
      NULL,
      NULL,
      this,
      lbfgs_params);
    double t2 = ros::Time::now().toSec();
    ROS_WARN_STREAM("dftpav planning time: "<<1000.0*(t2-t1)<<" ms");
  

    /* ---------- get result and check collision ---------- */
    if (result == lbfgs_me::LBFGS_CONVERGENCE ||
        result == lbfgs_me::LBFGS_CANCELED ||
        result == lbfgs_me::LBFGS_STOP||result == lbfgs_me::LBFGSERR_MAXIMUMITERATION)
    {
      ROS_WARN_STREAM("dif planner worked cost:"<<final_cost);
      ROS_WARN_STREAM("collision cost:"<<collision_cost);
    } 
    else if (result == lbfgs_me::LBFGSERR_MAXIMUMLINESEARCH){
      ROS_WARN("Lbfgs: The line-search routine reaches the maximum number of evaluations.");
      // return false;
    }
    else
    {
      ROS_WARN("Solver error. Return = %d, %s. Skip this planning.", result, lbfgs_me::lbfgs_strerror(result));
    }

    std::cout << "after optimized: "<<jerkOpt_.get_T1().transpose() << std::endl;
    
    // debugVis();
    // ROS_WARN_STREAM("dftpav energy: " << debugEnergy);
    if(collision_cost > 100.0){
      return false;
    }
    return true;
  }


  // dftpav::DifTrajectory getOptTraj(){
  //   dftpav::DifTrajectory optTraj;
  //   for(int i = 0; i < trajnum; i++){
  //     dftpav::Trajectory polytraj = jerkOpt_container[i].getTraj(singual_[i]);
  //     optTraj.Traj_container.push_back(polytraj);
  //     optTraj.etas.push_back(singual_[i]);
  //   }
  //   return optTraj;
  // }





  private:

    void positiveSmoothedL1(const double &x, double &f, double &df)
      {
                const double pe = 1.0e-4;
                const double half = 0.5 * pe;
                const double f3c = 1.0 / (pe * pe);
                const double f4c = -0.5 * f3c / pe;
                const double d2c = 3.0 * f3c;
                const double d3c = 4.0 * f4c;

                if (x < pe)
                {
                    f = (f4c * x + f3c) * x * x * x;
                    df = (d3c * x + d2c) * x * x;
                }
                else
                {
                    f = x - half;
                    df = 1.0;
                }
                return;
        }
    static double costFunctionCallback(void *func_data, const Eigen::VectorXd &x, Eigen::VectorXd &grad){
      double smcost = 0.0, timecost = 0.0, qvarcost = 0.0;
      GlobalTrajOptimizer *opt = reinterpret_cast<GlobalTrajOptimizer *>(func_data);
      int offset = 0;
      Eigen::Map<const Eigen::MatrixXd> P(x.data()+offset, 3, opt->piecenum - 1);
      Eigen::Map<Eigen::MatrixXd>gradP(grad.data()+offset, 3, opt->piecenum - 1);
      offset += 3 * (opt->piecenum - 1);
      gradP.setZero();

      Eigen::Map<const Eigen::VectorXd> Vdts(x.data()+offset, opt->piecenum);
      Eigen::Map<Eigen::VectorXd>gradVdts(grad.data()+offset, opt->piecenum);
      Eigen::VectorXd dts(opt->piecenum);
      Eigen::VectorXd gradDts(opt->piecenum); gradDts.setZero();
      opt->VirtualT2RealT(Vdts, dts);
      opt->jerkOpt_.generate(P, dts); // Generate trajectory from {P,T}
      
      // opt->jerkOpt_.gdC.setZero(); //
      // gradDts.setZero();
      opt->jerkOpt_.initGradCost(gradDts, smcost); // Smoothness cost
      opt->collision_cost = 0.0;
      opt->addPVAJGradCost2CT(gradDts, qvarcost); // Time int cost
      opt->jerkOpt_.getGrad2TP(gradDts, gradP); // Gradient prepagation
      gradDts.setConstant(gradDts.sum() / gradDts.size());
      opt->VirtualTGradCost(dts, Vdts, gradDts, gradVdts, timecost); // Real time back to virtual time

      return smcost + qvarcost + timecost;
    }
    void addPVAJGradCost2CT(Eigen::VectorXd &gdT, double& cost){
        Eigen::Vector3d sigma, dsigma, ddsigma, dddsigma, ddddsigma;
        Eigen::Matrix<double, 6, 1> beta0, beta1, beta2, beta3, beta4;
        double s1, s2, s3, s4, s5;
        double step, alpha;
        Eigen::Vector3d gradPos, gradVel, gradAcc, gradJerk;

        for (int i = 0; i < piecenum; ++i)
        {
          const Eigen::Matrix<double, 6, 3> &c = jerkOpt_.get_b().block<6, 3>(i * 6, 0);
          step = step = jerkOpt_.get_T1()(i) / traj_res; // T_i /k
          s1 = 0.0;
          for (int j = 1; j <= traj_res; ++j)
          {
            s2 = s1 * s1;
            s3 = s2 * s1;
            s4 = s2 * s2;
            s5 = s4 * s1;
            beta0 << 1.0, s1, s2, s3, s4, s5;
            beta1 << 0.0, 1.0, 2.0 * s1, 3.0 * s2, 4.0 * s3, 5.0 * s4;
            beta2 << 0.0, 0.0, 2.0, 6.0 * s1, 12.0 * s2, 20.0 * s3;
            beta3 << 0.0, 0.0, 0.0, 6.0, 24.0 * s1, 60.0 * s2;
            beta4 << 0.0, 0.0, 0.0, 0.0, 24.0, 120 * s1;
            alpha = 1.0 / traj_res * j;
            
            //update s1 for the next iteration
            s1 += step;

            sigma = c.transpose() * beta0;
            dsigma = c.transpose() * beta1;
            ddsigma = c.transpose() * beta2;
            dddsigma = c.transpose() * beta3;
            ddddsigma = c.transpose() * beta4;
            gradPos.setZero();
            gradVel.setZero();
            gradAcc.setZero();
            gradJerk.setZero();
            Eigen::Vector3d pos,vel,acc,jerk;
            pos = sigma;
            vel = dsigma;
            acc = ddsigma;
            jerk = dddsigma;
            
            {
              double vioVel = vel.squaredNorm() - vmax * vmax;
              if(vioVel > 0){
                  double pena, penaD;
                  positiveSmoothedL1(vioVel, pena, penaD);
                  cost += penaWei * pena;
                  gradVel += penaWei * penaD *  2.0 * dsigma;
              }
            }

            {
              double vioAcc = acc.squaredNorm() - amax * amax;
              if(vioAcc > 0){
                  double pena, penaD;
                  positiveSmoothedL1(vioAcc, pena, penaD);
                  cost += penaWei * pena;
                  gradAcc += penaWei * penaD * 2.0 * ddsigma;
              }
            }
            {
                  Eigen::Vector3d gradViolaSdpos;
                  double dis = map_itf_->getDistGrad(pos, gradViolaSdpos);
                  double vioSdist = (-dis + safeMargin);
                  if(vioSdist > 0){
                      double pena, penaD;
                      positiveSmoothedL1(vioSdist, pena, penaD);
                      cost += esdfWei * pena;
                      gradPos += esdfWei * penaD * (-1.0) * gradViolaSdpos;
                      if((pos-headState_.col(0)).norm()<4.0){
                        collision_cost+=esdfWei * pena;
                      }
                  }
            }




          jerkOpt_.get_gdC().block<6, 3>(i * 6, 0) += beta0 * gradPos.transpose() +
                                            beta1 * gradVel.transpose() +
                                            beta2 * gradAcc.transpose() +
                                            beta3 * gradJerk.transpose();
          gdT(i)  += (gradPos.dot(vel) +
                          gradVel.dot(acc) +
                          gradAcc.dot(jerk) +
                          gradJerk.dot(ddddsigma)) *
                            alpha;
              
            
            
          }
        }
    }







    void positiveSmoothedL3(const double &x, double &f, double &df){
          df = x * x;
          f = df *x;
          df *= 3.0;
      

          return ;
      }


    template <typename EIGENVEC>
    void VirtualT2RealT(const EIGENVEC &VT, Eigen::VectorXd &RT)
    {
      for (int i = 0; i < VT.size(); ++i)
        {
        RT(i) = VT(i) > 0.0 ? ((0.5 * VT(i) + 1.0) * VT(i) + 1.0) + gslar_
                            : 1.0 / ((0.5 * VT(i) - 1.0) * VT(i) + 1.0) + gslar_;
        }
    }
    void VirtualT2RealT(const  double & VT, double &RT){
        
        
        RT = VT > 0.0 ? ((0.5 * VT + 1.0) * VT + 1.0) + gslar_
                            : 1.0 / ((0.5 * VT - 1.0) * VT + 1.0) + gslar_;
    }
    template <typename EIGENVEC>
    inline void RealT2VirtualT(const Eigen::VectorXd &RT, EIGENVEC &VT)
    {
      for (int i = 0; i < RT.size(); ++i)
        {
            VT(i) = RT(i) > 1.0 + gslar_ 
          ? (sqrt(2.0 * RT(i) - 1.0 - 2 * gslar_) - 1.0)
          : (1.0 - sqrt(2.0 / (RT(i)-gslar_) - 1.0));
        }
    }
    inline void RealT2VirtualT(const double &RT, double &VT)
    {
      VT = RT > 1.0 + gslar_ 
        ? (sqrt(2.0 * RT - 1.0 - 2 * gslar_) - 1.0)
        : (1.0 - sqrt(2.0 / (RT-gslar_) - 1.0));
    }
    template <typename EIGENVEC, typename EIGENVECGD>
    void VirtualTGradCost(const Eigen::VectorXd &RT, const EIGENVEC &VT,const Eigen::VectorXd &gdRT, EIGENVECGD &gdVT,double &costT)
    {
        for (int i = 0; i < VT.size(); ++i)
        {
        double gdVT2Rt;
        if (VT(i) > 0)
        {
            gdVT2Rt = VT(i) + 1.0;
        }
        else
        {
            double denSqrt = (0.5 * VT(i) - 1.0) * VT(i) + 1.0;
            gdVT2Rt = (1.0 - VT(i)) / (denSqrt * denSqrt);
        }

        gdVT(i) = (gdRT(i) + wei_time_) * gdVT2Rt;
        }

        costT = RT.sum() * wei_time_;
    }
    void VirtualTGradCost(const double &RT, const double &VT, const double &gdRT, double &gdVT, double& costT){
        double gdVT2Rt;
        if (VT > 0)
        {
        gdVT2Rt = VT + 1.0;
        }
        else
        {
        double denSqrt = (0.5 * VT - 1.0) * VT + 1.0;
        gdVT2Rt = (1.0 - VT) / (denSqrt * denSqrt);
        }

        gdVT = (gdRT + wei_time_) * gdVT2Rt;
        costT = RT * wei_time_;
    }        
    void VirtualTGrad2t(const double &VT, const double &gdRT, double &gdVT){
        double gdVT2Rt;
        if (VT > 0)
        {
        gdVT2Rt = VT + 1.0;
        }
        else
        {
        double denSqrt = (0.5 * VT - 1.0) * VT + 1.0;
        gdVT2Rt = (1.0 - VT) / (denSqrt * denSqrt);
        }
        gdVT = gdRT * gdVT2Rt;
    }    
    template <typename EIGENVEC, typename EIGENVECGD>
    void Virtual2Grad(const EIGENVEC &VT, const Eigen::VectorXd &gdRT, EIGENVECGD &gdVT){
        for (int i = 0; i < VT.size(); ++i)
        {
        double gdVT2Rt;
        if (VT(i) > 0)
        {
            gdVT2Rt = VT(i) + 1.0;
        }
        else
        {
            double denSqrt = (0.5 * VT(i) - 1.0) * VT(i) + 1.0;
            gdVT2Rt = (1.0 - VT(i)) / (denSqrt * denSqrt);
        }

        gdVT(i) = (gdRT(i)) * gdVT2Rt;
        }
        
    }        
    double expC2(double t) {
        return t > 0.0 ? ((0.5 * t + 1.0) * t + 1.0)
              : 1.0 / ((0.5 * t - 1.0) * t + 1.0);
    }
    double logC2(double T) {
        return T > 1.0 ? (sqrt(2.0 * T - 1.0) - 1.0) : (1.0 - sqrt(2.0 / T - 1.0));
    }
    void forwardT(const Eigen::Ref<const Eigen::VectorXd>& t, const double& sT, Eigen::Ref<Eigen::VectorXd> vecT) {
        int M = t.size();
        for (int i = 0; i < M; ++i) {
            vecT(i) = expC2(t(i));
        }
        vecT(M) = 0.0;
        vecT /= 1.0 + vecT.sum();
        vecT(M) = 1.0 - vecT.sum();
        vecT *= sT;
        return;
    }
    void backwardT(const Eigen::Ref<const Eigen::VectorXd>& vecT, Eigen::Ref<Eigen::VectorXd> t) {
        int M = t.size();
        t = vecT.head(M) / vecT(M);
        for (int i = 0; i < M; ++i) {
        t(i) = logC2(vecT(i));
        }
        return;
    }
    




  public:
    typedef unique_ptr<GlobalTrajOptimizer> Ptr;

  };

} // namespace plan_manage
#endif