// Copyright 2023 The Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "approximate_fairness_algorithm.h"

#include <memory>
#include <string>
#include <vector>
#include <iostream>
#include <cassert>

#include "fairness_constraint.h"
#include "matroid.h"
#include "matroid_intersection.h"
#include "submodular_function.h"

// WARNING: ONLY IMPLEMENTED FOR PARTITION MATROIDS!

ApproximateFairnessAlgorithm::ApproximateFairnessAlgorithm(double epsilon)
    : epsilon_(epsilon) {}

void ApproximateFairnessAlgorithm::Init(const SubmodularFunction& sub_func_f,
                           const FairnessConstraint& fairness,
                           const Matroid& matroid) {     
  Algorithm::Init(sub_func_f, fairness, matroid);
  matroid_->Reset();
  solution_.clear();
}

void ApproximateFairnessAlgorithm::Insert(int element) {
  universe_elements_.push_back(element);
}

// WARNING: ONLY IMPLEMENTED FOR PARTITION MATROIDS!
double ApproximateFairnessAlgorithm::GetSolutionValue() {
  // Run the approximate fairness algorithm.
  assert(matroid_ != nullptr);
  ApproximateFairSubmodularMaximization(static_cast<PartitionMatroid*>(matroid_.get()),
    fairness_.get(),
    sub_func_f_.get(),
    epsilon_
  );
  solution_ = matroid_->GetCurrent();
  return sub_func_f_->ObjectiveAndIncreaseOracleCall(solution_);
}

std::vector<int> ApproximateFairnessAlgorithm::GetSolutionVector() { return solution_; }

std::string ApproximateFairnessAlgorithm::GetAlgorithmName() const {
  return "Approximate Fairness Algorithm (epsilon=" + std::to_string(epsilon_) + ")";
}
