#include "current_policy.hpp"

namespace open_spiel {
namespace algorithms {

CurrentPolicy::CurrentPolicy(const InfoStateValuesTable &info_states, std::shared_ptr<Policy> default_policy)
        : info_states_(info_states), default_policy_(default_policy) {}

ActionsAndProbs CurrentPolicy::GetStatePolicy(const State &state, Player player) const {
    auto entry = info_states_.find(state.InformationStateString(player));
    if (entry == info_states_.end()) {
        if (default_policy_) {
            return default_policy_->GetStatePolicy(state, player);
        } else {
            SpielFatalError("No policy found, and no default policy.");
        }
    }
    ActionsAndProbs actions_and_probs;
    return GetStatePolicyFromInformationStateValues(entry->second, actions_and_probs);
}

ActionsAndProbs CurrentPolicy::GetStatePolicy(const std::string &info_state) const {
    auto entry = info_states_.find(info_state);
    if (entry == info_states_.end()) {
        if (default_policy_) {
            return default_policy_->GetStatePolicy(info_state);
        } else {
            SpielFatalError("No policy found, and no default policy.");
        }
    }
    ActionsAndProbs actions_and_probs;
    GetStatePolicyFromInformationStateValues(entry->second, actions_and_probs);
    return actions_and_probs;
}

ActionsAndProbs CurrentPolicy::GetStatePolicyFromInformationStateValues(
        const InfoStateValues &is_vals, ActionsAndProbs &actions_and_probs) const {
    for (int aidx = 0; aidx < is_vals.num_actions(); ++aidx) {
        actions_and_probs.push_back({is_vals.legal_actions[aidx], is_vals.current_policy[aidx]});
    }
    return actions_and_probs;
}

TabularPolicy CurrentPolicy::AsTabular() const {
    TabularPolicy policy;
    for (const auto &infoset_and_entry : info_states_) {
        policy.SetStatePolicy(infoset_and_entry.first, infoset_and_entry.second.GetCurrentPolicy());
    }
    return policy;
}

}  // namespace algorithms
}  // namespace open_spiel