// Copyright 2019 DeepMind Technologies Ltd. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef OPEN_SPIEL_ALGORITHMS_ALPHA_ZERO_TORCH_EVAL_MCTS_EVAL_H_
#define OPEN_SPIEL_ALGORITHMS_ALPHA_ZERO_TORCH_EVAL_MCTS_EVAL_H_

#include <stdint.h>
#include <torch/torch.h>

#include <memory>
#include <random>
#include <string>
#include <utility>
#include <vector>

#include "open_spiel/spiel.h"
#include "open_spiel/spiel_bots.h"

// A vanilla Monte Carlo Tree Search algorithm.
//
// This algorithm searches the game tree from the given state.
// At the leaf, the evaluator is called if the game state is not terminal.
// A total of max_simulations states are explored.
//
// At every node, the algorithm chooses the action with the highest PUCT value
// defined as: `Q/N + c * prior * sqrt(parent_N) / N`, where Q is the total
// reward after the action, and N is the number of times the action was
// explored in this position. The input parameter c controls the balance
// between exploration and exploitation; higher values of c encourage
// exploration of under-explored nodes. Unseen actions are always explored
// first.
//
// At the end of the search, the chosen action is the action that has been
// explored most often. This is the action that is returned.
//
// This implementation supports sequential n-player games, with or without
// chance nodes. All players maximize their own reward and ignore the other
// players' rewards. This corresponds to max^n for n-player games. It is the
// norm for zero-sum games, but doesn't have any special handling for
// non-zero-sum games. It doesn't have any special handling for imperfect
// information games.
//
// The implementation also supports backing up solved states, i.e. MCTS-Solver.
// The implementation is general in that it is based on a max^n backup (each
// player greedily chooses their maximum among proven children values, or there
// exists one child whose proven value is Game::MaxUtility()), so it will work
// for multiplayer, general-sum, and arbitrary payoff games (not just win/loss/
// draw games). Also chance nodes are considered proven only if all children
// have the same value.
//
// Some references:
// - Sturtevant, An Analysis of UCT in Multi-Player Games,  2008,
//   https://web.cs.du.edu/~sturtevant/papers/multi-player_UCT.pdf
// - Nijssen, Monte-Carlo Tree Search for Multi-Player Games, 2013,
//   https://project.dke.maastrichtuniversity.nl/games/files/phd/Nijssen_thesis.pdf
// - Silver, AlphaGo Zero: Starting from scratch, 2017
//   https://deepmind.com/blog/article/alphago-zero-starting-scratch
// - Winands, Bjornsson, and Saito, Monte-Carlo Tree Search Solver, 2008.
//   https://dke.maastrichtuniversity.nl/m.winands/documents/uctloa.pdf

namespace open_spiel
{
  namespace algorithms
  {

    namespace torch_az_eval
    {
      enum class ChildSelectionPolicy
      {
        UCT,
        PUCT,
      };

      // Abstract class representing an evaluation function for a game.
      // The evaluation function takes in an intermediate state in the game and
      // returns an evaluation of that state, which should correlate with chances of
      // winning the game for player 0.
      class Evaluator
      {
      public:
        virtual ~Evaluator() = default;

        // Return a value of this state for each player.
        virtual std::vector<double> Evaluate(const State &state) = 0;

        // Return a policy: the probability of the current player playing each action.
        virtual ActionsAndProbs Prior(const State &state) = 0;

        virtual std::vector<double> Repr(const State &state) = 0;
      };

      class RandomRolloutEvaluator : public Evaluator
      {
      public:
        explicit RandomRolloutEvaluator(int n_rollouts, int seed)
            : n_rollouts_(n_rollouts), rng_(seed) {}

        // Runs random games, returning the average returns.
        std::vector<double> Evaluate(const State &state) override;

        // Returns equal probability for each action.
        ActionsAndProbs Prior(const State &state) override;

      private:
        int n_rollouts_;
        std::mt19937 rng_;
      };

      // A node in the search tree for MCTS
      struct SearchNode
      {
        Action action = 0;                // The action taken to get to this node.
        double prior = 0;                 // The prior probability of playing this action.
        Player player = 0;                // Which player gets to make this action.
        int explore_count = 0;            // Number of times this node was explored.
        int sample_count = 0;             // Number of times node was sampled
        double total_reward = 0;          // Total reward passing through this node.
        std::vector<double> outcome;      // The reward if each players plays perfectly.
        std::vector<SearchNode> children; // The successors to this state.
        std::vector<double> repr;         // The representation of the the state.

