#include <fstream>
#include <memory>
#include <tuple>
#include <thread>
#include <math.h>
#include <stdlib.h>
#include <algorithm>
#include <cmath>
#include "exp3_policy.hh"
#include "ws_client.hh"
#include "json.hpp"
#include "timestamp.hh"

using namespace std;
using json = nlohmann::json;

Exp3Policy::Exp3Policy(const WebSocketClient &client,
                       const string &abr_name, const YAML::Node &abr_config)
    : ABRAlgo(client, abr_name)
{
  if (abr_config["max_lookahead_horizon"])
  {
    max_lookahead_horizon_ = min(
        max_lookahead_horizon_,
        abr_config["max_lookahead_horizon"].as<size_t>());
  }

  if (abr_config["rebuffer_length_coeff"])
  {
    rebuffer_length_coeff_ = abr_config["rebuffer_length_coeff"].as<double>();
  }

  if (abr_config["ssim_diff_coeff"])
  {
    ssim_diff_coeff_ = abr_config["ssim_diff_coeff"].as<double>();
  }

  if (abr_config["use_inputs"])
  {
    use_inputs_ = abr_config["use_inputs"].as<bool>();
  }

  if (abr_config["use_boggart"])
  {
    use_boggart_ = abr_config["use_boggart"].as<bool>();
  }

  if (abr_config["actions_are_delta"])
  {
    actions_are_delta_ = abr_config["actions_are_delta"].as<bool>();
  }

  if (abr_config["num_of_actions"])
  {
    num_of_actions_ = abr_config["num_of_actions"].as<double>();
  }

  if (abr_config["weights_dir"])
  {
    if (use_boggart_)
    {
      boggart_ = new Boggart();
    }
    else
    {
      kmeans_ = new Kmeans(abr_config["kmeans_dir"].as<string>());
    }
  }

  dis_buf_length_ = min(dis_buf_length_,
                        discretize_buffer(WebSocketClient::MAX_BUFFER_S));
}

Exp3Policy::~Exp3Policy()
{
  delete kmeans_;
  delete boggart_;
}

double Exp3Policy::calc_qoe(double curr_ssim,
                            double prev_ssim,
                            uint64_t curr_trans_time,
                            std::size_t dis_buffer,
                            bool parse_ssim = true)
{
  if (parse_ssim)
  {
    curr_ssim = ssim_db(curr_ssim);
  }
  double qoe = curr_ssim;
  qoe -= ssim_diff_coeff_ * fabs(curr_ssim - ssim_db(prev_ssim));

  double rebuffer = rebuffer_length_coeff_ * max(curr_trans_time * 0.001 - dis_buffer * unit_buf_length_, 0.0);
  qoe -= rebuffer;

  return qoe;
}

double Exp3Policy::get_qoe(std::size_t dis_buffer)
{
  if (past_chunks_.size() < 2)
  {
    return 0;
  }

  Chunk curr_chunk = past_chunks_.back();
  Chunk prev_chunk = *(past_chunks_.end() - 2);
  return calc_qoe(curr_chunk.ssim, prev_chunk.ssim, curr_chunk.trans_time, dis_buffer);
}

void Exp3Policy::video_chunk_acked(Chunk &&c)
{
  past_chunks_.push_back(c);
  if (past_chunks_.size() > max_num_past_chunks_)
  {
    past_chunks_.pop_front();
  }

  auto [context_idx, format, dis_buf] = datapoints_.back();
  double reward = get_qoe(dis_buf);

  json body;
  body["episode_id"] = client_.episode_id_;
  body["command"] = "LOG_RETURNS";
  body["reward"] = (float)reward;
  body["done"] = false;
  body["info"] = {};

  client_.sender_->post(body);
}

VideoFormat Exp3Policy::select_video_format()
{
  // auto before_ts = timestamp_ms();
  reinit();

  int action_idx = 0;
  int context_idx = 0;

  if (use_boggart_)
  {
    auto [max_conservative, max_risk] = find_optimal_actions(curr_buffer_);
    context_idx = max_conservative * actions_.size() + max_risk;
  }
  else
  {
    vector<double> feature_vector = get_input();
    context_idx = kmeans_->find_optimal_cluster(feature_vector);
  }

  json body;
  body["episode_id"] = client_.episode_id_;
  body["command"] = "GET_ACTION";
  body["observation"] = context_idx;

  json res = client_.sender_->post(body);
  action_idx = res["action"];

  datapoints_.push_back(std::make_tuple(context_idx, action_idx, curr_buffer_));

  int selected_format_in_sorted;
  if (actions_are_delta_)
  {
    selected_format_in_sorted = last_sorted_format_idx_ + actions_.at(action_idx);
    selected_format_in_sorted = std::max(0, std::min(selected_format_in_sorted, (int)sorted_vformats_.size() - 1));
    last_sorted_format_idx_ = selected_format_in_sorted;
  }
  else
  {
    selected_format_in_sorted = actions_.at(action_idx);
  }

  return sorted_vformats_.at(selected_format_in_sorted);
}

