#include <plan_manage/planner_manager.h>
#include <thread>
#include "visualization_msgs/Marker.h" // zx-todo

namespace ego_planner
{

  // SECTION interfaces for setup and query

  EGOPlannerManager::EGOPlannerManager() {}

  EGOPlannerManager::~EGOPlannerManager() { std::cout << "des manager" << std::endl; }

  void EGOPlannerManager::initPlanModules(ros::NodeHandle &nh, PlanningVisualization::Ptr vis)
  {
    /* read algorithm parameters */
    double low_wei_time = 500.0;
    double low_wei_time_ego = 500.0; //
    double low_wei_time_fast = 250.0; //250suc太高
     nh_ = nh;
    nh.param("manager/max_vel", pp_.max_vel_, -1.0);
    nh.param("manager/max_acc", pp_.max_acc_, -1.0);
    nh.param("manager/feasibility_tolerance", pp_.feasibility_tolerance_, 0.0);
    nh.param("manager/polyTraj_piece_length", pp_.polyTraj_piece_length, -1.0);
    nh.param("manager/planning_horizon", pp_.planning_horizen_, 5.0);
    nh.param("manager/use_multitopology_trajs", pp_.use_multitopology_trajs, false);
    nh.param("manager/drone_id", pp_.drone_id, -1);
    nh.param("manager/velocity_type", velocity_type, -1);
    nh.param("manager/planner_type", planner_type_, -1);
    ROS_WARN_STREAM("velocity_type: "<<velocity_type);
    ROS_WARN_STREAM("planner_type: "<<planner_type_);
    grid_map_.reset(new GridMap);
    std::cout << "1111111111111111111111111111\n";
    if(planner_type_==1 || planner_type_==3){
      grid_map_->initMap(nh, true);
    }
    else{
      grid_map_->initMap(nh, false);
    }
    local_map_.setParam(nh_);
    local_traj_opt_.init(nh_, &local_map_);
    hybrid_path_finder_.reset(new hybrid::KinodynamicAstar<local_map_util::VoxelMapUtil>);
    hybrid_path_finder_->setParam(nh_);
    hybrid_path_finder_->setEnvironment(&local_map_);
    hybrid_path_finder_->init();
    hybrid_path_finder_->reset();





    if(velocity_type==0){
      pp_.max_vel_ = 2.0;
      pp_.max_acc_ = 3.0; 
     
      local_traj_opt_.vmax = 2.0;
      local_traj_opt_.amax = 3.0;
      hybrid_path_finder_->max_vel_ = 2.0;
      hybrid_path_finder_->max_acc_ = 3.0;

      
    }
    else if(velocity_type==1){
      pp_.max_vel_ = 5.0;
      pp_.max_acc_ = 6.0; 

      local_traj_opt_.vmax = 5.0;
      local_traj_opt_.amax = 6.0;
      hybrid_path_finder_->max_vel_ = 5.0;
      hybrid_path_finder_->max_acc_ = 6.0;

    }
    else if (velocity_type==2){
      pp_.max_vel_ = 8.0;
      pp_.max_acc_ = 10.0; 

      local_traj_opt_.vmax = 8.0;
      local_traj_opt_.amax = 10.0;
      hybrid_path_finder_->max_vel_ = 8.0;
      hybrid_path_finder_->max_acc_ = 10.0;

      local_traj_opt_.wei_time_ = low_wei_time_fast; //meme
    }
    else{
      ROS_ERROR("wrong velocity type");
    }


    ploy_traj_opt_.reset(new PolyTrajOptimizer);
    ploy_traj_opt_->setParam(nh);
    if(velocity_type==0){
     ploy_traj_opt_->max_vel_ = 2.0;
     ploy_traj_opt_->max_acc_ = 3.0; 
    }
    else if(velocity_type==1){
     ploy_traj_opt_->max_vel_ = 5.0;
     ploy_traj_opt_->max_acc_ = 6.0; 
    }
    else if (velocity_type==2){
     ploy_traj_opt_->max_vel_ = 8.0;
     ploy_traj_opt_->max_acc_ = 10.0; 
     ploy_traj_opt_->wei_time_ = low_wei_time_ego;
    }
    else{
      ROS_ERROR("wrong velocity type");
    }
    ploy_traj_opt_->setEnvironment(grid_map_);

    visualization_ = vis;

    ploy_traj_opt_->setSwarmTrajs(&traj_.swarm_traj);
    ploy_traj_opt_->setDroneId(pp_.drone_id);


    // global_map_.setParam(nh);
    // astar_path_finder_.reset(new GraphSearch::Astar);
    // astar_path_finder_->setParam(nh);
    // astar_path_finder_->setEnvironment(&global_map_);

    ROS_WARN("FFFFFFFFFFFFFFFFFFFF");
    // hybrid_path_finder_.reset(new hybrid::KinodynamicAstar);
    // hybrid_path_finder_->setParam(nh);
    // hybrid_path_finder_->setEnvironment(&global_map_);
    // hybrid_path_finder_->init();
    // global_traj_opt_.init(nh, &global_map_, true);
    // local_traj_opt_.init(nh, &global_map_);


    

    // NN
    nn_depth_debug_pub_ = nh.advertise<sensor_msgs::Image>("/nn_depth_debug", 10);
    nn_corridor_pub_ = nh.advertise<visualization_msgs::MarkerArray>("/nn_corridor_vis", 10);
    nn_traj_vis_pub_ = nh.advertise<visualization_msgs::Marker>("/nn_traj_vis", 10);
    nn_debug_pub_ = nh.advertise<ego_planner::NNDebug>("/nn_debug", 10);
    torch::Device nn_device(torch::kCUDA);
    if(planner_type_ == 0 || planner_type_ == 2){
      if(velocity_type==0){
        nn_planner = torch::jit::load(ros::package::getPath("ego_planner")+"/models_low/full_jit_muti_tune.pt", nn_device);
        ROS_WARN("Loading model: /models_low/full_jit_muti_tune.pt");
        nn_planner.to(torch::kFloat32);
        nn_planner.eval();
        if(planner_type_==2){
          iplanner = torch::jit::load(ros::package::getPath("ego_planner")+"/models_low/full_jit_iplanner.pt", nn_device);
          iplanner.to(torch::kFloat32);
          iplanner.eval();
        }
      }
      else if(velocity_type==1){
        nn_planner = torch::jit::load(ros::package::getPath("ego_planner")+"/models_middle/full_jit_muti_tune.pt", nn_device);
        ROS_WARN("Loading model: /models_middle/full_jit_muti_tune.pt");
        nn_planner.to(torch::kFloat32);
        nn_planner.eval();
        if(planner_type_==2){
          iplanner = torch::jit::load(ros::package::getPath("ego_planner")+"/models_middle/full_jit_iplanner.pt", nn_device);
          iplanner.to(torch::kFloat32);
          iplanner.eval();
        }
      }
      else if (velocity_type==2){
        nn_planner = torch::jit::load(ros::package::getPath("ego_planner")+"/models/full_jit_muti_tune.pt", nn_device);
        ROS_WARN("Loading model: /models/full_jit_muti_tune.pt");
        nn_planner.to(torch::kFloat32);
        nn_planner.eval();
        if(planner_type_==2){
          iplanner = torch::jit::load(ros::package::getPath("ego_planner")+"/models/full_jit_iplanner.pt", nn_device);
          iplanner.to(torch::kFloat32);
          iplanner.eval();
        }
      }
      else{
        ROS_ERROR("wrong velocity type");
      }
      torch::NoGradGuard no_grad_;
      corridor_marker_array_.markers.clear();
      double scale = 0.12;
      qp_traj.header.frame_id = corridor_marker.header.frame_id = "world";
      qp_traj.header.stamp = corridor_marker.header.stamp =  ros::Time::now();
      qp_traj.type = visualization_msgs::Marker::LINE_STRIP;
      corridor_marker.type = visualization_msgs::Marker::SPHERE;
      qp_traj.action = corridor_marker.action = visualization_msgs::Marker::ADD;
      qp_traj.id = 1000;
      qp_traj.pose.orientation.w = 1.0;
      corridor_marker.pose.orientation.w = 1.0;
      corridor_marker.color.a = 0.1;
      corridor_marker.color.r = 0.0;
      corridor_marker.color.g = 1.0;
      corridor_marker.color.b = 0.0;
      qp_traj.color.r = 0.0;
      qp_traj.color.g = 0.0;
      qp_traj.color.b = 1.0;
      qp_traj.color.a = 1;
      qp_traj.scale.x = scale / 2;
      PRINT_GREEN("[NN Planner Manager] GPU warm up....");
      for (int i = 0; i < 100; i++)
      { 
        at::Tensor input_tensor= torch::ones({1, 10*(120*160+7)+18}).to(torch::kFloat32).to(at::kCUDA);
        std::vector<torch::jit::IValue> inputs;
        at::Tensor ltt= torch::ones({1, 100, 3}).to(torch::kFloat32).to(at::kCUDA);
        inputs.push_back(input_tensor);
        nn_planner.forward(inputs);
      }
      PRINT_GREEN("[NN Planner Manager] Done.");
      cudaWarmTimer = nh.createTimer(ros::Duration(0.01), &EGOPlannerManager::warm,this);

  }
    minco_nlp_.init(nh);
    if(velocity_type==0){
     minco_nlp_.vmax = 2.0;
     minco_nlp_.accmax = 3.0; 
    }
    else if(velocity_type==1){
     minco_nlp_.vmax  = 5.0;
     minco_nlp_.accmax = 6.0; 
    }
    else if (velocity_type==2){
     minco_nlp_.vmax = 8.0;
     minco_nlp_.accmax = 10.0; 
     minco_nlp_.wei_time_ = low_wei_time;
    }
    else{
      ROS_ERROR("wrong velocity type");
    }


    prv_nh_ = nh;
  }

