// Copyright 2019 DeepMind Technologies Ltd. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "open_spiel/algorithms/alpha_zero_torch_eval/alpha_zero.h"

#include <cstdlib>
#include <iostream>
#include <memory>
#include <random>
#include <string>
#include <vector>
#include <thread>
#include <chrono>

#include "open_spiel/abseil-cpp/absl/algorithm/container.h"
#include "open_spiel/abseil-cpp/absl/random/uniform_real_distribution.h"
#include "open_spiel/abseil-cpp/absl/strings/str_cat.h"
#include "open_spiel/abseil-cpp/absl/strings/str_join.h"
#include "open_spiel/abseil-cpp/absl/strings/str_split.h"
#include "open_spiel/abseil-cpp/absl/synchronization/mutex.h"
#include "open_spiel/abseil-cpp/absl/time/clock.h"
#include "open_spiel/abseil-cpp/absl/time/time.h"
#include "open_spiel/algorithms/alpha_zero_torch_eval/device_manager.h"
#include "open_spiel/algorithms/alpha_zero_torch_eval/vpevaluator.h"
#include "open_spiel/algorithms/alpha_zero_torch_eval/vpnet.h"
#include "open_spiel/algorithms/alpha_zero_torch_eval/mcts_eval.h"
#include "open_spiel/spiel.h"
#include "open_spiel/spiel_utils.h"
#include "open_spiel/utils/circular_buffer.h"
#include "open_spiel/utils/data_logger.h"
#include "open_spiel/utils/file.h"
#include "open_spiel/utils/json.h"
#include "open_spiel/utils/logger.h"
#include "open_spiel/utils/lru_cache.h"
#include "open_spiel/utils/stats.h"
#include "open_spiel/utils/thread.h"
#include "open_spiel/utils/threaded_queue.h"

namespace open_spiel
{
  namespace algorithms
  {
    namespace torch_az_eval
    {

      struct Trajectory
      {
        struct State
        {
          std::vector<float> observation;
          open_spiel::Player current_player;
          std::vector<open_spiel::Action> legal_actions;
          open_spiel::Action action;
          open_spiel::ActionsAndProbs policy;
          double value;
        };

        std::vector<State> states;
        std::vector<double> returns;
      };

      Trajectory PlayGame(Logger *logger, int game_num, const open_spiel::Game &game,
                          std::vector<std::unique_ptr<MCTSBot>> *bots,
                          std::shared_ptr<VPNetEvaluator> opp_vp_eval,
                          std::mt19937 *rng, double temperature, int temperature_drop,
                          double cutoff_value, bool verbose = false)
      {
        std::unique_ptr<open_spiel::State> state = game.NewInitialState();
        std::vector<std::string> history;
        Trajectory trajectory;
        while (true)
        {
          open_spiel::Player player = state->CurrentPlayer();
          open_spiel::Action action;
          if (player == kDefaultPlayerId)
          {
            open_spiel::ActionsAndProbs policy = opp_vp_eval->Prior(*state);
            action = open_spiel::SampleAction(policy, *rng).first;
          }
          else
          {
            std::unique_ptr<SearchNode> root = (*bots)[player]->MCTSearch(*state);
            open_spiel::ActionsAndProbs policy;
            policy.reserve(root->children.size());
            for (const SearchNode &c : root->children)
            {
              policy.emplace_back(c.action,
                                  std::pow(c.explore_count, 1.0 / temperature));
            }
            NormalizePolicy(&policy);
            if (history.size() >= temperature_drop)
            {
              action = root->BestChild().action;
            }
            else
            {
              action = open_spiel::SampleAction(policy, *rng).first;
            }
            double root_value = root->total_reward / root->explore_count;
            trajectory.states.push_back(Trajectory::State{
                state->ObservationTensor(), player, state->LegalActions(), action,
                std::move(policy), root_value});
          }
          std::string action_str = state->ActionToString(player, action);
          history.push_back(action_str);
          state->ApplyAction(action);
          if (verbose)
          {
            logger->Print("Player: %d, action: %s", player, action_str);
          }
          if (state->IsTerminal())
          {
            trajectory.returns = state->Returns();
            break;
          }
          // else if (std::abs(root_value) > cutoff_value)
          // {
          //   trajectory.returns.resize(2);
          //   trajectory.returns[player] = root_value;
          //   trajectory.returns[1 - player] = -root_value;
          //   break;
          // }
        }

        logger->Print("Game %d: Returns: %s; Actions: %s", game_num,
                      absl::StrJoin(trajectory.returns, " "),
                      absl::StrJoin(history, " "));
        return trajectory;
      }

