#ifndef _MINCO_NLP_H_
#define _MINCO_NLP_H_

#include <Eigen/Eigen>
#include <ros/ros.h>

#include "utils/minco.hpp"
#include "utils/lbfgsme.hpp"
#include <tbb/parallel_for.h>
#include <tbb/blocked_range.h>
using namespace std;
using namespace Eigen;
using namespace net_planner;




class MincoNLP
{
public:
double wei_time_ = 1000.0;
double safeWei = 1000.0;
double dynamicWei = 1000.0;
double vmax = 5.0;
double accmax = 6.0;
double inflation = 0.2;

private:
MinJerkOpt<3> jerkOpt;
int traj_res = -1;
MatrixXd head_state_, end_state_;
MatrixXd waypoints_;
MatrixXd raw_cor;


double gslar_ = 0.0;
double mem_size = 256;
double past = 3; //3 
double g_epsilon = 1.0e-3;
double min_step = 1.0e-32;
double delta = 1.0e-4;
double max_iterations = 1000;
int piecenum = 5;


public:

MincoNLP() {}
~MincoNLP() {}
void init(ros::NodeHandle& nh){
    nh.param("manager/max_vel", vmax, 5.0);//5
    nh.param("manager/max_acc", accmax, 6.0);//6
    nh.param("grid_map/obstacles_inflation", inflation, 0.2);
    nh.param("optimization/weight_time", wei_time_, 1000.0);
    // wei_time_ = 1.0;



}
/* main planning API */
inline bool solveNLP(const Eigen::MatrixXd& head_pva,
        const Eigen::MatrixXd& tail_pva,
        const Eigen::MatrixXd& cords,
        VectorXd& T1,
        VectorXd& wps,
        double& cost
        ){
    // ROS_INFO_STREAM("vmax: "<<vmax);
    // ROS_INFO_STREAM("amax: "<<accmax);
    // ROS_INFO_STREAM("weight_time: "<<wei_time_);
    // ROS_INFO_STREAM("safeMargin: "<<inflation);
    
    T1 = T1.setConstant(T1.sum() / piecenum);
    // ROS_WARN_STREAM("[NN Planner] Initial Time: "<< T1.transpose());
    // ROS_WARN_STREAM("[NN Planner] Initial cords: "<< cords);
    // ROS_WARN_STREAM("[NN Planner] head_pva: "<< head_pva);
    // ROS_WARN_STREAM("[NN Planner] tail_pva: "<< tail_pva);


    for(int i = 0; i < T1.size(); i++){
        if(T1[i] <= gslar_){
            ROS_ERROR("piece time <= gslar_");
            T1[i] = gslar_+0.05;
        }
    }
    head_state_ = head_pva.transpose();
    end_state_ = tail_pva.transpose();
    traj_res = int(cords.rows()) / piecenum;

    /*cubic cords*/
    raw_cor = cords;
    waypoints_.resize(3, piecenum-1);
    for(int i = 0; i < piecenum-1; i++){
        waypoints_.col(i)  = cords.row((i+1)*traj_res).head(3);
    }


    
    jerkOpt.reset(piecenum);
    int variable_num_ = 3*(piecenum-1) + piecenum;


    //optimization variables dt
    Eigen::VectorXd x;
    x.resize(variable_num_);
    int offset = 0;
    memcpy(x.data()+offset,waypoints_.data(), waypoints_.size() * sizeof(x[0]));
    offset += waypoints_.size();
    Eigen::Map<Eigen::VectorXd> Vt(x.data()+offset, T1.size());
    RealT2VirtualT(T1, Vt);
    offset += T1.size();


    lbfgsme::lbfgs_parameter_t lbfgs_params;
    lbfgs_params.mem_size = mem_size;//128
    lbfgs_params.past = past; //3 
    lbfgs_params.g_epsilon = g_epsilon;
    lbfgs_params.min_step = min_step;
    lbfgs_params.delta = delta;
    lbfgs_params.max_iterations = max_iterations;
    int result;
    double final_cost;
    double t1 = ros::Time::now().toSec();
    // std::cout << "before optimized: \n"<<x.transpose() << std::endl;


    jerkOpt.generate(head_state_, end_state_,waypoints_, T1);
    double smcost = jerkOpt.getTrajJerkCost();
    double timecost = wei_time_ * T1.sum();
    // std::cout <<"before-----------------------\n";
    // std::cout << "smcost: " << smcost << std::endl;
    // std::cout << "timecost: " << timecost << std::endl;
    // std::cout <<"total cost: "<<smcost+timecost<<std::endl;

    result = lbfgsme::lbfgs_optimize(
        x,
        final_cost,
        MincoNLP::costFunctionCallback,
        NULL,
        NULL,
        this,
        lbfgs_params);
    double t2 = ros::Time::now().toSec();


    wps = x.head(3*(piecenum-1));//me
    VirtualT2RealT(x.tail(piecenum), T1);

    /* ---------- get result and check collision ---------- */
    if (result == lbfgsme::LBFGS_CONVERGENCE ||
        result == lbfgsme::LBFGS_STOP||result == lbfgsme::LBFGSERR_MAXIMUMITERATION)
    {
        double smcost = jerkOpt.getTrajJerkCost();
        double timecost = wei_time_ * T1.sum();
        // std::cout <<"after-----------------------\n";
        // std::cout << "smcost: " << smcost << std::endl;
        // std::cout << "timecost: " << timecost << std::endl;
        // std::cout <<"total cost: "<<smcost+timecost<<std::endl;

        // ROS_WARN_STREAM("minco nlp planner worked cost:"<<final_cost);
        // ROS_WARN_STREAM("minco nlp planning time: "<<1000.0*(t2-t1)<<" ms");
        // std::cout << "after optimized: \n"<<x.transpose() << std::endl;
        // cost = final_cost;
        cost = smcost + timecost;

        return true;
        
    } 
    else if (result == lbfgsme::LBFGSERR_MAXIMUMLINESEARCH){
        ROS_ERROR("Lbfgs: The line-search routine reaches the maximum number of evaluations.");
        return false;
    }
    else
    {
        ROS_ERROR("Solver error. Return = %d, %s. Skip this planning.", result, lbfgsme::lbfgs_strerror(result));
        return false;
    }


}



private:
/* callbacks by the L-BFGS optimizer */
// static int monitor(void *func_data,const Eigen::VectorXd &x,
//                                 const Eigen::VectorXd &g,
//                                 const double fx,
//                                 const double step,
//                                 const int k,
//                                 const int ls){
// double nowtime = ros::Time::now().toSec();
// PolyTrajOptimizer *opt = reinterpret_cast<PolyTrajOptimizer *>(func_data);
// double budget = 0.075;
// if(opt->non_sinv<0.05){
//     budget = 0.1;
// }

// if(nowtime-(opt->startT) > budget){
//     ROS_WARN("reach budget time");
//     return 1;
// }
// else{
//     return 0;
// }
// }

static double costFunctionCallback(void *func_data, const Eigen::VectorXd &x, Eigen::VectorXd &grad){
    double smcost = 0.0, timecost = 0.0, penaltycost = 0.0;
    MincoNLP *opt = reinterpret_cast<MincoNLP *>(func_data);
    int offset = 0;
    Eigen::Map<const Eigen::MatrixXd> P(x.data()+offset, 3, opt->piecenum - 1);
    Eigen::Map<Eigen::MatrixXd>gradP(grad.data()+offset, 3, opt->piecenum - 1);
    offset += 3 * (opt->piecenum - 1);
    gradP.setZero();
    Eigen::Map<const Eigen::VectorXd> Vdts(x.data()+offset, opt->piecenum);
    Eigen::Map<Eigen::VectorXd>gradVdts(grad.data()+offset, opt->piecenum);
    offset += opt->piecenum;
    Eigen::VectorXd dts(opt->piecenum);
    Eigen::VectorXd gradDts(opt->piecenum); gradDts.setZero();
    opt->VirtualT2RealT(Vdts, dts);
    Eigen::MatrixXd gdC;
    Eigen::VectorXd gdT;
    Eigen::MatrixXd gdP;
    opt->jerkOpt.generate( opt->head_state_, opt->end_state_,P, dts);
    smcost = opt->jerkOpt.getTrajJerkCost();
    opt->jerkOpt.calJerkGradCT(gdC, gdT);


    penaltycost = opt->addPVAGradCost2CT(gdC, gdT); // Time int cost
    timecost = opt->wei_time_ * dts.sum();
    gdT.array() += opt->wei_time_;



    opt->jerkOpt.calGradCTtoQT(gdC, gdT, gdP); // gdt gdp gdhead gdtail
    //waypoint
   
    gradDts = gdT;
    opt->Virtual2Grad(Vdts,  gradDts, gradVdts);
    gradP = gdP;
    return smcost + timecost + penaltycost;
}




double addPVAGradCost2CT(Eigen::MatrixXd& gdC, Eigen::VectorXd& gdT){
Eigen::Vector3d sigma, dsigma, ddsigma, dddsigma, ddddsigma;
Eigen::Matrix<double, 6, 1> beta0, beta1, beta2, beta3, beta4;
double s1, s2, s3, s4, s5;
double step, alpha;



int cor_id = 0;
double penaCost = 0.0;
Eigen::Vector3d gradPos, gradVel, gradAcc, gradJerk;
for (int i = 0; i < piecenum; ++i)
{
    const Eigen::Matrix<double, 6, 3> &c = jerkOpt.getCoeffs().block<6, 3>(i * 6, 0);
    step = jerkOpt.T1[i] / traj_res; // T_i /k
    s1 = 0.0;
    for (int j = 1; j <= traj_res; ++j)
    {
    s2 = s1 * s1;
    s3 = s2 * s1;
    s4 = s2 * s2;
    s5 = s4 * s1;
    beta0 << 1.0, s1, s2, s3, s4, s5;
    beta1 << 0.0, 1.0, 2.0 * s1, 3.0 * s2, 4.0 * s3, 5.0 * s4;
    beta2 << 0.0, 0.0, 2.0, 6.0 * s1, 12.0 * s2, 20.0 * s3;
    beta3 << 0.0, 0.0, 0.0, 6.0, 24.0 * s1, 60.0 * s2;
    beta4 << 0.0, 0.0, 0.0, 0.0, 24.0, 120 * s1;
    alpha = 1.0 / traj_res * j;
    //update s1 for the next iteration
    s1 += step;

    sigma = c.transpose() * beta0;
    dsigma = c.transpose() * beta1;
    ddsigma = c.transpose() * beta2;
    dddsigma = c.transpose() * beta3;
    ddddsigma = c.transpose() * beta4;
    gradPos.setZero();
    gradVel.setZero();
    gradAcc.setZero();
    gradJerk.setZero();
    Eigen::Vector3d pos,vel,acc,jerk;
    pos = sigma;
    vel = dsigma;
    acc = ddsigma;
    jerk = dddsigma;
    
    {
        double vioVel = dsigma.squaredNorm() - vmax * vmax;
        if(vioVel > 0){
            double pena, penaD;
            positiveSmoothedL1(vioVel, pena, penaD);
            penaCost += dynamicWei * pena;
            gradVel += dynamicWei * penaD *  2.0 * dsigma;
        }
    }

    {
      double vioAcc = ddsigma.squaredNorm() - accmax * accmax;
      if(vioAcc > 0){
          double pena, penaD;
          positiveSmoothedL1(vioAcc, pena, penaD);
          penaCost += dynamicWei * pena;
          gradAcc += dynamicWei * penaD * 2.0 * ddsigma;
      }
    }




    {
        // Eigen::Vector3d gradViolaSdpos;
        // Eigen::MatrixXd cube = corridors_[cor_id];
        // std::cout << "cube: \n" << cube << std::endl;
        // Eigen::Vector3d p, n;
        // double safeMargin = 0.1;
        // for(int k = 0; k < cube.cols(); k++){
        //     p = cube.col(k).head(3);
        //     n = cube.col(k).tail(3);
        //     double vioSdist = (pos-p).dot(n) + safeMargin;
        //     if(vioSdist > 0){
        //         double pena, penaD;
        //         positiveSmoothedL1(vioSdist, pena, penaD);
        //         penaCost += safeWei * pena;
        //         gradPos += safeWei * penaD * n;
        //     }
        // }
        
        // Eigen::Vector3d c = raw_cor.row(cor_id).head(3);
        // double r = raw_cor.row(cor_id)[3];
        // double dis_x = (pos-c).dot(Eigen::Vector3d(1,0,0));
        // double dis_y = (pos-c).dot(Eigen::Vector3d(0,1,0));
        // double dis_z = (pos-c).dot(Eigen::Vector3d(0,0,1));
        // double vioDisx = dis_x * dis_x - r * r + safeMargin;
        // double vioDisy = dis_y * dis_y - r * r + safeMargin;
        // double vioDisz = dis_z * dis_z - r * r + safeMargin;
        // if(vioDisx > 0.0){
        //     double pena, penaD;
        //     positiveSmoothedL1(vioDisx, pena, penaD);
        //     penaCost += safeWei * pena;
        //     gradPos += safeWei * penaD * 2.0 * dis_x * Eigen::Vector3d(1,0,0);
        // }
        // if(vioDisy > 0.0){
        //     double pena, penaD;
        //     positiveSmoothedL1(vioDisy, pena, penaD);
        //     penaCost += safeWei * pena;
        //     gradPos += safeWei * penaD * 2.0 * dis_y * Eigen::Vector3d(0,1,0);
        // }
        // if(vioDisz > 0.0){
        //     double pena, penaD;
        //     positiveSmoothedL1(vioDisz, pena, penaD);
        //     penaCost += safeWei * pena;
        //     gradPos += safeWei * penaD * 2.0 * dis_z * Eigen::Vector3d(0,0,1);
        // }



        Eigen::Vector3d c = raw_cor.row(cor_id).head(3);
        double r  = raw_cor.row(cor_id)[3]-inflation;
        double vioSdist = (pos-c).squaredNorm() - r * r;
        if(vioSdist>0){
            double pena, penaD;
            positiveSmoothedL1(vioSdist, pena, penaD);
            penaCost += safeWei * pena;
            gradPos += safeWei * penaD * 2.0*(pos-c);
        }





        cor_id++;
    }


   

    gdC.block<6, 3>(i * 6, 0) += beta0 * gradPos.transpose() +
                                        beta1 * gradVel.transpose() +
                                        beta2 * gradAcc.transpose() +
                                        beta3 * gradJerk.transpose();
    gdT[i]  += (gradPos.dot(vel) +
                    gradVel.dot(acc) +
                    gradAcc.dot(jerk) +
                    gradJerk.dot(ddddsigma)) *
                        alpha;
    
    
    
    }
  }
  return penaCost;
}






void positiveSmoothedL1(const double &x, double &f, double &df)
    {
            const double pe = 1.0e-4;
            const double half = 0.5 * pe;
            const double f3c = 1.0 / (pe * pe);
            const double f4c = -0.5 * f3c / pe;
            const double d2c = 3.0 * f3c;
            const double d3c = 4.0 * f4c;

            if (x < pe)
            {
                f = (f4c * x + f3c) * x * x * x;
                df = (d3c * x + d2c) * x * x;
            }
            else
            {
                f = x - half;
                df = 1.0;
            }
            return;
    }
void positiveSmoothedL3(const double &x, double &f, double &df){
        df = x * x;
        f = df *x;
        df *= 3.0;
    

        return ;
    }


template <typename EIGENVEC>
void VirtualT2RealT(const EIGENVEC &VT, Eigen::VectorXd &RT)
{
    for (int i = 0; i < VT.size(); ++i)
    {
    RT(i) = VT(i) > 0.0 ? ((0.5 * VT(i) + 1.0) * VT(i) + 1.0) + gslar_
                        : 1.0 / ((0.5 * VT(i) - 1.0) * VT(i) + 1.0) + gslar_;
    }
}
void VirtualT2RealT(const  double & VT, double &RT){
    
    
    RT = VT > 0.0 ? ((0.5 * VT + 1.0) * VT + 1.0) + gslar_
                        : 1.0 / ((0.5 * VT - 1.0) * VT + 1.0) + gslar_;
}
template <typename EIGENVEC>
inline void RealT2VirtualT(const Eigen::VectorXd &RT, EIGENVEC &VT)
{
    for (int i = 0; i < RT.size(); ++i)
    {
        VT(i) = RT(i) > 1.0 + gslar_ 
        ? (sqrt(2.0 * RT(i) - 1.0 - 2 * gslar_) - 1.0)
        : (1.0 - sqrt(2.0 / (RT(i)-gslar_) - 1.0));
    }
}
inline void RealT2VirtualT(const double &RT, double &VT)
{
    VT = RT > 1.0 + gslar_ 
    ? (sqrt(2.0 * RT - 1.0 - 2 * gslar_) - 1.0)
    : (1.0 - sqrt(2.0 / (RT-gslar_) - 1.0));
}
template <typename EIGENVEC, typename EIGENVECGD>
void VirtualTGradCost(const Eigen::VectorXd &RT, const EIGENVEC &VT,const Eigen::VectorXd &gdRT, EIGENVECGD &gdVT,double &costT)
{
    for (int i = 0; i < VT.size(); ++i)
    {
    double gdVT2Rt;
    if (VT(i) > 0)
    {
        gdVT2Rt = VT(i) + 1.0;
    }
    else
    {
        double denSqrt = (0.5 * VT(i) - 1.0) * VT(i) + 1.0;
        gdVT2Rt = (1.0 - VT(i)) / (denSqrt * denSqrt);
    }

    gdVT(i) = (gdRT(i) + wei_time_) * gdVT2Rt;
    }

    costT = RT.sum() * wei_time_;
}
void VirtualTGradCost(const double &RT, const double &VT, const double &gdRT, double &gdVT, double& costT){
    double gdVT2Rt;
    if (VT > 0)
    {
    gdVT2Rt = VT + 1.0;
    }
    else
    {
    double denSqrt = (0.5 * VT - 1.0) * VT + 1.0;
    gdVT2Rt = (1.0 - VT) / (denSqrt * denSqrt);
    }

    gdVT = (gdRT + wei_time_) * gdVT2Rt;
    costT = RT * wei_time_;
}        
void VirtualTGrad2t(const double &VT, const double &gdRT, double &gdVT){
    double gdVT2Rt;
    if (VT > 0)
    {
    gdVT2Rt = VT + 1.0;
    }
    else
    {
    double denSqrt = (0.5 * VT - 1.0) * VT + 1.0;
    gdVT2Rt = (1.0 - VT) / (denSqrt * denSqrt);
    }
    gdVT = gdRT * gdVT2Rt;
}    
template <typename EIGENVEC, typename EIGENVECGD>
void Virtual2Grad(const EIGENVEC &VT, const Eigen::VectorXd &gdRT, EIGENVECGD &gdVT){
    for (int i = 0; i < VT.size(); ++i)
    {
    double gdVT2Rt;
    if (VT(i) > 0)
    {
        gdVT2Rt = VT(i) + 1.0;
    }
    else
    {
        double denSqrt = (0.5 * VT(i) - 1.0) * VT(i) + 1.0;
        gdVT2Rt = (1.0 - VT(i)) / (denSqrt * denSqrt);
    }

    gdVT(i) = (gdRT(i)) * gdVT2Rt;
    }
    
}        
double expC2(double t) {
    return t > 0.0 ? ((0.5 * t + 1.0) * t + 1.0)
            : 1.0 / ((0.5 * t - 1.0) * t + 1.0);
}
double logC2(double T) {
    return T > 1.0 ? (sqrt(2.0 * T - 1.0) - 1.0) : (1.0 - sqrt(2.0 / T - 1.0));
}
void forwardT(const Eigen::Ref<const Eigen::VectorXd>& t, const double& sT, Eigen::Ref<Eigen::VectorXd> vecT) {
    int M = t.size();
    for (int i = 0; i < M; ++i) {
        vecT(i) = expC2(t(i));
    }
    vecT(M) = 0.0;
    vecT /= 1.0 + vecT.sum();
    vecT(M) = 1.0 - vecT.sum();
    vecT *= sT;
    return;
}
void backwardT(const Eigen::Ref<const Eigen::VectorXd>& vecT, Eigen::Ref<Eigen::VectorXd> t) {
    int M = t.size();
    t = vecT.head(M) / vecT(M);
    for (int i = 0; i < M; ++i) {
    t(i) = logC2(vecT(i));
    }
    return;
}



public:
static void solveProblemThreadTBB( 
                                const Eigen::MatrixXd& head_pvas,
                                const Eigen::MatrixXd& tail_pvas,
                                const std::vector<Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>& cords,
                                std::vector<VectorXd>& T1,
                                std::vector<VectorXd>& wps,
                                Eigen::VectorXd& cost,
                                int idx, void* ptrObj)
{
    MincoNLP& obj = *(MincoNLP*)ptrObj;
    Eigen::VectorXd ts, ws; double cc;
    ts = T1[idx]; ws = wps[idx]; cc = 1.0e9;
    bool suc = obj.solveNLP(head_pvas, tail_pvas, cords[idx].cast<double>(),  ts, ws, cc);
    if(suc){
        T1[idx] = ts;
        wps[idx] = ws;
        cost[idx] = cc;
    }
    return;
}
static void solveParallTBB(
    const Eigen::MatrixXd& head_pva,
    const Eigen::MatrixXd& tail_pva,
    const std::vector<Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>& cords,
    std::vector<VectorXd>& T1,
    std::vector<VectorXd>& wps,
    Eigen::VectorXd& cost,
    MincoNLP nh
){
   
    //here
    double t0 = ros::Time::now().toSec();
    int batch_size = T1.size();
    std::vector<MincoNLP> mincos;
    for(int i = 0; i < batch_size; ++i){
        mincos.push_back(nh);
    }
    double t1 = ros::Time::now().toSec();
    tbb::parallel_for(tbb::blocked_range<int>(0, batch_size),
                    [&](const tbb::blocked_range<int>& r) {
                        for (int i = r.begin(); i != r.end(); ++i) {
                            solveProblemThreadTBB(head_pva, tail_pva, cords, T1,wps, cost,i, &mincos[i]);
                        }
                    });
    double t2 = ros::Time::now().toSec();


    return;
}

};

#endif