#include "cfr.hpp"

namespace open_spiel {
namespace algorithms {

MyCFRSolver::MyCFRSolver(const Game &game, const bool regret_matching_plus)
        : RegretMinimizationAlgorithm(game, alg_name),
          regret_matching_plus_(regret_matching_plus) {
}

MyCFRSolver::MyCFRSolver(const Game &game, std::string name, const bool regret_matching_plus)
        : RegretMinimizationAlgorithm(game, std::move(name)),
          regret_matching_plus_(regret_matching_plus) {
}

void MyCFRSolver::UpdateStrategy() {
    for (auto &entry : info_states_) {
        entry.second.ApplyRegretMatching(regret_matching_plus_);
        entry.second.instant_regrets = std::vector<double>(entry.second.num_actions(), 0);
    }
}

void MyCFRSolver::UpdateAveragePolicy(InfoStateValues *is_vals, const int aidx, const double value){
    if (regret_matching_plus_) {
        is_vals->cumulative_policy[aidx] = iteration_ * value;
    } else{
        is_vals->cumulative_policy[aidx] += value;
    }
}

}  // namespace algorithms
}  // namespace open_spiel