//
// Created by soraxas on 4/7/21.
//
#pragma once

#include <moveit/ompl_interface/parameterization/joint_space/joint_model_state_space.h>
#include <moveit/ompl_interface/parameterization/joint_space/joint_model_state_space_factory.h>
#include <ompl/base/spaces/RealVectorStateSpace.h>
#include <ompl/util/Exception.h>

#include <EigenRand/EigenRand>

#include "TorchOccMap.hpp"
#include "concurrentqueue.h"

namespace ob = ompl::base;
namespace og = ompl::geometric;

#define WORLDSPACE_NUM_DIM 3
#define CSPACE_NUM_DIM 6

using EigenCSpacePt           = Eigen::Array<double, 1, CSPACE_NUM_DIM>;
using EigenCSpacePtVector     = Eigen::Matrix<double, CSPACE_NUM_DIM, 1>;
using EigenWorldSpacePtVector = Eigen::Matrix<double, WORLDSPACE_NUM_DIM, 1>;

template <typename StateSpaceType, size_t NumDim>
class DiffeomorphicStateSampler_Base : public ob::StateSampler {
  using EigenCSpacePt       = Eigen::Array<double, 1, NumDim>;
  using EigenCSpacePtVector = Eigen::Matrix<double, NumDim, 1>;

 public:
  using MatrixXdRowMajor =
      Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
  using MatrixXfRowMajor =
      Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
  using ArrayXdRowMajor =
      Eigen::Array<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;

  using SampledMatRowMajor =
      Eigen::Array<double, Eigen::Dynamic, NumDim, Eigen::RowMajor>;

  DiffeomorphicStateSampler_Base(const ob::StateSpace* space,
                                 int rand_batch_sample_size = 50,
                                 double epsilon = 2, uint num_morph = 2)
      : StateSampler(space),
        rand_batch_sample_size(rand_batch_sample_size),
        epsilon(epsilon),
        num_morph(num_morph),
        num_dimensions(space_->getDimension()),
        bounds_high(std::vector<double>()),
        bounds_low(std::vector<double>()) {
    if (num_dimensions != NumDim) {
      std::cout << "Given state-space dim " << num_dimensions
                << " is different than the defined dimension " << NumDim
                << std::endl;
      assert(num_dimensions == NumDim);
    }

    //    ////////////////////////////////////////
    //    ///// Setup state space
    //    std::string space_type;
    //    space_type = "RealVectorStateSpace";
    //    space_type = "JointModelStateSpace";

    casted_space_ptr = dynamic_cast<const StateSpaceType*>(space_);
    assert(("Bad cast of a state-space ptr.", casted_space_ptr != nullptr));
  }

  void init() {
    // Derived class will override this function
    _retrieve_dimension_bounds();

    assert(bounds_high.size() == NumDim);
    assert(bounds_low.size() == NumDim);

    rand_urng = Eigen::Rand::Vmt19937_64{std::random_device{}()};
    rand_sampled_batch_cur_row = rand_batch_sample_size;  // we starts with
                                                          // an empty batch
    bounds_lower = EigenCSpacePt::Map(bounds_low.data(), bounds_low.size());
    bounds_upper = EigenCSpacePt::Map(bounds_high.data(), bounds_high.size());
    bounds_diff  = bounds_upper - bounds_lower;
  }

  virtual void _retrieve_dimension_bounds() = 0;

  const StateSpaceType* casted_space_ptr;

  // batch size for the random number
  size_t rand_batch_sample_size;
  // random number generator
  Eigen::Rand::Vmt19937_64 rand_urng;
  // the sampled batch
  SampledMatRowMajor rand_sampled_batch;
  std::vector<EigenCSpacePtVector> rand_sampled_batch_as_vec;
  // the current row index of the sampled batch
  size_t rand_sampled_batch_cur_row{};
  // random number generate as uniform random
  Eigen::Rand::UniformRealGen<double> rand_uni_gen;

  bool use_diff = true;

  ///////////////////////////////////////////////////////////////////////

  //#define PLOT_SAMPLED_PTS

