#pragma once

#include <eigen3/Eigen/Eigen>
#include <eigen3/Eigen/Dense>
#include <vector>
#include <cmath>
#include <iostream>
#include <fstream>
#include <string.h>
#include <algorithm>
#include <numeric>
#include "utils/minco.hpp" 
#include <OsqpEigen/OsqpEigen.h>

#define PRINTF_WHITE(STRING) std::cout<<STRING
#define PRINT_GREEN(STRING) std::cout<<"\033[32m"<<STRING<<"\033[m\n"
#define PRINT_RED(STRING) std::cout<<"\033[31m"<<STRING<<"\033[m\n"
#define PRINT_YELLOW(STRING) std::cout<<"\033[33m"<<STRING<<"\033[m\n"

using namespace std;
using namespace Eigen;
using namespace net_planner;
using RowVectorXf = Eigen::Matrix<float, 1, Eigen::Dynamic, Eigen::RowMajor>;
using RowMatrixXf = Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
using RowMatrixXd = Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
using RowMatrixXi = Eigen::Matrix<int, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;

constexpr int N = 5;
class MincoQP
{
private:
    Vector3d max_v = Vector3d(2.0, 2.0, 2.0);
    Vector3d max_a = Vector3d(6.0, 6.0, 6.0);
    // Vector3d max_v = Vector3d(1000.0, 1000.0, 10000.5);
    // Vector3d max_a = Vector3d(1000.0, 1000.0, 10002.0);
    vector<Array2i> idxs;

    //solver
    OsqpEigen::Solver qp_solver;

public:
    bool has_solution = false;

    MincoQP()
    {
        MatrixXi qget = MatrixXi::Zero(1, 6*N);
        for (int i=0; i<N-1; i++)
            qget(0, 6*i+5) = 1;
        MatrixXi Qget = qget.transpose() * qget;
        for (int i=0; i<6*N; i++)
            for (int j=0; j<6*N; j++)
                if (Qget(i, j) == 1)
                    idxs.push_back(Array2i(i, j));
    }
    ~MincoQP() {};

