// Copyright 2022 The Lagrange Ann Iclr2023 Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "convex_opt.h"

#include <tuple>

#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "absl/random/random.h"

namespace iclr2023 {
namespace {

TEST(ConvexOptTest, TestMinimizeLagrangeSimple) {
  vector<vector<double>> losses = {
      {12, 8, 5, 2, 0}, {5, 3, 2, 1, 0}, {5, 4, 3, 2, 2}};
  vector<vector<double>> costs = {
      {1, 2, 3, 4, 5}, {2, 4, 6, 8, 10}, {3, 6, 9, 12, 15}};
  LagrangeMinimizer optimizer(losses, costs);
  vector<int> result_indices;

  // Original input:
  // 12 8 5 2 0
  // 5  3 2 1 0
  // 5  4 3 2 2
  // DP result:
  // 0 0 0 0 0
  // 0 0 0 0 0
  // 2 2 2 2 2
  double loss, cost;
  std::tie(loss, cost) = MinimizeLagrangianSimple(losses, costs, 0);
  EXPECT_EQ(loss, 2);
  EXPECT_EQ(cost, 5 + 10 + 12);
  std::tie(loss, cost) =
      MinimizeLagrangianSimple(losses, costs, 0, &result_indices);
  EXPECT_EQ(loss, 2);
  EXPECT_EQ(cost, 5 + 10 + 12);
  EXPECT_THAT(result_indices, testing::ElementsAre(4, 4, 3));
  std::tie(loss, cost) = optimizer.OptimizeForLambda(0);
  EXPECT_EQ(loss, 2);
  EXPECT_EQ(cost, 5 + 10 + 12);

  // Original input:
  // 13 10 8 6 5
  // 7 7 8 9 10
  // 8 10 12 14 17
  // DP result:
  // 5 5 5 5 5
  // 12 12 13 14 15
  // 20 22 25 28 32
  std::tie(loss, cost) = MinimizeLagrangianSimple(losses, costs, 1);
  EXPECT_EQ(loss, 20 - (5 + 2 + 3));
  EXPECT_EQ(cost, 5 + 2 + 3);
  std::tie(loss, cost) =
      MinimizeLagrangianSimple(losses, costs, 1, &result_indices);
  EXPECT_EQ(loss, 20 - (5 + 2 + 3));
  EXPECT_EQ(cost, 5 + 2 + 3);
  EXPECT_THAT(result_indices, testing::ElementsAre(4, 0, 0));
  std::tie(loss, cost) = optimizer.OptimizeForLambda(1);
  EXPECT_EQ(loss, 20 - (5 + 2 + 3));
  EXPECT_EQ(cost, 5 + 2 + 3);

  // Lambda is so big it should take the cheapest on every level now.
  std::tie(loss, cost) = MinimizeLagrangianSimple(losses, costs, 1000);
  EXPECT_EQ(loss, 12 + 5 + 5);
  EXPECT_EQ(cost, 1 + 2 + 3);
  std::tie(loss, cost) =
      MinimizeLagrangianSimple(losses, costs, 1000, &result_indices);
  EXPECT_EQ(loss, 12 + 5 + 5);
  EXPECT_EQ(cost, 1 + 2 + 3);
  EXPECT_THAT(result_indices, testing::ElementsAre(0, 0, 0));
  std::tie(loss, cost) = optimizer.OptimizeForLambda(1000);
  EXPECT_EQ(loss, 12 + 5 + 5);
  EXPECT_EQ(cost, 1 + 2 + 3);
}

struct RandomTestParams {
  size_t rng_seed;
  size_t num_points;
  size_t num_quants;
};

using ConvexOptRandomTest = ::testing::TestWithParam<RandomTestParams>;

TEST_P(ConvexOptRandomTest, TestRandomData) {
  const RandomTestParams& params = GetParam();
  absl::BitGen gen(std::seed_seq{params.rng_seed});
  vector<vector<double>> losses(params.num_quants,
                                vector<double>(params.num_points));
  vector<vector<double>> costs(params.num_quants,
                               vector<double>(params.num_points));
  for (int i = 0; i < params.num_quants; i++) {
    double quant_cost = absl::Exponential(gen, 1.1);
    for (int j = 0; j < params.num_points; j++) {
      losses[i][j] = absl::Uniform<double>(gen, 0, 4);
      costs[i][j] = quant_cost * (j + 1);
    }
    // First integral: this makes losses decreasing
    for (int j = params.num_points - 2; j >= 0; j--)
      losses[i][j] += losses[i][j + 1];
    // Second integral: now losses is convex
    for (int j = params.num_points - 2; j >= 0; j--)
      losses[i][j] += losses[i][j + 1];
  }
  LagrangeMinimizer optimizer(losses, costs);
  for (double lambda : vector<double>{0, 0.2, 0.4, 5, 77, 12345.67}) {
    std::cout << "lambda = " << lambda << std::endl;
    vector<int> res_idx1, res_idx2;
    auto [loss1, cost1] =
        MinimizeLagrangianSimple(losses, costs, lambda, &res_idx1);
    auto [loss2, cost2] = optimizer.OptimizeForLambda(lambda, &res_idx2);
    EXPECT_NEAR(loss1, loss2, 1e-5);
    EXPECT_NEAR(cost1, cost2, 1e-5);
    EXPECT_THAT(res_idx1, testing::ContainerEq(res_idx2));
  }
}

INSTANTIATE_TEST_SUITE_P(
    ConvexOptRandomTests, ConvexOptRandomTest,
    testing::Values(
        RandomTestParams{.rng_seed = 29, .num_points = 91, .num_quants = 10},
        RandomTestParams{.rng_seed = 910, .num_points = 200, .num_quants = 10},
        RandomTestParams{.rng_seed = 212, .num_points = 200, .num_quants = 55},
        RandomTestParams{.rng_seed = 19, .num_points = 10000, .num_quants = 10},
        RandomTestParams{.rng_seed = 420, .num_points = 10000, .num_quants = 3},
        RandomTestParams{
            .rng_seed = 518, .num_points = 10000, .num_quants = 10},
        RandomTestParams{
            .rng_seed = 44, .num_points = 20000, .num_quants = 10}));

}  // namespace
}  // namespace iclr2023