  void EGOPlannerManager::warm(const ros::TimerEvent &)
  {
    at::Tensor input_tensor= torch::ones({1, 10*(120*160+7)+18}).to(torch::kFloat32).to(at::kCUDA);
    std::vector<torch::jit::IValue> inputs;
    inputs.push_back(input_tensor);
    nn_planner.forward(inputs);
    return;
  }










  bool EGOPlannerManager::reboundReplanNN(
        const Eigen::Vector3d &start_pt, const Eigen::Vector3d &start_vel,
        const Eigen::Vector3d &start_acc, const Eigen::Vector3d &end_pt,
        const Eigen::Vector3d &end_vel, const Eigen::Vector3d &end_acc)
  {
    torch::NoGradGuard no_grad_;
    ros::Time start_time = ros::Time::now();
    std::vector<nav_msgs::Odometry> odoms = grid_map_->getOdomQueue();
    std::vector<cv::Mat> depths = grid_map_->getDepthImgQueue();
    if (depths.size() < 10 || odoms.size() < 10)
    {
      PRINT_RED("odom or depth queue size is not enough, skip replan");
      return false;
    }

    Eigen::Vector3f trans(odoms.back().pose.pose.position.x,
                          odoms.back().pose.pose.position.y,
                          odoms.back().pose.pose.position.z);
    Eigen::Matrix3f R(Eigen::Quaternionf(odoms.back().pose.pose.orientation.w,
                                        odoms.back().pose.pose.orientation.x,
                                        odoms.back().pose.pose.orientation.y,
                                        odoms.back().pose.pose.orientation.z));
    at::Tensor input_tensor = torch::zeros({1, 10*120*160+10*7+3*3+3*3}).to(torch::kFloat32);
    int per_depth_size = 120*160;
    int depth_size = 10 * per_depth_size;

    for (int i=0; i<10; i++)
    {
      nav_msgs::Odometry odom = odoms[i];
      Eigen::Vector3f pos(odom.pose.pose.position.x,\
                          odom.pose.pose.position.y,\
                          odom.pose.pose.position.z);
      Eigen::Matrix3f rot(Eigen::Quaternionf(odom.pose.pose.orientation.w,
                                            odom.pose.pose.orientation.x,
                                            odom.pose.pose.orientation.y,
                                            odom.pose.pose.orientation.z));
      pos = R.transpose() * (pos - trans);
      Eigen::Quaternionf q(R.transpose() * rot);
      RowVectorXf rodom;
      rodom.resize(7);
      rodom << pos(0), pos(1), pos(2), q.w(), q.x(), q.y(), q.z();
      input_tensor.slice(1, depth_size+7*i, depth_size+7*i+7) = torch::from_blob(rodom.data(), {7}, torch::kFloat32);
      cv::Mat depth_image = depths[i].clone();
      input_tensor.slice(1, per_depth_size*i, per_depth_size*(i+1)) = torch::from_blob(depth_image.data, {per_depth_size}, torch::kFloat32);
    }

    ros::Time process_done = ros::Time::now();
    // PRINT_GREEN("[NN Planner] DepthOdom process time: "<< (double)(process_done - start_time).toSec() * 1000.0 << "ms.");

    // start_state and end_state
    RowMatrixXf start_state, end_state;
    start_state.resize(3, 3);
    start_state.row(0) = start_pt.cast<float>() - trans;
    start_state.row(1) = start_vel.cast<float>();
    start_state.row(2) = start_acc.cast<float>();
    end_state.resize(3, 3);
    end_state.setZero();
    end_state.row(0) = end_pt.cast<float>() - trans;
    end_state.row(1) = end_vel.cast<float>();
    end_state.row(2) = end_acc.cast<float>();
    start_state = start_state * R;
    end_state = end_state * R;

    // push data to tensor
    std::vector<torch::jit::IValue> inputs;
    int depodom_size = depth_size + 7 * 10;
    input_tensor.slice(1, depodom_size, depodom_size+9) = torch::from_blob(start_state.data(), {9}, torch::kFloat32);
    input_tensor.slice(1, depodom_size+9, depodom_size+18) = torch::from_blob(end_state.data(), {9}, torch::kFloat32);
    inputs.push_back(input_tensor.to(at::kCUDA));
    cudaDeviceSynchronize();

    ros::Time nn_start_time = ros::Time::now();
    auto output_tuple = nn_planner.forward(inputs).toTuple();  // 前向传播，获取输出
    ros::Time nn_end_time = ros::Time::now();
    // PRINT_GREEN("[NN Planner] To GPU time: "<< (double)(nn_start_time - process_done).toSec() * 1000.0 << "ms.");

    // PRINT_GREEN("[NN Planner] Inference time: "<< (double)(nn_end_time - nn_start_time).toSec() * 1000.0 << "ms.");
    std::vector<torch::Tensor> outputs;
    for (const auto& element : output_tuple->elements()) {
      outputs.push_back(element.toTensor());
    }
    int primitive_num = outputs[1].sizes()[1];
    primitive_num = 5;
    std::cout << "primitive_num: " << primitive_num << std::endl;
    cudaDeviceSynchronize();
    Eigen::VectorXd costs;
    costs.resize(primitive_num); costs.setConstant(1.0e9);
    at::Tensor output1_tensor = outputs[1].to(at::kCPU).squeeze(0);//3
    std::vector<float> temp(output1_tensor.data<float>(), output1_tensor.data<float>() + output1_tensor.numel());
    Eigen::Map<Eigen::VectorXf> probabilities_map(temp.data(), temp.size());
    at::Tensor output_tensors = outputs[0].to(at::kCPU).squeeze(0);//3*N
    std::vector<VectorXd> T1s, minco_wpss;
    std::vector<RowMatrixXf> corridors;
    T1s.resize(primitive_num);
    minco_wpss.resize(primitive_num);
    corridors.resize(primitive_num);

    ros::Time qp_start_time = ros::Time::now();
    for(int index = 0; index < primitive_num; index++){
      at::Tensor output_tensor = output_tensors[index];
      float* pos_mat = output_tensor.data_ptr<float>();
      float* r_mat = output_tensor.data_ptr<float>() + 150;
      float* t_mat = output_tensor.data_ptr<float>() + 200;
      Eigen::Map<RowMatrixXf> corridor_pos_map(pos_mat, 50, 3);
      Eigen::Map<RowMatrixXf> corridor_r_map(r_mat, 50, 1);
      Eigen::Map<RowVectorXf> t_map(t_mat, 5);
      RowMatrixXf corridor;
      corridor.resize(50, 4);
      corridor.setConstant(0.5f);
      corridor.leftCols(3) = corridor_pos_map;
      corridor.rightCols(1) = corridor_r_map;//! me
      for (int i=0; i<50; i++)
      {
        if (corridor_r_map(i) > 10.0)
        {
          continue;
        }
        if (corridor_r_map(i) < 1e-2)
        {
          corridor(i, 3) = 1e-2;
        }
      }
      // PRINT_GREEN("[NN Planner] From GPU time: "<< (double)(qp_start_time - nn_end_time).toSec() * 1000.0 << "ms.");
      // VectorXd T1(t_map);
      VectorXd T1; 
      T1.resize(t_map.size());
      for(int i = 0; i < t_map.size(); i++){
        T1[i] = t_map[i];
      }
      VectorXd minco_wps = VectorXd::Zero(12);


      double cost;
      corridors[index] = corridor;
      T1s[index] = T1;
      minco_wpss[index] = minco_wps;

      // double t1 = ros::Time::now().toSec();
      // bool succ  = minco_nlp_.solveNLP(start_state.cast<double>(), end_state.cast<double>(), corridor.cast<double>(),  T1, minco_wps, cost);
      // if(succ){
      //   T1s[index] = T1;
      //   minco_wpss[index] = minco_wps;
      //   costs[index] = cost;
      // }
      // double t2 = ros::Time::now().toSec();
      // std::cout << "t2-t1: " << 1000*(t2-t1) <<" ms" << std::endl;

    }
    double t1 = ros::Time::now().toSec();
    MincoNLP::solveParallTBB(start_state.cast<double>(), end_state.cast<double>(),corridors,  T1s, minco_wpss, costs, minco_nlp_);
    double t2 = ros::Time::now().toSec();




    if(costs.minCoeff() > 1.0e6){
      ROS_ERROR("cost is too high");
      return false;
    }

    //select the best one
    PRINT_GREEN("probabilities_map: "<<probabilities_map.transpose() << std::endl);
    PRINT_GREEN("costs: "<<costs.transpose() << std::endl);
    for(int i = 0; i < costs.size(); i++){
      // costs[i] *= pow((1-probabilities_map[i]), 1 / 2.0);
      // costs[i] *= 1.0;//here we need to modify
      costs[i] += 400.0*(1-probabilities_map[i]);
      // costs[i] = 3000.0*(1-probabilities_map[i]);
    }

    int min_index;
    costs.minCoeff(&min_index);
    PRINT_GREEN("min index: "<<min_index <<"\n");
    VectorXd T1, minco_wps;
    T1 = T1s[min_index];
    minco_wps = minco_wpss[min_index];
    








    

    RowMatrixXd wps_mat;
    wps_mat.resize(3, 4);
    for (int j=0; j<4; j++)
    {
      for (int i=0; i<3; i++)
      {
        wps_mat(i, j) = minco_wps(i+j*3);
      }
      wps_mat.col(j) = R.cast<double>() * wps_mat.col(j) + trans.cast<double>();
    }



    PRINT_GREEN("[NN Planner] solve time: " << (double)(ros::Time::now() - qp_start_time).toSec() * 1000.0 << "ms");
    PRINT_GREEN("[NN Planner] Planning time: "<< (double)(ros::Time::now() - start_time).toSec() * 1000.0 << "ms.");
    ROS_WARN_STREAM("vel: "<<start_vel.norm());

    if (abs(minco_wps[0]) > 1.0e+5)
    {
      PRINT_RED("[NN Planner] QP solusion strange");
      PRINT_YELLOW("start_state:\n"<<start_state);
      PRINT_YELLOW("end_state:\n"<<end_state);
      PRINT_YELLOW("T1:\n"<<T1.transpose());
    }

    // pub debug depth
    cv_bridge::CvImage out_msg;
    out_msg.header.frame_id = "world";
    out_msg.encoding = sensor_msgs::image_encodings::TYPE_32FC1;
    out_msg.image = (depths.back().clone() + 0.5) * 10.0;
    nn_depth_debug_pub_.publish(out_msg);

    // pub corridor marker
    corridor_marker_array_.markers.clear();
    for(int idx = 0; idx < primitive_num; idx++){
      RowMatrixXf corridor = corridors[idx];
      corridor.leftCols(3) = corridor.leftCols(3) * R.transpose();
      for (int i = 0; i < corridor.rows(); ++i)
      {
        corridor_marker.id = i + 100*idx;
        corridor_marker.pose.position.x = corridor(i, 0) + trans(0);
        corridor_marker.pose.position.y = corridor(i, 1) + trans(1);
        corridor_marker.pose.position.z = corridor(i, 2) + trans(2);
        corridor_marker.pose.orientation = odoms.back().pose.pose.orientation;
        if (corridor(i, 3) < 1e-2)
          corridor(i, 3) = 1e-2;
        corridor_marker.scale.x = corridor(i, 3) * 2.0;
        corridor_marker.scale.y = corridor(i, 3) * 2.0;
        corridor_marker.scale.z = corridor(i, 3) * 2.0;
        corridor_marker_array_.markers.push_back(corridor_marker);
      }

    }


    nn_corridor_pub_.publish(corridor_marker_array_);
    if (fabs(minco_wps[0]) > 1.0e+5)
    {
      PRINT_RED("[NN Planner] QP failed, skip replan");
      return false;
    }


    start_state = start_state * R.transpose();
    end_state = end_state * R.transpose();
    start_state.row(0) += trans;
    end_state.row(0) += trans;




    EgoMinco minco_traj;
    minco_traj.setTraj(start_state.transpose().cast<double>(), 
                       end_state.transpose().cast<double>(), 
                       wps_mat, T1.cast<double>());





    // ROS_WARN_STREAM("minco traj max vel: "<<minco_traj.traj.getMaxVelNorm());
    // ROS_WARN_STREAM("minco traj max acc: "<<minco_traj.traj.getMaxAccNorm());
    RowMatrixXd traj_points = minco_traj.sampleTrajPoints(0.1);
    

    geometry_msgs::Point pt;
    qp_traj.points.clear();
    for (int i = 0; i < traj_points.rows(); ++i)
    {
      pt.x = traj_points(i, 0);
      pt.y = traj_points(i, 1);
      pt.z = traj_points(i, 2);
      qp_traj.points.push_back(pt);
    }
    nn_traj_vis_pub_.publish(qp_traj);

    // set local traj
    std::vector<double> traj_durs;
    std::vector<poly_traj::CoefficientMat> traj_coeff_mats;
    Eigen::VectorXd traj_times = minco_traj.traj.getDurations();
    auto coeffs = minco_traj.traj.getCoeffMats();
    for (int i = 0; i < traj_times.size(); ++i)
    {
      traj_durs.push_back(traj_times(i));
      traj_coeff_mats.push_back(coeffs[i]);
    }
    poly_traj::Trajectory trajectory(traj_durs, traj_coeff_mats);
    traj_.local_traj.drone_id = 0;
    traj_.local_traj.traj_id++;
    traj_.local_traj.duration = minco_traj.traj.getTotalDuration();
    traj_.local_traj.start_pos = minco_traj.traj.getPos(0.0);
    traj_.local_traj.start_time = ros::Time::now().toSec();
    traj_.local_traj.traj = trajectory;
    traj_.local_traj.pts_chk = PtsChk_t();

    return true;
  }

