#include "regret_minimization_algorithm.hpp"

#include <utility>

namespace open_spiel {
namespace algorithms {

RegretMinimizationAlgorithm::RegretMinimizationAlgorithm(const Game &game, std::string name)
        : game_(game.shared_from_this()),
          iteration_(0),
          root_state_(game.NewInitialState()),
          root_reach_probs_(game_->NumPlayers() + 1, 1.0),
          chance_player_(game.NumPlayers()),
          name_(std::move(name)) {
    if (game_->GetType().dynamics != GameType::Dynamics::kSequential) {
        SpielFatalError("This algorithm requires sequential games. If you're trying to run it "
                        "on a simultaneous (or normal-form) game, please first transform it "
                        "using turn_based_simultaneous_game.");
    }
    InitializeInfostateNodes(*root_state_);
}

void RegretMinimizationAlgorithm::InitializeInfostateNodes(const State &state) {
    if (state.IsTerminal()) {
        return;
    }
    if (state.IsChanceNode()) {
        for (const auto &action_prob : state.ChanceOutcomes()) {
            InitializeInfostateNodes(*state.Child(action_prob.first));
        }
        return;
    }

    int current_player = state.CurrentPlayer();
    std::string info_state = state.InformationStateString(current_player);
    std::vector<Action> legal_actions = state.LegalActions();

    InfoStateValues is_vals(current_player, legal_actions);
    info_states_[info_state] = is_vals;

    for (const Action &action : legal_actions) {
        InitializeInfostateNodes(*state.Child(action));
    }
}

void RegretMinimizationAlgorithm::EvaluateAndUpdatePolicy() {
    ++iteration_;
    for (int player = 0; player < game_->NumPlayers(); player++) {
        ComputeCounterFactualRegret(*root_state_, player, root_reach_probs_);
        UpdateStrategy();
    }
}

// Compute counterfactual regrets. Alternates recursively with
// ComputeCounterFactualRegretForActionProbs.
//
// Args:
// - state: The state to start the recursion.
// - alternating_player: Optionally only update this player.
// - reach_probabilities: The reach probabilities of this state for each
//      player, ending with the chance player.
//
// Returns:
//   The value of the state for each player (excluding the chance player).
std::vector<double> RegretMinimizationAlgorithm::ComputeCounterFactualRegret(
        const State &state, const absl::optional<int> &alternating_player,
        const std::vector<double> &reach_probabilities) {
    if (state.IsTerminal()) {
        return state.Returns();
    }
    if (state.IsChanceNode()) {
        ActionsAndProbs actions_and_probs = state.ChanceOutcomes();
        std::vector<double> dist(actions_and_probs.size(), 0);
        std::vector<Action> outcomes(actions_and_probs.size(), 0);
        for (int oidx = 0; oidx < actions_and_probs.size(); ++oidx) {
            outcomes[oidx] = actions_and_probs[oidx].first;
            dist[oidx] = actions_and_probs[oidx].second;
        }
        return ComputeCounterFactualRegretForActionProbs(
                state, alternating_player, reach_probabilities, chance_player_, dist, outcomes, nullptr);
    }
    if (AllPlayersHaveZeroReachProb(reach_probabilities)) {
        // The value returned is not used: if the reach probability for all players
        // is 0, then the last taken action has probability 0, so the
        // returned value is not impacting the parent node value.
        return std::vector<double>(game_->NumPlayers(), 0.0);
    }

    int current_player = state.CurrentPlayer();
    std::string info_state = state.InformationStateString();
    std::vector<Action> legal_actions = state.LegalActions(current_player);

    // Load current policy.
    std::vector<double> info_state_policy;
    info_state_policy = GetPolicy(info_state, current_player, legal_actions);

    std::vector<double> child_utilities;
    child_utilities.reserve(legal_actions.size());
    const std::vector<double> state_value = ComputeCounterFactualRegretForActionProbs(
            state, alternating_player, reach_probabilities, current_player,
            info_state_policy, legal_actions, &child_utilities);

    // Perform regret and average strategy updates.
    if (!alternating_player || *alternating_player == current_player) {
        InfoStateValues is_vals = info_states_[info_state];
        SPIEL_CHECK_FALSE(is_vals.empty());

        const double self_reach_prob = reach_probabilities[current_player];
        const double cfr_reach_prob = CounterFactualReachProb(reach_probabilities, current_player);

        for (int aidx = 0; aidx < legal_actions.size(); ++aidx) {
            // Update regrets.
            double cfr_regret = cfr_reach_prob * (child_utilities[aidx] - state_value[current_player]);
            is_vals.cumulative_regrets[aidx] += cfr_regret;
            is_vals.instant_regrets[aidx] += cfr_regret;

            // Update average policy.
            UpdateAveragePolicy(&is_vals, aidx, self_reach_prob * info_state_policy[aidx]);
        }
        is_vals.cfr_reach_prob += cfr_reach_prob;
        is_vals.need_update = true;
        info_states_[info_state] = is_vals;
    }

    return state_value;
}

// Compute counterfactual regrets given certain action probabilities.
// Alternates recursively with ComputeCounterFactualRegret.
//
// Args:
// - state: The state to start the recursion.
// - alternating_player: Optionally only update this player.
// - reach_probabilities: The reach probabilities of this state.
// - current_player: Either a player or chance_player_.
// - action_probs: The action probabilities to use for this state.
// - child_values_out: optional output parameter which is filled with the child
//           utilities for each action, for current_player.
// Returns:
//   The value of the state for each player (excluding the chance player).
std::vector<double> RegretMinimizationAlgorithm::ComputeCounterFactualRegretForActionProbs(
        const State &state, const absl::optional<int> &alternating_player,
        const std::vector<double> &reach_probabilities, const int current_player,
        const std::vector<double> &info_state_policy, const std::vector<Action> &legal_actions,
        std::vector<double> *child_values_out) {
    std::vector<double> state_value(game_->NumPlayers());

    for (int aidx = 0; aidx < legal_actions.size(); ++aidx) {
        const Action action = legal_actions[aidx];
        const double prob = info_state_policy[aidx];
        const std::unique_ptr<State> new_state = state.Child(action);
        std::vector<double> new_reach_probabilities(reach_probabilities);
        new_reach_probabilities[current_player] *= prob;
        std::vector<double> child_value =
                ComputeCounterFactualRegret(*new_state, alternating_player, new_reach_probabilities);
        for (int i = 0; i < state_value.size(); ++i) {
            state_value[i] += prob * child_value[i];
        }
        if (child_values_out != nullptr) {
            child_values_out->push_back(child_value[current_player]);
        }
    }
    return state_value;
}

bool RegretMinimizationAlgorithm::AllPlayersHaveZeroReachProb(const std::vector<double> &reach_probabilities) const {
    for (int i = 0; i < game_->NumPlayers(); i++) {
        if (reach_probabilities[i] != 0.0) {
            return false;
        }
    }
    return true;
}

std::vector<double> RegretMinimizationAlgorithm::GetPolicy(const std::string &info_state, const int player,
                                          const std::vector<Action> &legal_actions) {
    auto entry = info_states_.find(info_state);
    if (entry == info_states_.end()) {
        info_states_[info_state] = InfoStateValues(player, legal_actions);
        entry = info_states_.find(info_state);
    }

    SPIEL_CHECK_FALSE(entry == info_states_.end());
    SPIEL_CHECK_FALSE(entry->second.empty());
    SPIEL_CHECK_FALSE(entry->second.current_policy.empty());
    return entry->second.current_policy;
}

}  // namespace algorithms
}  // namespace open_spiel