      std::unique_ptr<MCTSBot> InitAZBot(const AlphaZeroConfig &config,
                                         const open_spiel::Game &game,
                                         std::shared_ptr<Evaluator> evaluator,
                                         std::shared_ptr<Evaluator> opp_evaluator,
                                         bool evaluation,
                                         int seed)
      {
        return std::make_unique<MCTSBot>(
            game, std::move(evaluator), std::move(opp_evaluator),
            config.uct_c, config.max_simulations,
            /*max_memory_mb=*/10,
            /*solve=*/false,
            /*seed=*/seed,
            /*verbose=*/false, ChildSelectionPolicy::PUCT,
            evaluation ? 0 : config.policy_alpha,
            evaluation ? 0 : config.policy_epsilon,
            config.pw_exp,
            config.ar_factor,
            config.ar_exp,
            config.use_ar,
            config.perturb);
      }

      // An actor thread runner that generates games and returns trajectories.
      void actor(const open_spiel::Game &game, const AlphaZeroConfig &config, int num,
                 ThreadedQueue<Trajectory> *trajectory_queue,
                 std::shared_ptr<VPNetEvaluator> vp_eval, std::shared_ptr<VPNetEvaluator> opp_vp_eval,
                 std::shared_ptr<std::atomic<int>> ep_pointer,
                 StopToken *stop)
      {
        std::unique_ptr<Logger> logger;
        if (num < 20)
        { // Limit the number of open files.
          logger.reset(new FileLogger(config.path, absl::StrCat("actor-", num)));
        }
        else
        {
          logger.reset(new NoopLogger());
        }
        std::mt19937 rng;
        int num_runs = 0;
        for (int game_num = 1; !stop->StopRequested(); ++game_num)
        {
          rng.seed(10000 * num + game_num);
          std::vector<std::unique_ptr<MCTSBot>> bots;
          bots.reserve(2);
          for (int player = 0; player < 2; player++)
          {
            bots.push_back(InitAZBot(config, game, vp_eval, opp_vp_eval, true, 10000 * num + game_num));
          }
          if (!trajectory_queue->Push(
                  PlayGame(logger.get(), game_num, game, &bots, opp_vp_eval, &rng,
                           config.temperature, config.temperature_drop, 25 + 1),
                  absl::Seconds(10)))
          {
            logger->Print("Failed to push a trajectory after 10 seconds.");
          }
          else
          {
            (*ep_pointer)++;
          }
          if ((*ep_pointer) > config.num_games - 1)
          {
            return;
          }
        }
        logger->Print("Got a quit.");
      }

      // class EvalResults
      // {
      // public:
      //   explicit EvalResults(int count, int evaluation_window)
      //   {
      //     results_.reserve(count);
      //     for (int i = 0; i < count; ++i)
      //     {
      //       results_.emplace_back(evaluation_window);
      //     }
      //   }

      //   // How many evals per difficulty.
      //   int EvalCount()
      //   {
      //     absl::MutexLock lock(&m_);
      //     return eval_num_ / results_.size();
      //   }

      //   // Which eval to do next: difficulty, player0.
      //   std::pair<int, bool> Next()
      //   {
      //     absl::MutexLock lock(&m_);
      //     int next = eval_num_ % (results_.size() * 2);
      //     eval_num_ += 1;
      //     return {next / 2, next % 2};
      //   }

      //   void Add(int i, double value)
      //   {
      //     absl::MutexLock lock(&m_);
      //     results_[i].Add(value);
      //   }