    bool solveQP(const RowMatrixXd& head_pva, 
                const RowMatrixXd& tail_pva, 
                const RowMatrixXd& corridor, 
                const VectorXd& T1,
                Ref<VectorXd> minco_wps)
    {
        VectorXd T2 = T1.cwiseProduct(T1);
        VectorXd T3 = T2.cwiseProduct(T1);
        VectorXd T4 = T2.cwiseProduct(T2);
        VectorXd T5 = T4.cwiseProduct(T1);

        MatrixXd M;
        M.resize(N*6, N*6);
        M.setZero();
        M(0, 0) = 1.0;
        M(1, 1) = 1.0;
        M(2, 2) = 2.0;
        for (int j=0; j<N-1; j++)
        {
            // jerk continuity
            M(6*j+3, 6*j+3) = 6.0;
            M(6*j+3, 6*j+4) = 24.0 * T1(j);
            M(6*j+3, 6*j+5) = 60.0 * T2(j);
            M(6*j+3, 6*j+9) = -6.0;
            // snap continuity
            M(6*j+4, 6*j+4) = 24.0;
            M(6*j+4, 6*j+5) = 120.0 * T1(j);
            M(6*j+4, 6*j+10) = -24.0;
            // pos setting
            M(6*j+5, 6*j) = 1.0;
            M(6*j+5, 6*j+1) = T1(j);
            M(6*j+5, 6*j+2) = T2(j);
            M(6*j+5, 6*j+3) = T3(j);
            M(6*j+5, 6*j+4) = T4(j);
            M(6*j+5, 6*j+5) = T5(j);
            // pos continutiy
            M(6*j+6, 6*j) = 1.0;
            M(6*j+6, 6*j+1) = T1(j);
            M(6*j+6, 6*j+2) = T2(j);
            M(6*j+6, 6*j+3) = T3(j);
            M(6*j+6, 6*j+4) = T4(j);
            M(6*j+6, 6*j+5) = T5(j);
            M(6*j+6, 6*j+6) = -1.0;
            // vel continutiy
            M(6*j+7, 6*j+1) = 1.0;
            M(6*j+7, 6*j+2) = 2.0 * T1(j);
            M(6*j+7, 6*j+3) = 3.0 * T2(j);
            M(6*j+7, 6*j+4) = 4.0 * T3(j);
            M(6*j+7, 6*j+5) = 5.0 * T4(j);
            M(6*j+7, 6*j+7) = -1.0;
            // acc continutiy
            M(6*j+8, 6*j+2) = 2.0;
            M(6*j+8, 6*j+3) = 6.0 * T1(j);
            M(6*j+8, 6*j+4) = 12.0 * T2(j);
            M(6*j+8, 6*j+5) = 20.0 * T3(j);
            M(6*j+8, 6*j+8) = -2.0;
        }
        // tail pva setting
        M(6*N-3, 6*(N-1)) = 1.0;
        M(6*N-3, 6*(N-1)+1) = T1(N-1);
        M(6*N-3, 6*(N-1)+2) = T2(N-1);
        M(6*N-3, 6*(N-1)+3) = T3(N-1);
        M(6*N-3, 6*(N-1)+4) = T4(N-1);
        M(6*N-3, 6*(N-1)+5) = T5(N-1);
        M(6*N-2, 6*(N-1)+1) = 1.0;
        M(6*N-2, 6*(N-1)+2) = 2.0 * T1(N-1);
        M(6*N-2, 6*(N-1)+3) = 3.0 * T2(N-1);
        M(6*N-2, 6*(N-1)+4) = 4.0 * T3(N-1);
        M(6*N-2, 6*(N-1)+5) = 5.0 * T4(N-1);
        M(6*N-1, 6*(N-1)+2) = 2.0;
        M(6*N-1, 6*(N-1)+3) = 6.0 * T1(N-1);
        M(6*N-1, 6*(N-1)+4) = 12.0 * T2(N-1);
        M(6*N-1, 6*(N-1)+5) = 20.0 * T3(N-1);

        MatrixXd Minv = M.inverse();

        // jerk matrix
        MatrixXd Qc = MatrixXd::Zero(6*N, 6*N);
        MatrixXd Jmat = MatrixXd::Zero(3, 3);
        Jmat << 3.0, 6.0, 10.0, 
                6.0, 16.0, 30.0, 
                10.0, 30.0, 60.0;
        for (int j=0; j<N; j++)
        {
            MatrixXd Tmat = MatrixXd::Zero(3, 3);
            Tmat <<T1(j), T2(j), T3(j), 
                T2(j), T3(j), T4(j), 
                T3(j), T4(j), T5(j);
            Qc.block<3, 3>(6*j+3, 6*j+3) = 12.0 * Jmat.cwiseProduct(Tmat);
        }
        MatrixXd Qall = 2.0 * Minv.transpose() * Qc * Minv;

        // Q, p
        const int dim = 3;
        const int nx = dim*(N-1);
        Eigen::SparseMatrix<double> Q;
        Q.resize(nx, nx);
        Eigen::VectorXd p = Eigen::VectorXd::Zero(nx);
        for (int i=0; i<dim; i++)
        {
            int cnt = 0;
            for (int j=0; j<N-1; j++)
            {
                for (int k=0; k<N-1; k++)
                {
                    Array2i idx = idxs[cnt];
                    Q.insert(i*(N-1)+j, i*(N-1)+k) = Qall(idx[0], idx[1]);
                    cnt++;
                }
                double hv = (head_pva.block<3, 1>(0, i).transpose() * Qall.block<3, 1>(0, 6*j+5))(0, 0);
                double tv = (tail_pva.block<3, 1>(0, i).transpose() * Qall.block<3, 1>(6*N-3, 6*j+5))(0, 0);
                p(i*(N-1)+j) = hv + tv;
            }
        }
        
        // low<=Gx<=up
        int penalty_num = corridor.rows() / N;
        int nc = dim*N*penalty_num*3;
        Eigen::SparseMatrix<double> G;
        G.resize(nc, nx);
        Eigen::VectorXd lowerBound = VectorXd::Zero(nc);
        Eigen::VectorXd upperBound = VectorXd::Zero(nc);
        int cnt = 0;
        VectorXd zero_temp = VectorXd::Zero(penalty_num);
        VectorXd ones_temp = VectorXd::Ones(penalty_num);
        VectorXd penal_temp = VectorXd::Zero(penalty_num);
        for (int k=0; k<penalty_num; k++)
            penal_temp(k) = 1.0 * (k+1) / penalty_num;
        for (int j=0; j<N; j++)
        {
            VectorXd s = penal_temp * T1(j);
            VectorXd s2 = s.cwiseProduct(s);
            VectorXd s3 = s2.cwiseProduct(s);
            VectorXd s4 = s2.cwiseProduct(s2);
            VectorXd s5 = s4.cwiseProduct(s);
            MatrixXd pmat = MatrixXd::Zero(penalty_num, 6);
            MatrixXd vmat = MatrixXd::Zero(penalty_num, 6);
            MatrixXd amat = MatrixXd::Zero(penalty_num, 6);
            vmat.col(0) = amat.col(0) = amat.col(1) = zero_temp;
            pmat.col(0) = vmat.col(1) = amat.col(2) = ones_temp;
            pmat.col(1) = vmat.col(2) = amat.col(3) = s;
            pmat.col(2) = vmat.col(3) = amat.col(4) = s2;
            pmat.col(3) = vmat.col(4) = amat.col(5) = s3;
            pmat.col(4) = vmat.col(5) = s4;
            pmat.col(5) = s5;
            pmat = pmat * Minv.block<6, 6*N>(6*j, 0);
            vmat = vmat * Minv.block<6, 6*N>(6*j, 0);
            amat = amat * Minv.block<6, 6*N>(6*j, 0);
            MatrixXd hp = pmat.leftCols(3) * head_pva + pmat.rightCols(3) * tail_pva;
            MatrixXd hv = vmat.leftCols(3) * head_pva + vmat.rightCols(3) * tail_pva;
            MatrixXd ha = amat.leftCols(3) * head_pva + amat.rightCols(3) * tail_pva;
            for (int k=0; k<dim; k++)
            {
                for (int i=0; i<penalty_num; i++)
                {
                    for (int l=0; l<N-1; l++)
                    {
                        G.insert(cnt, k*(N-1)+l) = pmat(i, 6*l+5);
                        G.insert(cnt+1, k*(N-1)+l) = vmat(i, 6*l+5);
                        G.insert(cnt+2, k*(N-1)+l) = amat(i, 6*l+5);
                    }
                    double inflate = 0.1;
                    lowerBound(cnt) = - hp(i, k) - corridor(j*penalty_num+i, 3) + corridor(j*penalty_num+i, k) - inflate;
                    upperBound(cnt) = - hp(i, k) + corridor(j*penalty_num+i, 3) + corridor(j*penalty_num+i, k) + inflate;
                    lowerBound(cnt+1) = - hv(i, k) - max_v(k);
                    upperBound(cnt+1) = - hv(i, k) + max_v(k);
                    lowerBound(cnt+2) = - ha(i, k) - max_a(k);
                    upperBound(cnt+2) = - ha(i, k) + max_a(k);
                    cnt+=3;
                }
            } 
        }

        // MatrixXd dQ(Q);
        // MatrixXd dG(G);
        // std::cout<<"dq:\n"<<dQ<<std::endl;
        // std::cout<<"dg:\n"<<dG<<std::endl;
        // std::cout<<"p:\n"<<p.transpose()<<std::endl;
        // std::cout<<"lb:\n"<<lowerBound.transpose()<<std::endl;
        // std::cout<<"ub:\n"<<upperBound.transpose()<<std::endl;

        // set the initial data and solve the problem
        bool flag_work = true;
        if (has_solution)
        {
            PRINT_GREEN("[NN Planner] QP receding horizon optimizing...");
            if (!flag_work || !qp_solver.updateHessianMatrix(Q)) flag_work = false;
            if (!flag_work || !qp_solver.updateGradient(p)) flag_work = false;
            if (!flag_work || !qp_solver.updateLinearConstraintsMatrix(G)) flag_work = false;
            if (!flag_work || !qp_solver.updateBounds(lowerBound, upperBound)) flag_work = false;
            if (!flag_work || qp_solver.solveProblem() != OsqpEigen::ErrorExitFlag::NoError) flag_work = false;
        }
        
        if (!flag_work || !has_solution)
        {
            PRINT_YELLOW("[NN Planner] QP solver resetting...");
            qp_solver.clearSolver();
            qp_solver.data()->clearHessianMatrix();
            qp_solver.data()->clearLinearConstraintsMatrix();
            // setting
            qp_solver.settings()->setVerbosity(false);
            qp_solver.settings()->setWarmStart(false);
            qp_solver.settings()->setAbsoluteTolerance(1e-4);
            qp_solver.settings()->setMaxIteration(10000);
            qp_solver.settings()->setRelativeTolerance(1e-4);
            qp_solver.data()->setNumberOfVariables(nx);
            qp_solver.data()->setNumberOfConstraints(nc);
            // data setting
            if(!qp_solver.data()->setHessianMatrix(Q)) return false;
            if(!qp_solver.data()->setGradient(p)) return false;
            if(!qp_solver.data()->setLinearConstraintsMatrix(G)) return false;
            if(!qp_solver.data()->setLowerBound(lowerBound)) return false;
            if(!qp_solver.data()->setUpperBound(upperBound)) return false;
            if(!qp_solver.initSolver()) return false;
            if(qp_solver.solveProblem() != OsqpEigen::ErrorExitFlag::NoError) return false;
            has_solution = true;
        }
        
        minco_wps = qp_solver.getSolution();
        
        return true;
    }
};