        SearchNode() {}

        SearchNode(Action action_, Player player_, double prior_, std::vector<double> repr_)
            : action(action_), prior(prior_), player(player_), repr(repr_) {}

        // The value as returned by the UCT formula.
        double UCTValue(int parent_explore_count, double uct_c) const;

        // The value as returned by the PUCT formula.
        double PUCTValue(int parent_explore_count, double uct_c) const;

        // The sort order for the BestChild.
        bool CompareFinal(const SearchNode &b) const;
        const SearchNode &BestChild() const;

        // Return a string representation of this node, or all its children.
        // The state is needed to convert the action to a string.
        std::string ToString(const State &state) const;
        std::string ChildrenStr(const State &state) const;
      };

      // A SpielBot that uses the MCTS algorithm as its policy.
      class MCTSBot : public Bot
      {
      public:
        // The evaluator is passed as a shared pointer to make it explicit that
        // the same evaluator instance can be passed to multiple bots and to
        // make the MCTSBot Python interface work regardless of the scope of the
        // Python evaluator object.
        //
        // TODO(author5): The second parameter needs to be a const reference at the
        // moment, even though it gets assigned to a member of type
        // std::shared_ptr<Evaluator>. This is because using a
        // std::shared_ptr<Evaluator> in the constructor leads to the Julia API test
        // failing. We don't know why right now, but intend to fix this.
        MCTSBot(
            const Game &game, std::shared_ptr<Evaluator> evaluator,
            std::shared_ptr<Evaluator> opp_evaluator,
            double uct_c, int max_simulations,
            int64_t max_memory_mb, // Max memory use in megabytes.
            bool solve,            // Whether to back up solved states.
            int seed, bool verbose,
            ChildSelectionPolicy child_selection_policy = ChildSelectionPolicy::UCT,
            double dirichlet_alpha = 0, double dirichlet_epsilon = 0,
            double pw_exp = 1, double ar_factor = 0, double ar_exp = 1, bool use_ar = false, bool perturb = false);
        ~MCTSBot() = default;

        void Restart() override {}
        void RestartAt(const State &state) override {}
        // Run MCTS for one step, choosing the action, and printing some information.
        Action Step(const State &state) override;

        // Implements StepWithPolicy. This is equivalent to calling Step, but wraps
        // the action as an ActionsAndProbs with 100% probability assigned to the
        // lone action.
        std::pair<ActionsAndProbs, Action> StepWithPolicy(
            const State &state) override;

        // Run MCTS on a given state, and return the resulting search tree.
        std::unique_ptr<SearchNode> MCTSearch(const State &state);

      private:
        // Applies the UCT policy to play the game until reaching a leaf node.
        //
        // A leaf node is defined as a node that is terminal or has not been evaluated
        // yet. If it reaches a node that has been evaluated before but hasn't been
        // expanded, then expand it's children and continue.
        //
        // Args:
        //   root: The root node in the search tree.
        //   state: The state of the game at the root node.
        //   visit_path: A vector of nodes to be filled in descending from the root
        //     node to a leaf node.
        //
        // Returns: The state of the game at the leaf node.
        std::unique_ptr<State> ApplyTreePolicy(SearchNode *root, const State &state,
                                               std::vector<SearchNode *> *visit_path);

        void GarbageCollect(SearchNode *node);

        double uct_c_;
        int max_simulations_;
        int max_nodes_; // Max nodes allowed in the tree
        int nodes_;     // Nodes used in the tree.
        int gc_limit_;
        bool verbose_;
        bool solve_;
        double max_utility_;
        double dirichlet_alpha_;
        double dirichlet_epsilon_;
        std::mt19937 rng_;
        const ChildSelectionPolicy child_selection_policy_;
        std::shared_ptr<Evaluator> evaluator_;
        std::shared_ptr<Evaluator> opp_evaluator_;
        double pw_exp_;
        double ar_factor_;
        double ar_exp_;
        bool use_ar_;
        bool perturb_;
      };

      // Returns a vector of noise sampled from a dirichlet distribution. See:
      // https://en.wikipedia.org/wiki/Dirichlet_process
      std::vector<double> dirichlet_noise(int count, double alpha, std::mt19937 *rng);
    }
  } // namespace algorithms
} // namespace open_spiel

#endif // OPEN_SPIEL_ALGORITHMS_MCTS_H_