  void draw_sample_bucket() {
    rand_sampled_batch = rand_uni_gen.generate<SampledMatRowMajor>(
        rand_batch_sample_size, num_dimensions, rand_urng);
    // 0-1  =>  0-diff  =>  lower-upper
    rand_sampled_batch =
        (rand_sampled_batch.rowwise() * bounds_diff).rowwise() + bounds_lower;
    rand_sampled_batch_cur_row = 0;
    drift_states();
  }

  virtual void drift_states() { throw std::runtime_error("not implemented"); };

  sxs::Stats& m_stats = sxs::g::get_stats();

  void sampleUniform(ob::State* state) override {
    bool morphed;
    if (this->use_diff) {
      morphed = sampleUniform_Diffeomorphic(state);
    } else {
      sampleUniform_PurelyUniformStateSampler(state);
      morphed = false;
    }

    auto si_ = sxs::g::storage::get<ob::SpaceInformationPtr>("si");
    state->as<ompl_interface::JointModelStateSpace::StateType>()
        ->clearKnownInformation();
    bool is_valid = si_->isValid(state);
    m_stats.of<long>("total_sample") += 1;
    m_stats.of<long>("total_valid") += is_valid;
    m_stats.of<double>("total_valid_pct") =
        double(m_stats.of<long>("total_valid")) /
        double(m_stats.of<long>("total_sample"));

    m_stats.of<long>("total_valid_morphed") += 0;   // initialise
    m_stats.of<long>("total_sample_morphed") += 0;  // initialise
    m_stats.of<long>("total_valid_uniform") += 0;   // initialise
    m_stats.of<long>("total_sample_uniform") += 0;  // initialise
    if (morphed) {
      m_stats.of<long>("total_valid_morphed") += is_valid;
      m_stats.of<long>("total_sample_morphed") += 1;
    } else {
      m_stats.of<long>("total_valid_uniform") += is_valid;
      m_stats.of<long>("total_sample_uniform") += 1;
    }

    m_stats.of<double>("total_valid_morphed_pct") =
        double(m_stats.of<long>("total_valid_morphed")) /
        double(m_stats.of<long>("total_sample_morphed"));

    m_stats.of<double>("total_valid_uniform_pct") =
        double(m_stats.of<long>("total_valid_uniform")) /
        double(m_stats.of<long>("total_sample_uniform"));

    //    auto& timer = sxs::g::get<cppm::pm_timer>("timer");
    //    sxs::g::stats.format_item(timer);
    //    timer.update();
  }

  void sampleUniform_PurelyUniformStateSampler(ob::State* state) {
    const unsigned int dim = space_->getDimension();

    auto* rstate = state->as<typename StateSpaceType::StateType>();
    for (unsigned int i = 0; i < dim; ++i)
      rstate->values[i] = rng_.uniformReal(bounds_low[i], bounds_high[i]);
  }

  virtual bool sampleUniform_Diffeomorphic(ob::State* state) {
    auto* rstate          = state->as<typename StateSpaceType::StateType>();
    double* rstate_values = rstate->values;

    if (rand_sampled_batch_cur_row >= rand_batch_sample_size)
      draw_sample_bucket();

    /* assign rstate_values with pointer (the matrix NEEDS to be row-major) */
    // get raw pointer from the matrix
    double* from = &rand_sampled_batch(rand_sampled_batch_cur_row++, 0);
    std::copy(from, from + num_dimensions, rstate_values);
    return true;
  }

  void sampleUniformNear(ob::State*, const ob::State*, const double) override {
    throw ompl::Exception("not implemented");
  }

  void sampleGaussian(ob::State*, const ob::State*, const double) override {
    throw ompl::Exception("not implemented");
  }

  double epsilon;
  uint num_morph;

 protected:
  unsigned int num_dimensions;
  ompl::RNG rng_;
  /** \brief The sampler to build upon */

  std::vector<double> bounds_high;
  std::vector<double> bounds_low;

  EigenCSpacePt bounds_lower;
  EigenCSpacePt bounds_upper;
  EigenCSpacePt bounds_diff;
};

/**
 * The abstract base class for the diffeomorphic sampler
 *
 * @tparam StateSpaceType - The state space type to be specialised on (e.g.
 *         Real Vector or MoveIt Joint Space)
 * @tparam NumDim - The dimensionality of the space. TODO: make this generic
 */
