#include "average_policy.hpp"

namespace open_spiel {
namespace algorithms {
AveragePolicy::AveragePolicy(const InfoStateValuesTable &info_states, std::shared_ptr<Policy> default_policy)
        : info_states_(info_states), default_policy_(default_policy) {}

ActionsAndProbs AveragePolicy::GetStatePolicy(const State &state, Player player) const {
    auto entry = info_states_.find(state.InformationStateString(player));
    if (entry == info_states_.end()) {
        if (default_policy_) {
            return default_policy_->GetStatePolicy(state, player);
        } else {
            // This should never get called.
            SpielFatalError("No policy found, and no default policy.");
        }
    }
    ActionsAndProbs actions_and_probs;
    GetStatePolicyFromInformationStateValues(entry->second, &actions_and_probs);
    return actions_and_probs;
}

ActionsAndProbs AveragePolicy::GetStatePolicy(const std::string &info_state) const {
    auto entry = info_states_.find(info_state);
    if (entry == info_states_.end()) {
        if (default_policy_) {
            return default_policy_->GetStatePolicy(info_state);
        } else {
            // This should never get called.
            SpielFatalError("No policy found, and no default policy.");
        }
    }
    ActionsAndProbs actions_and_probs;
    GetStatePolicyFromInformationStateValues(entry->second, &actions_and_probs);
    return actions_and_probs;
}

void AveragePolicy::GetStatePolicyFromInformationStateValues(
        const InfoStateValues &is_vals, ActionsAndProbs *actions_and_probs) const {
    double sum_prob = 0.0;
    for (int aidx = 0; aidx < is_vals.num_actions(); ++aidx) {
        sum_prob += is_vals.cumulative_policy[aidx];
    }

    if (sum_prob == 0.0) {
        // Return a uniform policy at this node
        double prob = 1. / is_vals.num_actions();
        for (Action action : is_vals.legal_actions) {
            actions_and_probs->push_back({action, prob});
        }
        return;
    }

    for (int aidx = 0; aidx < is_vals.num_actions(); ++aidx) {
        actions_and_probs->push_back({is_vals.legal_actions[aidx], is_vals.cumulative_policy[aidx] / sum_prob});
    }
}

TabularPolicy AveragePolicy::AsTabular() const {
    TabularPolicy policy;
    for (const auto &infoset_and_entry : info_states_) {
        ActionsAndProbs state_policy;
        GetStatePolicyFromInformationStateValues(infoset_and_entry.second, &state_policy);
        policy.SetStatePolicy(infoset_and_entry.first, state_policy);
    }
    return policy;
}

}  // namespace algorithms
}  // namespace open_spiel