std::tuple<std::size_t, std::size_t> Exp3Policy::find_optimal_actions(std::size_t dis_buffer)
{
  // calc expected transmission time for each vformat
  double expected_trans_time[MAX_LOOKAHEAD_HORIZON + 1][MAX_NUM_FORMATS];
  for (size_t i = 1; i <= lookahead_horizon_; i++)
  {
    for (size_t j = 0; j < num_formats_; j++)
    {
      expected_trans_time[i][j] = 0;
      for (size_t k = 0; k < dis_sending_time_; k++)
      {
        if (sending_time_prob_[i][j][k] < st_prob_eps_)
        {
          continue;
        }

        expected_trans_time[i][j] += sending_time_prob_[i][j][k] * (BINS[k] + BINS[k + 1]) / 2;
      }
    }
  }

  size_t max_conservative = 0;
  size_t max_risk = 0;

  for (size_t i = 1; i <= lookahead_horizon_; i++)
  {
    for (std::size_t action_idx = 0; action_idx < actions_.size(); action_idx++)
    {
      int format_idx;

      if (actions_are_delta_)
      {
        int action = actions_.at(action_idx);
        int sorted_format_idx = last_sorted_format_idx_ + action;
        sorted_format_idx = std::max(0, std::min(sorted_format_idx, (int)sorted_vformats_.size() - 1));
        auto &vformats = client_.channel()->vformats();
        auto it = find(vformats.begin(), vformats.end(), sorted_vformats_.at(sorted_format_idx));
        format_idx = it - vformats.begin();
      }
      else
      {
        format_idx = action_idx;
      }

      if (is_ban_[i][format_idx] == true)
      {
        continue;
      }

      // get highest format chunk can be downloaded less than chunk duration
      if (expected_trans_time[i][format_idx] <= dis_chunk_length_ * unit_buf_length_)
      {
        if (curr_ssims_[i][max_conservative] < curr_ssims_[i][format_idx])
        {
          max_conservative = action_idx;
        }
      }

      // get highest format chunk can be downloaded less than (chunk duration +)[?] buffer size
      if (expected_trans_time[i][format_idx] <= dis_buffer * unit_buf_length_)
      {
        if (curr_ssims_[i][max_risk] < curr_ssims_[i][format_idx])
        {
          max_risk = action_idx;
        }
      }
    }
  }

  if (dis_chunk_length_ > dis_buffer && curr_ssims_[1][max_conservative] > curr_ssims_[1][max_risk])
  {
    // only here the max_risk allowed to be less than max_conservative because the buffer is almost empty
    max_risk = max_conservative;
  }

  // TODO: there is no bug here because the indexes are action (delta actions) and and not the bitrate themselves
  // TODO: who gets the update in case of multiple actions are fit
  // assert(curr_ssims_[1][max_conservative] <= curr_ssims_[1][max_risk]);

  return std::make_tuple(max_conservative, max_risk);
}

std::vector<double> Exp3Policy::get_input()
{
  std::vector<double> input;

  if (use_inputs_)
  {
    for (size_t i = 1; i <= lookahead_horizon_; i++)
    {
      for (size_t j = 0; j < num_formats_; j++)
      {
        for (size_t k = 0; k < ttp_input_dim_; k++)
        {
          input.push_back(raw_input_[i][j][k]);
        }
      }
    }
  }
  else
  {
    // avg sending time
    double BINS[MAX_DIS_SENDING_TIME + 2] = {0, 0.25, 0.75, 1.25, 1.75, 2.25, 2.75, 3.25, 3.75, 4.25, 4.75, 5.25, 5.75, 6.25, 6.75, 7.25, 7.75, 8.25, 8.75, 9.25, 9.75, 30};
    std::vector<double> sending_time;
    for (size_t i = 1; i <= lookahead_horizon_; i++)
    {
      for (size_t j = 0; j < num_formats_; j++)
      {
        double expected_trans_time = 0;
        for (size_t k = 0; k < dis_sending_time_; k++)
        {
          if (sending_time_prob_[i][j][k] > st_prob_eps_)
          {
            expected_trans_time += sending_time_prob_[i][j][k] * (BINS[k] + BINS[k + 1]) / 2;
          }
        }

        sending_time.push_back(expected_trans_time);
      }
    }
    input.insert(input.end(), sending_time.begin(), sending_time.end());
  }

  // add current buffer level params
  input.push_back((double)curr_buffer_);

  // add last format ssim
  if (past_chunks_.size() == 0)
  {
    input.push_back(-1.0);
  }
  else
  {
    auto c = past_chunks_.back();
    input.push_back(c.ssim / MAX_SSIM);
  }

  // add ssim indexes
  for (size_t i = 1; i <= lookahead_horizon_; i++)
  {
    for (size_t j = 0; j < num_formats_; j++)
    {
      input.push_back(curr_ssims_[i][j] / MAX_SSIM);
    }
  }

  return input;
}

