// Copyright 2022 The Lagrange Ann Iclr2023 Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "convex_opt.h"

#include <algorithm>
#include <tuple>

namespace iclr2023 {

pair<double, double> MinimizeLagrangianSimple(ConstSpan<vector<double>> losses,
                                              ConstSpan<vector<double>> costs,
                                              const double lambda,
                                              vector<int>* result_indices) {
  const size_t n = losses.front().size();
  vector<double> old_dp(n), cur_dp(n), old_totcost(n), cur_totcost(n);
  vector<vector<int>> choices(losses.size(), vector<int>(n));
  for (int i = 0; i < losses.size(); i++) {
    ConstSpan<double> cur_losses = losses[i], cur_costs = costs[i];

    cur_dp[n - 1] =
        old_dp[n - 1] + cur_losses[n - 1] + lambda * cur_costs[n - 1];
    cur_totcost[n - 1] = old_totcost[n - 1] + cur_costs[n - 1];
    choices[i][n - 1] = n - 1;
    for (int j = n - 2; j >= 0; j--) {
      double l = old_dp[j] + cur_losses[j] + lambda * cur_costs[j];
      // Note <= so that if there's a tie, we will pick the cheaper option.
      if (l <= cur_dp[j + 1]) {
        cur_dp[j] = l;
        cur_totcost[j] = old_totcost[j] + cur_costs[j];
        choices[i][j] = j;
      } else {
        cur_dp[j] = cur_dp[j + 1];
        cur_totcost[j] = cur_totcost[j + 1];
        choices[i][j] = choices[i][j + 1];
      }
    }
    old_dp = cur_dp;
    old_totcost = cur_totcost;
  }
  if (result_indices) {
    result_indices->resize(losses.size());
    int pos = 0;
    for (int i = result_indices->size() - 1; i >= 0; i--) {
      pos = choices[i][pos];
      result_indices->at(i) = pos;
    }
  }
  return {old_dp[0] - lambda * old_totcost[0], old_totcost[0]};
}

LagrangeMinimizer::LagrangeMinimizer(ConstSpan<vector<double>> losses,
                                     ConstSpan<vector<double>> costs) {
  n_ = losses.front().size();
  m_ = losses.size();
  cum_losses_ = vector<vector<double>>(m_ + 1, vector<double>(n_));
  cum_costs_ = vector<vector<double>>(m_ + 1, vector<double>(n_));
  for (int i = 0; i < m_; i++) {
    for (int j = 0; j < n_; j++) {
      cum_losses_[i + 1][j] = cum_losses_[i][j] + losses[i][j];
      cum_costs_[i + 1][j] = cum_costs_[i][j] + costs[i][j];
    }
  }
}

pair<double, double> LagrangeMinimizer::OptimizeForLambda(
    const double lambda, vector<int>* result_indices) const {
  double best_loss, best_cost;
  // Hull pieces, sorted so hull[0] is the highest x values
  vector<HullPiece> hull;
  hull.push_back(HullPiece{
      .func_index = 0, .x0 = 0, .x1 = n_ - 1, .y_offset = 0, .cost_offset = 0});
  for (int layer = 0; layer < m_; layer++) {
    int lo = 0, hi = hull.size() - 1;
    while (lo != hi) {
      int m = (lo + hi) / 2;
      const HullPiece& cur_piece = hull[m];
      double y0 = EvaluateHull(cur_piece, layer, cur_piece.x0, lambda);
      double y1 = EvaluateHull(cur_piece, layer, cur_piece.x0 + 1, lambda);
      // In equality case, we are at the right-end of the flat portion, so
      // continue searching left for leftmost part of this portion.
      if (y0 <= y1) {
        lo = m + 1;
        continue;
      }
      double y2 = EvaluateHull(cur_piece, layer, cur_piece.x1 - 1, lambda);
      double y3 = EvaluateHull(cur_piece, layer, cur_piece.x1, lambda);
      if (y2 <= y3) {
        // y0 > y1 and y2 <= y3 so we've found leftmost portion of minimum.
        lo = m;
        hi = m;
        break;
      } else {
        // y0 > y1 and y2 > y3 so entire portion is going down; go more right.
        hi = m;
      }
    }
    auto [x, loss, cost] = MinimizeHullPiece(hull[lo], layer, lambda);
    // Remove no longer relevant pieces from the hull.
    while (!hull.empty() && hull.back().x1 <= x) hull.pop_back();
    // Update hull piece that goes over the minimum; cut it short.
    if (!hull.empty()) hull.back().x0 = x;
    if (x > 0)
      hull.push_back(HullPiece{.func_index = layer + 1,
                               .x0 = 0,
                               .x1 = x,
                               .y_offset = loss,
                               .cost_offset = cost});

    // Only needed at layer == m_ - 1, but it's ok to do this every iteration.
    best_loss = loss;
    best_cost = cost;
  }
  if (result_indices) {
    result_indices->resize(m_);
    int cur_row = 0, cur_pos = n_ - 1;
    for (const HullPiece& h : hull) {
      while (cur_row < h.func_index) result_indices->at(cur_row++) = cur_pos;
      cur_pos = h.x0;
    }
    while (cur_row < m_) result_indices->at(cur_row++) = cur_pos;
  }
  return {best_loss - lambda * best_cost, best_cost};
}

double LagrangeMinimizer::EvaluateHull(const HullPiece& hull, int cur_index,
                                       int x, double lambda) const {
  double loss = cum_losses_[cur_index + 1][x] - cum_losses_[hull.func_index][x];
  double cost = cum_costs_[cur_index + 1][x] - cum_costs_[hull.func_index][x];
  return hull.y_offset + loss + lambda * cost;
}

std::tuple<size_t, double, double> LagrangeMinimizer::MinimizeHullPiece(
    const HullPiece& hull, int cur_index, double lambda) const {
  size_t lo = hull.x0, hi = hull.x1;
  double y;
  while (lo != hi) {
    size_t m = (lo + hi) / 2;
    double y0 = EvaluateHull(hull, cur_index, m, lambda);
    double y1 = EvaluateHull(hull, cur_index, m + 1, lambda);
    y = std::min(y0, y1);
    // Strictly greater, so that ties lead to going as left as possible.
    if (y0 > y1) {
      lo = m + 1;
    } else {
      hi = m;
    }
  }
  // Compute cost at lo, because EvaluateHull doesn't do that for us.
  double cost = cum_costs_[cur_index + 1][lo] - cum_costs_[hull.func_index][lo];
  return {lo, y, cost + hull.cost_offset};
}

}  // namespace iclr2023