class EgoMinco
{
    public:
        bool traj_setted = false;
        PolyTrajectory<3, 5> traj;
        MinJerkOpt<3> minco;

    public:
        double getTotalDuration() const
        {
            return traj.getTotalDuration();
        }

        Eigen::VectorXd getMaxVelAxis() const
        {
            Eigen::VectorXd mv;
            mv.resize(3);
            mv(0) = mv(1) = mv(2) = -1.0;
            double dt = 0.01;
            for (double t=0; t<=traj.getTotalDuration(); t+=0.01)
            {
                Eigen::Vector3d vel = traj.getVel(t);
                for (int i=0; i<3; i++)
                {
                    if (fabs(vel(i)) > mv(i))
                        mv(i) = fabs(vel(i));
                }
            }
            return mv;
        }



        

        Eigen::VectorXd getMaxAccAxis() const
        {
            Eigen::VectorXd ma;
            ma.resize(3);
            ma(0) = ma(1) = ma(2) = -1.0;
            double dt = 0.01;
            for (double t=0; t<=traj.getTotalDuration(); t+=0.01)
            {
                Eigen::Vector3d acc = traj.getAcc(t);
                for (int i=0; i<3; i++)
                {
                    if (fabs(acc(i)) > ma(i))
                        ma(i) = fabs(acc(i));
                }
            }
            return ma;
        }