  bool EGOPlannerManager::reboundReplanIplanner(
    const Eigen::Vector3d &start_pt, const Eigen::Vector3d &start_vel,
    const Eigen::Vector3d &start_acc, const Eigen::Vector3d &end_pt,
    const Eigen::Vector3d &end_vel, const Eigen::Vector3d &end_acc){
      torch::NoGradGuard no_grad_;
      ros::Time start_time = ros::Time::now();
      std::vector<nav_msgs::Odometry> odoms = grid_map_->getOdomQueue();
      std::vector<cv::Mat> depths = grid_map_->getDepthImgQueue();
      if (depths.size() < 10 || odoms.size() < 10)
      {
        PRINT_RED("odom or depth queue size is not enough, skip replan");
        return false;
      }

      Eigen::Vector3f trans(odoms.back().pose.pose.position.x,
                            odoms.back().pose.pose.position.y,
                            odoms.back().pose.pose.position.z);
      Eigen::Matrix3f R(Eigen::Quaternionf(odoms.back().pose.pose.orientation.w,
                                          odoms.back().pose.pose.orientation.x,
                                          odoms.back().pose.pose.orientation.y,
                                          odoms.back().pose.pose.orientation.z));
      at::Tensor input_tensor = torch::zeros({1, 10*120*160+10*7+3*3+3*3}).to(torch::kFloat32);
      int per_depth_size = 120*160;
      int depth_size = 10 * per_depth_size;

      for (int i=0; i<10; i++)
      {
        nav_msgs::Odometry odom = odoms[i];
        Eigen::Vector3f pos(odom.pose.pose.position.x,\
                            odom.pose.pose.position.y,\
                            odom.pose.pose.position.z);
        Eigen::Matrix3f rot(Eigen::Quaternionf(odom.pose.pose.orientation.w,
                                              odom.pose.pose.orientation.x,
                                              odom.pose.pose.orientation.y,
                                              odom.pose.pose.orientation.z));
        pos = R.transpose() * (pos - trans);
        Eigen::Quaternionf q(R.transpose() * rot);
        RowVectorXf rodom;
        rodom.resize(7);
        rodom << pos(0), pos(1), pos(2), q.w(), q.x(), q.y(), q.z();
        input_tensor.slice(1, depth_size+7*i, depth_size+7*i+7) = torch::from_blob(rodom.data(), {7}, torch::kFloat32);
        cv::Mat depth_image = depths[i].clone();
        input_tensor.slice(1, per_depth_size*i, per_depth_size*(i+1)) = torch::from_blob(depth_image.data, {per_depth_size}, torch::kFloat32);
      }

      ros::Time process_done = ros::Time::now();
      // PRINT_GREEN("[NN Planner] DepthOdom process time: "<< (double)(process_done - start_time).toSec() * 1000.0 << "ms.");

      // start_state and end_state
      RowMatrixXf start_state, end_state;
      start_state.resize(3, 3);
      start_state.row(0) = start_pt.cast<float>() - trans;
      start_state.row(1) = start_vel.cast<float>();
      start_state.row(2) = start_acc.cast<float>();
      end_state.resize(3, 3);
      end_state.setZero();
      end_state.row(0) = end_pt.cast<float>() - trans;
      end_state.row(1) = end_vel.cast<float>();
      end_state.row(2) = end_acc.cast<float>();
      start_state = start_state * R;
      end_state = end_state * R;

      // push data to tensor
      std::vector<torch::jit::IValue> inputs;
      int depodom_size = depth_size + 7 * 10;
      input_tensor.slice(1, depodom_size, depodom_size+9) = torch::from_blob(start_state.data(), {9}, torch::kFloat32);
      input_tensor.slice(1, depodom_size+9, depodom_size+18) = torch::from_blob(end_state.data(), {9}, torch::kFloat32);
      inputs.push_back(input_tensor.to(at::kCUDA));
      cudaDeviceSynchronize();

      ros::Time nn_start_time = ros::Time::now();
      PRINT_GREEN("[NN Planner] To GPU time: "<< (double)(nn_start_time - process_done).toSec() * 1000.0 << "ms.");
      // get results


      
      auto output_tuple = iplanner.forward(inputs).toTuple();  // 前向传播，获取输出
      cudaDeviceSynchronize();
      // PRINT_GREEN("[NN Planner] Inference time: "<< (double)(nn_end_time - nn_start_time).toSec() * 1000.0 << "ms.");
      std::vector<torch::Tensor> outputs;
      for (const auto& element : output_tuple->elements()) {
        outputs.push_back(element.toTensor());
      }
      at::Tensor wps_tensor = outputs[0].to(at::kCPU).squeeze(0);//3*4
      at::Tensor t_tensor = outputs[1].to(at::kCPU).squeeze(0);//5
      float mu = outputs[2].to(at::kCPU)[0].item<float>();
      if(mu < 0.01){
        return false;//meme
      }

      float* wps_mat_ptr = wps_tensor.data_ptr<float>();
      float* t_mat_ptr = t_tensor.data_ptr<float>();
      Eigen::Map<RowMatrixXf> wps_map(wps_mat_ptr, 3, 4);
      Eigen::Map<RowVectorXf> t_map(t_mat_ptr, 5);

    



      RowMatrixXd wps_mat;
      wps_mat.resize(3, 4);
      for (int j=0; j<4; j++)
      {
        for (int i=0; i<3; i++)
        {
          wps_mat(i, j) = double(wps_map(i,j));
        }
        wps_mat.col(j) = R.cast<double>() * wps_mat.col(j) + trans.cast<double>();
      }

      VectorXd T1;
      T1 = t_map.cast<double>();


      start_state = start_state * R.transpose();
      end_state = end_state * R.transpose();
      start_state.row(0) += trans;
      end_state.row(0) += trans;




      EgoMinco minco_traj;
      minco_traj.setTraj(start_state.transpose().cast<double>(), 
                        end_state.transpose().cast<double>(), 
                        wps_mat, T1.cast<double>());
      ros::Time nn_end_time = ros::Time::now();



      // ROS_WARN_STREAM("minco traj max vel: "<<minco_traj.traj.getMaxVelNorm());
      // ROS_WARN_STREAM("minco traj max acc: "<<minco_traj.traj.getMaxAccNorm());
      RowMatrixXd traj_points = minco_traj.sampleTrajPoints(0.1);
      

      geometry_msgs::Point pt;
      qp_traj.points.clear();
      for (int i = 0; i < traj_points.rows(); ++i)
      {
        pt.x = traj_points(i, 0);
        pt.y = traj_points(i, 1);
        pt.z = traj_points(i, 2);
        qp_traj.points.push_back(pt);
      }
      nn_traj_vis_pub_.publish(qp_traj);

      // set local traj
      std::vector<double> traj_durs;
      std::vector<poly_traj::CoefficientMat> traj_coeff_mats;
      Eigen::VectorXd traj_times = minco_traj.traj.getDurations();
      auto coeffs = minco_traj.traj.getCoeffMats();
      for (int i = 0; i < traj_times.size(); ++i)
      {
        traj_durs.push_back(traj_times(i));
        traj_coeff_mats.push_back(coeffs[i]);
      }
      poly_traj::Trajectory trajectory(traj_durs, traj_coeff_mats);
      traj_.local_traj.drone_id = 0;
      traj_.local_traj.traj_id++;
      traj_.local_traj.duration = minco_traj.traj.getTotalDuration();
      traj_.local_traj.start_pos = minco_traj.traj.getPos(0.0);
      traj_.local_traj.start_time = ros::Time::now().toSec();
      traj_.local_traj.traj = trajectory;
      traj_.local_traj.pts_chk = PtsChk_t();

      return true;

    }
  bool EGOPlannerManager::reboundReplan(
      const Eigen::Vector3d &start_pt, const Eigen::Vector3d &start_vel,
      const Eigen::Vector3d &start_acc, const Eigen::Vector3d &local_target_pt,
      const Eigen::Vector3d &local_target_vel, const Eigen::Vector3d &end_acc, const bool flag_polyInit,
      const bool flag_randomPolyTraj, const bool touch_goal)
  {
    ros::Time t_start = ros::Time::now();
    ros::Duration t_init, t_opt;

    static int count = 0;
    
    // local_target_vel
    if(planner_type_==0){
      return reboundReplanNN(start_pt, start_vel, start_acc,
                                    local_target_pt, local_target_vel, Eigen::Vector3d(0,0,0));
    }
    
    if(planner_type_==2){
      return reboundReplanIplanner(start_pt, start_vel, start_acc,
                                    local_target_pt, local_target_vel, Eigen::Vector3d(0,0,0));
    }
    poly_traj::MinJerkOpt best_MJO;
    bool flag_success = false;
    poly_traj::MinJerkOpt initMJO;
    Eigen::MatrixXd cstr_pts;
    
    if(planner_type_==3){
      //fast planner
      double up1 = ros::Time::now().toSec();
      local_map_.SetMapBuild(grid_map_->cur_cloud_msg_, grid_map_->cur_pos_);
      double up2 = ros::Time::now().toSec();
      hybrid_path_finder_->reset();
      int flag;
      if(start_vel.norm()<0.2 && start_acc.norm()<0.2){
        flag = hybrid_path_finder_->search(start_pt, start_vel, start_acc, local_target_pt, local_target_vel, false);
      }
      else{
        flag=  hybrid_path_finder_->search(start_pt, start_vel, start_acc, local_target_pt, local_target_vel);
      }
      if(flag==hybrid::KinodynamicAstar<map_util::VoxelMapUtil>::NO_PATH){
        ROS_ERROR("NO PATH");
        return false;
      }
      double up3 = ros::Time::now().toSec();
      double delta_time = 0.01;
      std::vector<Eigen::Vector3d> kinopath = hybrid_path_finder_->getKinoTraj(delta_time);
      Eigen::Matrix3d is, es;
      is << start_pt, start_vel, start_acc;
      int pn_me =5;
      // if(velocity_type==2){
      //   pn_me = 8;
      // }
      es << local_target_pt, local_target_vel, Eigen::Vector3d(0,0,0);
      double up4 = ros::Time::now().toSec();
      Eigen::MatrixXd init_wps;
      Eigen::VectorXd init_rts;
      init_wps.resize(3, pn_me-1);
      init_rts.resize(pn_me);
      double init_dt = kinopath.size()*delta_time/pn_me;
      double totalLength = 0;
      for(int i = 0; i < kinopath.size()-1; i++){
        totalLength += (kinopath[i]-kinopath[i+1]).norm();
      }
      init_dt = (totalLength / pn_me)/((start_vel.norm()+local_target_vel.norm())/2.4);
      init_rts.setConstant(init_dt);
      int delta =  int(kinopath.size() / pn_me);
      for(int i = 0; i < pn_me-1; i++){
        init_wps.col(i) << kinopath[(i+1)*delta];
      }
      initMJO.reset(is, es, pn_me);
      initMJO.generate(init_wps, init_rts);
      cstr_pts = initMJO.getInitConstraintPoints(ploy_traj_opt_->get_cps_num_prePiece_());
      t_init = ros::Time::now() - t_start;
      std::vector<Eigen::Vector3d> point_set;
      for (int i = 0; i < cstr_pts.cols(); ++i)
        point_set.push_back(cstr_pts.col(i));
      // visualization_->displayInitPathList(point_set, 0.2, 0);
      visualization_->displayInitPathList(kinopath, 0.2, 0);
      double up5 = ros::Time::now().toSec();
      /*** STEP 2: OPTIMIZE ***/
      vector<vector<Eigen::Vector3d>> vis_trajs;
      poly_traj::Trajectory initTraj = initMJO.getTraj();
      Eigen::MatrixXd innerPts, headState, tailState;
      Eigen::VectorXd initts;
      innerPts.resize(3, pn_me-1);
      headState.resize(3, 3);
      tailState.resize(3, 3);
      headState << initTraj.getPos(0), initTraj.getVel(0), initTraj.getAcc(0);
      tailState << initTraj.getPos(initTraj.getTotalDuration()), initTraj.getVel(initTraj.getTotalDuration()), initTraj.getAcc(initTraj.getTotalDuration());
      initts.resize(pn_me); initts.setConstant(initTraj.getTotalDuration()/pn_me);
      for(int i = 0; i < pn_me-1; i++){
        innerPts.col(i) << initTraj.getPos(1.0*(i+1) / pn_me * initTraj.getTotalDuration());
      }
      flag_success = local_traj_opt_.OptimizeLocalTrajectory(headState, tailState, innerPts, initts);
      best_MJO = local_traj_opt_.jerkOpt_;
      double up6 = ros::Time::now().toSec();
      ROS_WARN_STREAM("up2 - up1: "<<1000.0*(up2-up1) <<" ms");
      ROS_WARN_STREAM("up3 - up2: "<<1000.0*(up3-up2) <<" ms");
      ROS_WARN_STREAM("up4 - up3: "<<1000.0*(up4-up3) <<" ms");
      ROS_WARN_STREAM("up5 - up4: "<<1000.0*(up5-up4) <<" ms");
      ROS_WARN_STREAM("up6 - up5: "<<1000.0*(up6-up5) <<" ms");


    }
    if(planner_type_==1){
      //ego planner
      double t1  = ros::Time::now().toSec();
      ploy_traj_opt_->setIfTouchGoal(touch_goal);
      double ts = pp_.polyTraj_piece_length / pp_.max_vel_;
      if (!computeInitState(start_pt, start_vel, start_acc, local_target_pt, local_target_vel,
                            flag_polyInit, flag_randomPolyTraj, ts, initMJO))
      {
        ROS_ERROR("!computeInitState");
        return false;
      }
      cstr_pts = initMJO.getInitConstraintPoints(ploy_traj_opt_->get_cps_num_prePiece_());
      vector<std::pair<int, int>> segments;
      if (ploy_traj_opt_->finelyCheckAndSetConstraintPoints(segments, initMJO, true) == PolyTrajOptimizer::CHK_RET::ERR)
      {
        ROS_ERROR("ploy_traj_opt_->finelyCheckAndSetConstraintPoints(segments, initMJO, true) == PolyTrajOptimizer::CHK_RET::ERR");
        return false;
      }
      
      std::vector<Eigen::Vector3d> point_set;
      for (int i = 0; i < cstr_pts.cols(); ++i)
        point_set.push_back(cstr_pts.col(i));
      visualization_->displayInitPathList(point_set, 0.2, 0);
      poly_traj::Trajectory initTraj = initMJO.getTraj();
      int PN = initTraj.getPieceNum();
      Eigen::MatrixXd all_pos = initTraj.getPositions();
      Eigen::MatrixXd innerPts = all_pos.block(0, 1, 3, PN - 1);
      Eigen::Matrix<double, 3, 3> headState, tailState;
      headState << initTraj.getJuncPos(0), initTraj.getJuncVel(0), initTraj.getJuncAcc(0);
      tailState << initTraj.getJuncPos(PN), initTraj.getJuncVel(PN), initTraj.getJuncAcc(PN);
      double final_cost;
      flag_success = ploy_traj_opt_->optimizeTrajectory(headState, tailState,
                                                        innerPts, initTraj.getDurations(), final_cost);//me
      best_MJO = ploy_traj_opt_->getMinJerkOpt();
      double t2  = ros::Time::now().toSec();
    }





    if (flag_success)
    {
      static double sum_time = 0;
      static int count_success = 0;
      sum_time += (t_init + t_opt).toSec();
      count_success++;
      printf("Time:\033[42m%.3fms,\033[0m init:%.3fms, optimize:%.3fms, avg=%.3fms\n",
             (t_init + t_opt).toSec() * 1000, t_init.toSec() * 1000, t_opt.toSec() * 1000, sum_time / count_success * 1000);
      // cout << "total time:\033[42m" << (t_init + t_opt).toSec()
      //      << "\033[0m,init:" << t_init.toSec()
      //      << ",optimize:" << t_opt.toSec()
      //      << ",avg_time=" << sum_time / count_success << endl;

      setLocalTrajFromOpt(best_MJO, touch_goal);
      cstr_pts = best_MJO.getInitConstraintPoints(ploy_traj_opt_->get_cps_num_prePiece_());
      visualization_->displayOptimalList(cstr_pts, 0);
      
      continous_failures_count_ = 0;
    }
    else
    {
      cstr_pts = best_MJO.getInitConstraintPoints(ploy_traj_opt_->get_cps_num_prePiece_());
      visualization_->displayFailedList(cstr_pts, 0);
      continous_failures_count_++;
      
    }
      // if(!flag_success)
      //   ros::Duration(10000.0).sleep();
    ROS_WARN_STREAM("max vel: "<<best_MJO.getTraj().getMaxVelRate());
    ROS_WARN_STREAM("max acc: "<<best_MJO.getTraj().getMaxAccRate());
    return flag_success;
  }