      //   std::vector<double> AvgResults()
      //   {
      //     absl::MutexLock lock(&m_);
      //     std::vector<double> out;
      //     out.reserve(results_.size());
      //     for (const auto &result : results_)
      //     {
      //       out.push_back(result.Empty() ? 0
      //                                    : (absl::c_accumulate(result.Data(), 0.0) /
      //                                       result.Size()));
      //     }
      //     return out;
      //   }

      // private:
      //   std::vector<CircularBuffer<double>> results_;
      //   int eval_num_ = 0;
      //   absl::Mutex m_;
      // };

      // // A thread that plays vs standard MCTS.
      // void evaluator(const open_spiel::Game &game, const AlphaZeroConfig &config,
      //                int num, EvalResults *results,
      //                std::shared_ptr<VPNetEvaluator> vp_eval,
      //                std::shared_ptr<VPNetEvaluator> opp_vp_eval, StopToken *stop)
      // {
      //   FileLogger logger(config.path, absl::StrCat("evaluator-", num));
      //   std::mt19937 rng;
      //   auto rand_evaluator = std::make_shared<RandomRolloutEvaluator>(1, num);

      //   for (int game_num = 1; !stop->StopRequested(); ++game_num)
      //   {
      //     auto [difficulty, first] = results->Next();
      //     int az_player = first ? 0 : 1;
      //     int rand_max_simulations =
      //         config.max_simulations * std::pow(10, difficulty / 2.0);
      //     std::vector<std::unique_ptr<MCTSBot>> bots;
      //     bots.reserve(2);
      //     bots.push_back(InitAZBot(config, game, vp_eval, opp_vp_eval, true));
      //     bots.push_back(std::make_unique<MCTSBot>(
      //         game, rand_evaluator, opp_vp_eval, config.uct_c, rand_max_simulations,
      //         /*max_memory_mb=*/1000,
      //         /*solve=*/true,
      //         /*seed=*/num * 1000 + game_num,
      //         /*verbose=*/false, ChildSelectionPolicy::UCT));
      //     if (az_player == 1)
      //     {
      //       std::swap(bots[0], bots[1]);
      //     }

      //     logger.Print("Running MCTS with %d simulations", rand_max_simulations);
      //     Trajectory trajectory = PlayGame(
      //         &logger, game_num, game, &bots, opp_vp_eval, &rng, /*temperature=*/1,
      //         /*temperature_drop=*/0, /*cutoff_value=*/game.MaxUtility() + 1);

      //     results->Add(difficulty, trajectory.returns[az_player]);
      //     logger.Print("Game %d: AZ: %5.2f, MCTS: %5.2f, MCTS-sims: %d, length: %d",
      //                  game_num, trajectory.returns[az_player],
      //                  trajectory.returns[1 - az_player], rand_max_simulations,
      //                  trajectory.states.size());
      //   }
      //   logger.Print("Got a quit.");
      // }

      // void learner(const open_spiel::Game &game, const AlphaZeroConfig &config,
      //              DeviceManager *device_manager,
      //              std::shared_ptr<VPNetEvaluator> eval,
      //              ThreadedQueue<Trajectory> *trajectory_queue,
      //              EvalResults *eval_results, StopToken *stop)
      // {
      //   FileLogger logger(config.path, "learner");
      //   DataLoggerJsonLines data_logger(config.path, "learner", true);
      //   std::mt19937 rng;

      //   int device_id = 0; // Do not change, the first device is the learner.
      //   logger.Print("Running the learner on device %d: %s", device_id,
      //                device_manager->Get(0, device_id)->Device());

      //   // CircularBuffer<VPNetModel::TrainInputs> replay_buffer(
      //   //     config.replay_buffer_size);
      //   int learn_rate = config.replay_buffer_size / config.replay_buffer_reuse;
      //   int64_t total_trajectories = 0;

      //   const int stage_count = 7;
      //   std::vector<open_spiel::BasicStats> value_accuracies(stage_count);
      //   std::vector<open_spiel::BasicStats> value_predictions(stage_count);
      //   open_spiel::BasicStats game_lengths;
      //   open_spiel::HistogramNumbered game_lengths_hist(game.MaxGameLength() + 1);