        Eigen::VectorXd getState(double t) const
        {
            Eigen::VectorXd state;
            state = traj.getPos(t);
            return state;
        }

        Eigen::VectorXd getDurations() const
        {
            return traj.getDurations();
        }

        void setTraj(const RowMatrixXd start_state,
                    const RowMatrixXd end_state,
                    const RowMatrixXd inner_pts,
                    const Eigen::VectorXd durations)
        {
            minco.reset(durations.size());
            minco.generate(Eigen::MatrixXd(start_state),
                           Eigen::MatrixXd(end_state),
                           Eigen::MatrixXd(inner_pts),durations);
            traj = minco.getTraj();
            return;
        }

        RowMatrixXd getCoefficients() const
        {
            std::vector<Eigen::MatrixXd> coeffs = traj.getCoeffMats();
            RowMatrixXd row_coeffs;
            row_coeffs.resize(coeffs.size()*6, 3);
            for (size_t i = 0; i < coeffs.size(); i++)
            {
                row_coeffs.block<6, 3>(6*i, 0) = coeffs[i].rowwise().reverse().transpose();
            }
            return row_coeffs;
        }

        void setTrajFromCoeff(Eigen::VectorXd times, RowMatrixXd coeffs)
        {
            std::vector<CoefficientMat<3, 5>> coeffs_mat;
            std::vector<double> times_vec(times.data(), times.data() + times.size());
            for (int i = 0; i < times.size(); i++)
            {
                coeffs_mat.push_back(CoefficientMat<3, 5>(coeffs.block<6, 3>(6*i, 0)
                                                            .transpose()
                                                            .rowwise()
                                                            .reverse()));
            }
            traj = PolyTrajectory<3, 5>(times_vec, coeffs_mat);
        }