void Exp3Policy::reinit()
{
  curr_round_++;

  const auto &channel = client_.channel();
  const auto &vformats = channel->vformats();
  const unsigned int vduration = channel->vduration();
  const uint64_t next_ts = client_.next_vts().value();

  dis_chunk_length_ = discretize_buffer((double)vduration / channel->timescale());
  num_formats_ = vformats.size();

  /* initialization failed if there is no ready chunk ahead */
  if (channel->vready_frontier().value() < next_ts || num_formats_ == 0)
  {
    throw runtime_error("no ready chunk ahead");
  }

  lookahead_horizon_ = min(
      max_lookahead_horizon_,
      (channel->vready_frontier().value() - next_ts) / vduration + 1);

  curr_buffer_ = min(dis_buf_length_,
                     discretize_buffer(client_.video_playback_buf()));

  /* init curr_ssims */
  if (past_chunks_.size() > 0)
  {
    is_init_ = false;
    curr_ssims_[0][0] = ssim_db(past_chunks_.back().ssim);
  }
  else
  {
    is_init_ = true;
  }

  for (size_t i = 1; i <= lookahead_horizon_; i++)
  {
    const auto &data_map = channel->vdata(next_ts + vduration * (i - 1));

    for (size_t j = 0; j < num_formats_; j++)
    {
      try
      {
        curr_ssims_[i][j] = ssim_db(
            channel->vssim(vformats[j], next_ts + vduration * (i - 1)));
      }
      catch (const exception &e)
      {
        cerr << "Error occurs when getting the ssim of "
             << next_ts + vduration * (i - 1) << " " << vformats[j] << endl;
        curr_ssims_[i][j] = MIN_SSIM;
      }

      try
      {
        curr_sizes_[i][j] = get<1>(data_map.at(vformats[j]));
      }
      catch (const exception &e)
      {
        cerr << "Error occurs when getting the sizes of "
             << next_ts + vduration * (i - 1) << " " << vformats[j] << endl;
        curr_sizes_[i][j] = -1;
      }
    }
  }

  /* init sending time */
  reinit_sending_time();

  int vformats_size = (int)num_formats_;
  if (actions_are_delta_)
  {
    if (num_of_actions_ == 7)
    {
      actions_ = {-(vformats_size - 1), -(int)floor(sqrt(vformats_size - 1)), -1, 0, 1, (int)floor(sqrt(vformats_size - 1)), (vformats_size - 1)};
    }
    else if (num_of_actions_ == 5)
    {
      actions_ = {-(int)floor(sqrt(vformats_size - 1)), -1, 0, 1, (int)floor(sqrt(vformats_size - 1))};
    }
    else if (num_of_actions_ == 3)
    {
      actions_ = {-1, 0, 1};
    }
  }
  else
  {
    for (int i = 0; i < vformats_size; ++i)
    {
      actions_.push_back(i);
    }
  }

  uint64_t curr_vts = client_.next_vts().value();
  sorted_vformats_ = client_.channel()->vformats();
  std::sort(sorted_vformats_.begin(),
            sorted_vformats_.end(),
            [&](VideoFormat format1, VideoFormat format2)
            {
              return client_.channel()->vssim(curr_vts).at(format1) < client_.channel()->vssim(curr_vts).at(format2);
            });
}

void Exp3Policy::deal_all_ban(size_t i)
{
  double min_v = 0;
  size_t min_id = num_formats_;

  for (size_t j = 0; j < num_formats_; j++)
  {
    double tmp = curr_sizes_[i][j];
    if (tmp > 0 and (min_id == num_formats_ or min_v > tmp))
    {
      min_v = curr_sizes_[i][j];
      min_id = j;
    }
  }

  if (min_id == num_formats_)
  {
    min_id = 0;
  }

  is_ban_[i][min_id] = false;
  for (size_t k = 0; k < dis_sending_time_; k++)
  {
    sending_time_prob_[i][min_id][k] = 0;
  }

  sending_time_prob_[i][min_id][dis_sending_time_] = 1;
}

size_t Exp3Policy::discretize_buffer(double buf)
{
  return (buf + unit_buf_length_ * 0.5) / unit_buf_length_;
}
