// Copyright 2019 DeepMind Technologies Ltd. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef OPEN_SPIEL_ALGORITHMS_ALPHA_ZERO_EVAL_TORCH_VPNET_H_
#define OPEN_SPIEL_ALGORITHMS_ALPHA_ZERO_EVAL_TORCH_VPNET_H_

#include <torch/torch.h>

#include "open_spiel/algorithms/alpha_zero_torch_eval/model.h"
#include "open_spiel/spiel.h"

namespace open_spiel
{
  namespace algorithms
  {
    namespace torch_az_eval
    {

      // To avoid having to change function calls and the flow of the AlphaZero setup,
      // this function is still called, but rather than create a TensorFlow graph
      // definition, it simply creates a struct that the Libtorch model can use to
      // initialize from. This struct is then saved and then loaded again when needed.
      bool CreateGraphDef(const Game &game, double learning_rate, double weight_decay,
                          const std::string &path, const std::string &filename,
                          std::string nn_model, int nn_width, int nn_depth,
                          bool verbose = false);

      class VPNetModel
      {
      public:
        // A class to handle the network's loss.
        class LossInfo
        {
        public:
          LossInfo() {}
          LossInfo(double policy, double value, double l2)
              : policy_(policy), value_(value), l2_(l2), batches_(1) {}

          // Merge another LossInfo into this one.
          LossInfo &operator+=(const LossInfo &other)
          {
            policy_ += other.policy_;
            value_ += other.value_;
            l2_ += other.l2_;
            batches_ += other.batches_;
            return *this;
          }

          // Return the average losses over all merged into this one.
          double Policy() const { return policy_ / batches_; }
          double Value() const { return value_ / batches_; }
          double L2() const { return l2_ / batches_; }
          double Total() const { return Policy() + Value() + L2(); }

        private:
          double policy_ = 0;
          double value_ = 0;
          double l2_ = 0;
          int batches_ = 0;
        };

        // A struct to handle the inputs for inference.
        struct InferenceInputs
        {
          std::vector<Action> legal_actions;
          std::vector<float> observations;

          bool operator==(const InferenceInputs &other) const
          {
            return legal_actions == other.legal_actions &&
                   observations == other.observations;
          }

          template <typename H>
          friend H AbslHashValue(H h, const InferenceInputs &in)
          {
            return H::combine(std::move(h), in.legal_actions, in.observations);
          }
        };

        // A struct to hold the outputs of the inference (value and policy).
        struct InferenceOutputs
        {
          double value;
          ActionsAndProbs policy;
          std::vector<double> repr;
        };

        // A struct to hold the inputs for training.
        struct TrainInputs
        {
          std::vector<Action> legal_actions;
          std::vector<float> observations;
          ActionsAndProbs policy;
          double value;
        };

        VPNetModel(const Game &game, const std::string &path,
                   const std::string &file_name,
                   const std::string &device = "/cpu:0");

        // Move only, not copyable.
        VPNetModel(VPNetModel &&other) = default;
        VPNetModel &operator=(VPNetModel &&other) = default;
        VPNetModel(const VPNetModel &) = delete;
        VPNetModel &operator=(const VPNetModel &) = delete;

        // Inference: Get both at the same time.
        std::vector<InferenceOutputs> Inference(
            const std::vector<InferenceInputs> &inputs);

        // Training: do one (batch) step of neural net training
        LossInfo Learn(const std::vector<TrainInputs> &inputs);

        std::string SaveCheckpoint(int step);
        void LoadCheckpoint(const std::string &path);

        const std::string Device() const { return device_; }

      private:
        std::string device_;
        std::string path_;

        // Store the full model metagraph file
        // for writing python compatible checkpoints.
        std::string model_meta_graph_contents_;

        int flat_input_size_;
        int num_actions_;

        // NOTE:
        // The member model_ takes an already initialized model_config_,
        // and model_optimizer_ takes an already initialized model_
        // parameters and model_config_ learning rate. Therefore, keep the
        // members' (model_config_, model_, model_optimizer_) declaration in
        // the order shown below so the member initialization list works.
        ModelConfig model_config_;
        ResModel model_;
        torch::optim::Adam model_optimizer_;
        torch::Device torch_device_;
      };

    } // namespace torch_az
  }   // namespace algorithms
} // namespace open_spiel

#endif // OPEN_SPIEL_ALGORITHMS_ALPHA_ZERO_TORCH_VPNET_H_