        RowMatrixXd sampleTrajPoints(double dt) const
        {
            RowMatrixXd points;
            std::vector<Eigen::VectorXd> points_vec;
            for (double t = 0; t < getTotalDuration(); t += dt)
                points_vec.push_back(getState(t));
            points.resize(points_vec.size(), 3);
            for (size_t i = 0; i < points_vec.size(); i++)
                points.row(i) = points_vec[i];
            return points;
        }

        RowMatrixXd samplePenalPoints(int K) const
        {
            RowMatrixXd points;
            std::vector<Eigen::VectorXd> points_vec;
            Eigen::VectorXd ts = traj.getDurations();
            int N = ts.size();
            double nowt = 0.0;
            for (int i=0; i < N; i++)
            {
                for (int j = 1; j < K+1; j++)
                {
                    points_vec.push_back(getState(nowt+j*ts[i]/K));
                }
                nowt += ts[i];
            }
            points.resize(points_vec.size(), 3);
            for (size_t i = 0; i < points_vec.size(); i++)
                points.row(i) = points_vec[i];
            return points;
        }

        RowMatrixXd sampleArcPoints(int num) const
        {
            // get ds
            double length = 0.0;
            double dt = 0.01;
            Eigen::VectorXd xlast = getState(0.0);
            for (double t = dt; t <= getTotalDuration(); t += dt)
            {
                Eigen::VectorXd x = getState(t);
                length += (x-xlast).norm();
                xlast = x;
            }
            double ds = length / (num+1);
            // push points
            RowMatrixXd points;
            std::vector<Eigen::VectorXd> points_vec;
            points_vec.push_back(getState(0.0));
            double temp_len = 0.0;
            xlast = points_vec.back();
            int k = 1;
            for (double t = dt; t < getTotalDuration(); t += dt)
            {
                Eigen::VectorXd x = getState(t);
                temp_len += (x-xlast).norm();
                if (temp_len > ds*k)
                {
                    points_vec.push_back(x);
                    k++;
                }
                xlast = x;
                if (points_vec.size() == num+1)
                    break;
            }
            points_vec.push_back(getState(getTotalDuration()));
            assert(points_vec.size() == num+2);
            points.resize(points_vec.size(), 3);
            for (size_t i = 0; i < points_vec.size(); i++)
                points.row(i) = points_vec[i];
            return points;
        }
};