#include <fstream>
#include <memory>
#include <tuple>
#include <map>
#include <stdlib.h>
#include "reinforce_policy.hh"
#include "ws_client.hh"
#include "json.hpp"
#include "timestamp.hh"

using namespace std;
using json = nlohmann::json;

ReinforcePolicy::ReinforcePolicy(const WebSocketClient &client,
                                 const string &abr_name, const YAML::Node &abr_config)
    : ABRAlgo(client, abr_name)
{
  if (abr_config["max_lookahead_horizon"])
  {
    max_lookahead_horizon_ = min(
        max_lookahead_horizon_,
        abr_config["max_lookahead_horizon"].as<size_t>());
  }

  if (abr_config["rebuffer_length_coeff"])
  {
    rebuffer_length_coeff_ = abr_config["rebuffer_length_coeff"].as<double>();
  }

  if (abr_config["ssim_diff_coeff"])
  {
    ssim_diff_coeff_ = abr_config["ssim_diff_coeff"].as<double>();
  }

  if (abr_config["features"])
  {
    features_ = abr_config["features"].as<std::string>();
  }

  if (abr_config["optimize_bitrate"])
  {
    optimize_bitrate_ = abr_config["optimize_bitrate"].as<bool>();
  }

  dis_buf_length_ = min(dis_buf_length_,
                        discretize_buffer(WebSocketClient::MAX_BUFFER_S));
}

double ReinforcePolicy::get_qoe_by_ssim(std::size_t dis_buffer)
{
  if (past_chunks_.size() == 0)
  {
    return 0;
  }

  Chunk curr_chunk = past_chunks_.back();
  double qoe = ssim_db(curr_chunk.ssim);

  if (past_chunks_.size() > 1)
  {
    Chunk prev_chunk = *(past_chunks_.end() - 2);

    qoe -= ssim_diff_coeff_ * fabs(ssim_db(curr_chunk.ssim) - ssim_db(prev_chunk.ssim));
    qoe -= rebuffer_length_coeff_ * max(curr_chunk.trans_time * 0.001 - dis_buffer * unit_buf_length_, 0.0);
  }

  return qoe;
}

double ReinforcePolicy::get_qoe_by_bitrate(std::size_t dis_buffer)
{
  if (past_chunks_.size() == 0)
  {
    return 0;
  }

  Chunk curr_chunk = past_chunks_.back();
  double curr_bitrate = ((curr_chunk.size / 2.002) * 8.0 / 1000) / 1000; // convert chunk size in kB into kbps

  double qoe = curr_bitrate;

  if (past_chunks_.size() > 1)
  {
    Chunk prev_chunk = *(past_chunks_.end() - 2);
    double prev_bitrate = ((prev_chunk.size / 2.002) * 8.0 / 1000) / 1000;

    qoe -= ssim_diff_coeff_ * fabs(curr_bitrate - prev_bitrate);
    qoe -= rebuffer_length_coeff_ * max(curr_chunk.trans_time * 0.001 - dis_buffer * unit_buf_length_, 0.0);
  }

  return qoe;
}

void ReinforcePolicy::video_chunk_acked(Chunk &&c)
{
  past_chunks_.push_back(c);
  if (past_chunks_.size() > max_num_past_chunks_)
  {
    past_chunks_.pop_front();
  }

  auto [input, dis_buf] = datapoints_.back();
  datapoints_.pop_back();

  double reward;
  if (optimize_bitrate_)
  {
    reward = get_qoe_by_bitrate(dis_buf);
  }
  else
  {
    reward = get_qoe_by_ssim(dis_buf);
  }
  // std::cout << "Acked " << client_.episode_id_ << ", reward: " << reward << ", buf: " << dis_buf << ", quality " << c.format << std::endl;

  json body;
  body["episode_id"] = client_.episode_id_;
  body["command"] = "LOG_RETURNS";
  body["reward"] = (float)reward;
  body["done"] = false;
  body["info"] = {};

  client_.sender_->post(body);
}