  bool EGOPlannerManager::computeInitState(
      const Eigen::Vector3d &start_pt, const Eigen::Vector3d &start_vel, const Eigen::Vector3d &start_acc,
      const Eigen::Vector3d &local_target_pt, const Eigen::Vector3d &local_target_vel,
      const bool flag_polyInit, const bool flag_randomPolyTraj, const double &ts,
      poly_traj::MinJerkOpt &initMJO)
  {

    static bool flag_first_call = true;

    if (flag_first_call || flag_polyInit) /*** case 1: polynomial initialization ***/
    {
      flag_first_call = false;

      /* basic params */
      Eigen::Matrix3d headState, tailState;
      Eigen::MatrixXd innerPs;
      Eigen::VectorXd piece_dur_vec;
      int piece_nums;
      constexpr double init_of_init_totaldur = 2.0;
      headState << start_pt, start_vel, start_acc;
      tailState << local_target_pt, local_target_vel, Eigen::Vector3d::Zero();

      /* determined or random inner point */
      if (!flag_randomPolyTraj)
      {
        if (innerPs.cols() != 0)
        {
          ROS_ERROR("innerPs.cols() != 0");
        }

        piece_nums = 1;
        piece_dur_vec.resize(1);
        piece_dur_vec(0) = init_of_init_totaldur;
      }
      else
      {
        Eigen::Vector3d horizen_dir = ((start_pt - local_target_pt).cross(Eigen::Vector3d(0, 0, 1))).normalized();
        Eigen::Vector3d vertical_dir = ((start_pt - local_target_pt).cross(horizen_dir)).normalized();
        innerPs.resize(3, 1);
        innerPs = (start_pt + local_target_pt) / 2 +
                  (((double)rand()) / RAND_MAX - 0.5) *
                      (start_pt - local_target_pt).norm() *
                      horizen_dir * 0.8 * (-0.978 / (continous_failures_count_ + 0.989) + 0.989) +
                  (((double)rand()) / RAND_MAX - 0.5) *
                      (start_pt - local_target_pt).norm() *
                      vertical_dir * 0.4 * (-0.978 / (continous_failures_count_ + 0.989) + 0.989);

        piece_nums = 2;
        piece_dur_vec.resize(2);
        piece_dur_vec = Eigen::Vector2d(init_of_init_totaldur / 2, init_of_init_totaldur / 2);
      }

      /* generate the init of init trajectory */
      initMJO.reset(headState, tailState, piece_nums);
      initMJO.generate(innerPs, piece_dur_vec);
      poly_traj::Trajectory initTraj = initMJO.getTraj();

      /* generate the real init trajectory */
      piece_nums = round((headState.col(0) - tailState.col(0)).norm() / pp_.polyTraj_piece_length);
      if (piece_nums < 2){
        piece_nums = 2;
      }
      piece_nums = 5;
      if(velocity_type==2){
        piece_nums = 8;
      }
      // ROS_WARN_STREAM("piece_nums: "<<piece_nums);//me
      double piece_dur = init_of_init_totaldur / (double)piece_nums;
      piece_dur_vec.resize(piece_nums);
      piece_dur_vec = Eigen::VectorXd::Constant(piece_nums, ts);
      innerPs.resize(3, piece_nums - 1);
      int id = 0;
      double t_s = piece_dur, t_e = init_of_init_totaldur - piece_dur / 2;
      for (double t = t_s; t < t_e; t += piece_dur)
      {
        innerPs.col(id++) = initTraj.getPos(t);
      }
      if (id != piece_nums - 1)
      {
        ROS_ERROR("Should not happen! x_x");
        return false;
      }
      initMJO.reset(headState, tailState, piece_nums);
      initMJO.generate(innerPs, piece_dur_vec);
    }
    else /*** case 2: initialize from previous optimal trajectory ***/
    {
      if (traj_.global_traj.last_glb_t_of_lc_tgt < 0.0)
      {
        ROS_ERROR("You are initialzing a trajectory from a previous optimal trajectory, but no previous trajectories up to now.");
        return false;
      }

      /* the trajectory time system is a little bit complicated... */
      double passed_t_on_lctraj = ros::Time::now().toSec() - traj_.local_traj.start_time;
      double t_to_lc_end = traj_.local_traj.duration - passed_t_on_lctraj;
      if (t_to_lc_end < 0)
      {
        ROS_INFO("t_to_lc_end < 0, exit and wait for another call.");
        return false;
      }
      double t_to_lc_tgt = t_to_lc_end +
                           (traj_.global_traj.glb_t_of_lc_tgt - traj_.global_traj.last_glb_t_of_lc_tgt);
      int piece_nums = ceil((start_pt - local_target_pt).norm() / pp_.polyTraj_piece_length);
      if (piece_nums < 2)
        piece_nums = 2;
      piece_nums = 5;
      if(velocity_type==2){
        piece_nums = 8;
      }
      // ROS_WARN_STREAM("piece_nums: "<<piece_nums);//me
      Eigen::Matrix3d headState, tailState;
      Eigen::MatrixXd innerPs(3, piece_nums - 1);
      Eigen::VectorXd piece_dur_vec = Eigen::VectorXd::Constant(piece_nums, t_to_lc_tgt / piece_nums);
      headState << start_pt, start_vel, start_acc;
      tailState << local_target_pt, local_target_vel, Eigen::Vector3d::Zero();

      double t = piece_dur_vec(0);
      for (int i = 0; i < piece_nums - 1; ++i)
      {
        if (t < t_to_lc_end)
        {
          innerPs.col(i) = traj_.local_traj.traj.getPos(t + passed_t_on_lctraj);
        }
        else if (t <= t_to_lc_tgt)
        {
          double glb_t = t - t_to_lc_end + traj_.global_traj.last_glb_t_of_lc_tgt - traj_.global_traj.global_start_time;
          innerPs.col(i) = traj_.global_traj.traj.getPos(glb_t);
        }
        else
        {
          ROS_ERROR("Should not happen! x_x 0x88 t=%.2f, t_to_lc_end=%.2f, t_to_lc_tgt=%.2f", t, t_to_lc_end, t_to_lc_tgt);
        }

        t += piece_dur_vec(i + 1);
      }

      initMJO.reset(headState, tailState, piece_nums);
      initMJO.generate(innerPs, piece_dur_vec);











      /* the trajectory time system is a little bit complicated... */
      // double glb_t_of_lc_tgt = (traj_.local_traj.traj.getPos(traj_.local_traj.traj.getTotalDuration()) - local_target_pt).norm() / pp_.max_vel_;
      // double passed_t_on_lctraj = ros::Time::now().toSec() - traj_.local_traj.start_time;
      // double t_to_lc_end = traj_.local_traj.duration - passed_t_on_lctraj;
      // if (t_to_lc_end < 0)
      // {
      //   ROS_INFO("t_to_lc_end < 0, exit and wait for another call.");
      //   return false;
      // }
      // double t_to_lc_tgt = t_to_lc_end + 
      //                      glb_t_of_lc_tgt;



      // int piece_nums = ceil((start_pt - local_target_pt).norm() / pp_.polyTraj_piece_length);
      // if (piece_nums < 2)
      //   piece_nums = 2;
      // piece_nums = 5;
      // // ROS_WARN_STREAM("piece_nums: "<<piece_nums);//me
      // Eigen::Matrix3d headState, tailState;
      // Eigen::MatrixXd innerPs(3, piece_nums - 1);
      // Eigen::VectorXd piece_dur_vec = Eigen::VectorXd::Constant(piece_nums, t_to_lc_tgt / piece_nums);
      // headState << start_pt, start_vel, start_acc;
      // tailState << local_target_pt, local_target_vel, Eigen::Vector3d::Zero();

      // double t = piece_dur_vec(0);
      // for (int i = 0; i < piece_nums - 1; ++i)
      // {
      //   if (t < t_to_lc_end)
      //   {
      //     innerPs.col(i) = traj_.local_traj.traj.getPos(t + passed_t_on_lctraj);
      //   }
      //   else if (t <= t_to_lc_tgt)
      //   {
      //     // double glb_t = t - t_to_lc_end + traj_.global_traj.last_glb_t_of_lc_tgt - traj_.global_traj.global_start_time;
      //     // innerPs.col(i) = traj_.global_traj.traj.getPos(glb_t);
      //     double glb_t = t - t_to_lc_end;
      //     innerPs.col(i) =  glb_t * (local_target_pt  - traj_.local_traj.traj.getPos(traj_.local_traj.traj.getTotalDuration())) / glb_t_of_lc_tgt + traj_.local_traj.traj.getPos(traj_.local_traj.traj.getTotalDuration());
      //   }
      //   else
      //   {
      //     ROS_ERROR("Should not happen! x_x 0x88 t=%.2f, t_to_lc_end=%.2f, t_to_lc_tgt=%.2f", t, t_to_lc_end, t_to_lc_tgt);
      //   }

      //   t += piece_dur_vec(i + 1);
      // }

      // initMJO.reset(headState, tailState, piece_nums);
      // initMJO.generate(innerPs, piece_dur_vec);
    }

    return true;
  }
  void EGOPlannerManager::getLocalTarget_me(      const double planning_horizen, const Eigen::Vector3d &start_pt,
      const Eigen::Vector3d &global_end_pt, Eigen::Vector3d &local_target_pos,
      Eigen::Vector3d &local_target_vel, 
      Eigen::Vector3d &local_target_acc,
      bool &touch_goal){
    if((global_end_pt-start_pt).norm()<=planning_horizen){
      touch_goal = true;
      local_target_pos = global_end_pt;
      local_target_vel.setZero();
      local_target_acc.setZero();
    }
    else{
      local_target_pos = planning_horizen * (global_end_pt - start_pt) / (global_end_pt - start_pt).norm() + start_pt;
      local_target_vel = (global_end_pt - start_pt) / (global_end_pt - start_pt).norm() * pp_.max_vel_ / 1.2;
      local_target_acc.setZero();
      touch_goal = false;
    }

// pp_.max_vel_ * pp_.max_vel_
  }
  void EGOPlannerManager::getLocalTarget(
      const double planning_horizen, const Eigen::Vector3d &start_pt,
      const Eigen::Vector3d &global_end_pt, Eigen::Vector3d &local_target_pos,
      Eigen::Vector3d &local_target_vel, 
      Eigen::Vector3d &local_target_acc,
      bool &touch_goal)
  {
    double t;
    touch_goal = false;

    traj_.global_traj.last_glb_t_of_lc_tgt = traj_.global_traj.glb_t_of_lc_tgt;

    double t_step = planning_horizen / 20 / pp_.max_vel_;
    // double dist_min = 9999, dist_min_t = 0.0;
    for (t = traj_.global_traj.glb_t_of_lc_tgt;
         t < (traj_.global_traj.global_start_time + traj_.global_traj.duration);
         t += t_step)
    {
      Eigen::Vector3d pos_t = traj_.global_traj.traj.getPos(t - traj_.global_traj.global_start_time);
      double dist = (pos_t - start_pt).norm();

      if (dist >= planning_horizen)
      {
        local_target_pos = pos_t;
        traj_.global_traj.glb_t_of_lc_tgt = t;
        break;
      }
    }

    //me
    double minclose = 9999999.0;
    for(double itert = 0.0; itert <= traj_.global_traj.traj.getTotalDuration(); itert+=0.1){
      if(minclose > (traj_.global_traj.traj.getPos(itert)-start_pt).norm()){
        minclose = (traj_.global_traj.traj.getPos(itert)-start_pt).norm();
        start_global_time_ = itert;
      }
    }




    if ((t - traj_.global_traj.global_start_time) >= traj_.global_traj.duration - 1e-5) // Last global point
    {
      local_target_pos = global_end_pt;
      traj_.global_traj.glb_t_of_lc_tgt = traj_.global_traj.global_start_time + traj_.global_traj.duration;
      touch_goal = true;
    }

    if ((global_end_pt - local_target_pos).norm() < (pp_.max_vel_ * pp_.max_vel_) / (2 * pp_.max_acc_))
    {
      local_target_vel = Eigen::Vector3d::Zero();
      local_target_acc = Eigen::Vector3d::Zero();
    }
    else
    {
      local_target_vel = traj_.global_traj.traj.getVel(t - traj_.global_traj.global_start_time);
      local_target_acc = traj_.global_traj.traj.getAcc(t - traj_.global_traj.global_start_time);
    }
  }