template <typename StateSpaceType, size_t NumDim>
class DiffeomorphicStateSampler
    : public DiffeomorphicStateSampler_Base<StateSpaceType, NumDim> {};

/**
 * Specialisation for `RealVectorStateSpace`
 *
 * @tparam NumDim
 */
template <size_t NumDim>
class DiffeomorphicStateSampler<ob::RealVectorStateSpace, NumDim>
    : public DiffeomorphicStateSampler_Base<ob::RealVectorStateSpace, NumDim> {
 public:
  // call base constructor
  template <typename... Args>
  explicit DiffeomorphicStateSampler(Args... args)
      : DiffeomorphicStateSampler_Base<ob::RealVectorStateSpace, NumDim>(
            args...) {
    this->init();
  }

  void _retrieve_dimension_bounds() override {
    auto bounds = this->casted_space_ptr->getBounds();
    this->bounds_high.insert(this->bounds_high.end(), bounds.high.begin(),
                             bounds.high.end());
    this->bounds_low.insert(this->bounds_low.end(), bounds.low.begin(),
                            bounds.low.end());
  }
};

#ifdef DIFFEOMORPHIC_SAMPLER_FOR_MOVEIT
/**
 * Specialisation for `JointModelStateSpace`
 */
template <size_t NumDim>
class DiffeomorphicStateSampler<ompl_interface::JointModelStateSpace, NumDim>
    : public DiffeomorphicStateSampler_Base<
          ompl_interface::JointModelStateSpace, NumDim> {
 public:
  using SampledMatRowMajor = typename DiffeomorphicStateSampler_Base<
      ompl_interface ::JointModelStateSpace, NumDim>::SampledMatRowMajor;
  using MatrixXdRowMajor = typename DiffeomorphicStateSampler_Base<
      ompl_interface ::JointModelStateSpace, NumDim>::MatrixXdRowMajor;
  using MatrixXfRowMajor = typename DiffeomorphicStateSampler_Base<
      ompl_interface ::JointModelStateSpace, NumDim>::MatrixXfRowMajor;

  using ArrayXdRowMajor = typename DiffeomorphicStateSampler_Base<
      ompl_interface ::JointModelStateSpace, NumDim>::ArrayXdRowMajor;

  // call base constructor
  template <typename... Args>
  DiffeomorphicStateSampler(Args... args)
      : DiffeomorphicStateSampler_Base<ompl_interface ::JointModelStateSpace,
                                       NumDim>(args...),
        //        m_thread_num(4),
        //        m_thread_pool(m_thread_num),
        m_finish_sampling(false),
        m_difted_samples_Q(200, omp_get_max_threads(), omp_get_max_threads()) {
    this->init();
  }

  ~DiffeomorphicStateSampler() {
    if (background_sampling_th) {
      background_sampling_th->join();
      background_sampling_th.reset();
    }
  }

  void _retrieve_dimension_bounds() override {
    for (auto&& joint_bound : this->casted_space_ptr->getJointsBounds()) {
      // each joint bound might have more than one variable
      for (auto&& bound : *joint_bound) {
        this->bounds_high.push_back(bound.max_position_);
        this->bounds_low.push_back(bound.min_position_);
      }
    }
    ////////////////////////////////////////
    ///// Get robot model for retrieving jacobian
    kinematic_state = std::make_shared<moveit::core::RobotState>(
        this->casted_space_ptr->getRobotModel());
    kinematic_state->setToDefaultValues();
    //
    joint_model_group = this->casted_space_ptr->getJointModelGroup();
    linkModels        = joint_model_group->getLinkModels();

    reference_point_position = Eigen::Vector3d(0.0, 0.0, 0.0);
  }

  ////////////////////////////////////////////////////////////
  moveit::core::RobotStatePtr kinematic_state;
  const moveit::core::JointModelGroup* joint_model_group;
  std::vector<const moveit::core::LinkModel*> linkModels;
  Eigen::Vector3d reference_point_position;
  //  Eigen::MatrixXd _jacobian;
  ////////////////////////////////////////////////////////////
  //  size_t m_thread_num;
  //  mutable ThreadPool m_thread_pool;
  ////////////////////////////////////////////////////////////

  moodycamel::ConcurrentQueue<std::array<double, NumDim>> m_difted_samples_Q;
  std::unique_ptr<std::thread> background_sampling_th;
  std::atomic<bool> m_finish_sampling;
  double m_radius_of_joint{0.05};
  ////////////////////////////////////////////////////////////

  void set_radius_of_joint(double r) { m_radius_of_joint = r; }

  bool sampleUniform_Diffeomorphic(ob::State* state) override {
    std::array<double, NumDim> vals;
    bool ok = m_difted_samples_Q.try_dequeue(vals);
    if (!ok) {
      this->sampleUniform_PurelyUniformStateSampler(state);
      return false;
    }
    /* assign rstate_values with pointer (the matrix NEEDS to be row-major) */
    // get raw pointer from the matrix
    double* rstate_values =
        state->as<ompl_interface::JointModelStateSpace::StateType>()->values;
    std::copy(std::begin(vals), std::end(vals), rstate_values);
    return true;
  }

  void finish_sampling() { m_finish_sampling.store(true); }

  void start_sampling() {
    if (!this->use_diff) return;
    background_sampling_th.reset(
        new std::thread([this] { _start_sampling(); }));
  }

  void _start_sampling() {
    const size_t Actual_linkModel_start = 1;
    const size_t linkModel_start        = Actual_linkModel_start + 1;
    const size_t linkModel_end          = Actual_linkModel_start + 6;

    const size_t extra_bodypts_per_joint = 6;

    auto body_points_offset_mat = torch::tensor({{-m_radius_of_joint, 0., 0.},
                                                 {+m_radius_of_joint, 0., 0.},
                                                 {0., -m_radius_of_joint, 0.},
                                                 {0., +m_radius_of_joint, 0.},
                                                 {0., 0., -m_radius_of_joint},
                                                 {0., 0., +m_radius_of_joint}},
                                                torch::kFloat);
    // repeat this for all joints and qs
    body_points_offset_mat = body_points_offset_mat.repeat(
        {this->rand_batch_sample_size * (linkModel_end - linkModel_start), 1});

    const MatrixXfRowMajor body_points_offset_matXXX = MatrixXfRowMajor::Map(
        body_points_offset_mat.data_ptr<float>(),
        body_points_offset_mat.size(0), body_points_offset_mat.size(1));

    // the following is the number of worldspace point per q
    // i.e., the number of joints X the number of body pts per joint
    const size_t size_per_q =
        (linkModel_end - linkModel_start) * extra_bodypts_per_joint;

    TorchOccMapManager::mutable_model_path() =
        "occmap_divider_torchmodel.pt";
    TorchOccMapManager::init();

#pragma omp parallel num_threads(9)
    {
      TorchOccMapManager::TorchOccMapPtr occ_map =
          TorchOccMapManager::get_occmap(omp_get_thread_num());

      const auto robot_model = this->casted_space_ptr->getRobotModel();
      moveit::core::RobotState robot_state(robot_model);
      std::vector<double> joint_values(NumDim);

      thread_local Eigen::Rand::Vmt19937_64 rand_urng{std::random_device{}()};
      Eigen::Rand::UniformRealGen<double> rand_uni_gen;
      while (!m_finish_sampling) {
        SampledMatRowMajor batch =
            rand_uni_gen.template generate<SampledMatRowMajor>(
                this->rand_batch_sample_size, NumDim, rand_urng);
        // 0-1  =>  0-diff  =>  lower-upper
        batch = (batch.rowwise() * this->bounds_diff).rowwise() +
                this->bounds_lower;

        /* this is a list of Jacobian in the format of
         * j1 (of q1)
         * j2 (of q1)
         * ...
         * j6 (of q1)
         * j1 (of q2)
         * j2 (of q2)
         * ...
         * j6 (of qn)
         */
        // this list of jacobian will be reused within each morph
        //        std::vector<Eigen::MatrixXd> Jacobians;
        //        Jacobians.resize(this->rand_batch_sample_size *
        //                         (linkModel_end - linkModel_start));

        //  Matrix<double, -1, -1> m(6, 6 * batch_size);
        /** --------- ---------     ---------        ---------
         *  | q1-J1 | | q1-J2 | ... | q1-J6 | ...... | qn-J6 |
         *  --------- ---------     ---------        ---------
         */
        torch::Tensor Jacobians2 = torch::empty(
            {this->rand_batch_sample_size, (linkModel_end - linkModel_start),
             NumDim, WORLDSPACE_NUM_DIM},
            torch::kDouble);
        // ^^ we stores the transposed J, hence the swapped
        // [WORLDSPACE_NUM_DIM, NumDim] position (ie. 6,3 instead of 3,6)

        for (size_t _drift = 0; _drift < this->num_morph; ++_drift) {
          // collects all body points
          MatrixXfRowMajor pts_to_get_grad(
              this->rand_batch_sample_size * (linkModel_end - linkModel_start) *
                  extra_bodypts_per_joint,
              3);

          for (size_t i = 0; i < this->rand_batch_sample_size; ++i) {
            double* _from = &batch(i, 0);
            // set robot state to the desire values
            robot_state.setJointGroupPositions(joint_model_group, _from);
            for (size_t j = 0; j < linkModel_end - linkModel_start; ++j) {
              size_t link_idx = j + linkModel_start;  // link index
              // we directly pass the jacobian stored in the list (as a
              // reference) to the function.

              Eigen::MatrixXd _jacobian;
              robot_state.getJacobian(joint_model_group, linkModels[link_idx],
                                      reference_point_position, _jacobian);

              // assign this local jacobian to the assembled big
              MatrixXdRowMajor::Map(
                  Jacobians2
                      .index({static_cast<int>(i), static_cast<int>(j), "..."})
                      .data_ptr<double>(),
                  Jacobians2.size(-2), Jacobians2.size(-1)) =
                  _jacobian.topRows(3).transpose();

              Eigen::Vector3d pos =
                  robot_state.getGlobalLinkTransform(linkModels[link_idx])
                      .translation();

              // insert the pos 6 times for body points

              int joint_offset = i * ((linkModel_end - linkModel_start) *
                                      extra_bodypts_per_joint) +
                                 j * extra_bodypts_per_joint;

              pts_to_get_grad.middleRows(joint_offset, extra_bodypts_per_joint)
                  .rowwise() = pos.transpose().cast<float>();

            }  // for each joint
          }    // for each sampled q

          pts_to_get_grad += body_points_offset_matXXX;

          torch::Tensor pts_to_get_grad2 = torch::from_blob(
              pts_to_get_grad.data(),
              {pts_to_get_grad.rows(), pts_to_get_grad.cols()});
          auto grad2_as_tensor =
              occ_map->grad2(pts_to_get_grad2.to(occ_map->device)) * 1e-3;

          torch::Tensor reshaped_tensor =
              grad2_as_tensor
                  .reshape({this->rand_batch_sample_size,
                            (linkModel_end - linkModel_start),
                            extra_bodypts_per_joint, WORLDSPACE_NUM_DIM})
                  .mean(2);

          // back to original shape, sum across joints
          auto summed_grad =
              Jacobians2.to(occ_map->device)
                  .matmul(reshaped_tensor.unsqueeze(-1).to(torch::kDouble))
                  .sum(1)
                  .squeeze()
                  .to(torch::kCPU)
                  .contiguous();

          // take epsilon step
          batch.array() -=
              this->epsilon *
              ArrayXdRowMajor::Map(summed_grad.data_ptr<double>(),
                                   summed_grad.size(0), summed_grad.size(1));

          /////////////////////////////////////////
          // clamp to boundary (the following two joints are specific to
          // jaco)
          batch.col(1) = batch.col(1)
                             .cwiseMin(this->bounds_high[1])
                             .cwiseMax(this->bounds_low[1]);
          batch.col(2) = batch.col(2)
                             .cwiseMin(this->bounds_high[2])
                             .cwiseMax(this->bounds_low[2]);
          /////////////////////////////////////////

        }  // for each morph

        // TODO operate in-place for the queue
        // we are able to reinterpert the pointer as array pointer of fix size
        // because the eigen matrix is continguous along row
        m_difted_samples_Q.enqueue_bulk(
            reinterpret_cast<std::array<double, NumDim>*>(&batch(0, 0)),
            this->rand_batch_sample_size);
      }
    }
    std::cout << "background sampling stopped" << std::endl;
  }

  void drift_states() override { throw std::runtime_error("not in-use"); }
};
#endif