      //   open_spiel::HistogramNamed outcomes({"Player1", "Player2", "Draw"});
      //   // Actor threads have likely been contributing for a while, so put `last` in
      //   // the past to avoid a giant spike on the first step.
      //   absl::Time last = absl::Now() - absl::Seconds(60);
      //   for (int step = 1; !stop->StopRequested() &&
      //                      (config.max_steps == 0 || step <= config.max_steps);
      //        ++step)
      //   {
      //     outcomes.Reset();
      //     game_lengths.Reset();
      //     game_lengths_hist.Reset();
      //     for (auto &value_accuracy : value_accuracies)
      //     {
      //       value_accuracy.Reset();
      //     }
      //     for (auto &value_prediction : value_predictions)
      //     {
      //       value_prediction.Reset();
      //     }

      //     // Collect trajectories
      //     int queue_size = trajectory_queue->Size();
      //     int num_states = 0;
      //     int num_trajectories = 0;
      //     while (!stop->StopRequested() && num_states < learn_rate)
      //     {
      //       absl::optional<Trajectory> trajectory = trajectory_queue->Pop();
      //       if (trajectory)
      //       {
      //         num_trajectories += 1;
      //         total_trajectories += 1;
      //         game_lengths.Add(trajectory->states.size());
      //         game_lengths_hist.Add(trajectory->states.size());

      //         double p1_outcome = trajectory->returns[0];
      //         outcomes.Add(p1_outcome > 0 ? 0 : (p1_outcome < 0 ? 1 : 2));

      //         for (const Trajectory::State &state : trajectory->states)
      //         {
      //           replay_buffer.Add(VPNetModel::TrainInputs{state.legal_actions,
      //                                                     state.observation,
      //                                                     state.policy, p1_outcome});
      //           num_states += 1;
      //         }

      //         for (int stage = 0; stage < stage_count; ++stage)
      //         {
      //           // Scale for the length of the game
      //           int index = (trajectory->states.size() - 1) *
      //                       static_cast<double>(stage) / (stage_count - 1);
      //           const Trajectory::State &s = trajectory->states[index];
      //           value_accuracies[stage].Add(
      //               (s.value >= 0) == (trajectory->returns[s.current_player] >= 0));
      //           value_predictions[stage].Add(abs(s.value));
      //         }
      //       }
      //     }
      //     absl::Time now = absl::Now();
      //     double seconds = absl::ToDoubleSeconds(now - last);

      //     logger.Print("Step: %d", step);
      //     logger.Print(
      //         "Collected %5d states from %3d games, %.1f states/s; "
      //         "%.1f states/(s*actor), game length: %.1f",
      //         num_states, num_trajectories, num_states / seconds,
      //         num_states / (config.actors * seconds),
      //         static_cast<double>(num_states) / num_trajectories);
      //     logger.Print("Queue size: %d. Buffer size: %d. States seen: %d", queue_size,
      //                  replay_buffer.Size(), replay_buffer.TotalAdded());

      //     if (stop->StopRequested())
      //     {
      //       break;
      //     }

      //     last = now;

      //     VPNetModel::LossInfo losses;
      //     { // Extra scope to return the device for use for inference asap.
      //       DeviceManager::DeviceLoan learn_model =
      //           device_manager->Get(config.train_batch_size, device_id);

      //       // Let the device manager know that the first device is now
      //       // off-limits for inference and should only be used for learning
      //       // (if config.explicit_learning == true).
      //       device_manager->SetLearning(config.explicit_learning);

      //       // Learn from them.
      //       for (int i = 0; i < replay_buffer.Size() / config.train_batch_size; i++)
      //       {
      //         losses += learn_model->Learn(
      //             replay_buffer.Sample(&rng, config.train_batch_size));
      //       }

      //       // The device manager can now once again use the first device for
      //       // inference (if it could not before).
      //       device_manager->SetLearning(false);
      //     }

