#include "dcfr.hpp"
#include <cmath>

namespace open_spiel {
namespace algorithms {

void DCFRSolver::UpdateStrategy() {
    // std::cerr << "DCFR UpdateStrategy" << std::endl;
    // std::cerr << "regret_matching_plus_ = " << regret_matching_plus_ << std::endl;
    // std::cerr << "alpha_ = " << alpha_ << std::endl;
    // std::cerr << "beta_ = " << beta_ << std::endl;
    // std::cerr << "gamma_ = " << gamma_ << std::endl;
    for (auto &entry : info_states_) {
        /////////////////////////////
        // Discount the cumulative regrets
        for (int aidx = 0; aidx < entry.second.num_actions(); ++aidx) {
            double discount;
            entry.second.cumulative_regrets[aidx] -= entry.second.instant_regrets[aidx];
            if (entry.second.cumulative_regrets[aidx] > 0) {
                discount = std::pow(iteration_ - 1, alpha_);
            } else {
                discount = std::pow(iteration_ - 1, beta_);
            }
            entry.second.cumulative_regrets[aidx] *= discount/(discount + 1);
            entry.second.cumulative_regrets[aidx] += entry.second.instant_regrets[aidx];
        }
        /////////////////////////////
        entry.second.ApplyRegretMatching(regret_matching_plus_);
        entry.second.instant_regrets = std::vector<double>(entry.second.num_actions(), 0);
    }
}

void DCFRSolver::UpdateAveragePolicy(InfoStateValues *is_vals, const int aidx, const double value) {
    is_vals->cumulative_policy[aidx] = std::pow(iteration_, gamma_) * value;
}


}  // namespace algorithms
}  // namespace open_spiel
