#include "pcfr.hpp"

namespace open_spiel {
namespace algorithms {

void PCFRSolver::EvaluateAndUpdatePolicy() {
    ++iteration_;
    for (int player = 0; player < game_->NumPlayers(); player++) {
        ComputeCounterFactualRegret(*root_state_, player, root_reach_probs_);
        UpdateStrategy();
    }
    if (update_anchoring_interval_ > 0 && iteration_ % update_anchoring_interval_ == 0) {
        UpdateAnchoringStrategy();
    }
}

std::vector<double> PCFRSolver::ComputeCounterFactualRegretForActionProbs(
        const State &state, const absl::optional<int> &alternating_player,
        const std::vector<double> &reach_probabilities, const int current_player,
        const std::vector<double> &info_state_policy, const std::vector<Action> &legal_actions,
        std::vector<double> *child_values_out) {
    std::vector<double> state_value(game_->NumPlayers());
    std::vector<std::vector<double>> perturbation_terms = std::vector<std::vector<double>>(game_->NumPlayers(), std::vector<double>());

    // @todo Adapt to the case where alternating_player is nullptr
    if ((!alternating_player || *alternating_player == current_player) && current_player != asym_player_id_){
    // if (current_player != chance_player_ &&
    // (!alternating_player || *alternating_player == current_player)) {
        std::string info_state = state.InformationStateString();
        InfoStateValues is_vals = info_states_[info_state];
        for (int aidx = 0; aidx < legal_actions.size(); ++aidx) {
            const double prob = info_state_policy[aidx];
            const double anchoring_prob = is_vals.anchoring_strategy[aidx];
            const double perturbation_term = mutation_rate_ * (anchoring_prob - prob);
            perturbation_terms[current_player].push_back(perturbation_term);
            perturbation_terms[1 - current_player].push_back(0);
        }
    } else {
        for (int i = 0; i < game_->NumPlayers(); ++i) {
            for (int aidx = 0; aidx < legal_actions.size(); ++aidx) {
                perturbation_terms[i].push_back(0);
            }
        }
    }

    for (int aidx = 0; aidx < legal_actions.size(); ++aidx) {
        const Action action = legal_actions[aidx];
        const double prob = info_state_policy[aidx];
        const std::unique_ptr<State> new_state = state.Child(action);
        std::vector<double> new_reach_probabilities(reach_probabilities);
        new_reach_probabilities[current_player] *= prob;
        std::vector<double> child_value =
                ComputeCounterFactualRegret(*new_state, alternating_player, new_reach_probabilities);
        for (int i = 0; i < state_value.size(); ++i) {
            child_value[i] += perturbation_terms[i][aidx];
        }
        for (int i = 0; i < state_value.size(); ++i) {
            state_value[i] += prob * child_value[i];
        }
        if (child_values_out != nullptr) {
            child_values_out->push_back(child_value[current_player]);
        }
    }
    return state_value;
}

void PCFRSolver::UpdateAnchoringStrategy() {
    for (auto &entry : info_states_) {
        const int num_actions = entry.second.num_actions();
        for (int aidx = 0; aidx < num_actions; ++aidx) {
            entry.second.anchoring_strategy[aidx] = entry.second.current_policy[aidx];
        }
    }
}


}  // namespace algorithms
}  // namespace open_spiel