#ifndef REINFROCE_POLICY_HH
#define REINFROCE_POLICY_HH

#include "abr_algo.hh"
#include <deque>
#include <vector>
#include "filesystem.hh"
#include "reinforce.hh"
#include <torch/torch.h>
#include "sender.hh"
#include <random>

class ReinforcePolicy : public ABRAlgo
{
public:
  ReinforcePolicy(const WebSocketClient &client,
                  const std::string &abr_name, const YAML::Node &abr_config);

  void video_chunk_acked(Chunk &&c) override;
  VideoFormat select_video_format() override;

protected:
  static constexpr size_t MAX_NUM_PAST_CHUNKS = 8;
  static constexpr size_t MAX_LOOKAHEAD_HORIZON = 5;
  static constexpr size_t MAX_DIS_BUF_LENGTH = 100;
  static constexpr double REBUFFER_LENGTH_COEFF = 20;
  static constexpr double SSIM_DIFF_COEFF = 1;
  static constexpr size_t MAX_NUM_FORMATS = 20;
  static constexpr double UNIT_BUF_LENGTH = 0.5;
  static constexpr size_t MAX_DIS_SENDING_TIME = 20;
  static constexpr double ST_PROB_EPS = 1e-5;
  static constexpr size_t TTP_HIDDEN_DIM = 64;
  static constexpr size_t TTP_INPUT_DIM = 62;
  static constexpr double HIGH_SENDING_TIME = 10000;

  /* past chunks and max number of them */
  size_t max_num_past_chunks_{MAX_NUM_PAST_CHUNKS};
  std::deque<Chunk> past_chunks_{};

  /* all the time durations are measured in sec */
  size_t max_lookahead_horizon_{MAX_LOOKAHEAD_HORIZON};
  size_t lookahead_horizon_{};
  size_t dis_chunk_length_{};
  size_t dis_buf_length_{MAX_DIS_BUF_LENGTH};
  size_t dis_sending_time_{MAX_DIS_SENDING_TIME};
  double unit_buf_length_{UNIT_BUF_LENGTH};
  size_t num_formats_{};
  double rebuffer_length_coeff_{REBUFFER_LENGTH_COEFF};
  double ssim_diff_coeff_{SSIM_DIFF_COEFF};
  double st_prob_eps_{ST_PROB_EPS};
  size_t ttp_input_dim_ {TTP_INPUT_DIM};

  /* whether the current chunk is the first chunk */
  bool is_init_{};

  /* for the current buffer length */
  size_t curr_buffer_{};

  /* for storing the value function */
  uint64_t flag_[MAX_LOOKAHEAD_HORIZON + 1][MAX_DIS_BUF_LENGTH + 1][MAX_NUM_FORMATS]{};
  double v_[MAX_LOOKAHEAD_HORIZON + 1][MAX_DIS_BUF_LENGTH + 1][MAX_NUM_FORMATS]{};

  /* record the current round of DP */
  uint64_t curr_round_{};

  /* the ssim and size of the chunk given the timestamp and format */
  double curr_ssims_[MAX_LOOKAHEAD_HORIZON + 1][MAX_NUM_FORMATS]{};
  int curr_sizes_[MAX_LOOKAHEAD_HORIZON + 1][MAX_NUM_FORMATS]{};

  /* the estimation of sending time given the timestamp and format */
  double sending_time_prob_[MAX_LOOKAHEAD_HORIZON + 1][MAX_NUM_FORMATS]
                           [MAX_DIS_SENDING_TIME + 1]{};

  /* denote whether a chunk is abandoned */
  bool is_ban_[MAX_LOOKAHEAD_HORIZON + 1][MAX_NUM_FORMATS]{};

  void reinit();
  virtual void reinit_sending_time(){};

  /* calculate the value of corresponding state and return the best strategy */
  size_t update_value(size_t i, size_t curr_buffer, size_t curr_format);

  /* return the qvalue of the given cur state and next action */
  double get_qvalue(size_t i, size_t curr_buffer, size_t curr_format,
                    size_t next_format);

  /* return the value of the given state */
  double get_value(size_t i, size_t curr_buffer, size_t curr_format);

  /* discretize the buffer length */
  size_t discretize_buffer(double buf);

  /* deal with the situation when all formats are banned */
  void deal_all_ban(size_t i);

  /* stats of training data used for POLICY normalization */
  size_t ttp_hidden_dim_{TTP_HIDDEN_DIM};
  double ttp_hidden_[MAX_LOOKAHEAD_HORIZON + 1][MAX_NUM_FORMATS]
                    [TTP_HIDDEN_DIM]{};

  std::deque<std::tuple<std::vector<double>, std::size_t>> datapoints_;
  Sender sender_{};
  double get_qoe_by_ssim(std::size_t dis_buffer);
  double get_qoe_by_bitrate(std::size_t dis_buffer);
  std::vector<double> get_input();

  std::string features_;
  bool optimize_bitrate_{};
  double inputs_[MAX_LOOKAHEAD_HORIZON + 1][MAX_NUM_FORMATS][TTP_INPUT_DIM]{};
  
  void compute_sending_time_mpc();
  double unit_sending_time_[MAX_LOOKAHEAD_HORIZON + 1 + MAX_NUM_PAST_CHUNKS] {};
  double curr_sending_time_[MAX_LOOKAHEAD_HORIZON + 1][MAX_NUM_FORMATS] {};

};

#endif /* REINFROCE_POLICY_HH */
