// Copyright 2022 The Lagrange Ann Iclr2023 Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef LAGRANGE_ANN_ICLR2023_CONVEX_OPT_H_
#define LAGRANGE_ANN_ICLR2023_CONVEX_OPT_H_
#include <algorithm>
#include <string>
#include <tuple>
#include <vector>

#include "absl/types/span.h"

namespace iclr2023 {

using std::pair;
using std::vector;

template <typename T>
using ConstSpan = absl::Span<const T>;

// Evaluating a hull piece at x will return:
// y_offset + sum, from func_index to current index, of f(x)
struct HullPiece {
  // func_index = -1 means it includes no functions at all.
  int func_index;
  // Domain of this hull.
  size_t x0, x1;
  double y_offset, cost_offset;
};

// Returns (loss, cost) result for minimizing L + lambda * C
// If result_indices are not nullptr, the selected indices are stored there.
pair<double, double> MinimizeLagrangianSimple(
    ConstSpan<vector<double>> losses, ConstSpan<vector<double>> costs,
    const double lambda, vector<int>* result_indices = nullptr);

class LagrangeMinimizer {
 public:
  LagrangeMinimizer(ConstSpan<vector<double>> losses,
                    ConstSpan<vector<double>> costs);
  pair<double, double> OptimizeForLambda(
      const double lambda, vector<int>* result_indices = nullptr) const;

 private:
  double EvaluateHull(const HullPiece& hull, int cur_index, int x,
                      double lambda) const;
  std::tuple<size_t, double, double> MinimizeHullPiece(const HullPiece& hull,
                                                       int cur_index,
                                                       double lambda) const;
  size_t n_, m_;
  vector<vector<double>> cum_losses_, cum_costs_;
};

}  // namespace iclr2023

#endif  // LAGRANGE_ANN_ICLR2023_CONVEX_OPT_H_