  bool EGOPlannerManager::setLocalTrajFromOpt(const poly_traj::MinJerkOpt &opt, const bool touch_goal)
  {
    poly_traj::Trajectory traj = opt.getTraj();
    Eigen::MatrixXd cps = opt.getInitConstraintPoints(getCpsNumPrePiece());
    PtsChk_t pts_to_check;
    bool ret;
    if(planner_type_==3){
       // set local traj
      traj_.local_traj.drone_id = 0;
      traj_.local_traj.traj_id++;
      traj_.local_traj.duration = traj.getTotalDuration();
      traj_.local_traj.start_pos = traj.getPos(0.0);
      traj_.local_traj.start_time = ros::Time::now().toSec();
      traj_.local_traj.traj = traj;
      traj_.local_traj.pts_chk = PtsChk_t();
      return true;
    }
    else{
      ret = ploy_traj_opt_->computePointsToCheck(traj, ConstraintPoints::two_thirds_id(cps, touch_goal), pts_to_check);
    }
    if (ret && pts_to_check.size() >= 1 && pts_to_check.back().size() >= 1)
    {
      traj_.setLocalTraj(traj, pts_to_check, ros::Time::now().toSec());
    }

    return ret;
  }

  bool EGOPlannerManager::EmergencyStop(Eigen::Vector3d stop_pos)
  {
    auto ZERO = Eigen::Vector3d::Zero();
    Eigen::Matrix<double, 3, 3> headState, tailState;
    headState << stop_pos, ZERO, ZERO;
    tailState = headState;
    poly_traj::MinJerkOpt stopMJO;
    stopMJO.reset(headState, tailState, 2);
    stopMJO.generate(stop_pos, Eigen::Vector2d(1.0, 1.0));

    setLocalTrajFromOpt(stopMJO, false);

    return true;
  }