      //     // Always save a checkpoint, either for keeping or for loading the weights
      //     // to the other sessions. It only allows numbers, so use -1 as "latest".
      //     std::string checkpoint_path =
      //         device_manager->Get(0, device_id)
      //             ->SaveCheckpoint(step % config.checkpoint_freq == 0 ? step : -1);
      //     if (device_manager->Count() > 0)
      //     {
      //       for (int i = 0; i < device_manager->Count(); ++i)
      //       {
      //         if (i != device_id)
      //         {
      //           device_manager->Get(0, i)->LoadCheckpoint(checkpoint_path);
      //         }
      //       }
      //     }
      //     logger.Print("Checkpoint saved: %s", checkpoint_path);

      //     DataLogger::Record record = {
      //         {"step", step},
      //         {"total_states", replay_buffer.TotalAdded()},
      //         {"states_per_s", num_states / seconds},
      //         {"states_per_s_actor", num_states / (config.actors * seconds)},
      //         {"total_trajectories", total_trajectories},
      //         {"trajectories_per_s", num_trajectories / seconds},
      //         {"queue_size", queue_size},
      //         {"game_length", game_lengths.ToJson()},
      //         {"game_length_hist", game_lengths_hist.ToJson()},
      //         {"outcomes", outcomes.ToJson()},
      //         {"value_accuracy",
      //          json::TransformToArray(value_accuracies,
      //                                 [](auto v) { return v.ToJson(); })},
      //         {"value_prediction",
      //          json::TransformToArray(value_predictions,
      //                                 [](auto v) { return v.ToJson(); })},
      //         {"eval", json::Object({
      //                      {"count", eval_results->EvalCount()},
      //                      {"results", json::CastToArray(eval_results->AvgResults())},
      //                  })},
      //         {"batch_size", eval->BatchSizeStats().ToJson()},
      //         {"batch_size_hist", eval->BatchSizeHistogram().ToJson()},
      //         {"loss", json::Object({
      //                      {"policy", losses.Policy()},
      //                      {"value", losses.Value()},
      //                      {"l2reg", losses.L2()},
      //                      {"sum", losses.Total()},
      //                  })},
      //     };
      //     eval->ResetBatchSizeStats();
      //     logger.Print("Losses: policy: %.4f, value: %.4f, l2: %.4f, sum: %.4f",
      //                  losses.Policy(), losses.Value(), losses.L2(), losses.Total());

      //     LRUCacheInfo cache_info = eval->CacheInfo();
      //     if (cache_info.size > 0)
      //     {
      //       logger.Print(absl::StrFormat(
      //           "Cache size: %d/%d: %.1f%%, hits: %d, misses: %d, hit rate: %.3f%%",
      //           cache_info.size, cache_info.max_size, 100.0 * cache_info.Usage(),
      //           cache_info.hits, cache_info.misses, 100.0 * cache_info.HitRate()));
      //       eval->ClearCache();
      //     }
      //     record.emplace("cache",
      //                    json::Object({
      //                        {"size", cache_info.size},
      //                        {"max_size", cache_info.max_size},
      //                        {"usage", cache_info.Usage()},
      //                        {"requests", cache_info.Total()},
      //                        {"requests_per_s", cache_info.Total() / seconds},
      //                        {"hits", cache_info.hits},
      //                        {"misses", cache_info.misses},
      //                        {"misses_per_s", cache_info.misses / seconds},
      //                        {"hit_rate", cache_info.HitRate()},
      //                    }));

      //     data_logger.Write(record);
      //     logger.Print("");
      //   }
      // }

