
#ifndef OPEN_SPIEL_ALGORITHMS_REGRET_MINIMIZATION_ALGORITHM_H_
#define OPEN_SPIEL_ALGORITHMS_REGRET_MINIMIZATION_ALGORITHM_H_

#include <memory>
#include "open_spiel/policy.h"
#include "open_spiel/spiel.h"
#include "algorithm_base.hpp"
#include "info_state_values.hpp"
#include "average_policy.hpp"
#include "current_policy.hpp"

namespace open_spiel {
namespace algorithms {

static double CounterFactualReachProb(const std::vector<double> &reach_probabilities, const int player) {
    double cfr_reach_prob = 1.0;
    for (int i = 0; i < reach_probabilities.size(); i++) {
        if (i != player) {
            cfr_reach_prob *= reach_probabilities[i];
        }
    }
    return cfr_reach_prob;
}

class RegretMinimizationAlgorithm: public AlgorithmBase {
public:
    RegretMinimizationAlgorithm(const Game &game, std::string name);

    virtual ~RegretMinimizationAlgorithm() = default;

    void EvaluateAndUpdatePolicy() override;

    // Compute the counterfactual regret and update the average policy for the
    // specified player.
    virtual std::vector<double> ComputeCounterFactualRegret(
            const State &state, const absl::optional<int> &alternating_player,
            const std::vector<double> &reach_probabilities);

    virtual std::vector<double> ComputeCounterFactualRegretForActionProbs(
            const State &state, const absl::optional<int> &alternating_player,
            const std::vector<double> &reach_probabilities, const int current_player,
            const std::vector<double> &info_state_policy, const std::vector<Action> &legal_actions,
            std::vector<double> *child_values_out);

    // Computes the average policy, containing the policy for all players.
    // The returned policy instance should only be used during the lifetime of
    // the FTRLSolver object.
    std::shared_ptr<Policy> AveragePolicy() const override {
        return std::make_shared<open_spiel::algorithms::AveragePolicy>(info_states_, nullptr);
    }

    // Computes the current policy, containing the policy for all players.
    // The returned policy instance should only be used during the lifetime of
    // the CFRSolver object.
    std::shared_ptr<Policy> CurrentPolicy() const override {
        return std::make_shared<open_spiel::algorithms::CurrentPolicy>(info_states_, nullptr);
    }

    std::shared_ptr<open_spiel::TabularPolicy> CurrentTabularPolicy() const override {
            auto current_policy = std::make_shared<open_spiel::algorithms::CurrentPolicy>(info_states_, nullptr);
            return std::make_shared<open_spiel::TabularPolicy>(current_policy->AsTabular());
    }

    std::string GetName() const override { return name_;}

protected:
    std::shared_ptr<const Game> game_;
    int iteration_;
    InfoStateValuesTable info_states_;
    const std::unique_ptr<State> root_state_;
    const std::vector<double> root_reach_probs_;
    const int chance_player_;
    const std::string name_;

    // Update the current policy for all information states.
    virtual void UpdateStrategy() = 0;
    virtual void UpdateAveragePolicy(InfoStateValues *is_vals, const int aidx, const double value) {
        is_vals->cumulative_regrets[aidx] += value;
    }

    // Get the policy at this information state. The probabilities are ordered in
    // the same order as legal_actions.
    std::vector<double> GetPolicy(const std::string &info_state, const int player,
                                  const std::vector<Action> &legal_actions);

    bool AllPlayersHaveZeroReachProb(const std::vector<double> &reach_probabilities) const;

private:
    void InitializeInfostateNodes(const State &state);
};

}  // namespace algorithms
}  // namespace open_spiel

#endif  // OPEN_SPIEL_ALGORITHMS_REGRET_MINIMIZATION_ALGORITHM_H_