VideoFormat ReinforcePolicy::select_video_format()
{
  reinit();

  auto input = get_input();
  datapoints_.push_back(std::make_tuple(input, curr_buffer_));

  json body;
  body["episode_id"] = client_.episode_id_;
  body["command"] = "GET_ACTION";
  body["observation"] = input;

  // auto before_ts = timestamp_ms();
  json res = client_.sender_->post(body);
  size_t format = res["action"];

  // auto after = timestamp_ms() - before_ts;
  // std::cout <<  "diff time:" << after << std::endl;

  return client_.channel()->vformats()[format];
}

std::vector<double> ReinforcePolicy::get_input()
{
  std::vector<double> input;

  if (features_ == "input")
  {
    for (size_t i = 1; i <= lookahead_horizon_; i++)
    {
      for (size_t j = 0; j < num_formats_; j++)
      {
        for (size_t k = 0; k < ttp_input_dim_; k++)
        {
          input.push_back(inputs_[i][j][k]);
        }
      }
    }
  }
  else if (features_ == "TTP")
  {
    // avg sending time
    double BINS[MAX_DIS_SENDING_TIME + 2] = {0, 0.25, 0.75, 1.25, 1.75, 2.25, 2.75, 3.25, 3.75, 4.25, 4.75, 5.25, 5.75, 6.25, 6.75, 7.25, 7.75, 8.25, 8.75, 9.25, 9.75, 30};
    std::vector<double> sending_time;
    for (size_t i = 1; i <= lookahead_horizon_; i++)
    {
      for (size_t j = 0; j < num_formats_; j++)
      {
        double expected_trans_time = 0;
        for (size_t k = 0; k < dis_sending_time_; k++)
        {
          if (sending_time_prob_[i][j][k] > st_prob_eps_)
          {
            expected_trans_time += sending_time_prob_[i][j][k] * (BINS[k] + BINS[k + 1]) / 2;
          }
        }

        sending_time.push_back(expected_trans_time);
      }
    }
    input.insert(input.end(), sending_time.begin(), sending_time.end());
  }
  else if (features_ == "HM") 
  {
    for (size_t i = 1; i <= lookahead_horizon_; i++)
    {
      for (size_t j = 0; j < num_formats_; j++)
      {
        input.push_back(curr_sending_time_[i][j]);
      }
    }
  }

  // add current buffer level params
  input.push_back((double)curr_buffer_);

  // add last format ssim
  if (past_chunks_.size() == 0)
  {
    input.push_back(-1.0);
  }
  else
  {
    auto c = past_chunks_.back();
    input.push_back(c.ssim / MAX_SSIM);
  }

  // input.push_back(curr_round_ / 300.0);

  // const auto &channel = client_.channel();
  // const unsigned int vduration = channel->vduration();
  // const uint64_t next_ts = client_.next_vts().value();
  // input.push_back(((double)next_ts/vduration)/300); // curr chunk number

  // add ssim indexes
  for (size_t i = 1; i <= lookahead_horizon_; i++)
  {
    for (size_t j = 0; j < num_formats_; j++)
    {
      input.push_back(curr_ssims_[i][j] / MAX_SSIM);
    }
  }

  return input;
}