      bool AlphaZero(AlphaZeroConfig config, StopToken *stop)
      {
        std::shared_ptr<const open_spiel::Game> game =
            open_spiel::LoadGame(config.game);

        open_spiel::GameType game_type = game->GetType();
        if (game->NumPlayers() != 2)
          open_spiel::SpielFatalError("AlphaZero can only handle 2-player games.");
        if (game_type.reward_model != open_spiel::GameType::RewardModel::kTerminal)
          open_spiel::SpielFatalError("Game must have terminal rewards.");
        if (game_type.dynamics != open_spiel::GameType::Dynamics::kSequential)
          open_spiel::SpielFatalError("Game must have sequential turns.");
        if (game_type.chance_mode != open_spiel::GameType::ChanceMode::kDeterministic)
          open_spiel::SpielFatalError("Game must be deterministic.");

        file::Mkdirs(config.path);
        if (!file::IsDirectory(config.path))
        {
          std::cerr << config.path << " is not a directory." << std::endl;
          return false;
        }

        std::cout << "Logging directory: " << config.path << std::endl;

        std::cout << "Playing game: " << config.game << std::endl;

        config.inference_batch_size = std::max(
            1,
            std::min(config.inference_batch_size, config.actors + config.evaluators));

        config.inference_threads =
            std::max(1, std::min(config.inference_threads,
                                 (1 + config.actors + config.evaluators) / 2));

        {
          file::File fd(config.path + "/config.json", "w");
          fd.Write(json::ToString(config.ToJson(), true) + "\n");
        }

        DeviceManager opp_device_manager;
        for (const absl::string_view &device : absl::StrSplit(config.devices, ','))
        {
          opp_device_manager.AddDevice(VPNetModel(*game, "/home/fixed-opponent-self-play-logs", "vpnet.pb", std::string(device)));
        }

        DeviceManager device_manager;
        for (const absl::string_view &device : absl::StrSplit(config.devices, ','))
        {
          device_manager.AddDevice(
              VPNetModel(*game, "/home/fixed-opponent-self-play-logs", "vpnet.pb", std::string(device)));
        }

        for (int i = 0; i < opp_device_manager.Count(); ++i)
        {
          opp_device_manager.Get(0, i)->LoadCheckpoint("/home/fixed-opponent-self-play-logs/checkpoint-50");
        }
        for (int i = 0; i < device_manager.Count(); ++i)
        {
          device_manager.Get(0, i)->LoadCheckpoint("/home/fixed-opponent-self-play-logs/checkpoint-50");
        }

        if (device_manager.Count() == 0)
        {
          std::cerr << "No devices specified?" << std::endl;
          return false;
        }

        // The explicit_learning option should only be used when multiple
        // devices are available (so that inference can continue while
        // also undergoing learning).
        if (device_manager.Count() <= 1 && config.explicit_learning)
        {
          std::cerr << "Explicit learning can only be used with multiple devices."
                    << std::endl;
          return false;
        }

        auto eval = std::make_shared<VPNetEvaluator>(
            &device_manager, config.inference_batch_size, config.inference_threads,
            config.inference_cache, (config.actors + config.evaluators) / 16);

        auto opp_eval = std::make_shared<VPNetEvaluator>(
            &opp_device_manager, config.inference_batch_size, config.inference_threads,
            config.inference_cache, (config.actors + config.evaluators) / 16);

        ThreadedQueue<Trajectory> trajectory_queue(10000);
        std::atomic<int> num_eps = 0;
        std::shared_ptr<std::atomic<int>> ep_pointer = std::make_shared<std::atomic<int>>(0);
        std::vector<Thread> actors;
        actors.reserve(config.actors);
        for (int i = 0; i < config.actors; ++i)
        {
          actors.emplace_back(
              [&, i, ep_pointer]() { actor(*game, config, i, &trajectory_queue, eval, opp_eval, ep_pointer, stop); });
        }

        while (*ep_pointer < config.num_games)
        {
          std::this_thread::sleep_for(std::chrono::milliseconds(200));
        }

        if (!stop->StopRequested())
        {
          stop->Stop();
        }

        // Empty the queue so that the actors can exit.
        trajectory_queue.BlockNewValues();
        // trajectory_queue.Clear();

        std::cout << "Joining all the threads." << std::endl;
        for (auto &t : actors)
        {
          t.join();
        }
        std::cout << *ep_pointer << std::endl;
        int queue_size = trajectory_queue.Size();
        std::cout << queue_size << std::endl;
        std::ofstream myfile(config.path + "/results.csv");
        while (trajectory_queue.Size() > 0)
        {
          absl::optional<Trajectory> trajectory = trajectory_queue.Pop();
          myfile << trajectory->returns[1] << std::endl;
        }
        myfile.close();
        std::cout << "Exiting cleanly." << std::endl;
        return true;
      }

    } // namespace torch_az
  }   // namespace algorithms
} // namespace open_spiel
