// Copyright 2019 DeepMind Technologies Ltd. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "open_spiel/game_transforms/restricted_nash_response.h"
#include "open_spiel/abseil-cpp/absl/random/uniform_int_distribution.h"

#include <string>

#include "open_spiel/spiel.h"

namespace open_spiel {
namespace {

void SimulateGame(std::mt19937 *rng,
                  const Game &game,
                  std::unique_ptr<State> normal_state,
                  std::unique_ptr<State> rnr_state,
                  bool fixed,
                  Player fixed_player) {
  // Now check that the states are identical via the ToString().
  std::string prefix = fixed ? "fixed_" : "free_";
  while (!normal_state->IsTerminal()) {
    SPIEL_CHECK_EQ(prefix + normal_state->ToString(), rnr_state->ToString());
    if (game.GetType().provides_information_state_string) {
      // Check the information states to each player are consistent.
      for (auto p = Player{0}; p < game.NumPlayers(); p++) {
        SPIEL_CHECK_EQ((p == fixed_player ? prefix : "") + normal_state->InformationStateString(p),
                       rnr_state->InformationStateString(p));
      }
    }

    if (normal_state->IsChanceNode()) {
      SPIEL_CHECK_TRUE(rnr_state->IsChanceNode());

      // Chance node; sample one according to underlying distribution
      std::vector<std::pair<Action, double>> outcomes =
          normal_state->ChanceOutcomes();
      Action action =
          open_spiel::SampleAction(
              outcomes, std::uniform_real_distribution<double>(0.0, 1.0)(*rng))
              .first;

//      std::cout << "sampled outcome: %s\n"
//                << normal_state->ActionToString(kChancePlayerId, action)
//                << std::endl;

      normal_state->ApplyAction(action);
      rnr_state->ApplyAction(action);
    } else if (normal_state->CurrentPlayer() == kSimultaneousPlayerId) {
      SPIEL_CHECK_EQ(rnr_state->CurrentPlayer(), kSimultaneousPlayerId);

      // Players choose simultaneously.
      std::vector<Action> joint_action;

      // Sample an action for each player
      for (auto p = Player{0}; p < game.NumPlayers(); p++) {

        std::vector<Action> actions;
        actions = normal_state->LegalActions(p);
        absl::uniform_int_distribution<> dis(0, actions.size() - 1);
        Action action = actions[dis(*rng)];
        joint_action.push_back(action);
      }

      normal_state->ApplyActions(joint_action);
      rnr_state->ApplyActions(joint_action);
    } else {
      SPIEL_CHECK_EQ(normal_state->CurrentPlayer(),
                     rnr_state->CurrentPlayer());

      Player p = normal_state->CurrentPlayer();

      std::vector<Action> actions;
      actions = normal_state->LegalActions(p);
      absl::uniform_int_distribution<> dis(0, actions.size() - 1);
      Action action = actions[dis(*rng)];

//      std::cout << "player " << p << " chose "
//                << normal_state->ActionToString(p, action) << std::endl;

      normal_state->ApplyAction(action);
      rnr_state->ApplyAction(action);
    }

//    std::cout << "State: " << std::endl << normal_state->ToString() << std::endl;
  }

  SPIEL_CHECK_TRUE(rnr_state->IsTerminal());

  auto sim_returns = normal_state->Returns();
  auto turn_returns = rnr_state->Returns();

  for (auto player = Player{0}; player < sim_returns.size(); player++) {
    double utility = sim_returns[player];
    SPIEL_CHECK_GE(utility, game.MinUtility());
    SPIEL_CHECK_LE(utility, game.MaxUtility());
//    std::cout << "Utility to player " << player << " is " << utility
//              << std::endl;

    double other_utility = turn_returns[player];
    SPIEL_CHECK_EQ(utility, other_utility);
  }
}

void TestBasicCreation() {
  std::mt19937 rng;

  for (const std::string name : {"blotto", "goofspiel", "kuhn_poker", "tiny_hanabi", "phantom_ttt", "matrix_rps",
                                 "leduc_poker"}) {
    std::cout << "TurnBasedSimultaneous: Testing " << name << std::endl;
    for (Player fixed_player = 0; fixed_player < 2; fixed_player++) {
      for (int i = 0; i < 100; ++i) {
        std::shared_ptr<const Game> normal_game_game = LoadGame(name);
        std::shared_ptr<const Game> rnr_game =
            ConvertToRNR(*LoadGame(name), fixed_player, 0.5);
        auto normal_init_fixed = normal_game_game->NewInitialState();
        auto rnr_init_fixed = rnr_game->NewInitialState();
        rnr_init_fixed->ApplyAction(Action(kFixedAction));
        SimulateGame(&rng,
                     *normal_game_game,
                     std::move(normal_init_fixed),
                     std::move(rnr_init_fixed),
                     true,
                     fixed_player);

        auto rnr_init_free = rnr_game->NewInitialState();
        auto normal_init_free = normal_game_game->NewInitialState();
        rnr_init_free->ApplyAction(Action(kFreeAction));
        SimulateGame(&rng,
                     *normal_game_game,
                     std::move(normal_init_free),
                     std::move(rnr_init_free),
                     false,
                     fixed_player);
      }
    }
  }
}
}  // namespace
}  // namespace open_spiel

int main(int argc, char **argv) {
  open_spiel::TestBasicCreation();
}