void ReinforcePolicy::compute_sending_time_mpc()
{
  const auto &channel = client_.channel();
  const auto &vformats = channel->vformats();
  const unsigned int vduration = channel->vduration();
  const uint64_t next_ts = client_.next_vts().value();

  /* init curr_sending_time */
  size_t num_past_chunks = past_chunks_.size();
  auto it = past_chunks_.begin();
  double max_err = 0;

  // for (size_t i = 1; it != past_chunks_.end(); it++, i++) {
  //   unit_sending_time_[i] = (double) it->trans_time / it->size / 1000;
  //   max_err = max(max_err, it->pred_err);
  // }

  // if (not is_robust_) {
  //   max_err = 0;
  // }

  for (size_t i = 1; i <= lookahead_horizon_; i++) {
    double tmp = 0;
    for (size_t j = 0; j < num_past_chunks; j++) {
      tmp += unit_sending_time_[i + j];
    }

    if (num_past_chunks != 0) {
      double unit_st = tmp / num_past_chunks;

      // if (i == 1) {
      //   last_tp_pred_ = 1 / unit_st;
      // }

      unit_sending_time_[i + num_past_chunks] = unit_st * (1 + max_err);
    } else {
      /* set the sending time to be a default hight value */
      unit_sending_time_[i + num_past_chunks] = HIGH_SENDING_TIME;
    }

    const auto & data_map = channel->vdata(next_ts + vduration * (i - 1));

    for (size_t j = 0; j < num_formats_; j++) {
      try {
        curr_sending_time_[i][j] = get<1>(data_map.at(vformats[j]))
                                   * unit_sending_time_[i + num_past_chunks];
      } catch (const exception & e) {
        cerr << "Error occurs when getting the video size of "
             << next_ts + vduration * (i - 1) << " " << vformats[j] << endl;
        curr_sending_time_[i][j] = HIGH_SENDING_TIME;
      }
    }
  }  
}

void ReinforcePolicy::reinit()
{
  curr_round_++;

  const auto &channel = client_.channel();
  const auto &vformats = channel->vformats();
  const unsigned int vduration = channel->vduration();
  const uint64_t next_ts = client_.next_vts().value();

  dis_chunk_length_ = discretize_buffer((double)vduration / channel->timescale());
  num_formats_ = vformats.size();

  /* initialization failed if there is no ready chunk ahead */
  if (channel->vready_frontier().value() < next_ts || num_formats_ == 0)
  {
    throw runtime_error("no ready chunk ahead");
  }

  lookahead_horizon_ = min(
      max_lookahead_horizon_,
      (channel->vready_frontier().value() - next_ts) / vduration + 1);

  curr_buffer_ = min(dis_buf_length_,
                     discretize_buffer(client_.video_playback_buf()));

  /* init curr_ssims */
  if (past_chunks_.size() > 0)
  {
    is_init_ = false;
    curr_ssims_[0][0] = ssim_db(past_chunks_.back().ssim);
  }
  else
  {
    is_init_ = true;
  }

  for (size_t i = 1; i <= lookahead_horizon_; i++)
  {
    const auto &data_map = channel->vdata(next_ts + vduration * (i - 1));

    for (size_t j = 0; j < num_formats_; j++)
    {

      try
      {
        curr_ssims_[i][j] = ssim_db(
            channel->vssim(vformats[j], next_ts + vduration * (i - 1)));
      }
      catch (const exception &e)
      {
        cerr << "Error occurs when getting the ssim of "
             << next_ts + vduration * (i - 1) << " " << vformats[j] << endl;
        curr_ssims_[i][j] = MIN_SSIM;
      }

      try
      {
        curr_sizes_[i][j] = get<1>(data_map.at(vformats[j]));
      }
      catch (const exception &e)
      {
        cerr << "Error occurs when getting the sizes of "
             << next_ts + vduration * (i - 1) << " " << vformats[j] << endl;
        curr_sizes_[i][j] = -1;
      }
    }
  }

  /* init sending time */
  reinit_sending_time();

  compute_sending_time_mpc();
}

void ReinforcePolicy::deal_all_ban(size_t i)
{
  double min_v = 0;
  size_t min_id = num_formats_;

  for (size_t j = 0; j < num_formats_; j++)
  {
    double tmp = curr_sizes_[i][j];
    if (tmp > 0 and (min_id == num_formats_ or min_v > tmp))
    {
      min_v = curr_sizes_[i][j];
      min_id = j;
    }
  }

  if (min_id == num_formats_)
  {
    min_id = 0;
  }

  is_ban_[i][min_id] = false;
  for (size_t k = 0; k < dis_sending_time_; k++)
  {
    sending_time_prob_[i][min_id][k] = 0;
  }

  sending_time_prob_[i][min_id][dis_sending_time_] = 1;
}

size_t ReinforcePolicy::discretize_buffer(double buf)
{
  return (buf + unit_buf_length_ * 0.5) / unit_buf_length_;
}