  bool EGOPlannerManager::checkCollision(int drone_id)
  {
    if (traj_.local_traj.start_time < 1e9) // It means my first planning has not started
      return false;
    if (traj_.swarm_traj[drone_id].drone_id != drone_id) // The trajectory is invalid
      return false;

    double my_traj_start_time = traj_.local_traj.start_time;
    double other_traj_start_time = traj_.swarm_traj[drone_id].start_time;

    double t_start = max(my_traj_start_time, other_traj_start_time);
    double t_end = min(my_traj_start_time + traj_.local_traj.duration * 2 / 3,
                       other_traj_start_time + traj_.swarm_traj[drone_id].duration);

    for (double t = t_start; t < t_end; t += 0.03)
    {
      if ((traj_.local_traj.traj.getPos(t - my_traj_start_time) -
           traj_.swarm_traj[drone_id].traj.getPos(t - other_traj_start_time))
              .norm() < (getSwarmClearance() + traj_.swarm_traj[drone_id].des_clearance) )
      {
        return true;
      }
    }

    return false;
  }

  bool EGOPlannerManager::planGlobalTrajWaypoints(
      const Eigen::Vector3d &start_pos, const Eigen::Vector3d &start_vel,
      const Eigen::Vector3d &start_acc, const std::vector<Eigen::Vector3d> &waypoints,
      const Eigen::Vector3d &end_vel, const Eigen::Vector3d &end_acc)
  {

    poly_traj::MinJerkOpt globalMJO;
    Eigen::Matrix<double, 3, 3> headState, tailState;
    headState << start_pos, start_vel, start_acc;
    tailState << waypoints.back(), end_vel, end_acc;

    // astar_path_finder_->reset();
    // astar_path_finder_->search(start_pos, waypoints.back());
    // astar_path_finder_->visPath();
    // global_traj_opt_.OptimizeGlobalTrajectory(headState, tailState, astar_path_finder_->getPath());
    // traj_.setGlobalTraj(global_traj_opt_.jerkOpt_.getTraj(),ros::Time::now().toSec());
    // return true;
    Eigen::MatrixXd innerPts;

    if (waypoints.size() > 1)
    {

      innerPts.resize(3, waypoints.size() - 1);
      for (int i = 0; i < (int)waypoints.size() - 1; ++i)
      {
        innerPts.col(i) = waypoints[i];
      }
    }
    else
    {
      if (innerPts.size() != 0)
      {
        ROS_ERROR("innerPts.size() != 0");
      }
    }

    globalMJO.reset(headState, tailState, waypoints.size());

    double des_vel = pp_.max_vel_ / 1.5;
    Eigen::VectorXd time_vec(waypoints.size());

    for (int j = 0; j < 2; ++j)
    {
      for (size_t i = 0; i < waypoints.size(); ++i)
      {
        time_vec(i) = (i == 0) ? (waypoints[0] - start_pos).norm() / des_vel
                               : (waypoints[i] - waypoints[i - 1]).norm() / des_vel;
      }

      globalMJO.generate(innerPts, time_vec);

      if (globalMJO.getTraj().getMaxVelRate() < pp_.max_vel_ ||
          start_vel.norm() > pp_.max_vel_ ||
          end_vel.norm() > pp_.max_vel_)
      {
        break;
      }

      if (j == 2)
      {
        ROS_WARN("Global traj MaxVel = %f > set_max_vel", globalMJO.getTraj().getMaxVelRate());
        cout << "headState=" << endl
             << headState << endl;
        cout << "tailState=" << endl
             << tailState << endl;
      }

      des_vel /= 1.5;
    }

    auto time_now = ros::Time::now();
    traj_.setGlobalTraj(globalMJO.getTraj(), time_now.toSec());

    return true;
  }

} // namespace ego_planner
